feat: Add custom fields & role assignment to OpenID strategy (#5612)

* started with Support for Customizable OpenID Profile Fields via Environment Variable

* kept as much of the original code as possible but still added the custom data mapper

* kept as much of the original code as possible but still added the custom data mapper

* resolved merge conflicts

* resolved merge conflicts

* resolved merge conflicts

* resolved merge conflicts

* removed some unneeded comments

* fix: conflicted issue

---------

Co-authored-by: Talstra Ruben SRSNL <ruben.talstra@stadlerrail.com>
This commit is contained in:
Ruben Talstra
2025-02-11 16:42:05 +01:00
committed by GitHub
parent 404b27d045
commit 2ef6e4462d
5 changed files with 250 additions and 2 deletions

View File

@@ -422,6 +422,9 @@ OPENID_NAME_CLAIM=
OPENID_BUTTON_LABEL=
OPENID_IMAGE_URL=
OPENID_CUSTOM_DATA=
OPENID_PROVIDER=
OPENID_ADMIN_ROLE=
# LDAP
LDAP_URL=
@@ -527,4 +530,4 @@ HELP_AND_FAQ_URL=https://librechat.ai
#=====================================================#
# OpenWeather #
#=====================================================#
OPENWEATHER_API_KEY=
OPENWEATHER_API_KEY=

View File

@@ -20,6 +20,7 @@ const { SystemRoles } = require('librechat-data-provider');
* @property {string} [googleId] - Optional Google ID for the user
* @property {string} [facebookId] - Optional Facebook ID for the user
* @property {string} [openidId] - Optional OpenID ID for the user
* @property {Map<string, string>} [customOpenIdData] - A map containing provider-specific custom data retrieved via OpenID Connect.
* @property {string} [ldapId] - Optional LDAP ID for the user
* @property {string} [githubId] - Optional GitHub ID for the user
* @property {string} [discordId] - Optional Discord ID for the user
@@ -97,6 +98,11 @@ const userSchema = mongoose.Schema(
unique: true,
sparse: true,
},
customOpenIdData: {
type: Map,
of: mongoose.Schema.Types.Mixed,
default: {},
},
ldapId: {
type: String,
unique: true,

View File

@@ -0,0 +1,168 @@
const fetch = require('node-fetch');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { logger } = require('~/config');
const { URL } = require('url');
// Microsoft SDK
const { Client: MicrosoftGraphClient } = require('@microsoft/microsoft-graph-client');
/**
* Base class for provider-specific data mappers.
*/
class BaseDataMapper {
/**
* Map custom OpenID data.
* @param {string} accessToken - The access token to authenticate the request.
* @param {string|Array<string>} customQuery - Either a full query string (if it contains operators)
* or an array of fields to select.
* @returns {Promise<Map<string, any>>} A promise that resolves to a map of custom fields.
* @throws {Error} Throws an error if not implemented in the subclass.
*/
async mapCustomData(accessToken, customQuery) {
throw new Error('mapCustomData() must be implemented by subclasses');
}
/**
* Optionally handle proxy settings for HTTP requests.
* @returns {Object} Configuration object with proxy settings if PROXY is set.
*/
getProxyOptions() {
if (process.env.PROXY) {
const agent = new HttpsProxyAgent(process.env.PROXY);
return { agent };
}
return {};
}
}
/**
* Microsoft-specific data mapper using the Microsoft Graph SDK.
*/
class MicrosoftDataMapper extends BaseDataMapper {
/**
* Initializes the MicrosoftGraphClient once for reuse.
*/
constructor() {
super();
this.accessToken = null;
this.client = MicrosoftGraphClient.init({
defaultVersion: 'beta',
authProvider: (done) => {
// The authProvider will be called for each request to get the token
if (this.accessToken) {
done(null, this.accessToken);
} else {
done(new Error('Access token is not set.'), null);
}
},
fetch: fetch,
...this.getProxyOptions(),
});
// Bind methods to maintain context
this.mapCustomData = this.mapCustomData.bind(this);
this.cleanData = this.cleanData.bind(this);
}
/**
* Set the access token for the client.
* This method should be called before making any requests.
*
* @param {string} accessToken - The access token.
*/
setAccessToken(accessToken) {
if (!accessToken || typeof accessToken !== 'string') {
throw new Error('[MicrosoftDataMapper] Invalid access token provided.');
}
this.accessToken = accessToken;
}
/**
* Map custom OpenID data using the Microsoft Graph SDK.
*
* @param {string} accessToken - The access token to authenticate the request.
* @param {string|Array<string>} customQuery - Fields to select from the Microsoft Graph API.
* @returns {Promise<Map<string, any>>} A promise that resolves to a map of custom fields.
*/
async mapCustomData(accessToken, customQuery) {
try {
this.setAccessToken(accessToken);
if (!customQuery) {
logger.warn('[MicrosoftDataMapper] No customQuery provided.');
return new Map();
}
// Convert customQuery to a comma-separated string if it's an array
const fields = Array.isArray(customQuery) ? customQuery.join(',') : customQuery;
if (!fields) {
logger.warn('[MicrosoftDataMapper] No fields specified in customQuery.');
return new Map();
}
const result = await this.client
.api('/me')
.select(fields)
.get();
const cleanedData = this.cleanData(result);
return new Map(Object.entries(cleanedData));
} catch (error) {
// Handle specific Microsoft Graph errors if needed
logger.error(`[MicrosoftDataMapper] Error fetching user data: ${error.message}`, { stack: error.stack });
return new Map();
}
}
/**
* Recursively remove all keys starting with @odata. from an object and convert Maps.
*
* @param {object|Array} obj - The object or array to clean.
* @returns {object|Array} - The cleaned object or array.
*/
cleanData(obj) {
if (Array.isArray(obj)) {
return obj.map(this.cleanData);
} else if (obj && typeof obj === 'object') {
return Object.entries(obj).reduce((acc, [key, value]) => {
if (!key.startsWith('@odata.')) {
acc[key] = this.cleanData(value);
}
return acc;
}, {});
}
return obj;
}
}
/**
* Map provider names to their specific data mappers.
*/
const PROVIDER_MAPPERS = {
// Fully Working
microsoft: MicrosoftDataMapper,
};
/**
* Abstraction layer that returns a provider-specific mapper instance.
*/
class OpenIdDataMapper {
/**
* Retrieve an instance of the mapper for the specified provider.
*
* @param {string} provider - The name of the provider (e.g., 'microsoft').
* @returns {BaseDataMapper} An instance of the specific data mapper for the provider.
* @throws {Error} Throws an error if no mapper is found for the specified provider.
*/
static getMapper(provider) {
const MapperClass = PROVIDER_MAPPERS[provider.toLowerCase()];
if (!MapperClass) {
throw new Error(`No mapper found for provider: ${provider}`);
}
return new MapperClass();
}
}
module.exports = OpenIdDataMapper;

View File

@@ -8,6 +8,8 @@ const { findUser, createUser, updateUser } = require('~/models/userMethods');
const { hashToken } = require('~/server/utils/crypto');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const { SystemRoles } = require('librechat-data-provider');
const OpenIdDataMapper = require('./OpenId/openidDataMapper');
let crypto;
try {
@@ -105,6 +107,45 @@ function convertToUsername(input, defaultValue = '') {
return defaultValue;
}
/**
* Decodes a JWT token safely.
* @param {string} token
* @returns {Object|null}
*/
function safeDecode(token) {
try {
const decoded = jwtDecode(token);
if (decoded && typeof decoded === 'object') {
return decoded;
}
logger.error('[openidStrategy] Decoded token is not an object.');
return null;
} catch (error) {
logger.error('[openidStrategy] safeDecode: Error decoding token:', error);
return null;
}
}
/**
* Extracts roles from a decoded token based on the provided path.
* @param {Object} decodedToken
* @param {string} parameterPath
* @returns {string[]}
*/
function extractRolesFromToken(decodedToken, parameterPath) {
if (!decodedToken) {
return [];
}
const roles = parameterPath.split('.').reduce((obj, key) => (obj?.[key] ?? null), decodedToken);
if (!Array.isArray(roles)) {
logger.error('[openidStrategy] extractRolesFromToken: Roles extracted from token are not in array format.');
return [];
}
return roles;
}
async function setupOpenId() {
try {
if (process.env.PROXY) {
@@ -136,6 +177,7 @@ async function setupOpenId() {
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
const adminRole = process.env.OPENID_ADMIN_ROLE;
const openidLogin = new OpenIDStrategy(
{
client,
@@ -194,6 +236,30 @@ async function setupOpenId() {
}
}
let customOpenIdData = new Map();
if (process.env.OPENID_CUSTOM_DATA) {
const dataMapper = OpenIdDataMapper.getMapper(process.env.OPENID_PROVIDER.toLowerCase());
customOpenIdData = await dataMapper.mapCustomData(tokenset.access_token, process.env.OPENID_CUSTOM_DATA);
const tokenBasedRoles =
requiredRole &&
extractRolesFromToken(
safeDecode(requiredRoleTokenKind === 'access' ? tokenset.access_token : tokenset.id_token),
requiredRoleParameterPath,
);
if (tokenBasedRoles && tokenBasedRoles.length) {
customOpenIdData.set('roles', tokenBasedRoles);
} else {
logger.warn('[openidStrategy] tokenBasedRoles is missing or invalid.');
}
}
const token = requiredRoleTokenKind === 'access' ? tokenset.access_token : tokenset.id_token;
const decodedToken = safeDecode(token);
const tokenBasedRoles = extractRolesFromToken(decodedToken, requiredRoleParameterPath);
const isAdmin = tokenBasedRoles.includes(adminRole);
const assignedRole = isAdmin ? SystemRoles.ADMIN : SystemRoles.USER;
logger.debug(`[openidStrategy] Assigned system role: ${assignedRole} (isAdmin: ${isAdmin})`);
let username = '';
if (process.env.OPENID_USERNAME_CLAIM) {
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
@@ -211,6 +277,8 @@ async function setupOpenId() {
email: userinfo.email || '',
emailVerified: userinfo.email_verified || false,
name: fullName,
role: assignedRole,
customOpenIdData: customOpenIdData,
};
user = await createUser(user, true, true);
} else {
@@ -218,6 +286,8 @@ async function setupOpenId() {
user.openidId = userinfo.sub;
user.username = username;
user.name = fullName;
user.role = assignedRole;
user.customOpenIdData = customOpenIdData;
}
if (userinfo.picture && !user.avatar?.includes('manual=true')) {
@@ -271,4 +341,4 @@ async function setupOpenId() {
}
}
module.exports = setupOpenId;
module.exports = setupOpenId;

View File

@@ -110,6 +110,7 @@ export type TUser = {
plugins?: string[];
createdAt: string;
updatedAt: string;
customOpenIdData: { [key: string]: any };
};
export type TGetConversationsResponse = {