Compare commits

..

6 Commits

8 changed files with 388 additions and 404 deletions

View File

@@ -20,8 +20,8 @@ DOMAIN_CLIENT=http://localhost:3080
DOMAIN_SERVER=http://localhost:3080
NO_INDEX=true
# Use the address that is at most n number of hops away from the Express application.
# req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left.
# Use the address that is at most n number of hops away from the Express application.
# req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left.
# A value of 0 means that the first untrusted address would be req.socket.remoteAddress, i.e. there is no reverse proxy.
# Defaulted to 1.
TRUST_PROXY=1
@@ -428,12 +428,9 @@ OPENID_CLIENT_ID=
OPENID_CLIENT_SECRET=
OPENID_ISSUER=
OPENID_SESSION_SECRET=
# OPENID_USE_PKCE=
OPENID_SCOPE="openid profile email"
OPENID_CALLBACK_URL=/oauth/openid/callback
OPENID_REQUIRED_ROLE=
# Set to 'userinfo' or 'token' to determine witch role source to use, Default is 'token'
OPENID_REQUIRED_ROLE_SOURCE=
OPENID_REQUIRED_ROLE_TOKEN_KIND=
OPENID_REQUIRED_ROLE_PARAMETER_PATH=
# Set to determine which user info property returned from OpenID Provider to store as the User's username

View File

@@ -1,6 +1,49 @@
const { matchModelName } = require('../utils');
const defaultRate = 6;
const customTokenOverrides = {};
const customCacheOverrides = {};
/**
* Allows overriding the default token multipliers.
*
* @param {Object} overrides - An object mapping model keys to their custom token multipliers.
* @param {Object} overrides.<model> - An object containing custom multipliers for the model.
* @param {number} overrides.<model>.prompt - The custom prompt multiplier for the model.
* @param {number} overrides.<model>.completion - The custom completion multiplier for the model.
*
* @example
* // Override the multipliers for "gpt-4o-mini" and "gpt-3.5":
* setCustomTokenOverrides({
* "gpt-4o-mini": { prompt: 0.2, completion: 0.5 },
* "gpt-3.5": { prompt: 1.0, completion: 2.0 }
* });
*/
const setCustomTokenOverrides = (overrides) => {
Object.assign(customTokenOverrides, overrides);
};
/**
* Allows overriding the default cache multipliers.
* The override values should be nested under a key named "Cache".
*
* @param {Object} overrides - An object mapping model keys to their custom cache multipliers.
* @param {Object} overrides.<model> - An object that must include a "Cache" property.
* @param {Object} overrides.<model>.Cache - An object containing custom cache multipliers for the model.
* @param {number} overrides.<model>.Cache.write - The custom cache write multiplier for the model.
* @param {number} overrides.<model>.Cache.read - The custom cache read multiplier for the model.
*
* @example
* // Override the cache multipliers for "gpt-4o-mini" and "gpt-3.5":
* setCustomCacheOverrides({
* "gpt-4o-mini": { cache: { write: 0.2, read: 0.5 } },
* "gpt-3.5": { cache: { write: 1.0, read: 1.5 } }
* });
*/
const setCustomCacheOverrides = (overrides) => {
Object.assign(customCacheOverrides, overrides);
};
/**
* AWS Bedrock pricing
* source: https://aws.amazon.com/bedrock/pricing/
@@ -283,20 +326,23 @@ const getCacheMultiplier = ({ valueKey, cacheType, model, endpoint, endpointToke
return endpointTokenConfig?.[model]?.[cacheType] ?? null;
}
if (valueKey && cacheType) {
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
if (!valueKey && model) {
valueKey = getValueKey(model, endpoint);
}
if (!cacheType || !model) {
return null;
}
valueKey = getValueKey(model, endpoint);
if (!valueKey) {
return null;
}
// If we got this far, and values[cacheType] is undefined somehow, return a rough average of default multipliers
// Check for custom cache overrides under the "cache" property.
if (
customCacheOverrides[valueKey] &&
customCacheOverrides[valueKey].cache &&
customCacheOverrides[valueKey].cache[cacheType] != null
) {
return customCacheOverrides[valueKey].cache[cacheType];
}
// Fallback to the default cacheTokenValues.
return cacheTokenValues[valueKey]?.[cacheType] ?? null;
};
@@ -307,4 +353,6 @@ module.exports = {
getCacheMultiplier,
defaultRate,
cacheTokenValues,
setCustomTokenOverrides,
setCustomCacheOverrides,
};

View File

@@ -21,6 +21,7 @@ const { initializeRoles } = require('~/models/Role');
const { isEnabled } = require('~/server/utils');
const { getMCPManager } = require('~/config');
const paths = require('~/config/paths');
const { loadTokenRatesConfig } = require('./Config/loadTokenRatesConfig');
/**
*
@@ -33,6 +34,7 @@ const AppService = async (app) => {
/** @type {TCustomConfig} */
const config = (await loadCustomConfig()) ?? {};
const configDefaults = getConfigDefaults();
loadTokenRatesConfig(config, configDefaults);
const ocr = loadOCRConfig(config.ocr);
const filteredTools = config.filteredTools;

View File

@@ -0,0 +1,71 @@
const { removeNullishValues } = require('librechat-data-provider');
const { logger } = require('~/config');
const { setCustomTokenOverrides, setCustomCacheOverrides } = require('~/models/tx');
/**
* Loads token rates from the user's configuration, merging with default token rates if available.
*
* @param {TCustomConfig | undefined} config - The loaded custom configuration.
* @param {TConfigDefaults} [configDefaults] - Optional default configuration values.
* @returns {TCustomConfig['tokenRates']} - The final token rates configuration.
*/
function loadTokenRatesConfig(config, configDefaults) {
const userTokenRates = removeNullishValues(config?.tokenRates ?? {});
if (!configDefaults?.tokenRates) {
logger.info(`User tokenRates configuration:\n${JSON.stringify(userTokenRates, null, 2)}`);
// Apply custom token rates even if there are no defaults
applyCustomTokenRates(userTokenRates);
return userTokenRates;
}
/** @type {TCustomConfig['tokenRates']} */
const defaultTokenRates = removeNullishValues(configDefaults.tokenRates);
const merged = { ...defaultTokenRates, ...userTokenRates };
// Apply custom token rates configuration
applyCustomTokenRates(merged);
logger.info(`Merged tokenRates configuration:\n${JSON.stringify(merged, null, 2)}`);
return merged;
}
/**
* Processes the token rates configuration to set up custom overrides for each model.
*
* The configuration is expected to be specified per model:
*
* For each model in the tokenRates configuration, this function will call the tx.js
* override functions to apply the custom token and cache multipliers.
*
* @param {TModelTokenRates} tokenRates - The token rates configuration mapping models to token costs.
*/
function applyCustomTokenRates(tokenRates) {
// Iterate over each model in the tokenRates configuration.
Object.keys(tokenRates).forEach((model) => {
const rate = tokenRates[model];
// If token multipliers are provided, set custom token overrides.
if (rate.prompt != null || rate.completion != null) {
setCustomTokenOverrides({
[model]: {
prompt: rate.prompt,
completion: rate.completion,
},
});
}
// Check for cache overrides.
const cacheOverrides = rate.cache;
if (cacheOverrides && (cacheOverrides.write != null || cacheOverrides.read != null)) {
setCustomCacheOverrides({
[model]: {
cache: {
write: cacheOverrides.write,
read: cacheOverrides.read,
},
},
});
}
});
}
module.exports = { loadTokenRatesConfig };

View File

@@ -17,15 +17,12 @@ try {
}
/**
* Downloads an image from a URL using an access token, returning a Buffer.
*
* @async
* @function downloadImage
* @param {string} url - The image URL
* @param {string} accessToken - The OAuth2 access token, if required by the server
* @returns {Promise<Buffer|string>} A Buffer if successful, or an empty string on failure
* Downloads an image from a URL using an access token.
* @param {string} url
* @param {string} accessToken
* @returns {Promise<Buffer>}
*/
async function downloadImage(url, accessToken) {
const downloadImage = async (url, accessToken) => {
if (!url) {
return '';
}
@@ -33,33 +30,34 @@ async function downloadImage(url, accessToken) {
try {
const options = {
method: 'GET',
headers: { Authorization: `Bearer ${accessToken}` },
headers: {
Authorization: `Bearer ${accessToken}`,
},
};
if (process.env.PROXY) {
options.agent = new HttpsProxyAgent(process.env.PROXY);
}
const response = await fetch(url, options);
if (!response.ok) {
if (response.ok) {
const buffer = await response.buffer();
return buffer;
} else {
throw new Error(`${response.statusText} (HTTP ${response.status})`);
}
return await response.buffer();
} catch (error) {
logger.error(`[openidStrategy] downloadImage: Failed to fetch "${url}": ${error}`);
logger.error(
`[openidStrategy] downloadImage: Error downloading image at URL "${url}": ${error}`,
);
return '';
}
}
};
/**
* Derives a user's "full name" from userinfo or environment-specified claim.
* Determines the full name of a user based on OpenID userinfo and environment configuration.
*
* Priority:
* 1) process.env.OPENID_NAME_CLAIM
* 2) userinfo.given_name + userinfo.family_name
* 3) userinfo.given_name OR userinfo.family_name
* 4) userinfo.username or userinfo.email
*
* @function getFullName
* @param {Object} userinfo - The user information object from OpenID Connect
* @param {string} [userinfo.given_name] - The user's first name
* @param {string} [userinfo.family_name] - The user's last name
@@ -68,252 +66,153 @@ async function downloadImage(url, accessToken) {
* @returns {string} The determined full name of the user
*/
function getFullName(userinfo) {
if (process.env.OPENID_NAME_CLAIM && userinfo[process.env.OPENID_NAME_CLAIM]) {
if (process.env.OPENID_NAME_CLAIM) {
return userinfo[process.env.OPENID_NAME_CLAIM];
}
if (userinfo.given_name && userinfo.family_name) {
return `${userinfo.given_name} ${userinfo.family_name}`;
}
if (userinfo.given_name) {
return userinfo.given_name;
}
if (userinfo.family_name) {
return userinfo.family_name;
}
return userinfo.username || userinfo.email || '';
return userinfo.username || userinfo.email;
}
/**
* Converts an input into a string suitable for a username.
* If the input is a string, it will be returned as is.
* If the input is an array, elements will be joined with underscores.
* In case of undefined or other falsy values, a default value will be returned.
*
* @function convertToUsername
* @param {string|string[]|undefined} input - Could be a string or array of strings
* @param {string} [defaultValue=''] - Fallback if input is invalid or not provided
* @returns {string} A processed username string
* @param {string | string[] | undefined} input - The input value to be converted into a username.
* @param {string} [defaultValue=''] - The default value to return if the input is falsy.
* @returns {string} The processed input as a string suitable for a username.
*/
function convertToUsername(input, defaultValue = '') {
if (typeof input === 'string') {
return input;
}
if (Array.isArray(input)) {
} else if (Array.isArray(input)) {
return input.join('_');
}
return defaultValue;
}
/**
* Safely extracts an array of roles from an object using dot notation (e.g. realm_access.roles).
*
* @function extractRolesFrom
* @param {Object} obj
* @param {string} path
* @returns {string[]} Array of roles, or empty array if not found
*/
function extractRolesFrom(obj, path) {
try {
let current = obj;
for (const part of path.split('.')) {
if (!current || typeof current !== 'object') {
return [];
}
current = current[part];
}
return Array.isArray(current) ? current : [];
} catch {
return [];
}
}
/**
* Retrieves user roles from either a token, the userinfo object, or both.
*
* Supports three strategies based on the roleSource:
* - 'token': Extract roles from the token (access or id token), fallback to userinfo if extraction fails.
* - 'userinfo': Extract roles solely from the userinfo object.
* - 'both': Extract roles from both token and userinfo and merge them.
*
* Also supports encrypted tokens by falling back to userinfo if the token is not JWT-decodable.
*
* @function getUserRoles
* @param {import('openid-client').TokenSet} tokenSet
* @param {Object} userinfo
* @param {string} rolePath - Dot-notation path to where roles are stored
* @param {'access'|'id'} tokenKind - Which token to parse for roles
* @param {'token'|'userinfo'|'both'} roleSource - Source of roles for extraction
* @returns {string[]} Array of roles, possibly empty
*/
function getUserRoles(tokenSet, userinfo, rolePath, tokenKind, roleSource) {
if (!tokenSet) {
return extractRolesFrom(userinfo, rolePath);
}
if (roleSource === 'userinfo') {
const roles = extractRolesFrom(userinfo, rolePath);
if (!roles.length) {
logger.warn(`[openidStrategy] Key '${rolePath}' not found in userinfo.`);
}
return roles;
} else if (roleSource === 'both') {
let tokenRoles = [];
try {
let tokenToDecode = tokenKind === 'access' ? tokenSet.access_token : tokenSet.id_token;
if (tokenToDecode && tokenToDecode.includes('.')) {
const tokenData = jwtDecode(tokenToDecode);
tokenRoles = extractRolesFrom(tokenData, rolePath);
} else {
logger.warn(
'[openidStrategy] Token is not a valid JWT for decoding, skipping token roles extraction.',
);
}
} catch (err) {
logger.error(`[openidStrategy] Failed to decode ${tokenKind} token: ${err}.`);
}
const userinfoRoles = extractRolesFrom(userinfo, rolePath);
const combinedRoles = Array.from(new Set([...tokenRoles, ...userinfoRoles]));
if (!combinedRoles.length) {
logger.warn(`[openidStrategy] Key '${rolePath}' not found in both token and userinfo.`);
}
return combinedRoles;
} else {
// default 'token' strategy
try {
let tokenToDecode = tokenKind === 'access' ? tokenSet.access_token : tokenSet.id_token;
if (!tokenToDecode || !tokenToDecode.includes('.')) {
throw new Error('Token is not a valid JWT for decoding.');
}
const tokenData = jwtDecode(tokenToDecode);
const roles = extractRolesFrom(tokenData, rolePath);
if (!roles.length) {
logger.warn(
`[openidStrategy] Key '${rolePath}' not found in ${tokenKind} token. Falling back to userinfo.`,
);
return extractRolesFrom(userinfo, rolePath);
}
return roles;
} catch (err) {
logger.error(`[openidStrategy] ${err}. Falling back to userinfo for role extraction.`);
return extractRolesFrom(userinfo, rolePath);
}
}
}
/**
* Registers and configures the OpenID Connect strategy with Passport, enabling PKCE when toggled.
*
* @async
* @function setupOpenId
* @returns {Promise<void>}
*/
async function setupOpenId() {
try {
// Set up a proxy if specified
if (process.env.PROXY) {
const proxyAgent = new HttpsProxyAgent(process.env.PROXY);
custom.setHttpOptionsDefaults({ agent: proxyAgent });
logger.info(`[openidStrategy] Using proxy: ${process.env.PROXY}`);
custom.setHttpOptionsDefaults({
agent: proxyAgent,
});
logger.info(`[openidStrategy] proxy agent added: ${process.env.PROXY}`);
}
// Discover issuer configuration
const issuer = await Issuer.discover(process.env.OPENID_ISSUER);
logger.info(`[openidStrategy] Discovered issuer: ${issuer.issuer}`);
/**
* Supported Algorithms, openid-client v5 doesn't set it automatically as discovered from server.
* - id_token_signed_response_alg // defaults to 'RS256'
* - request_object_signing_alg // defaults to 'RS256'
* - userinfo_signed_response_alg // not in v5
* - introspection_signed_response_alg // not in v5
* - authorization_signed_response_alg // not in v5
*/
/* Supported Algorithms, openid-client v5 doesn't set it automatically as discovered from server.
- id_token_signed_response_alg // defaults to 'RS256'
- request_object_signing_alg // defaults to 'RS256'
- userinfo_signed_response_alg // not in v5
- introspection_signed_response_alg // not in v5
- authorization_signed_response_alg // not in v5
*/
/** @type {import('openid-client').ClientMetadata} */
const clientMetadata = {
client_id: process.env.OPENID_CLIENT_ID,
client_secret: process.env.OPENID_CLIENT_SECRET || '',
client_secret: process.env.OPENID_CLIENT_SECRET,
redirect_uris: [process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL],
};
// Optionally force the first supported signing algorithm
if (isEnabled(process.env.OPENID_SET_FIRST_SUPPORTED_ALGORITHM)) {
clientMetadata.id_token_signed_response_alg =
issuer.id_token_signing_alg_values_supported?.[0] || 'RS256';
}
const client = new issuer.Client(clientMetadata);
// Determine whether to enable PKCE
const usePKCE = process.env.OPENID_USE_PKCE === 'true';
// Set up authorization parameters. Include code_challenge_method if PKCE is enabled.
const openidScope = process.env.OPENID_SCOPE || 'openid profile email';
/** @type {import('openid-client').AuthorizationParameters} */
const params = {
scope: openidScope,
response_type: 'code',
};
if (usePKCE) {
params.code_challenge_method = 'S256'; // Enable PKCE by specifying the code challenge method
}
// Role-based config
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
const rolePath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const tokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND || 'id'; // 'id'|'access'
const roleSource = process.env.OPENID_REQUIRED_ROLE_SOURCE || 'both'; // 'token'|'userinfo'|'both'
// Create the Passport strategy using the new type-correct instantiation and toggle for PKCE
const openidStrategy = new OpenIDStrategy(
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
const openidLogin = new OpenIDStrategy(
{
client,
params,
usePKCE,
params: {
scope: process.env.OPENID_SCOPE,
},
},
async (tokenSet, userinfo, done) => {
async (tokenset, userinfo, done) => {
try {
logger.info(`[openidStrategy] Verifying login for sub=${userinfo.sub}`);
logger.info(`[openidStrategy] verify login openidId: ${userinfo.sub}`);
logger.debug('[openidStrategy] very login tokenset and userinfo', { tokenset, userinfo });
// Find user by openidId or fallback to email
let user = await findUser({ openidId: userinfo.sub });
if (!user && userinfo.email) {
logger.info(
`[openidStrategy] user ${user ? 'found' : 'not found'} with openidId: ${userinfo.sub}`,
);
if (!user) {
user = await findUser({ email: userinfo.email });
logger.info(
`[openidStrategy] User ${user ? 'found' : 'not found'} by email=${userinfo.email}.`,
`[openidStrategy] user ${user ? 'found' : 'not found'} with email: ${
userinfo.email
} for openidId: ${userinfo.sub}`,
);
}
// If a role is required, check user roles
if (requiredRole && rolePath) {
const roles = getUserRoles(tokenSet, userinfo, rolePath, tokenKind, roleSource);
if (!roles.includes(requiredRole)) {
logger.warn(
`[openidStrategy] Missing required role "${requiredRole}". Roles: [${roles.join(', ')}]`,
const fullName = getFullName(userinfo);
if (requiredRole) {
let decodedToken = '';
if (requiredRoleTokenKind === 'access') {
decodedToken = jwtDecode(tokenset.access_token);
} else if (requiredRoleTokenKind === 'id') {
decodedToken = jwtDecode(tokenset.id_token);
}
const pathParts = requiredRoleParameterPath.split('.');
let found = true;
let roles = pathParts.reduce((o, key) => {
if (o === null || o === undefined || !(key in o)) {
found = false;
return [];
}
return o[key];
}, decodedToken);
if (!found) {
logger.error(
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
);
}
if (!roles.includes(requiredRole)) {
return done(null, false, {
message: `You must have the "${requiredRole}" role to log in.`,
});
}
}
// Derive name and username
const fullName = getFullName(userinfo);
const username = process.env.OPENID_USERNAME_CLAIM
? convertToUsername(userinfo[process.env.OPENID_USERNAME_CLAIM])
: convertToUsername(userinfo.username || userinfo.given_name || userinfo.email);
// Create or update user
if (!user) {
logger.info(`[openidStrategy] Creating a new user for sub=${userinfo.sub}`);
user = await createUser(
{
provider: 'openid',
openidId: userinfo.sub,
username,
email: userinfo.email || '',
emailVerified: Boolean(userinfo.email_verified) || false,
name: fullName,
},
true,
true,
let username = '';
if (process.env.OPENID_USERNAME_CLAIM) {
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
} else {
username = convertToUsername(
userinfo.username || userinfo.given_name || userinfo.email,
);
}
if (!user) {
user = {
provider: 'openid',
openidId: userinfo.sub,
username,
email: userinfo.email || '',
emailVerified: userinfo.email_verified || false,
name: fullName,
};
user = await createUser(user, true, true);
} else {
user.provider = 'openid';
user.openidId = userinfo.sub;
@@ -321,44 +220,54 @@ async function setupOpenId() {
user.name = fullName;
}
// Fetch avatar if not manually overridden
if (userinfo.picture && !String(user.avatar || '').includes('manual=true')) {
const imageBuffer = await downloadImage(userinfo.picture, tokenSet.access_token);
if (userinfo.picture && !user.avatar?.includes('manual=true')) {
/** @type {string | undefined} */
const imageUrl = userinfo.picture;
let fileName;
if (crypto) {
fileName = (await hashToken(userinfo.sub)) + '.png';
} else {
fileName = userinfo.sub + '.png';
}
const imageBuffer = await downloadImage(imageUrl, tokenset.access_token);
if (imageBuffer) {
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
const fileHash = crypto ? await hashToken(userinfo.sub) : userinfo.sub;
const fileName = `${fileHash}.png`;
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),
buffer: imageBuffer,
});
if (imagePath) {
user.avatar = imagePath;
}
user.avatar = imagePath ?? '';
}
}
// Persist user changes
user = await updateUser(user._id, user);
// Success
logger.info(
`[openidStrategy] Login success for sub=${user.openidId}, email=${user.email}, username=${user.username}`,
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
{
user: {
openidId: user.openidId,
username: user.username,
email: user.email,
name: user.name,
},
},
);
return done(null, user);
done(null, user);
} catch (err) {
logger.error('[openidStrategy] Login verification failed:', err);
return done(err);
logger.error('[openidStrategy] login failed', err);
done(err);
}
},
);
// Register the strategy under the 'openid' name
passport.use('openid', openidStrategy);
passport.use('openid', openidLogin);
} catch (err) {
logger.error('[openidStrategy] Error setting up OpenID strategy:', err);
logger.error('[openidStrategy]', err);
}
}

View File

@@ -10,6 +10,7 @@ jest.mock('openid-client');
jest.mock('jsonwebtoken/decode');
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({
// You can modify this mock as needed (here returning a dummy function)
saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'),
})),
}));
@@ -22,20 +23,18 @@ jest.mock('~/server/utils/crypto', () => ({
hashToken: jest.fn().mockResolvedValue('hashed-token'),
}));
jest.mock('~/server/utils', () => ({
isEnabled: jest.fn(() => false), // default to false; override per test if needed
isEnabled: jest.fn(() => false), // default to false, override per test if needed
}));
jest.mock('~/config', () => ({
logger: {
info: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
warn: jest.fn(),
},
}));
// Update Issuer.discover mock so that the returned issuer has an 'issuer' property.
// Mock Issuer.discover so that setupOpenId gets a fake issuer and client
Issuer.discover = jest.fn().mockResolvedValue({
issuer: 'https://fake-issuer.com',
id_token_signing_alg_values_supported: ['RS256'],
Client: jest.fn().mockImplementation((clientMetadata) => {
return {
@@ -44,7 +43,7 @@ Issuer.discover = jest.fn().mockResolvedValue({
}),
});
// To capture the verify callback from the strategy, we grab it from the mock constructor.
// To capture the verify callback from the strategy, we grab it from the mock constructor
let verifyCallback;
OpenIDStrategy.mockImplementation((options, verify) => {
verifyCallback = verify;
@@ -52,21 +51,21 @@ OpenIDStrategy.mockImplementation((options, verify) => {
});
describe('setupOpenId', () => {
// Helper to wrap the verify callback in a promise.
// Helper to wrap the verify callback in a promise
const validate = (tokenset, userinfo) =>
new Promise((resolve, reject) => {
verifyCallback(tokenset, userinfo, (err, user, details) => {
if (err) {
return reject(err);
reject(err);
} else {
resolve({ user, details });
}
resolve({ user, details });
});
});
// Default tokenset: tokens include a period to simulate a JWT.
const validTokenSet = {
id_token: 'header.payload.signature',
access_token: 'header.payload.signature',
const tokenset = {
id_token: 'fake_id_token',
access_token: 'fake_access_token',
};
const baseUserinfo = {
@@ -78,14 +77,13 @@ describe('setupOpenId', () => {
name: 'My Full',
username: 'flast',
picture: 'https://example.com/avatar.png',
roles: ['requiredRole'],
};
beforeEach(async () => {
// Clear previous mock calls and reset implementations.
// Clear previous mock calls and reset implementations
jest.clearAllMocks();
// Reset environment variables needed by the strategy.
// Reset environment variables needed by the strategy
process.env.OPENID_ISSUER = 'https://fake-issuer.com';
process.env.OPENID_CLIENT_ID = 'fake_client_id';
process.env.OPENID_CLIENT_SECRET = 'fake_client_secret';
@@ -95,29 +93,26 @@ describe('setupOpenId', () => {
process.env.OPENID_REQUIRED_ROLE = 'requiredRole';
process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'roles';
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id';
process.env.OPENID_REQUIRED_ROLE_SOURCE = 'token';
delete process.env.OPENID_USERNAME_CLAIM;
delete process.env.OPENID_NAME_CLAIM;
delete process.env.PROXY;
delete process.env.OPENID_USE_PKCE;
delete process.env.OPENID_SET_FIRST_SUPPORTED_ALGORITHM;
// By default, jwtDecode returns a token that includes the required role.
// Default jwtDecode mock returns a token that includes the required role.
jwtDecode.mockReturnValue({
roles: ['requiredRole'],
});
// By default, assume that no user is found so that createUser will be called.
// By default, assume that no user is found, so createUser will be called
findUser.mockResolvedValue(null);
createUser.mockImplementation(async (userData) => {
// Simulate created user with an _id property.
// simulate created user with an _id property
return { _id: 'newUserId', ...userData };
});
updateUser.mockImplementation(async (id, userData) => {
return { _id: id, ...userData };
});
// For image download, simulate a successful response.
// For image download, simulate a successful response
const fakeBuffer = Buffer.from('fake image');
const fakeResponse = {
ok: true,
@@ -125,13 +120,18 @@ describe('setupOpenId', () => {
};
fetch.mockResolvedValue(fakeResponse);
// (Re)initialize the strategy with current env settings.
// Finally, call the setup function so that passport.use gets called
await setupOpenId();
});
it('should create a new user with correct username when username claim exists', async () => {
// Arrange our userinfo already has username 'flast'
const userinfo = { ...baseUserinfo };
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert
expect(user.username).toBe(userinfo.username);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
@@ -147,10 +147,16 @@ describe('setupOpenId', () => {
});
it('should use given_name as username when username claim is missing', async () => {
// Arrange remove username from userinfo
const userinfo = { ...baseUserinfo };
delete userinfo.username;
// Expect the username to be the given name (unchanged case)
const expectUsername = userinfo.given_name;
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert
expect(user.username).toBe(expectUsername);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({ username: expectUsername }),
@@ -160,11 +166,16 @@ describe('setupOpenId', () => {
});
it('should use email as username when username and given_name are missing', async () => {
// Arrange remove username and given_name
const userinfo = { ...baseUserinfo };
delete userinfo.username;
delete userinfo.given_name;
const expectUsername = userinfo.email;
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert
expect(user.username).toBe(expectUsername);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({ username: expectUsername }),
@@ -174,10 +185,14 @@ describe('setupOpenId', () => {
});
it('should override username with OPENID_USERNAME_CLAIM when set', async () => {
// Arrange set OPENID_USERNAME_CLAIM so that the sub claim is used
process.env.OPENID_USERNAME_CLAIM = 'sub';
const userinfo = { ...baseUserinfo };
await setupOpenId();
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert username should equal the sub (converted as-is)
expect(user.username).toBe(userinfo.sub);
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({ username: userinfo.sub }),
@@ -187,21 +202,31 @@ describe('setupOpenId', () => {
});
it('should set the full name correctly when given_name and family_name exist', async () => {
// Arrange
const userinfo = { ...baseUserinfo };
const expectedFullName = `${userinfo.given_name} ${userinfo.family_name}`;
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert
expect(user.name).toBe(expectedFullName);
});
it('should override full name with OPENID_NAME_CLAIM when set', async () => {
// Arrange use the name claim as the full name
process.env.OPENID_NAME_CLAIM = 'name';
const userinfo = { ...baseUserinfo, name: 'Custom Name' };
await setupOpenId();
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert
expect(user.name).toBe('Custom Name');
});
it('should update an existing user on login', async () => {
// Arrange simulate that a user already exists
const existingUser = {
_id: 'existingUserId',
provider: 'local',
@@ -216,8 +241,13 @@ describe('setupOpenId', () => {
}
return null;
});
const userinfo = { ...baseUserinfo };
await validate(validTokenSet, userinfo);
// Act
await validate(tokenset, userinfo);
// Assert updateUser should be called and the user object updated
expect(updateUser).toHaveBeenCalledWith(
existingUser._id,
expect.objectContaining({
@@ -230,154 +260,43 @@ describe('setupOpenId', () => {
});
it('should enforce the required role and reject login if missing', async () => {
jwtDecode.mockReturnValue({ roles: ['SomeOtherRole'] });
// Arrange simulate a token without the required role.
jwtDecode.mockReturnValue({
roles: ['SomeOtherRole'],
});
const userinfo = { ...baseUserinfo };
const { user, details } = await validate(validTokenSet, userinfo);
// Act
const { user, details } = await validate(tokenset, userinfo);
// Assert verify that the strategy rejects login
expect(user).toBe(false);
expect(details.message).toBe('You must have the "requiredRole" role to log in.');
});
it('should attempt to download and save the avatar if picture is provided', async () => {
// Arrange ensure userinfo contains a picture URL
const userinfo = { ...baseUserinfo };
const { user } = await validate(validTokenSet, userinfo);
// Act
const { user } = await validate(tokenset, userinfo);
// Assert verify that download was attempted and the avatar field was set via updateUser
expect(fetch).toHaveBeenCalled();
// Our mock getStrategyFunctions.saveBuffer returns '/fake/path/to/avatar.png'
expect(user.avatar).toBe('/fake/path/to/avatar.png');
});
it('should not attempt to download avatar if picture is not provided', async () => {
// Arrange remove picture
const userinfo = { ...baseUserinfo };
delete userinfo.picture;
await validate(validTokenSet, userinfo);
// Act
await validate(tokenset, userinfo);
// Assert fetch should not be called and avatar should remain undefined or empty
expect(fetch).not.toHaveBeenCalled();
});
it('should fallback to userinfo roles if the id_token is invalid (missing a period)', async () => {
const invalidTokenSet = { ...validTokenSet, id_token: 'invalidtoken' };
const userinfo = { ...baseUserinfo, roles: ['requiredRole'] };
const { user } = await validate(invalidTokenSet, userinfo);
expect(user).toBeDefined();
expect(createUser).toHaveBeenCalled();
});
it('should handle downloadImage failure gracefully and not set an avatar', async () => {
fetch.mockRejectedValue(new Error('network error'));
const userinfo = { ...baseUserinfo };
const { user } = await validate(validTokenSet, userinfo);
expect(fetch).toHaveBeenCalled();
expect(user.avatar).toBeUndefined();
});
it('should allow login if no required role is specified', async () => {
delete process.env.OPENID_REQUIRED_ROLE;
delete process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
jwtDecode.mockReturnValue({});
const userinfo = { ...baseUserinfo };
const { user } = await validate(validTokenSet, userinfo);
expect(user).toBeDefined();
expect(createUser).toHaveBeenCalled();
});
it('should use roles from userinfo when OPENID_REQUIRED_ROLE_SOURCE is set to "userinfo"', async () => {
process.env.OPENID_REQUIRED_ROLE_SOURCE = 'userinfo';
jwtDecode.mockReturnValue({});
const userinfo = { ...baseUserinfo, roles: ['requiredRole'] };
await setupOpenId();
const { user } = await validate(validTokenSet, userinfo);
expect(user).toBeDefined();
expect(createUser).toHaveBeenCalled();
});
it('should merge roles from both token and userinfo when OPENID_REQUIRED_ROLE_SOURCE is "both"', async () => {
process.env.OPENID_REQUIRED_ROLE_SOURCE = 'both';
jwtDecode.mockReturnValue({ roles: ['extraRole'] });
const userinfo = { ...baseUserinfo, roles: ['requiredRole'] };
await setupOpenId();
const { user } = await validate(validTokenSet, userinfo);
expect(user).toBeDefined();
expect(createUser).toHaveBeenCalled();
});
it('should fall back to userinfo roles when token decode fails and roleSource is "both"', async () => {
process.env.OPENID_REQUIRED_ROLE_SOURCE = 'both';
jwtDecode.mockImplementation(() => {
throw new Error('Decode error');
});
const userinfo = { ...baseUserinfo, roles: ['requiredRole'] };
await setupOpenId();
const { user } = await validate(validTokenSet, userinfo);
expect(user).toBeDefined();
expect(createUser).toHaveBeenCalled();
});
it('should merge roles from both token and userinfo when token is invalid and roleSource is "both"', async () => {
process.env.OPENID_REQUIRED_ROLE_SOURCE = 'both';
const invalidTokenSet = { ...validTokenSet, id_token: 'invalidtoken' };
const userinfo = { ...baseUserinfo, roles: ['requiredRole'] };
await setupOpenId();
const { user } = await validate(invalidTokenSet, userinfo);
expect(user).toBeDefined();
expect(createUser).toHaveBeenCalled();
});
it('should reject login if merged roles from both token and userinfo do not include required role', async () => {
process.env.OPENID_REQUIRED_ROLE_SOURCE = 'both';
jwtDecode.mockReturnValue({ roles: ['SomeOtherRole'] });
const userinfo = { ...baseUserinfo, roles: ['AnotherRole'] };
await setupOpenId();
const { user, details } = await validate(validTokenSet, userinfo);
expect(user).toBe(false);
expect(details.message).toBe('You must have the "requiredRole" role to log in.');
});
it('should pass usePKCE true and set code_challenge_method in params when OPENID_USE_PKCE is "true"', async () => {
process.env.OPENID_USE_PKCE = 'true';
await setupOpenId();
const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0];
expect(callOptions.usePKCE).toBe(true);
expect(callOptions.params.code_challenge_method).toBe('S256');
});
it('should pass usePKCE false and not set code_challenge_method in params when OPENID_USE_PKCE is "false"', async () => {
process.env.OPENID_USE_PKCE = 'false';
await setupOpenId();
const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0];
expect(callOptions.usePKCE).toBe(false);
expect(callOptions.params.code_challenge_method).toBeUndefined();
});
it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => {
delete process.env.OPENID_USE_PKCE;
await setupOpenId();
const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0];
expect(callOptions.usePKCE).toBe(false);
expect(callOptions.params.code_challenge_method).toBeUndefined();
});
it('should set id_token_signed_response_alg if OPENID_SET_FIRST_SUPPORTED_ALGORITHM is enabled', async () => {
process.env.OPENID_SET_FIRST_SUPPORTED_ALGORITHM = 'true';
// Override isEnabled so that it returns true.
const { isEnabled } = require('~/server/utils');
isEnabled.mockReturnValue(true);
await setupOpenId();
const callOptions = OpenIDStrategy.mock.calls[OpenIDStrategy.mock.calls.length - 1][0];
expect(callOptions.client.metadata.id_token_signed_response_alg).toBe('RS256');
});
it('should use access token when OPENID_REQUIRED_ROLE_TOKEN_KIND is set to "access"', async () => {
process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'access';
// Reinitialize strategy so that the new token kind is used.
await setupOpenId();
jwtDecode.mockClear();
jwtDecode.mockReturnValue({ roles: ['requiredRole'] });
const userinfo = { ...baseUserinfo };
await validate(validTokenSet, userinfo);
expect(jwtDecode).toHaveBeenCalledWith(validTokenSet.access_token);
});
it('should use proxy agent if PROXY is provided', async () => {
process.env.PROXY = 'http://fake-proxy.com';
await setupOpenId();
const { logger } = require('~/config');
expect(logger.info).toHaveBeenCalledWith(`[openidStrategy] Using proxy: ${process.env.PROXY}`);
// Depending on your implementation, user.avatar may be undefined or an empty string.
});
});

View File

@@ -71,6 +71,17 @@ interface:
multiConvo: true
agents: true
# Example Custom Token Rates (optional)
#tokenRates:
# gpt-4o-mini:
# prompt: 200.0
# completion: 400.0
# claude-3.7-sonnet:
# Cache:
# read: 200.0
# write: 400.0
# Example Registration Object Structure (optional)
registration:
socialLogins: ['github', 'google', 'discord', 'openid', 'facebook', 'apple']

View File

@@ -536,6 +536,7 @@ export type TStartupConfig = {
helpAndFaqURL: string;
customFooter?: string;
modelSpecs?: TSpecsConfig;
tokenRates?: TModelTokenRates;
sharedLinksEnabled: boolean;
publicSharedLinksEnabled: boolean;
analyticsGtmId?: string;
@@ -544,6 +545,31 @@ export type TStartupConfig = {
staticBundlerURL?: string;
};
// Token cost schema type
export type TTokenCost = {
prompt?: number;
completion?: number;
cache?: {
write?: number;
read?: number;
};
};
// Endpoint token rates schema type
export type TModelTokenRates = Record<string, TTokenCost>;
const tokenCostSchema = z.object({
prompt: z.number().optional(), // e.g. 1.5 => $1.50 / 1M tokens
completion: z.number().optional(), // e.g. 2.0 => $2.00 / 1M tokens
cache: z
.object({
write: z.number().optional(),
read: z.number().optional(),
})
.optional(),
});
export enum OCRStrategy {
MISTRAL_OCR = 'mistral_ocr',
CUSTOM_OCR = 'custom_ocr',
@@ -601,6 +627,7 @@ export const configSchema = z.object({
rateLimits: rateLimitSchema.optional(),
fileConfig: fileConfigSchema.optional(),
modelSpecs: specsConfigSchema.optional(),
tokenRates: tokenCostSchema.optional(),
endpoints: z
.object({
all: baseEndpointSchema.optional(),