Compare commits

...

3 Commits

Author SHA1 Message Date
Danny Avila
5f7dc13c30 feat: Implement state parameter handling for OpenID authentication 2025-05-25 18:22:45 -04:00
Danny Avila
ac2e1b1586 feat: Enhance OpenID flow with state parameter handling 2025-05-25 16:33:34 -04:00
Danny Avila
45e4e70986 refactor: debounce setUserContext to avoid race condition 2025-05-25 16:32:29 -04:00
3 changed files with 105 additions and 34 deletions

View File

@@ -1,6 +1,8 @@
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware // file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
const express = require('express'); const express = require('express');
const jwt = require('jsonwebtoken');
const passport = require('passport'); const passport = require('passport');
const client = require('openid-client');
const { const {
checkBan, checkBan,
logHeaders, logHeaders,
@@ -19,6 +21,8 @@ const domains = {
server: process.env.DOMAIN_SERVER, server: process.env.DOMAIN_SERVER,
}; };
const JWT_SECRET = process.env.JWT_SECRET || process.env.OPENID_SESSION_SECRET;
router.use(logHeaders); router.use(logHeaders);
router.use(loginLimiter); router.use(loginLimiter);
@@ -103,20 +107,71 @@ router.get(
/** /**
* OpenID Routes * OpenID Routes
*/ */
router.get( router.get('/openid', (req, res, next) => {
'/openid', const state = client.randomState();
passport.authenticate('openid', {
session: false, try {
}), const stateToken = jwt.sign(
); {
state: state,
timestamp: Date.now(),
},
JWT_SECRET,
{ expiresIn: '10m' },
);
res.cookie('oauth_state', stateToken, {
httpOnly: true,
secure: process.env.NODE_ENV === 'production',
signed: false,
maxAge: 10 * 60 * 1000,
sameSite: 'lax',
});
passport.authenticate('openid', {
session: false,
state: state,
})(req, res, next);
} catch (error) {
logger.error('Error creating state token for OpenID authentication', error);
return res.redirect(`${domains.client}/oauth/error`);
}
});
router.get( router.get(
'/openid/callback', '/openid/callback',
passport.authenticate('openid', { (req, res, next) => {
failureRedirect: `${domains.client}/oauth/error`, if (!req.query.state) {
failureMessage: true, logger.error('Missing state parameter in OpenID callback');
session: false, return res.redirect(`${domains.client}/oauth/error`);
}), }
const stateToken = req.cookies.oauth_state;
if (!stateToken) {
logger.error('No state cookie found for OpenID callback');
return res.redirect(`${domains.client}/oauth/error`);
}
try {
const decodedState = jwt.verify(stateToken, JWT_SECRET);
if (req.query.state !== decodedState.state) {
logger.error('Invalid state parameter in OpenID callback', {
received: req.query.state,
expected: decodedState.state,
});
return res.redirect(`${domains.client}/oauth/error`);
}
res.clearCookie('oauth_state');
passport.authenticate('openid', {
failureRedirect: `${domains.client}/oauth/error`,
failureMessage: true,
session: false,
})(req, res, next);
} catch (error) {
logger.error('Invalid or expired state token in OpenID callback', error);
res.clearCookie('oauth_state');
return res.redirect(`${domains.client}/oauth/error`);
}
},
setBalanceConfig, setBalanceConfig,
oauthHandler, oauthHandler,
); );

View File

@@ -28,6 +28,17 @@ class CustomOpenIDStrategy extends OpenIDStrategy {
const hostAndProtocol = process.env.DOMAIN_SERVER; const hostAndProtocol = process.env.DOMAIN_SERVER;
return new URL(`${hostAndProtocol}${req.originalUrl ?? req.url}`); return new URL(`${hostAndProtocol}${req.originalUrl ?? req.url}`);
} }
/**
* Override to ensure proper authorization request parameters
*/
authorizationRequestParams(req, options) {
const params = super.authorizationRequestParams?.(req, options) || {};
if (options?.state != null && options.state && !params.has('state')) {
params.set('state', options.state);
}
return params;
}
} }
/** /**

View File

@@ -1,4 +1,5 @@
import { import {
useRef,
useMemo, useMemo,
useState, useState,
useEffect, useEffect,
@@ -6,10 +7,10 @@ import {
useContext, useContext,
useCallback, useCallback,
createContext, createContext,
useRef,
} from 'react'; } from 'react';
import { useNavigate } from 'react-router-dom'; import { debounce } from 'lodash';
import { useRecoilState } from 'recoil'; import { useRecoilState } from 'recoil';
import { useNavigate } from 'react-router-dom';
import { setTokenHeader, SystemRoles } from 'librechat-data-provider'; import { setTokenHeader, SystemRoles } from 'librechat-data-provider';
import type * as t from 'librechat-data-provider'; import type * as t from 'librechat-data-provider';
import { import {
@@ -47,27 +48,31 @@ const AuthContextProvider = ({
const navigate = useNavigate(); const navigate = useNavigate();
const setUserContext = useCallback( const setUserContext = useMemo(
(userContext: TUserContext) => { () =>
const { token, isAuthenticated, user, redirect } = userContext; debounce((userContext: TUserContext) => {
setUser(user); const { token, isAuthenticated, user, redirect } = userContext;
setToken(token); setUser(user);
//@ts-ignore - ok for token to be undefined initially setToken(token);
setTokenHeader(token); //@ts-ignore - ok for token to be undefined initially
setIsAuthenticated(isAuthenticated); setTokenHeader(token);
// Use a custom redirect if set setIsAuthenticated(isAuthenticated);
const finalRedirect = logoutRedirectRef.current || redirect;
// Clear the stored redirect // Use a custom redirect if set
logoutRedirectRef.current = undefined; const finalRedirect = logoutRedirectRef.current || redirect;
if (finalRedirect == null) { // Clear the stored redirect
return; logoutRedirectRef.current = undefined;
}
if (finalRedirect.startsWith('http://') || finalRedirect.startsWith('https://')) { if (finalRedirect == null) {
window.location.href = finalRedirect; return;
} else { }
navigate(finalRedirect, { replace: true });
} if (finalRedirect.startsWith('http://') || finalRedirect.startsWith('https://')) {
}, window.location.href = finalRedirect;
} else {
navigate(finalRedirect, { replace: true });
}
}, 50),
[navigate, setUser], [navigate, setUser],
); );
const doSetError = useTimeout({ callback: (error) => setError(error as string | undefined) }); const doSetError = useTimeout({ callback: (error) => setError(error as string | undefined) });