Compare commits

...

3 Commits

Author SHA1 Message Date
Danny Avila
5f7dc13c30 feat: Implement state parameter handling for OpenID authentication 2025-05-25 18:22:45 -04:00
Danny Avila
ac2e1b1586 feat: Enhance OpenID flow with state parameter handling 2025-05-25 16:33:34 -04:00
Danny Avila
45e4e70986 refactor: debounce setUserContext to avoid race condition 2025-05-25 16:32:29 -04:00
3 changed files with 105 additions and 34 deletions

View File

@@ -1,6 +1,8 @@
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
const express = require('express');
const jwt = require('jsonwebtoken');
const passport = require('passport');
const client = require('openid-client');
const {
checkBan,
logHeaders,
@@ -19,6 +21,8 @@ const domains = {
server: process.env.DOMAIN_SERVER,
};
const JWT_SECRET = process.env.JWT_SECRET || process.env.OPENID_SESSION_SECRET;
router.use(logHeaders);
router.use(loginLimiter);
@@ -103,20 +107,71 @@ router.get(
/**
* OpenID Routes
*/
router.get(
'/openid',
passport.authenticate('openid', {
session: false,
}),
);
router.get('/openid', (req, res, next) => {
const state = client.randomState();
try {
const stateToken = jwt.sign(
{
state: state,
timestamp: Date.now(),
},
JWT_SECRET,
{ expiresIn: '10m' },
);
res.cookie('oauth_state', stateToken, {
httpOnly: true,
secure: process.env.NODE_ENV === 'production',
signed: false,
maxAge: 10 * 60 * 1000,
sameSite: 'lax',
});
passport.authenticate('openid', {
session: false,
state: state,
})(req, res, next);
} catch (error) {
logger.error('Error creating state token for OpenID authentication', error);
return res.redirect(`${domains.client}/oauth/error`);
}
});
router.get(
'/openid/callback',
passport.authenticate('openid', {
failureRedirect: `${domains.client}/oauth/error`,
failureMessage: true,
session: false,
}),
(req, res, next) => {
if (!req.query.state) {
logger.error('Missing state parameter in OpenID callback');
return res.redirect(`${domains.client}/oauth/error`);
}
const stateToken = req.cookies.oauth_state;
if (!stateToken) {
logger.error('No state cookie found for OpenID callback');
return res.redirect(`${domains.client}/oauth/error`);
}
try {
const decodedState = jwt.verify(stateToken, JWT_SECRET);
if (req.query.state !== decodedState.state) {
logger.error('Invalid state parameter in OpenID callback', {
received: req.query.state,
expected: decodedState.state,
});
return res.redirect(`${domains.client}/oauth/error`);
}
res.clearCookie('oauth_state');
passport.authenticate('openid', {
failureRedirect: `${domains.client}/oauth/error`,
failureMessage: true,
session: false,
})(req, res, next);
} catch (error) {
logger.error('Invalid or expired state token in OpenID callback', error);
res.clearCookie('oauth_state');
return res.redirect(`${domains.client}/oauth/error`);
}
},
setBalanceConfig,
oauthHandler,
);

View File

@@ -28,6 +28,17 @@ class CustomOpenIDStrategy extends OpenIDStrategy {
const hostAndProtocol = process.env.DOMAIN_SERVER;
return new URL(`${hostAndProtocol}${req.originalUrl ?? req.url}`);
}
/**
* Override to ensure proper authorization request parameters
*/
authorizationRequestParams(req, options) {
const params = super.authorizationRequestParams?.(req, options) || {};
if (options?.state != null && options.state && !params.has('state')) {
params.set('state', options.state);
}
return params;
}
}
/**

View File

@@ -1,4 +1,5 @@
import {
useRef,
useMemo,
useState,
useEffect,
@@ -6,10 +7,10 @@ import {
useContext,
useCallback,
createContext,
useRef,
} from 'react';
import { useNavigate } from 'react-router-dom';
import { debounce } from 'lodash';
import { useRecoilState } from 'recoil';
import { useNavigate } from 'react-router-dom';
import { setTokenHeader, SystemRoles } from 'librechat-data-provider';
import type * as t from 'librechat-data-provider';
import {
@@ -47,27 +48,31 @@ const AuthContextProvider = ({
const navigate = useNavigate();
const setUserContext = useCallback(
(userContext: TUserContext) => {
const { token, isAuthenticated, user, redirect } = userContext;
setUser(user);
setToken(token);
//@ts-ignore - ok for token to be undefined initially
setTokenHeader(token);
setIsAuthenticated(isAuthenticated);
// Use a custom redirect if set
const finalRedirect = logoutRedirectRef.current || redirect;
// Clear the stored redirect
logoutRedirectRef.current = undefined;
if (finalRedirect == null) {
return;
}
if (finalRedirect.startsWith('http://') || finalRedirect.startsWith('https://')) {
window.location.href = finalRedirect;
} else {
navigate(finalRedirect, { replace: true });
}
},
const setUserContext = useMemo(
() =>
debounce((userContext: TUserContext) => {
const { token, isAuthenticated, user, redirect } = userContext;
setUser(user);
setToken(token);
//@ts-ignore - ok for token to be undefined initially
setTokenHeader(token);
setIsAuthenticated(isAuthenticated);
// Use a custom redirect if set
const finalRedirect = logoutRedirectRef.current || redirect;
// Clear the stored redirect
logoutRedirectRef.current = undefined;
if (finalRedirect == null) {
return;
}
if (finalRedirect.startsWith('http://') || finalRedirect.startsWith('https://')) {
window.location.href = finalRedirect;
} else {
navigate(finalRedirect, { replace: true });
}
}, 50),
[navigate, setUser],
);
const doSetError = useTimeout({ callback: (error) => setError(error as string | undefined) });