From 2ef6e4462deb291279ca4d03c1c72dbd71b0385e Mon Sep 17 00:00:00 2001 From: Ruben Talstra Date: Tue, 11 Feb 2025 16:42:05 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Add=20custom=20fields=20&?= =?UTF-8?q?=20role=20assignment=20to=20OpenID=20strategy=20(#5612)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * started with Support for Customizable OpenID Profile Fields via Environment Variable * kept as much of the original code as possible but still added the custom data mapper * kept as much of the original code as possible but still added the custom data mapper * resolved merge conflicts * resolved merge conflicts * resolved merge conflicts * resolved merge conflicts * removed some unneeded comments * fix: conflicted issue --------- Co-authored-by: Talstra Ruben SRSNL --- .env.example | 5 +- api/models/schema/userSchema.js | 6 + api/strategies/OpenId/openidDataMapper.js | 168 ++++++++++++++++++++++ api/strategies/openidStrategy.js | 72 +++++++++- packages/data-provider/src/types.ts | 1 + 5 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 api/strategies/OpenId/openidDataMapper.js diff --git a/.env.example b/.env.example index d87021ea4..a394b5de2 100644 --- a/.env.example +++ b/.env.example @@ -422,6 +422,9 @@ OPENID_NAME_CLAIM= OPENID_BUTTON_LABEL= OPENID_IMAGE_URL= +OPENID_CUSTOM_DATA= +OPENID_PROVIDER= +OPENID_ADMIN_ROLE= # LDAP LDAP_URL= @@ -527,4 +530,4 @@ HELP_AND_FAQ_URL=https://librechat.ai #=====================================================# # OpenWeather # #=====================================================# -OPENWEATHER_API_KEY= \ No newline at end of file +OPENWEATHER_API_KEY= diff --git a/api/models/schema/userSchema.js b/api/models/schema/userSchema.js index f58655336..3bab46bca 100644 --- a/api/models/schema/userSchema.js +++ b/api/models/schema/userSchema.js @@ -20,6 +20,7 @@ const { SystemRoles } = require('librechat-data-provider'); * @property {string} [googleId] - Optional Google ID for the user * @property {string} [facebookId] - Optional Facebook ID for the user * @property {string} [openidId] - Optional OpenID ID for the user + * @property {Map} [customOpenIdData] - A map containing provider-specific custom data retrieved via OpenID Connect. * @property {string} [ldapId] - Optional LDAP ID for the user * @property {string} [githubId] - Optional GitHub ID for the user * @property {string} [discordId] - Optional Discord ID for the user @@ -97,6 +98,11 @@ const userSchema = mongoose.Schema( unique: true, sparse: true, }, + customOpenIdData: { + type: Map, + of: mongoose.Schema.Types.Mixed, + default: {}, + }, ldapId: { type: String, unique: true, diff --git a/api/strategies/OpenId/openidDataMapper.js b/api/strategies/OpenId/openidDataMapper.js new file mode 100644 index 000000000..b3bb85f47 --- /dev/null +++ b/api/strategies/OpenId/openidDataMapper.js @@ -0,0 +1,168 @@ +const fetch = require('node-fetch'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { logger } = require('~/config'); +const { URL } = require('url'); + +// Microsoft SDK +const { Client: MicrosoftGraphClient } = require('@microsoft/microsoft-graph-client'); + +/** + * Base class for provider-specific data mappers. + */ +class BaseDataMapper { + /** + * Map custom OpenID data. + * @param {string} accessToken - The access token to authenticate the request. + * @param {string|Array} customQuery - Either a full query string (if it contains operators) + * or an array of fields to select. + * @returns {Promise>} A promise that resolves to a map of custom fields. + * @throws {Error} Throws an error if not implemented in the subclass. + */ + async mapCustomData(accessToken, customQuery) { + throw new Error('mapCustomData() must be implemented by subclasses'); + } + + /** + * Optionally handle proxy settings for HTTP requests. + * @returns {Object} Configuration object with proxy settings if PROXY is set. + */ + getProxyOptions() { + if (process.env.PROXY) { + const agent = new HttpsProxyAgent(process.env.PROXY); + return { agent }; + } + return {}; + } +} + +/** + * Microsoft-specific data mapper using the Microsoft Graph SDK. + */ +class MicrosoftDataMapper extends BaseDataMapper { + /** + * Initializes the MicrosoftGraphClient once for reuse. + */ + constructor() { + super(); + this.accessToken = null; + + this.client = MicrosoftGraphClient.init({ + defaultVersion: 'beta', + authProvider: (done) => { + // The authProvider will be called for each request to get the token + if (this.accessToken) { + done(null, this.accessToken); + } else { + done(new Error('Access token is not set.'), null); + } + }, + fetch: fetch, + ...this.getProxyOptions(), + }); + + // Bind methods to maintain context + this.mapCustomData = this.mapCustomData.bind(this); + this.cleanData = this.cleanData.bind(this); + } + + /** + * Set the access token for the client. + * This method should be called before making any requests. + * + * @param {string} accessToken - The access token. + */ + setAccessToken(accessToken) { + if (!accessToken || typeof accessToken !== 'string') { + throw new Error('[MicrosoftDataMapper] Invalid access token provided.'); + } + this.accessToken = accessToken; + } + + /** + * Map custom OpenID data using the Microsoft Graph SDK. + * + * @param {string} accessToken - The access token to authenticate the request. + * @param {string|Array} customQuery - Fields to select from the Microsoft Graph API. + * @returns {Promise>} A promise that resolves to a map of custom fields. + */ + async mapCustomData(accessToken, customQuery) { + try { + this.setAccessToken(accessToken); + + if (!customQuery) { + logger.warn('[MicrosoftDataMapper] No customQuery provided.'); + return new Map(); + } + + // Convert customQuery to a comma-separated string if it's an array + const fields = Array.isArray(customQuery) ? customQuery.join(',') : customQuery; + + if (!fields) { + logger.warn('[MicrosoftDataMapper] No fields specified in customQuery.'); + return new Map(); + } + + const result = await this.client + .api('/me') + .select(fields) + .get(); + + const cleanedData = this.cleanData(result); + return new Map(Object.entries(cleanedData)); + } catch (error) { + // Handle specific Microsoft Graph errors if needed + logger.error(`[MicrosoftDataMapper] Error fetching user data: ${error.message}`, { stack: error.stack }); + return new Map(); + } + } + + /** + * Recursively remove all keys starting with @odata. from an object and convert Maps. + * + * @param {object|Array} obj - The object or array to clean. + * @returns {object|Array} - The cleaned object or array. + */ + cleanData(obj) { + if (Array.isArray(obj)) { + return obj.map(this.cleanData); + } else if (obj && typeof obj === 'object') { + return Object.entries(obj).reduce((acc, [key, value]) => { + if (!key.startsWith('@odata.')) { + acc[key] = this.cleanData(value); + } + return acc; + }, {}); + } + return obj; + } +} + +/** + * Map provider names to their specific data mappers. + */ +const PROVIDER_MAPPERS = { + // Fully Working + microsoft: MicrosoftDataMapper, +}; + +/** + * Abstraction layer that returns a provider-specific mapper instance. + */ +class OpenIdDataMapper { + /** + * Retrieve an instance of the mapper for the specified provider. + * + * @param {string} provider - The name of the provider (e.g., 'microsoft'). + * @returns {BaseDataMapper} An instance of the specific data mapper for the provider. + * @throws {Error} Throws an error if no mapper is found for the specified provider. + */ + static getMapper(provider) { + const MapperClass = PROVIDER_MAPPERS[provider.toLowerCase()]; + if (!MapperClass) { + throw new Error(`No mapper found for provider: ${provider}`); + } + return new MapperClass(); + } +} + +module.exports = OpenIdDataMapper; \ No newline at end of file diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index b26b11efe..1b37a0114 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -8,6 +8,8 @@ const { findUser, createUser, updateUser } = require('~/models/userMethods'); const { hashToken } = require('~/server/utils/crypto'); const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); +const { SystemRoles } = require('librechat-data-provider'); +const OpenIdDataMapper = require('./OpenId/openidDataMapper'); let crypto; try { @@ -105,6 +107,45 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Decodes a JWT token safely. + * @param {string} token + * @returns {Object|null} + */ +function safeDecode(token) { + try { + const decoded = jwtDecode(token); + if (decoded && typeof decoded === 'object') { + return decoded; + } + logger.error('[openidStrategy] Decoded token is not an object.'); + return null; + } catch (error) { + logger.error('[openidStrategy] safeDecode: Error decoding token:', error); + return null; + } +} + +/** + * Extracts roles from a decoded token based on the provided path. + * @param {Object} decodedToken + * @param {string} parameterPath + * @returns {string[]} + */ +function extractRolesFromToken(decodedToken, parameterPath) { + if (!decodedToken) { + return []; + } + + const roles = parameterPath.split('.').reduce((obj, key) => (obj?.[key] ?? null), decodedToken); + if (!Array.isArray(roles)) { + logger.error('[openidStrategy] extractRolesFromToken: Roles extracted from token are not in array format.'); + return []; + } + + return roles; +} + async function setupOpenId() { try { if (process.env.PROXY) { @@ -136,6 +177,7 @@ async function setupOpenId() { const requiredRole = process.env.OPENID_REQUIRED_ROLE; const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; + const adminRole = process.env.OPENID_ADMIN_ROLE; const openidLogin = new OpenIDStrategy( { client, @@ -194,6 +236,30 @@ async function setupOpenId() { } } + let customOpenIdData = new Map(); + if (process.env.OPENID_CUSTOM_DATA) { + const dataMapper = OpenIdDataMapper.getMapper(process.env.OPENID_PROVIDER.toLowerCase()); + customOpenIdData = await dataMapper.mapCustomData(tokenset.access_token, process.env.OPENID_CUSTOM_DATA); + const tokenBasedRoles = + requiredRole && + extractRolesFromToken( + safeDecode(requiredRoleTokenKind === 'access' ? tokenset.access_token : tokenset.id_token), + requiredRoleParameterPath, + ); + if (tokenBasedRoles && tokenBasedRoles.length) { + customOpenIdData.set('roles', tokenBasedRoles); + } else { + logger.warn('[openidStrategy] tokenBasedRoles is missing or invalid.'); + } + } + + const token = requiredRoleTokenKind === 'access' ? tokenset.access_token : tokenset.id_token; + const decodedToken = safeDecode(token); + const tokenBasedRoles = extractRolesFromToken(decodedToken, requiredRoleParameterPath); + const isAdmin = tokenBasedRoles.includes(adminRole); + const assignedRole = isAdmin ? SystemRoles.ADMIN : SystemRoles.USER; + logger.debug(`[openidStrategy] Assigned system role: ${assignedRole} (isAdmin: ${isAdmin})`); + let username = ''; if (process.env.OPENID_USERNAME_CLAIM) { username = userinfo[process.env.OPENID_USERNAME_CLAIM]; @@ -211,6 +277,8 @@ async function setupOpenId() { email: userinfo.email || '', emailVerified: userinfo.email_verified || false, name: fullName, + role: assignedRole, + customOpenIdData: customOpenIdData, }; user = await createUser(user, true, true); } else { @@ -218,6 +286,8 @@ async function setupOpenId() { user.openidId = userinfo.sub; user.username = username; user.name = fullName; + user.role = assignedRole; + user.customOpenIdData = customOpenIdData; } if (userinfo.picture && !user.avatar?.includes('manual=true')) { @@ -271,4 +341,4 @@ async function setupOpenId() { } } -module.exports = setupOpenId; +module.exports = setupOpenId; \ No newline at end of file diff --git a/packages/data-provider/src/types.ts b/packages/data-provider/src/types.ts index 6d9cd87c8..59f434fb8 100644 --- a/packages/data-provider/src/types.ts +++ b/packages/data-provider/src/types.ts @@ -110,6 +110,7 @@ export type TUser = { plugins?: string[]; createdAt: string; updatedAt: string; + customOpenIdData: { [key: string]: any }; }; export type TGetConversationsResponse = {