Compare commits

..

18 Commits

Author SHA1 Message Date
Danny Avila
278590d0bb refactor: update processOpenIDAuth to add a flag for processing existing users only 2025-09-17 20:43:27 -04:00
Danny Avila
41a4674469 chore: update logger import to use data-schemas in login controller 2025-09-17 20:14:15 -04:00
Danny Avila
e7a9cf88ac chore: update middleware imports/exports 2025-09-17 20:14:14 -04:00
Danny Avila
f6925f906b WIP: first pass, OpenID Proxy Auth 2025-09-17 20:14:12 -04:00
Danny Avila
e90fd1df15 refactor: reorder middleware imports for clarity 2025-09-17 20:13:46 -04:00
Danny Avila
a1f9f3dd39 refactor: re-use logic for admin routes 2025-09-17 20:13:46 -04:00
Danny Avila
fbe0def2fa WIP: admin auth 2025-09-17 20:13:46 -04:00
Federico Ruggi
d04da60b3b 💫 feat: MCP OAuth Auto-Reconnect (#9646)
* add oauth reconnect tracker

* add connection tracker to mcp manager

* reconnect oauth mcp servers function

* call reconnection in auth controller

* make sure to check connection in panel

* wait for isConnected

* add const for poll interval

* add logging to tryReconnect

* check expiration

* check mcp manager is not null

* check mcp manager is not null

* add test for reconnecting mcp server

* unify logic inside OAuthReconnectionManager

* test reconnection manager, adjust

* chore: reorder import statements in index.js

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports and use types explicitly

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-09-17 16:49:36 -04:00
keltschdt
0e94d97bfb fix: Disable TTL For Transient OIDC Users In Permission Service (#9643) 2025-09-17 14:21:36 -04:00
Danny Avila
45ab4d4503 🎋 refactor: Improve Message UI State Handling (#9678)
* refactor: `ExecuteCode` component with submission state handling and cancellation message

* fix: Remove unnecessary argument check for execute_code tool call

* refactor: streamlined messages context

* chore: remove unused Convo prop

* chore: remove unnecessary whitespace in Message component

* refactor: enhance message context with submission state and latest message tracking

* chore: import order
2025-09-17 13:07:56 -04:00
github-actions[bot]
0ceef12eea 🌍 i18n: Update translation.json with latest translations (#9648)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2025-09-15 18:41:34 -04:00
Danny Avila
6738360051 📋 refactor: Agent Tool Permissions for File Upload Options (#9647)
- Added isEphemeralAgent function to streamline checks for ephemeral agents.
- Updated logic in useAgentToolPermissions to utilize the new function for determining tool access.
- Introduced comprehensive tests for useAgentToolPermissions covering various scenarios including ephemeral agents, regular agents with tools, and edge cases.
2025-09-15 12:57:40 -04:00
Dustin Healy
52b65492d5 👻 fix: Phantom MCP Tool Calls (#9634)
* fix: mcp tool calls no longer happening when unselected (without breaking new convo behavior)

* refactor: Improve ephemeral agent synchronization logic in useMCPSelect

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-09-15 10:35:15 -04:00
Danny Avila
7a9a99d2a0 🔗 refactor: URL sanitization for MCP logging (#9632) 2025-09-14 18:55:32 -04:00
Danny Avila
5bfb06b417 💻 feat: Add Proxy Config for Mistral OCR API (#9629)
* 💻 feat: Add proxy configuration support for Mistral OCR API requests

* refactor: Implement proxy support for Mistral API requests using HttpsProxyAgent
2025-09-14 18:50:41 -04:00
github-actions[bot]
2ce8f1f686 🌍 i18n: Update translation.json with latest translations (#9626)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2025-09-14 18:48:17 -04:00
Danny Avila
1a47601533 🔃 fix: Refresh Token Edge Cases (#9625)
* 🔃 fix: Refresh Token Edge Cases

* chore: Update parameter type for setAuthTokens function
2025-09-13 21:36:45 -04:00
Danny Avila
5245aeea8f 🔧 refactor: Consolidate MCP tool removal and Improve UX (#9609)
* 🔧 refactor: Consolidate MCP tool removal and Improve UX

- Removed redundant tool removal logic from MCPTool, UnconfiguredMCPTool, and UninitializedMCPTool components.
- Introduced `useRemoveMCPTool` hook to handle tool removal and toast notifications.
- Updated translation.json to include a reminder message for saving changes after tool removal.

* chore: remove unused i18n key
2025-09-12 21:37:07 -04:00
74 changed files with 3084 additions and 540 deletions

View File

@@ -1,6 +1,6 @@
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
const logger = require('./winston');
global.EventSource = EventSource;
@@ -26,4 +26,6 @@ module.exports = {
createMCPManager: MCPManager.createInstance,
getMCPManager: MCPManager.getInstance,
getFlowStateManager,
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
};

View File

@@ -11,8 +11,9 @@ const {
registerUser,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getOpenIdConfig } = require('~/strategies');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
const registrationController = async (req, res) => {
try {
@@ -96,14 +97,25 @@ const refreshController = async (req, res) => {
return res.status(200).send({ token, user });
}
// Find the session with the hashed refresh token
const session = await findSession({
userId: userId,
refreshToken: refreshToken,
});
/** Session with the hashed refresh token */
const session = await findSession(
{
userId: userId,
refreshToken: refreshToken,
},
{ lean: false },
);
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session._id);
const token = await setAuthTokens(userId, res, session);
// trigger OAuth MCP server reconnection asynchronously (best effort)
void getOAuthReconnectionManager()
.reconnectServers(userId)
.catch((err) => {
logger.error('Error reconnecting OAuth MCP servers:', err);
});
res.status(200).send({ token, user });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)

View File

@@ -1,6 +1,6 @@
const { logger } = require('@librechat/data-schemas');
const { generate2FATempToken } = require('~/server/services/twoFactorService');
const { setAuthTokens } = require('~/server/services/AuthService');
const { logger } = require('~/config');
const loginController = async (req, res) => {
try {

View File

@@ -0,0 +1,50 @@
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
const { checkBan } = require('~/server/middleware');
const domains = {
client: process.env.DOMAIN_CLIENT,
server: process.env.DOMAIN_SERVER,
};
function createOAuthHandler(redirectUri = domains.client) {
/**
* A handler to process OAuth authentication results.
* @type {Function}
* @param {ServerRequest} req - Express request object.
* @param {ServerResponse} res - Express response object.
* @param {NextFunction} next - Express next middleware function.
*/
return async (req, res, next) => {
try {
if (res.headersSent) {
return;
}
await checkBan(req, res);
if (req.banned) {
return;
}
if (
req.user &&
req.user.provider == 'openid' &&
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
} else {
await setAuthTokens(req.user._id, res);
}
res.redirect(redirectUri);
} catch (err) {
logger.error('Error in setting authentication tokens:', err);
next(err);
}
};
}
module.exports = {
createOAuthHandler,
};

View File

@@ -12,6 +12,7 @@ const { logger } = require('@librechat/data-schemas');
const mongoSanitize = require('express-mongo-sanitize');
const { isEnabled, ErrorController } = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
const createValidateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const { updateInterfacePermissions } = require('~/models/interface');
@@ -108,6 +109,7 @@ const startServer = async () => {
app.use('/oauth', routes.oauth);
/* API Endpoints */
app.use('/api/auth', routes.auth);
app.use('/api/admin', routes.adminAuth);
app.use('/api/actions', routes.actions);
app.use('/api/keys', routes.keys);
app.use('/api/user', routes.user);
@@ -154,7 +156,7 @@ const startServer = async () => {
res.send(updatedIndexHtml);
});
app.listen(port, host, () => {
app.listen(port, host, async () => {
if (host === '0.0.0.0') {
logger.info(
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
@@ -163,7 +165,9 @@ const startServer = async () => {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
initializeMCPs().then(() => checkMigrations());
await initializeMCPs();
await initializeOAuthReconnectManager();
await checkMigrations();
});
};

View File

@@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser');
const requireJwtAuth = require('./requireJwtAuth');
const configMiddleware = require('./config/app');
const validateModel = require('./validateModel');
const requireAdmin = require('./requireAdmin');
const moderateText = require('./moderateText');
const logHeaders = require('./logHeaders');
const setHeaders = require('./setHeaders');
@@ -36,6 +37,7 @@ module.exports = {
setHeaders,
logHeaders,
moderateText,
requireAdmin,
validateModel,
requireJwtAuth,
checkInviteUser,

View File

@@ -0,0 +1,22 @@
const { logger } = require('@librechat/data-schemas');
const { SystemRoles } = require('librechat-data-provider');
/**
* Middleware to check if authenticated user has admin role
* Should be used AFTER authentication middleware (requireJwtAuth, requireLocalAuth, etc.)
*/
const requireAdmin = (req, res, next) => {
if (!req.user) {
logger.warn('[requireAdmin] No user found in request');
return res.status(401).json({ message: 'Authentication required' });
}
if (!req.user.role || req.user.role !== SystemRoles.ADMIN) {
logger.debug('[requireAdmin] Access denied for non-admin user:', req.user.email);
return res.status(403).json({ message: 'Access denied: Admin privileges required' });
}
next();
};
module.exports = requireAdmin;

View File

@@ -1,6 +1,6 @@
const passport = require('passport');
const cookies = require('cookie');
const { isEnabled } = require('~/server/utils');
const passport = require('passport');
const { isEnabled } = require('@librechat/api');
/**
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse

View File

@@ -0,0 +1,66 @@
const express = require('express');
const passport = require('passport');
const { randomState } = require('openid-client');
const { createSetBalanceConfig } = require('@librechat/api');
const { loginController } = require('~/server/controllers/auth/LoginController');
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
const { getAppConfig } = require('~/server/services/Config');
const { getOpenIdConfig } = require('~/strategies');
const middleware = require('~/server/middleware');
const { Balance } = require('~/db/models');
const setBalanceConfig = createSetBalanceConfig({
getAppConfig,
Balance,
});
const router = express.Router();
router.post(
'/login/local',
middleware.logHeaders,
middleware.loginLimiter,
middleware.checkBan,
middleware.requireLocalAuth,
middleware.requireAdmin,
setBalanceConfig,
loginController,
);
router.get('/verify', middleware.requireJwtAuth, middleware.requireAdmin, (req, res) => {
const { password: _p, totpSecret: _t, __v, ...user } = req.user;
user.id = user._id.toString();
res.status(200).json({ user });
});
router.get('/oauth/openid/check', (req, res) => {
const openidConfig = getOpenIdConfig();
if (!openidConfig) {
return res.status(404).json({ message: 'OpenID configuration not found' });
}
res.status(200).json({ message: 'OpenID check successful' });
});
router.get('/oauth/openid', (req, res, next) => {
return passport.authenticate('openidAdmin', {
session: false,
state: randomState(),
})(req, res, next);
});
router.get(
'/oauth/openid/callback',
passport.authenticate('openidAdmin', {
failureRedirect: `${process.env.DOMAIN_CLIENT}/oauth/error`,
failureMessage: true,
session: false,
}),
middleware.requireAdmin,
setBalanceConfig,
middleware.checkDomainAllowed,
createOAuthHandler(
(process.env.ADMIN_PANEL_URL || 'http://localhost:3000') + '/auth/openid/callback',
),
);
module.exports = router;

View File

@@ -1,6 +1,7 @@
const accessPermissions = require('./accessPermissions');
const assistants = require('./assistants');
const categories = require('./categories');
const adminAuth = require('./admin/auth');
const tokenizer = require('./tokenizer');
const endpoints = require('./endpoints');
const staticRoute = require('./static');
@@ -32,6 +33,7 @@ module.exports = {
mcp,
edit,
auth,
adminAuth,
keys,
user,
tags,

View File

@@ -1,12 +1,12 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { requireJwtAuth } = require('~/server/middleware');
const { findPluginAuthsByKeys } = require('~/models');
@@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
);
// clear any reconnection attempts
const oauthReconnectionManager = getOAuthReconnectionManager();
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
const tools = await userConnection.fetchTools();
await updateMCPUserTools({
userId: flowState.userId,

View File

@@ -4,10 +4,9 @@ const passport = require('passport');
const { randomState } = require('openid-client');
const { logger } = require('@librechat/data-schemas');
const { ErrorTypes } = require('librechat-data-provider');
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
const { createSetBalanceConfig } = require('@librechat/api');
const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware');
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
const { getAppConfig } = require('~/server/services/Config');
const { Balance } = require('~/db/models');
@@ -26,32 +25,7 @@ const domains = {
router.use(logHeaders);
router.use(loginLimiter);
const oauthHandler = async (req, res, next) => {
try {
if (res.headersSent) {
return;
}
await checkBan(req, res);
if (req.banned) {
return;
}
if (
req.user &&
req.user.provider == 'openid' &&
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
} else {
await setAuthTokens(req.user._id, res);
}
res.redirect(domains.client);
} catch (err) {
logger.error('Error in setting authentication tokens:', err);
next(err);
}
};
const oauthHandler = createOAuthHandler();
router.get('/error', (req, res) => {
/** A single error message is pushed by passport when authentication fails. */

View File

@@ -357,23 +357,18 @@ const resetPassword = async (userId, token, password) => {
/**
* Set Auth Tokens
*
* @param {String | ObjectId} userId
* @param {Object} res
* @param {String} sessionId
* @param {ServerResponse} res
* @param {ISession | null} [session=null]
* @returns
*/
const setAuthTokens = async (userId, res, sessionId = null) => {
const setAuthTokens = async (userId, res, _session = null) => {
try {
const user = await getUserById(userId);
const token = await generateToken(user);
let session;
let session = _session;
let refreshToken;
let refreshTokenExpires;
if (sessionId) {
session = await findSession({ sessionId: sessionId }, { lean: false });
if (session && session._id && session.expiration != null) {
refreshTokenExpires = session.expiration.getTime();
refreshToken = await generateRefreshToken(session);
} else {
@@ -383,6 +378,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
refreshTokenExpires = session.expiration.getTime();
}
const user = await getUserById(userId);
const token = await generateToken(user);
res.cookie('refreshToken', refreshToken, {
expires: new Date(refreshTokenExpires),
httpOnly: true,

View File

@@ -20,8 +20,8 @@ const {
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, getAppConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getLogStores } = require('~/cache');
@@ -538,13 +538,20 @@ async function getServerConnectionStatus(
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
// connection state overrides specific to OAuth servers
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
// check if server is actively being reconnected
const oauthReconnectionManager = getOAuthReconnectionManager();
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
finalConnectionState = 'connecting';
} else {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}
}

View File

@@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getOAuthReconnectionManager: jest.fn(),
}));
jest.mock('~/cache', () => ({
@@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
let mockGetOAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
@@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
});
describe('getMCPSetupData', () => {
@@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
@@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return failed flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return active flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return no flow
const mockFlowManager = {
getFlowState: jest.fn(() => null),
@@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
it('should return connecting state when OAuth server is reconnecting', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager to return true for isReconnecting
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => true),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connecting',
});
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
mockUserId,
mockServerName,
);
});
it('should not check OAuth flow status when server is connected', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),

View File

@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
idOnTheSource: principal.idOnTheSource,
};
const userId = await createUser(userData, true, false);
const userId = await createUser(userData, true, true);
return userId.toString();
}

View File

@@ -0,0 +1,26 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getLogStores } = require('~/cache');
/**
* Initialize OAuth reconnect manager
*/
async function initializeOAuthReconnectManager() {
try {
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const tokenMethods = {
findToken,
updateToken,
createToken,
deleteTokens,
};
await createOAuthReconnectionManager(flowManager, tokenMethods);
logger.info(`OAuth reconnect manager initialized successfully.`);
} catch (error) {
logger.error('Failed to initialize OAuth reconnect manager:', error);
}
}
module.exports = initializeOAuthReconnectManager;

View File

@@ -1,14 +1,14 @@
const appleLogin = require('./appleStrategy');
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
const openIdJwtLogin = require('./openIdJwtStrategy');
const facebookLogin = require('./facebookStrategy');
const discordLogin = require('./discordStrategy');
const passportLogin = require('./localStrategy');
const googleLogin = require('./googleStrategy');
const githubLogin = require('./githubStrategy');
const discordLogin = require('./discordStrategy');
const facebookLogin = require('./facebookStrategy');
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
const jwtLogin = require('./jwtStrategy');
const ldapLogin = require('./ldapStrategy');
const { setupSaml } = require('./samlStrategy');
const openIdJwtLogin = require('./openIdJwtStrategy');
const appleLogin = require('./appleStrategy');
const ldapLogin = require('./ldapStrategy');
const jwtLogin = require('./jwtStrategy');
module.exports = {
appleLogin,

View File

@@ -281,6 +281,221 @@ function convertToUsername(input, defaultValue = '') {
return defaultValue;
}
/**
* Process OpenID authentication tokenset and userinfo
* This is the core logic extracted from the passport strategy callback
* Can be reused by both the passport strategy and proxy authentication
*
* @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc.
* @param {boolean} existingUsersOnly - If true, only existing users will be processed
* @returns {Promise<Object>} The authenticated user object with tokenset
*/
async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
const claims = tokenset.claims ? tokenset.claims() : tokenset;
const userinfo = {
...claims,
};
// Get userinfo from provider if we have access_token and haven't already
if (tokenset.access_token) {
const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub);
Object.assign(userinfo, providerUserinfo);
}
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Auth] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
);
throw new Error('Email domain not allowed');
}
const result = await findOpenIDUser({
openidId: claims.sub || userinfo.sub,
email: claims.email || userinfo.email,
strategyName: 'openidStrategy',
findUser,
});
let user = result.user;
const error = result.error;
if (error) {
throw new Error(ErrorTypes.AUTH_FAILED);
}
const fullName = getFullName(userinfo);
/** Required role if configured */
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
if (requiredRole) {
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
let decodedToken = '';
if (requiredRoleTokenKind === 'access' && tokenset.access_token) {
decodedToken = jwtDecode(tokenset.access_token);
} else if (requiredRoleTokenKind === 'id' && tokenset.id_token) {
decodedToken = jwtDecode(tokenset.id_token);
} else if (userinfo.roles) {
// If roles are already in userinfo, use them directly
const roles = Array.isArray(userinfo.roles) ? userinfo.roles : [userinfo.roles];
if (!roles.includes(requiredRole)) {
throw new Error(`You must have the "${requiredRole}" role to log in.`);
}
} else if (requiredRoleParameterPath) {
const pathParts = requiredRoleParameterPath.split('.');
let found = true;
let roles = pathParts.reduce((o, key) => {
if (o === null || o === undefined || !(key in o)) {
found = false;
return [];
}
return o[key];
}, decodedToken);
if (!found) {
logger.error(
`[OpenID Auth] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
);
}
if (!roles.includes(requiredRole)) {
throw new Error(`You must have the "${requiredRole}" role to log in.`);
}
}
}
let username = '';
if (process.env.OPENID_USERNAME_CLAIM) {
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
} else {
username = convertToUsername(
userinfo.preferred_username || userinfo.username || userinfo.email,
);
}
if (existingUsersOnly && !user) {
throw new Error('User does not exist');
}
if (!user) {
user = {
provider: 'openid',
openidId: userinfo.sub,
username,
email: userinfo.email || '',
emailVerified: userinfo.email_verified || false,
name: fullName,
idOnTheSource: userinfo.oid,
};
const balanceConfig = getBalanceConfig(appConfig);
user = await createUser(user, balanceConfig, true, true);
} else {
user.provider = 'openid';
user.openidId = userinfo.sub;
user.username = username;
user.name = fullName;
user.idOnTheSource = userinfo.oid;
}
// Handle avatar
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
const imageUrl = userinfo.picture;
let fileName;
if (crypto) {
fileName = (await hashToken(userinfo.sub)) + '.png';
} else {
fileName = userinfo.sub + '.png';
}
const imageBuffer = await downloadImage(
imageUrl,
openidConfig,
tokenset.access_token,
userinfo.sub,
);
if (imageBuffer) {
const { saveBuffer } = getStrategyFunctions(
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
);
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),
buffer: imageBuffer,
});
user.avatar = imagePath ?? '';
}
}
user = await updateUser(user._id, user);
logger.info(
`[OpenID Auth] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username}`,
{
user: {
openidId: user.openidId,
username: user.username,
email: user.email,
name: user.name,
},
},
);
return { ...user, tokenset };
}
/**
* @param {boolean | undefined} [existingUsersOnly]
*/
function createOpenIDCallback(existingUsersOnly) {
return async (tokenset, done) => {
try {
const user = await processOpenIDAuth(tokenset, existingUsersOnly);
done(null, user);
} catch (err) {
if (err.message === 'Email domain not allowed') {
return done(null, false, { message: err.message });
}
if (err.message === ErrorTypes.AUTH_FAILED) {
return done(null, false, { message: err.message });
}
if (err.message && err.message.includes('role to log in')) {
return done(null, false, { message: err.message });
}
logger.error('[openidStrategy] login failed', err);
done(err);
}
};
}
/**
* Sets up the OpenID strategy specifically for admin authentication.
* @param {Configuration} openidConfig
*/
const setupOpenIdAdmin = (openidConfig) => {
try {
if (!openidConfig) {
throw new Error('OpenID configuration not initialized');
}
const openidAdminLogin = new CustomOpenIDStrategy(
{
config: openidConfig,
scope: process.env.OPENID_SCOPE,
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback',
},
createOpenIDCallback(true),
);
passport.use('openidAdmin', openidAdminLogin);
} catch (err) {
logger.error('[openidStrategy] setupOpenIdAdmin', err);
}
};
/**
* Sets up the OpenID strategy for authentication.
* This function configures the OpenID client, handles proxy settings,
@@ -318,10 +533,6 @@ async function setupOpenId() {
},
);
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
const usePKCE = isEnabled(process.env.OPENID_USE_PKCE);
logger.info(`[openidStrategy] OpenID authentication configuration`, {
generateNonce: shouldGenerateNonce,
reason: shouldGenerateNonce
@@ -335,159 +546,19 @@ async function setupOpenId() {
scope: process.env.OPENID_SCOPE,
callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL,
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
usePKCE,
},
async (tokenset, done) => {
try {
const claims = tokenset.claims();
const userinfo = {
...claims,
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
};
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
const result = await findOpenIDUser({
openidId: claims.sub,
email: claims.email,
strategyName: 'openidStrategy',
findUser,
});
let user = result.user;
const error = result.error;
if (error) {
return done(null, false, {
message: ErrorTypes.AUTH_FAILED,
});
}
const fullName = getFullName(userinfo);
if (requiredRole) {
let decodedToken = '';
if (requiredRoleTokenKind === 'access') {
decodedToken = jwtDecode(tokenset.access_token);
} else if (requiredRoleTokenKind === 'id') {
decodedToken = jwtDecode(tokenset.id_token);
}
const pathParts = requiredRoleParameterPath.split('.');
let found = true;
let roles = pathParts.reduce((o, key) => {
if (o === null || o === undefined || !(key in o)) {
found = false;
return [];
}
return o[key];
}, decodedToken);
if (!found) {
logger.error(
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
);
}
if (!roles.includes(requiredRole)) {
return done(null, false, {
message: `You must have the "${requiredRole}" role to log in.`,
});
}
}
let username = '';
if (process.env.OPENID_USERNAME_CLAIM) {
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
} else {
username = convertToUsername(
userinfo.preferred_username || userinfo.username || userinfo.email,
);
}
if (!user) {
user = {
provider: 'openid',
openidId: userinfo.sub,
username,
email: userinfo.email || '',
emailVerified: userinfo.email_verified || false,
name: fullName,
idOnTheSource: userinfo.oid,
};
const balanceConfig = getBalanceConfig(appConfig);
user = await createUser(user, balanceConfig, true, true);
} else {
user.provider = 'openid';
user.openidId = userinfo.sub;
user.username = username;
user.name = fullName;
user.idOnTheSource = userinfo.oid;
}
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
/** @type {string | undefined} */
const imageUrl = userinfo.picture;
let fileName;
if (crypto) {
fileName = (await hashToken(userinfo.sub)) + '.png';
} else {
fileName = userinfo.sub + '.png';
}
const imageBuffer = await downloadImage(
imageUrl,
openidConfig,
tokenset.access_token,
userinfo.sub,
);
if (imageBuffer) {
const { saveBuffer } = getStrategyFunctions(
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
);
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),
buffer: imageBuffer,
});
user.avatar = imagePath ?? '';
}
}
user = await updateUser(user._id, user);
logger.info(
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
{
user: {
openidId: user.openidId,
username: user.username,
email: user.email,
name: user.name,
},
},
);
done(null, { ...user, tokenset });
} catch (err) {
logger.error('[openidStrategy] login failed', err);
done(err);
}
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
},
createOpenIDCallback(),
);
passport.use('openid', openidLogin);
setupOpenIdAdmin(openidConfig);
return openidConfig;
} catch (err) {
logger.error('[openidStrategy]', err);
return null;
}
}
/**
* @function getOpenIdConfig
* @description Returns the OpenID client instance.

View File

@@ -873,6 +873,13 @@
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
* @memberof typedefs
*/
/**
* @exports ISession
* @typedef {import('@librechat/data-schemas').ISession} ISession
* @memberof typedefs
*/
/**
* @exports IBalance
* @typedef {import('@librechat/data-schemas').IBalance} IBalance

View File

@@ -1,10 +1,15 @@
import { createContext, useContext } from 'react';
type MessageContext = {
messageId: string;
nextType?: string;
partIndex?: number;
isExpanded: boolean;
conversationId?: string | null;
/** Submission state for cursor display - only true for latest message when submitting */
isSubmitting?: boolean;
/** Whether this is the latest message in the conversation */
isLatestMessage?: boolean;
};
export const MessageContext = createContext<MessageContext>({} as MessageContext);

View File

@@ -0,0 +1,150 @@
import React, { createContext, useContext, useMemo } from 'react';
import { useAddedChatContext } from './AddedChatContext';
import { useChatContext } from './ChatContext';
interface MessagesViewContextValue {
/** Core conversation data */
conversation: ReturnType<typeof useChatContext>['conversation'];
conversationId: string | null | undefined;
/** Submission and control states */
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
isSubmittingFamily: boolean;
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
/** Message operations */
ask: ReturnType<typeof useChatContext>['ask'];
regenerate: ReturnType<typeof useChatContext>['regenerate'];
handleContinue: ReturnType<typeof useChatContext>['handleContinue'];
/** Message state management */
index: ReturnType<typeof useChatContext>['index'];
latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
getMessages: ReturnType<typeof useChatContext>['getMessages'];
setMessages: ReturnType<typeof useChatContext>['setMessages'];
}
const MessagesViewContext = createContext<MessagesViewContextValue | undefined>(undefined);
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
const chatContext = useChatContext();
const addedChatContext = useAddedChatContext();
const {
ask,
index,
regenerate,
isSubmitting: isSubmittingRoot,
conversation,
latestMessage,
setAbortScroll,
handleContinue,
setLatestMessage,
abortScroll,
getMessages,
setMessages,
} = chatContext;
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
/** Memoize conversation-related values */
const conversationValues = useMemo(
() => ({
conversation,
conversationId: conversation?.conversationId,
}),
[conversation],
);
/** Memoize submission states */
const submissionStates = useMemo(
() => ({
isSubmitting: isSubmittingRoot,
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
abortScroll,
setAbortScroll,
}),
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
);
/** Memoize message operations (these are typically stable references) */
const messageOperations = useMemo(
() => ({
ask,
regenerate,
getMessages,
setMessages,
handleContinue,
}),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
/** Memoize message state values */
const messageState = useMemo(
() => ({
index,
latestMessage,
setLatestMessage,
}),
[index, latestMessage, setLatestMessage],
);
/** Combine all values into final context value */
const contextValue = useMemo<MessagesViewContextValue>(
() => ({
...conversationValues,
...submissionStates,
...messageOperations,
...messageState,
}),
[conversationValues, submissionStates, messageOperations, messageState],
);
return (
<MessagesViewContext.Provider value={contextValue}>{children}</MessagesViewContext.Provider>
);
}
export function useMessagesViewContext() {
const context = useContext(MessagesViewContext);
if (!context) {
throw new Error('useMessagesViewContext must be used within MessagesViewProvider');
}
return context;
}
/** Hook for components that only need conversation data */
export function useMessagesConversation() {
const { conversation, conversationId } = useMessagesViewContext();
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
}
/** Hook for components that only need submission states */
export function useMessagesSubmission() {
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
useMessagesViewContext();
return useMemo(
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
);
}
/** Hook for components that only need message operations */
export function useMessagesOperations() {
const { ask, regenerate, handleContinue, getMessages, setMessages } = useMessagesViewContext();
return useMemo(
() => ({ ask, regenerate, handleContinue, getMessages, setMessages }),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
}
/** Hook for components that only need message state */
export function useMessagesState() {
const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
return useMemo(
() => ({ index, latestMessage, setLatestMessage }),
[index, latestMessage, setLatestMessage],
);
}

View File

@@ -26,4 +26,5 @@ export * from './SidePanelContext';
export * from './MCPPanelContext';
export * from './ArtifactsContext';
export * from './PromptGroupsContext';
export * from './MessagesViewContext';
export { default as BadgeRowProvider } from './BadgeRowContext';

View File

@@ -26,6 +26,7 @@ type ContentPartsProps = {
isCreatedByUser: boolean;
isLast: boolean;
isSubmitting: boolean;
isLatestMessage?: boolean;
edit?: boolean;
enterEdit?: (cancel?: boolean) => void | null | undefined;
siblingIdx?: number;
@@ -45,6 +46,7 @@ const ContentParts = memo(
isCreatedByUser,
isLast,
isSubmitting,
isLatestMessage,
edit,
enterEdit,
siblingIdx,
@@ -55,6 +57,8 @@ const ContentParts = memo(
const [isExpanded, setIsExpanded] = useState(showThinking);
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const hasReasoningParts = useMemo(() => {
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
const allThinkPartsHaveContent =
@@ -134,7 +138,9 @@ const ContentParts = memo(
})
}
label={
isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts')
effectiveIsSubmitting && isLast
? localize('com_ui_thinking')
: localize('com_ui_thoughts')
}
/>
</div>
@@ -155,12 +161,14 @@ const ContentParts = memo(
conversationId,
partIndex: idx,
nextType: content[idx + 1]?.type,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
<Part
part={part}
attachments={attachments}
isSubmitting={isSubmitting}
isSubmitting={effectiveIsSubmitting}
key={`part-${messageId}-${idx}`}
isCreatedByUser={isCreatedByUser}
isLast={idx === content.length - 1}

View File

@@ -4,7 +4,7 @@ import { useRecoilState, useRecoilValue } from 'recoil';
import { TextareaAutosize, TooltipAnchor } from '@librechat/client';
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
import type { TEditProps } from '~/common';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import Container from './Container';
@@ -22,7 +22,8 @@ const EditMessage = ({
const { addedIndex } = useAddedChatContext();
const saveButtonRef = useRef<HTMLButtonElement | null>(null);
const submitButtonRef = useRef<HTMLButtonElement | null>(null);
const { getMessages, setMessages, conversation } = useChatContext();
const { conversation } = useMessagesConversation();
const { getMessages, setMessages } = useMessagesOperations();
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
store.latestMessageFamily(addedIndex),
);

View File

@@ -5,7 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
import type { TMessageContentProps, TDisplayProps } from '~/common';
import Error from '~/components/Messages/Content/Error';
import Thinking from '~/components/Artifacts/Thinking';
import { useChatContext } from '~/Providers';
import { useMessageContext } from '~/Providers';
import MarkdownLite from './MarkdownLite';
import EditMessage from './EditMessage';
import { useLocalize } from '~/hooks';
@@ -70,16 +70,12 @@ export const ErrorMessage = ({
};
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
const { isSubmitting, latestMessage } = useChatContext();
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(
() => showCursor === true && isSubmitting,
[showCursor, isSubmitting],
);
const isLatestMessage = useMemo(
() => message.messageId === latestMessage?.messageId,
[message.messageId, latestMessage?.messageId],
);
let content: React.ReactElement;
if (!isCreatedByUser) {

View File

@@ -85,13 +85,14 @@ const Part = memo(
const isToolCall =
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
if (isToolCall && toolCall.name === Tools.execute_code) {
return (
<ExecuteCode
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
attachments={attachments}
isSubmitting={isSubmitting}
output={toolCall.output ?? ''}
initialProgress={toolCall.progress ?? 0.1}
attachments={attachments}
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
/>
);
} else if (

View File

@@ -6,8 +6,8 @@ import { useRecoilState, useRecoilValue } from 'recoil';
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
import type { Agents } from 'librechat-data-provider';
import type { TEditProps } from '~/common';
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
import Container from '~/components/Chat/Messages/Content/Container';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import store from '~/store';
@@ -25,7 +25,8 @@ const EditTextPart = ({
}) => {
const localize = useLocalize();
const { addedIndex } = useAddedChatContext();
const { ask, getMessages, setMessages, conversation } = useChatContext();
const { conversation } = useMessagesConversation();
const { ask, getMessages, setMessages } = useMessagesOperations();
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
store.latestMessageFamily(addedIndex),
);

View File

@@ -45,26 +45,28 @@ export function useParseArgs(args?: string): ParsedArgs | null {
}
export default function ExecuteCode({
isSubmitting,
initialProgress = 0.1,
args,
output = '',
attachments,
}: {
initialProgress: number;
isSubmitting: boolean;
args?: string;
output?: string;
attachments?: TAttachment[];
}) {
const localize = useLocalize();
const showAnalysisCode = useRecoilValue(store.showCode);
const [showCode, setShowCode] = useState(showAnalysisCode);
const codeContentRef = useRef<HTMLDivElement>(null);
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
const [isAnimating, setIsAnimating] = useState(false);
const hasOutput = output.length > 0;
const outputRef = useRef<string>(output);
const prevShowCodeRef = useRef<boolean>(showCode);
const codeContentRef = useRef<HTMLDivElement>(null);
const [isAnimating, setIsAnimating] = useState(false);
const showAnalysisCode = useRecoilValue(store.showCode);
const [showCode, setShowCode] = useState(showAnalysisCode);
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
const prevShowCodeRef = useRef<boolean>(showCode);
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
const progress = useProgress(initialProgress);
@@ -136,6 +138,8 @@ export default function ExecuteCode({
};
}, [showCode, isAnimating]);
const cancelled = !isSubmitting && progress < 1;
return (
<>
<div className="relative my-2.5 flex size-5 shrink-0 items-center gap-2.5">
@@ -143,9 +147,12 @@ export default function ExecuteCode({
progress={progress}
onClick={() => setShowCode((prev) => !prev)}
inProgressText={localize('com_ui_analyzing')}
finishedText={localize('com_ui_analyzing_finished')}
finishedText={
cancelled ? localize('com_ui_cancelled') : localize('com_ui_analyzing_finished')
}
hasInput={!!code?.length}
isExpanded={showCode}
error={cancelled}
/>
</div>
<div

View File

@@ -2,7 +2,7 @@ import { memo, useMemo, ReactElement } from 'react';
import { useRecoilValue } from 'recoil';
import MarkdownLite from '~/components/Chat/Messages/Content/MarkdownLite';
import Markdown from '~/components/Chat/Messages/Content/Markdown';
import { useChatContext, useMessageContext } from '~/Providers';
import { useMessageContext } from '~/Providers';
import { cn } from '~/utils';
import store from '~/store';
@@ -18,14 +18,9 @@ type ContentType =
| ReactElement;
const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => {
const { messageId } = useMessageContext();
const { isSubmitting, latestMessage } = useChatContext();
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]);
const isLatestMessage = useMemo(
() => messageId === latestMessage?.messageId,
[messageId, latestMessage?.messageId],
);
const content: ContentType = useMemo(() => {
if (!isCreatedByUser) {

View File

@@ -21,7 +21,7 @@ type THoverButtons = {
latestMessage: TMessage | null;
isLast: boolean;
index: number;
handleFeedback: ({ feedback }: { feedback: TFeedback | undefined }) => void;
handleFeedback?: ({ feedback }: { feedback: TFeedback | undefined }) => void;
};
type HoverButtonProps = {
@@ -238,7 +238,7 @@ const HoverButtons = ({
/>
{/* Feedback Buttons */}
{!isCreatedByUser && (
{!isCreatedByUser && handleFeedback != null && (
<Feedback handleFeedback={handleFeedback} feedback={message.feedback} isLast={isLast} />
)}

View File

@@ -73,7 +73,7 @@ export default function Message(props: TMessageProps) {
</div>
</div>
) : (
<div className="m-auto justify-center p-4 py-2 md:gap-6 ">
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<MessageRender {...props} />
</div>
)}

View File

@@ -125,6 +125,7 @@ export default function Message(props: TMessageProps) {
setSiblingIdx={setSiblingIdx}
isCreatedByUser={message.isCreatedByUser}
conversationId={conversation?.conversationId}
isLatestMessage={messageId === latestMessage?.messageId}
content={message.content as Array<TMessageContentParts | undefined>}
/>
</div>

View File

@@ -4,11 +4,12 @@ import { CSSTransition } from 'react-transition-group';
import type { TMessage } from 'librechat-data-provider';
import { useScreenshot, useMessageScrolling, useLocalize } from '~/hooks';
import ScrollToBottom from '~/components/Messages/ScrollToBottom';
import { MessagesViewProvider } from '~/Providers';
import MultiMessage from './MultiMessage';
import { cn } from '~/utils';
import store from '~/store';
export default function MessagesView({
function MessagesViewContent({
messagesTree: _messagesTree,
}: {
messagesTree?: TMessage[] | null;
@@ -92,3 +93,11 @@ export default function MessagesView({
</>
);
}
export default function MessagesView({ messagesTree }: { messagesTree?: TMessage[] | null }) {
return (
<MessagesViewProvider>
<MessagesViewContent messagesTree={messagesTree} />
</MessagesViewProvider>
);
}

View File

@@ -27,7 +27,7 @@ export default function MultiMessage({
useEffect(() => {
// reset siblingIdx when the tree changes, mostly when a new message is submitting.
setSiblingIdx(0);
}, [messagesTree?.length]);
}, [messagesTree?.length, setSiblingIdx]);
useEffect(() => {
if (messagesTree?.length && siblingIdx >= messagesTree.length) {

View File

@@ -71,6 +71,9 @@ const MessageRender = memo(
const showCardRender = isLast && !isSubmittingFamily && isCard;
const isLatestCard = isCard && !isSubmittingFamily && isLatestMessage;
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const iconData: TMessageIcon = useMemo(
() => ({
endpoint: msg?.endpoint ?? conversation?.endpoint,
@@ -166,6 +169,8 @@ const MessageRender = memo(
messageId: msg.messageId,
conversationId: conversation?.conversationId,
isExpanded: false,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
{msg.plugin && <Plugin plugin={msg.plugin} />}
@@ -177,7 +182,7 @@ const MessageRender = memo(
message={msg}
enterEdit={enterEdit}
error={!!(msg.error ?? false)}
isSubmitting={isSubmitting}
isSubmitting={effectiveIsSubmitting}
unfinished={msg.unfinished ?? false}
isCreatedByUser={msg.isCreatedByUser ?? true}
siblingIdx={siblingIdx ?? 0}
@@ -186,7 +191,7 @@ const MessageRender = memo(
</MessageContext.Provider>
</div>
{hasNoChildren && (isSubmittingFamily === true || isSubmitting) ? (
{hasNoChildren && (isSubmittingFamily === true || effectiveIsSubmitting) ? (
<PlaceholderRow isCard={isCard} />
) : (
<SubRow classes="text-xs">

View File

@@ -1,6 +1,5 @@
import { useMemo, memo, type FC, useCallback } from 'react';
import throttle from 'lodash/throttle';
import { parseISO, isToday } from 'date-fns';
import { Spinner, useMediaQuery } from '@librechat/client';
import { List, AutoSizer, CellMeasurer, CellMeasurerCache } from 'react-virtualized';
import { TConversation } from 'librechat-data-provider';
@@ -50,27 +49,17 @@ const MemoizedConvo = memo(
conversation,
retainView,
toggleNav,
isLatestConvo,
}: {
conversation: TConversation;
retainView: () => void;
toggleNav: () => void;
isLatestConvo: boolean;
}) => {
return (
<Convo
conversation={conversation}
retainView={retainView}
toggleNav={toggleNav}
isLatestConvo={isLatestConvo}
/>
);
return <Convo conversation={conversation} retainView={retainView} toggleNav={toggleNav} />;
},
(prevProps, nextProps) => {
return (
prevProps.conversation.conversationId === nextProps.conversation.conversationId &&
prevProps.conversation.title === nextProps.conversation.title &&
prevProps.isLatestConvo === nextProps.isLatestConvo &&
prevProps.conversation.endpoint === nextProps.conversation.endpoint
);
},
@@ -98,13 +87,6 @@ const Conversations: FC<ConversationsProps> = ({
[filteredConversations],
);
const firstTodayConvoId = useMemo(
() =>
filteredConversations.find((convo) => convo.updatedAt && isToday(parseISO(convo.updatedAt)))
?.conversationId ?? undefined,
[filteredConversations],
);
const flattenedItems = useMemo(() => {
const items: FlattenedItem[] = [];
groupedConversations.forEach(([groupName, convos]) => {
@@ -154,26 +136,25 @@ const Conversations: FC<ConversationsProps> = ({
</CellMeasurer>
);
}
let rendering: JSX.Element;
if (item.type === 'header') {
rendering = <DateLabel groupName={item.groupName} />;
} else if (item.type === 'convo') {
rendering = (
<MemoizedConvo conversation={item.convo} retainView={moveToTop} toggleNav={toggleNav} />
);
}
return (
<CellMeasurer cache={cache} columnIndex={0} key={key} parent={parent} rowIndex={index}>
{({ registerChild }) => (
<div ref={registerChild} style={style}>
{item.type === 'header' ? (
<DateLabel groupName={item.groupName} />
) : item.type === 'convo' ? (
<MemoizedConvo
conversation={item.convo}
retainView={moveToTop}
toggleNav={toggleNav}
isLatestConvo={item.convo.conversationId === firstTodayConvoId}
/>
) : null}
{rendering}
</div>
)}
</CellMeasurer>
);
},
[cache, flattenedItems, firstTodayConvoId, moveToTop, toggleNav],
[cache, flattenedItems, moveToTop, toggleNav],
);
const getRowHeight = useCallback(

View File

@@ -11,23 +11,17 @@ import { useGetEndpointsQuery } from '~/data-provider';
import { NotificationSeverity } from '~/common';
import { ConvoOptions } from './ConvoOptions';
import RenameForm from './RenameForm';
import { cn, logger } from '~/utils';
import ConvoLink from './ConvoLink';
import { cn } from '~/utils';
import store from '~/store';
interface ConversationProps {
conversation: TConversation;
retainView: () => void;
toggleNav: () => void;
isLatestConvo: boolean;
}
export default function Conversation({
conversation,
retainView,
toggleNav,
isLatestConvo,
}: ConversationProps) {
export default function Conversation({ conversation, retainView, toggleNav }: ConversationProps) {
const params = useParams();
const localize = useLocalize();
const { showToast } = useToastContext();
@@ -84,6 +78,7 @@ export default function Conversation({
});
setRenaming(false);
} catch (error) {
logger.error('Error renaming conversation', error);
setTitleInput(title as string);
showToast({
message: localize('com_ui_rename_failed'),

View File

@@ -173,6 +173,7 @@ const ContentRender = memo(
isSubmitting={isSubmitting}
searchResults={searchResults}
setSiblingIdx={setSiblingIdx}
isLatestMessage={isLatestMessage}
isCreatedByUser={msg.isCreatedByUser}
conversationId={conversation?.conversationId}
content={msg.content as Array<TMessageContentParts | undefined>}

View File

@@ -76,6 +76,8 @@ export default function Message(props: TMessageProps) {
messageId,
isExpanded: false,
conversationId: conversation?.conversationId,
isSubmitting: false, // Share view is always read-only
isLatestMessage: false, // No concept of latest message in share view
}}
>
{/* Legacy Plugins */}

View File

@@ -4,7 +4,6 @@ import { ChevronDown } from 'lucide-react';
import { useFormContext } from 'react-hook-form';
import { Constants } from 'librechat-data-provider';
import * as AccordionPrimitive from '@radix-ui/react-accordion';
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
import {
Label,
Checkbox,
@@ -14,20 +13,18 @@ import {
AccordionItem,
CircleHelpIcon,
OGDialogTrigger,
useToastContext,
AccordionContent,
OGDialogTemplate,
} from '@librechat/client';
import type { AgentForm, MCPServerInfo } from '~/common';
import { useLocalize, useMCPServerManager, useRemoveMCPTool } from '~/hooks';
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
import { useLocalize, useMCPServerManager } from '~/hooks';
import { cn } from '~/utils';
export default function MCPTool({ serverInfo }: { serverInfo?: MCPServerInfo }) {
const localize = useLocalize();
const { showToast } = useToastContext();
const updateUserPlugins = useUpdateUserPluginsMutation();
const { removeTool } = useRemoveMCPTool();
const { getValues, setValue } = useFormContext<AgentForm>();
const { getServerStatusIconProps, getConfigDialogProps } = useMCPServerManager();
@@ -56,36 +53,6 @@ export default function MCPTool({ serverInfo }: { serverInfo?: MCPServerInfo })
setValue('tools', [...otherTools, ...newSelectedTools]);
};
const removeTool = (serverName: string) => {
if (!serverName) {
return;
}
updateUserPlugins.mutate(
{
pluginKey: `${Constants.mcp_prefix}${serverName}`,
action: 'uninstall',
auth: {},
isEntityTool: true,
},
{
onError: (error: unknown) => {
showToast({ message: `Error while deleting the tool: ${error}`, status: 'error' });
},
onSuccess: () => {
const currentTools = getValues('tools');
const remainingToolIds =
currentTools?.filter(
(currentToolId) =>
currentToolId !== serverName &&
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
) || [];
setValue('tools', remainingToolIds);
showToast({ message: 'Tool deleted successfully', status: 'success' });
},
},
);
};
const selectedTools = getSelectedTools();
const isExpanded = accordionValue === currentServerName;

View File

@@ -1,25 +1,12 @@
import React, { useState } from 'react';
import { CircleX } from 'lucide-react';
import { useFormContext } from 'react-hook-form';
import { Constants } from 'librechat-data-provider';
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
import {
Label,
OGDialog,
TrashIcon,
useToastContext,
OGDialogTrigger,
OGDialogTemplate,
} from '@librechat/client';
import type { AgentForm } from '~/common';
import { useLocalize } from '~/hooks';
import { Label, OGDialog, TrashIcon, OGDialogTrigger, OGDialogTemplate } from '@librechat/client';
import { useLocalize, useRemoveMCPTool } from '~/hooks';
import { cn } from '~/utils';
export default function UnconfiguredMCPTool({ serverName }: { serverName?: string }) {
const localize = useLocalize();
const { showToast } = useToastContext();
const updateUserPlugins = useUpdateUserPluginsMutation();
const { getValues, setValue } = useFormContext<AgentForm>();
const { removeTool } = useRemoveMCPTool();
const [isFocused, setIsFocused] = useState(false);
const [isHovering, setIsHovering] = useState(false);
@@ -28,36 +15,6 @@ export default function UnconfiguredMCPTool({ serverName }: { serverName?: strin
return null;
}
const removeTool = () => {
updateUserPlugins.mutate(
{
pluginKey: `${Constants.mcp_prefix}${serverName}`,
action: 'uninstall',
auth: {},
isEntityTool: true,
},
{
onError: (error: unknown) => {
showToast({
message: localize('com_ui_delete_tool_error', { error: String(error) }),
status: 'error',
});
},
onSuccess: () => {
const currentTools = getValues('tools');
const remainingToolIds =
currentTools?.filter(
(currentToolId) =>
currentToolId !== serverName &&
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
) || [];
setValue('tools', remainingToolIds);
showToast({ message: localize('com_ui_delete_tool_success'), status: 'success' });
},
},
);
};
return (
<OGDialog>
<div
@@ -116,7 +73,7 @@ export default function UnconfiguredMCPTool({ serverName }: { serverName?: strin
</Label>
}
selection={{
selectHandler: () => removeTool(),
selectHandler: () => removeTool(serverName || ''),
selectClasses:
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white',
selectText: localize('com_ui_delete'),

View File

@@ -1,29 +1,18 @@
import React, { useState } from 'react';
import { useFormContext } from 'react-hook-form';
import { Constants } from 'librechat-data-provider';
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
import {
Label,
OGDialog,
TrashIcon,
OGDialogTrigger,
useToastContext,
OGDialogTemplate,
} from '@librechat/client';
import type { AgentForm, MCPServerInfo } from '~/common';
import { Label, OGDialog, TrashIcon, OGDialogTrigger, OGDialogTemplate } from '@librechat/client';
import type { MCPServerInfo } from '~/common';
import { useLocalize, useMCPServerManager, useRemoveMCPTool } from '~/hooks';
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
import { useLocalize, useMCPServerManager } from '~/hooks';
import { cn } from '~/utils';
export default function UninitializedMCPTool({ serverInfo }: { serverInfo?: MCPServerInfo }) {
const localize = useLocalize();
const { removeTool } = useRemoveMCPTool();
const [isFocused, setIsFocused] = useState(false);
const [isHovering, setIsHovering] = useState(false);
const localize = useLocalize();
const { showToast } = useToastContext();
const updateUserPlugins = useUpdateUserPluginsMutation();
const { getValues, setValue } = useFormContext<AgentForm>();
const { initializeServer, isInitializing, getServerStatusIconProps, getConfigDialogProps } =
useMCPServerManager();
@@ -31,39 +20,6 @@ export default function UninitializedMCPTool({ serverInfo }: { serverInfo?: MCPS
return null;
}
const removeTool = (serverName: string) => {
if (!serverName) {
return;
}
updateUserPlugins.mutate(
{
pluginKey: `${Constants.mcp_prefix}${serverName}`,
action: 'uninstall',
auth: {},
isEntityTool: true,
},
{
onError: (error: unknown) => {
showToast({
message: localize('com_ui_delete_tool_error', { error: String(error) }),
status: 'error',
});
},
onSuccess: () => {
const currentTools = getValues('tools');
const remainingToolIds =
currentTools?.filter(
(currentToolId) =>
currentToolId !== serverName &&
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
) || [];
setValue('tools', remainingToolIds);
showToast({ message: localize('com_ui_delete_tool_success'), status: 'success' });
},
},
);
};
const serverName = serverInfo.serverName;
const isServerInitializing = isInitializing(serverName);
const statusIconProps = getServerStatusIconProps(serverName);

View File

@@ -1,4 +1,4 @@
import React, { useState, useMemo, useCallback } from 'react';
import React, { useState, useMemo, useCallback, useEffect } from 'react';
import { ChevronLeft, Trash2 } from 'lucide-react';
import { useQueryClient } from '@tanstack/react-query';
import { Button, useToastContext } from '@librechat/client';
@@ -12,6 +12,8 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks';
import { useGetStartupConfig } from '~/data-provider';
import MCPPanelSkeleton from './MCPPanelSkeleton';
const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms
function MCPPanelContent() {
const localize = useLocalize();
const queryClient = useQueryClient();
@@ -26,6 +28,29 @@ function MCPPanelContent() {
null,
);
// Check if any connections are in 'connecting' state
const hasConnectingServers = useMemo(() => {
if (!connectionStatus) {
return false;
}
return Object.values(connectionStatus).some(
(status) => status?.connectionState === 'connecting',
);
}, [connectionStatus]);
// Set up polling when servers are connecting
useEffect(() => {
if (!hasConnectingServers) {
return;
}
const intervalId = setInterval(() => {
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
}, POLL_FOR_CONNECTION_STATUS_INTERVAL);
return () => clearInterval(intervalId);
}, [hasConnectingServers, queryClient]);
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
onSuccess: async () => {
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });

View File

@@ -0,0 +1,440 @@
import { renderHook } from '@testing-library/react';
import { Tools, Constants } from 'librechat-data-provider';
import useAgentToolPermissions from '../useAgentToolPermissions';
// Mock dependencies
jest.mock('~/data-provider', () => ({
useGetAgentByIdQuery: jest.fn(),
}));
jest.mock('~/Providers', () => ({
useAgentsMapContext: jest.fn(),
}));
// Import mocked functions after mocking
import { useGetAgentByIdQuery } from '~/data-provider';
import { useAgentsMapContext } from '~/Providers';
describe('useAgentToolPermissions', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('Ephemeral Agent Scenarios', () => {
it('should return true for all tools when agentId is null', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(null));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
it('should return true for all tools when agentId is undefined', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(undefined));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
it('should return true for all tools when agentId is empty string', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(''));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
it('should return true for all tools when agentId is EPHEMERAL_AGENT_ID', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() =>
useAgentToolPermissions(Constants.EPHEMERAL_AGENT_ID)
);
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
});
describe('Regular Agent with Tools', () => {
it('should allow file_search when agent has the tool', () => {
const agentId = 'agent-123';
const mockAgent = {
id: agentId,
tools: [Tools.file_search, 'other_tool'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual([Tools.file_search, 'other_tool']);
});
it('should allow execute_code when agent has the tool', () => {
const agentId = 'agent-456';
const mockAgent = {
id: agentId,
tools: [Tools.execute_code, 'another_tool'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toEqual([Tools.execute_code, 'another_tool']);
});
it('should allow both tools when agent has both', () => {
const agentId = 'agent-789';
const mockAgent = {
id: agentId,
tools: [Tools.file_search, Tools.execute_code, 'custom_tool'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code, 'custom_tool']);
});
it('should disallow both tools when agent has neither', () => {
const agentId = 'agent-no-tools';
const mockAgent = {
id: agentId,
tools: ['custom_tool1', 'custom_tool2'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual(['custom_tool1', 'custom_tool2']);
});
it('should handle agent with empty tools array', () => {
const agentId = 'agent-empty-tools';
const mockAgent = {
id: agentId,
tools: [],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual([]);
});
it('should handle agent with undefined tools', () => {
const agentId = 'agent-undefined-tools';
const mockAgent = {
id: agentId,
tools: undefined,
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
});
describe('Agent Data from Query', () => {
it('should prioritize agentData tools over selectedAgent tools', () => {
const agentId = 'agent-with-query-data';
const mockAgent = {
id: agentId,
tools: ['old_tool'],
};
const mockAgentData = {
id: agentId,
tools: [Tools.file_search, Tools.execute_code],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code]);
});
it('should fallback to selectedAgent tools when agentData has no tools', () => {
const agentId = 'agent-fallback';
const mockAgent = {
id: agentId,
tools: [Tools.file_search],
};
const mockAgentData = {
id: agentId,
tools: undefined,
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual([Tools.file_search]);
});
});
describe('Agent Not Found Scenarios', () => {
it('should disallow all tools when agent is not found in map', () => {
const agentId = 'non-existent-agent';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
it('should disallow all tools when agentsMap is null', () => {
const agentId = 'agent-with-null-map';
(useAgentsMapContext as jest.Mock).mockReturnValue(null);
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
it('should disallow all tools when agentsMap is undefined', () => {
const agentId = 'agent-with-undefined-map';
(useAgentsMapContext as jest.Mock).mockReturnValue(undefined);
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
});
describe('Memoization and Performance', () => {
it('should memoize results when inputs do not change', () => {
const agentId = 'memoized-agent';
const mockAgent = {
id: agentId,
tools: [Tools.file_search],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result, rerender } = renderHook(() => useAgentToolPermissions(agentId));
const firstResult = result.current;
// Rerender without changing inputs
rerender();
const secondResult = result.current;
// The hook returns a new object each time, but the values should be equal
expect(firstResult.fileSearchAllowedByAgent).toBe(secondResult.fileSearchAllowedByAgent);
expect(firstResult.codeAllowedByAgent).toBe(secondResult.codeAllowedByAgent);
// Tools array reference should be the same since it comes from useMemo
expect(firstResult.tools).toBe(secondResult.tools);
// Verify the actual values are correct
expect(secondResult.fileSearchAllowedByAgent).toBe(true);
expect(secondResult.codeAllowedByAgent).toBe(false);
expect(secondResult.tools).toEqual([Tools.file_search]);
});
it('should recompute when agentId changes', () => {
const agentId1 = 'agent-1';
const agentId2 = 'agent-2';
const mockAgents = {
[agentId1]: { id: agentId1, tools: [Tools.file_search] },
[agentId2]: { id: agentId2, tools: [Tools.execute_code] },
};
(useAgentsMapContext as jest.Mock).mockReturnValue(mockAgents);
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result, rerender } = renderHook(
({ agentId }) => useAgentToolPermissions(agentId),
{ initialProps: { agentId: agentId1 } }
);
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(false);
// Change agentId
rerender({ agentId: agentId2 });
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(true);
});
it('should handle switching between ephemeral and regular agents', () => {
const regularAgentId = 'regular-agent';
const mockAgent = {
id: regularAgentId,
tools: [],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[regularAgentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result, rerender } = renderHook(
({ agentId }) => useAgentToolPermissions(agentId),
{ initialProps: { agentId: null } }
);
// Start with ephemeral agent (null)
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
// Switch to regular agent
rerender({ agentId: regularAgentId });
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
// Switch back to ephemeral
rerender({ agentId: '' });
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
});
});
describe('Edge Cases', () => {
it('should handle agents with null tools gracefully', () => {
const agentId = 'agent-null-tools';
const mockAgent = {
id: agentId,
tools: null as any,
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeNull();
});
it('should handle whitespace-only agentId as ephemeral', () => {
// Note: Based on the current implementation, only empty string is treated as ephemeral
// Whitespace-only strings would be treated as regular agent IDs
const whitespaceId = ' ';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(whitespaceId));
// Whitespace ID is not considered ephemeral in current implementation
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
});
it('should handle query loading state', () => {
const agentId = 'loading-agent';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
data: undefined,
isLoading: true,
error: null,
});
const { result } = renderHook(() => useAgentToolPermissions(agentId));
// During loading, should return false for non-ephemeral agents
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
it('should handle query error state', () => {
const agentId = 'error-agent';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
data: undefined,
isLoading: false,
error: new Error('Failed to fetch agent'),
});
const { result } = renderHook(() => useAgentToolPermissions(agentId));
// On error, should return false for non-ephemeral agents
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
});
});

View File

@@ -1,5 +1,5 @@
import { useMemo } from 'react';
import { Tools } from 'librechat-data-provider';
import { Tools, Constants } from 'librechat-data-provider';
import { useGetAgentByIdQuery } from '~/data-provider';
import { useAgentsMapContext } from '~/Providers';
@@ -9,6 +9,10 @@ interface AgentToolPermissionsResult {
tools: string[] | undefined;
}
function isEphemeralAgent(agentId: string | null | undefined): boolean {
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
}
/**
* Hook to determine whether specific tools are allowed for a given agent.
*
@@ -33,8 +37,8 @@ export default function useAgentToolPermissions(
);
const fileSearchAllowedByAgent = useMemo(() => {
// If no agentId, allow for ephemeral agents
if (!agentId) return true;
// Allow for ephemeral agents
if (isEphemeralAgent(agentId)) return true;
// If agentId exists but agent not found, disallow
if (!selectedAgent) return false;
// Check if the agent has the file_search tool
@@ -42,8 +46,8 @@ export default function useAgentToolPermissions(
}, [agentId, selectedAgent, tools]);
const codeAllowedByAgent = useMemo(() => {
// If no agentId, allow for ephemeral agents
if (!agentId) return true;
// Allow for ephemeral agents
if (isEphemeralAgent(agentId)) return true;
// If agentId exists but agent not found, disallow
if (!selectedAgent) return false;
// Check if the agent has the execute_code tool

View File

@@ -0,0 +1,495 @@
import React from 'react';
import { Provider, createStore } from 'jotai';
import { renderHook, act, waitFor } from '@testing-library/react';
import { RecoilRoot, useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
import { ephemeralAgentByConvoId } from '~/store';
import { setTimestamp } from '~/utils/timestamps';
import { useMCPSelect } from '../useMCPSelect';
// Mock dependencies
jest.mock('~/utils/timestamps', () => ({
setTimestamp: jest.fn(),
}));
jest.mock('lodash/isEqual', () => jest.fn((a, b) => JSON.stringify(a) === JSON.stringify(b)));
const createWrapper = () => {
// Create a new Jotai store for each test to ensure clean state
const store = createStore();
const Wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => (
<RecoilRoot>
<Provider store={store}>{children}</Provider>
</RecoilRoot>
);
return Wrapper;
};
describe('useMCPSelect', () => {
beforeEach(() => {
jest.clearAllMocks();
localStorage.clear();
});
describe('Basic Functionality', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
expect(result.current.isPinned).toBe(true); // Default value from mcpPinnedAtom is true
expect(typeof result.current.setMCPValues).toBe('function');
expect(typeof result.current.setIsPinned).toBe('function');
});
it('should use conversationId when provided', () => {
const conversationId = 'test-convo-123';
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
});
it('should use NEW_CONVO constant when conversationId is null', () => {
const { result } = renderHook(() => useMCPSelect({ conversationId: null }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
});
});
describe('State Updates', () => {
it('should update mcpValues when setMCPValues is called', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const newValues = ['value1', 'value2'];
act(() => {
result.current.setMCPValues(newValues);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(newValues);
});
});
it('should not update mcpValues if non-array is passed', () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
act(() => {
// @ts-ignore - Testing invalid input
result.current.setMCPValues('not-an-array');
});
expect(result.current.mcpValues).toEqual([]);
});
it('should update isPinned state', () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Default is true
expect(result.current.isPinned).toBe(true);
// Toggle to false
act(() => {
result.current.setIsPinned(false);
});
expect(result.current.isPinned).toBe(false);
// Toggle back to true
act(() => {
result.current.setIsPinned(true);
});
expect(result.current.isPinned).toBe(true);
});
});
describe('Timestamp Management', () => {
it('should set timestamp when mcpValues is updated with values', async () => {
const conversationId = 'test-convo';
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
wrapper: createWrapper(),
});
const newValues = ['value1', 'value2'];
act(() => {
result.current.setMCPValues(newValues);
});
await waitFor(() => {
const expectedKey = `${LocalStorageKeys.LAST_MCP_}${conversationId}`;
expect(setTimestamp).toHaveBeenCalledWith(expectedKey);
});
});
it('should not set timestamp when mcpValues is empty', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
act(() => {
result.current.setMCPValues([]);
});
await waitFor(() => {
expect(setTimestamp).not.toHaveBeenCalled();
});
});
});
describe('Race Conditions and Infinite Loops Prevention', () => {
it('should not create infinite loop when syncing between Jotai and Recoil states', async () => {
const { result, rerender } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
let renderCount = 0;
const maxRenders = 10;
// Track renders to detect infinite loops
const trackRender = () => {
renderCount++;
if (renderCount > maxRenders) {
throw new Error('Potential infinite loop detected');
}
};
// Set initial value
act(() => {
trackRender();
result.current.setMCPValues(['initial']);
});
// Trigger multiple rerenders
for (let i = 0; i < 3; i++) {
rerender();
trackRender();
}
// Should not exceed max renders
expect(renderCount).toBeLessThanOrEqual(maxRenders);
});
it('should handle rapid consecutive updates without race conditions', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const updates = [
['value1'],
['value1', 'value2'],
['value1', 'value2', 'value3'],
['value4'],
[],
];
// Rapid fire updates
act(() => {
updates.forEach((update) => {
result.current.setMCPValues(update);
});
});
await waitFor(() => {
// Should settle on the last update
expect(result.current.mcpValues).toEqual([]);
});
});
it('should maintain stable setter function reference', () => {
const { result, rerender } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const firstSetMCPValues = result.current.setMCPValues;
// Trigger multiple rerenders
rerender();
rerender();
rerender();
// Setter should remain the same reference (memoized)
expect(result.current.setMCPValues).toBe(firstSetMCPValues);
});
it('should handle switching conversation IDs without issues', async () => {
const { result, rerender } = renderHook(
({ conversationId }) => useMCPSelect({ conversationId }),
{
wrapper: createWrapper(),
initialProps: { conversationId: 'convo1' },
},
);
// Set values for first conversation
act(() => {
result.current.setMCPValues(['convo1-value']);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['convo1-value']);
});
// Switch to different conversation
rerender({ conversationId: 'convo2' });
// Should have different state for new conversation
expect(result.current.mcpValues).toEqual([]);
// Set values for second conversation
act(() => {
result.current.setMCPValues(['convo2-value']);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['convo2-value']);
});
// Switch back to first conversation
rerender({ conversationId: 'convo1' });
// Should maintain separate state
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['convo1-value']);
});
});
});
describe('Ephemeral Agent Synchronization', () => {
it('should sync mcpValues when ephemeralAgent is updated externally', async () => {
// Create a shared wrapper for both hooks to share the same Recoil/Jotai context
const wrapper = createWrapper();
// Create a component that uses both hooks to ensure they share state
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(
ephemeralAgentByConvoId(Constants.NEW_CONVO),
);
return { mcpHook, ephemeralAgent, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
// Simulate external update to ephemeralAgent (e.g., from another component)
const externalMcpValues = ['external-value1', 'external-value2'];
act(() => {
result.current.setEphemeralAgent({
mcp: externalMcpValues,
});
});
// The hook should sync with the external update
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(externalMcpValues);
});
});
it('should update ephemeralAgent when mcpValues changes through hook', async () => {
// Create a shared wrapper for both hooks
const wrapper = createWrapper();
// Create a component that uses both the hook and accesses Recoil state
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, ephemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
const newValues = ['hook-value1', 'hook-value2'];
act(() => {
result.current.mcpHook.setMCPValues(newValues);
});
// Verify both mcpValues and ephemeralAgent are updated
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(newValues);
expect(result.current.ephemeralAgent?.mcp).toEqual(newValues);
});
});
it('should handle empty ephemeralAgent.mcp array correctly', async () => {
// Create a shared wrapper
const wrapper = createWrapper();
// Create a component that uses both hooks
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
// Set initial values
act(() => {
result.current.mcpHook.setMCPValues(['initial-value']);
});
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
});
// Try to set empty array externally
act(() => {
result.current.setEphemeralAgent({
mcp: [],
});
});
// Values should remain unchanged since empty mcp array doesn't trigger update
// (due to the condition: ephemeralAgent?.mcp && ephemeralAgent.mcp.length > 0)
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
});
it('should properly sync non-empty arrays from ephemeralAgent', async () => {
// Additional test to ensure non-empty arrays DO sync
const wrapper = createWrapper();
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
// Set initial values through ephemeralAgent with non-empty array
const initialValues = ['value1', 'value2'];
act(() => {
result.current.setEphemeralAgent({
mcp: initialValues,
});
});
// Should sync since it's non-empty
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(initialValues);
});
// Update with different non-empty values
const updatedValues = ['value3', 'value4', 'value5'];
act(() => {
result.current.setEphemeralAgent({
mcp: updatedValues,
});
});
// Should sync the new values
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(updatedValues);
});
});
});
describe('Edge Cases', () => {
it('should handle undefined conversationId', () => {
const { result } = renderHook(() => useMCPSelect({ conversationId: undefined }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
act(() => {
result.current.setMCPValues(['test']);
});
expect(() => result.current).not.toThrow();
});
it('should handle empty string conversationId', () => {
const { result } = renderHook(() => useMCPSelect({ conversationId: '' }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
});
it('should handle very large arrays without performance issues', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const largeArray = Array.from({ length: 1000 }, (_, i) => `value-${i}`);
const startTime = performance.now();
act(() => {
result.current.setMCPValues(largeArray);
});
const endTime = performance.now();
const executionTime = endTime - startTime;
// Should complete within reasonable time (< 100ms)
expect(executionTime).toBeLessThan(100);
await waitFor(() => {
expect(result.current.mcpValues).toEqual(largeArray);
});
});
it('should cleanup properly on unmount', () => {
const { unmount } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Should unmount without errors
expect(() => unmount()).not.toThrow();
});
});
describe('Memory Leak Prevention', () => {
it('should not leak memory on repeated updates', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Perform many updates to test for memory leaks
for (let i = 0; i < 100; i++) {
act(() => {
result.current.setMCPValues([`value-${i}`]);
});
}
// If we get here without crashing, memory management is likely OK
expect(result.current.mcpValues).toEqual(['value-99']);
});
it('should handle component remounting', () => {
const { result, unmount } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
act(() => {
result.current.setMCPValues(['before-unmount']);
});
unmount();
// Remount
const { result: newResult } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Should handle remounting gracefully
expect(newResult.current.mcpValues).toBeDefined();
});
});
});

View File

@@ -3,3 +3,4 @@ export * from './useMCPConnectionStatus';
export * from './useMCPSelect';
export * from './useVisibleTools';
export { useMCPServerManager } from './useMCPServerManager';
export { useRemoveMCPTool } from './useRemoveMCPTool';

View File

@@ -1,5 +1,6 @@
import { useCallback, useEffect } from 'react';
import { useAtom } from 'jotai';
import isEqual from 'lodash/isEqual';
import { useRecoilState } from 'recoil';
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
import { ephemeralAgentByConvoId, mcpValuesAtomFamily, mcpPinnedAtom } from '~/store';
@@ -19,15 +20,14 @@ export function useMCPSelect({ conversationId }: { conversationId?: string | nul
}
}, [ephemeralAgent?.mcp, setMCPValuesRaw]);
// Update ephemeral agent when Jotai state changes
useEffect(() => {
if (mcpValues.length > 0 && JSON.stringify(mcpValues) !== JSON.stringify(ephemeralAgent?.mcp)) {
setEphemeralAgent((prev) => ({
...prev,
mcp: mcpValues,
}));
}
}, [mcpValues, ephemeralAgent?.mcp, setEphemeralAgent]);
setEphemeralAgent((prev) => {
if (!isEqual(prev?.mcp, mcpValues)) {
return { ...(prev ?? {}), mcp: mcpValues };
}
return prev;
});
}, [mcpValues, setEphemeralAgent]);
useEffect(() => {
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${key}`;

View File

@@ -0,0 +1,61 @@
import { useCallback } from 'react';
import { useFormContext } from 'react-hook-form';
import { Constants } from 'librechat-data-provider';
import { useToastContext } from '@librechat/client';
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
import type { AgentForm } from '~/common';
import { useLocalize } from '~/hooks';
/**
* Hook for removing MCP tools/servers from an agent
* Provides unified logic for MCPTool, UninitializedMCPTool, and UnconfiguredMCPTool components
*/
export function useRemoveMCPTool() {
const localize = useLocalize();
const { showToast } = useToastContext();
const updateUserPlugins = useUpdateUserPluginsMutation();
const { getValues, setValue } = useFormContext<AgentForm>();
const removeTool = useCallback(
(serverName: string) => {
if (!serverName) {
return;
}
updateUserPlugins.mutate(
{
pluginKey: `${Constants.mcp_prefix}${serverName}`,
action: 'uninstall',
auth: {},
isEntityTool: true,
},
{
onError: (error: unknown) => {
showToast({
message: localize('com_ui_delete_tool_error', { error: String(error) }),
status: 'error',
});
},
onSuccess: () => {
const currentTools = getValues('tools');
const remainingToolIds =
currentTools?.filter(
(currentToolId) =>
currentToolId !== serverName &&
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
) || [];
setValue('tools', remainingToolIds, { shouldDirty: true });
showToast({
message: localize('com_ui_delete_tool_save_reminder'),
status: 'warning',
});
},
},
);
},
[getValues, setValue, updateUserPlugins, showToast, localize],
);
return { removeTool };
}

View File

@@ -2,7 +2,7 @@ import throttle from 'lodash/throttle';
import { useEffect, useRef, useCallback, useMemo } from 'react';
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import { useMessagesViewContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import useCopyToClipboard from './useCopyToClipboard';
import { getTextKey, logger } from '~/utils';
@@ -20,9 +20,9 @@ export default function useMessageHelpers(props: TMessageProps) {
setAbortScroll,
handleContinue,
setLatestMessage,
} = useChatContext();
const assistantMap = useAssistantsMapContext();
} = useMessagesViewContext();
const agentsMap = useAgentsMapContext();
const assistantMap = useAssistantsMapContext();
const { text, content, children, messageId = null, isCreatedByUser } = message ?? {};
const edit = messageId === currentEditId;

View File

@@ -3,7 +3,7 @@ import { useRecoilValue } from 'recoil';
import { Constants } from 'librechat-data-provider';
import { useEffect, useRef, useCallback, useMemo, useState } from 'react';
import type { TMessage } from 'librechat-data-provider';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { useMessagesViewContext } from '~/Providers';
import { getTextKey, logger } from '~/utils';
import store from '~/store';
@@ -18,14 +18,9 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu
latestMessage,
setAbortScroll,
setLatestMessage,
isSubmitting: isSubmittingRoot,
} = useChatContext();
const { isSubmitting: isSubmittingAdditional } = useAddedChatContext();
isSubmittingFamily,
} = useMessagesViewContext();
const latestMultiMessage = useRecoilValue(store.latestMessageFamily(index + 1));
const isSubmittingFamily = useMemo(
() => isSubmittingRoot || isSubmittingAdditional,
[isSubmittingRoot, isSubmittingAdditional],
);
useEffect(() => {
const convoId = conversation?.conversationId;

View File

@@ -2,8 +2,8 @@ import { useRecoilValue } from 'recoil';
import { Constants } from 'librechat-data-provider';
import { useState, useRef, useCallback, useEffect } from 'react';
import type { TMessage } from 'librechat-data-provider';
import { useMessagesConversation, useMessagesSubmission } from '~/Providers';
import useScrollToRef from '~/hooks/useScrollToRef';
import { useChatContext } from '~/Providers';
import store from '~/store';
const threshold = 0.85;
@@ -15,8 +15,8 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
const scrollableRef = useRef<HTMLDivElement | null>(null);
const messagesEndRef = useRef<HTMLDivElement | null>(null);
const [showScrollButton, setShowScrollButton] = useState(false);
const { conversation, setAbortScroll, isSubmitting, abortScroll } = useChatContext();
const { conversationId } = conversation ?? {};
const { conversation, conversationId } = useMessagesConversation();
const { setAbortScroll, isSubmitting, abortScroll } = useMessagesSubmission();
const timeoutIdRef = useRef<NodeJS.Timeout>();

View File

@@ -833,7 +833,6 @@
"com_ui_delete_tool": "Werkzeug löschen",
"com_ui_delete_tool_confirm": "Bist du sicher, dass du dieses Werkzeug löschen möchtest?",
"com_ui_delete_tool_error": "Fehler beim Löschen des Tools: {{error}}",
"com_ui_delete_tool_success": "Tool erfolgreich gelöscht",
"com_ui_deleted": "Gelöscht",
"com_ui_deleting_file": "Lösche Datei...",
"com_ui_descending": "Absteigend",

View File

@@ -834,7 +834,7 @@
"com_ui_delete_tool": "Delete Tool",
"com_ui_delete_tool_confirm": "Are you sure you want to delete this tool?",
"com_ui_delete_tool_error": "Error while deleting the tool: {{error}}",
"com_ui_delete_tool_success": "Tool deleted successfully",
"com_ui_delete_tool_save_reminder": "Tool removed. Save the agent to apply changes.",
"com_ui_deleted": "Deleted",
"com_ui_deleting_file": "Deleting file...",
"com_ui_descending": "Desc",

View File

@@ -337,7 +337,7 @@
"com_endpoint_prompt_prefix_assistants": "Papildu instrukcijas",
"com_endpoint_prompt_prefix_assistants_placeholder": "Iestatiet papildu norādījumus vai kontekstu virs Asistenta galvenajiem norādījumiem. Ja lauks ir tukšs, tas tiek ignorēts.",
"com_endpoint_prompt_prefix_placeholder": "Iestatiet pielāgotas instrukcijas vai kontekstu. Ja lauks ir tukšs, tas tiek ignorēts.",
"com_endpoint_reasoning_effort": "Spriešanas piepūle",
"com_endpoint_reasoning_effort": "Spriešanas līmenis",
"com_endpoint_reasoning_summary": "Spriešanas kopsavilkums",
"com_endpoint_save_as_preset": "Saglabāt kā iestatījumu",
"com_endpoint_search": "Meklēt galapunktu pēc nosaukuma",
@@ -834,7 +834,7 @@
"com_ui_delete_tool": "Dzēst rīku",
"com_ui_delete_tool_confirm": "Vai tiešām vēlaties dzēst šo rīku?",
"com_ui_delete_tool_error": "Kļūda, dzēšot rīku: {{error}}",
"com_ui_delete_tool_success": "Rīks veiksmīgi izdzēsts",
"com_ui_delete_tool_save_reminder": "Rīks noņemts. Saglabājiet aģentu, lai piemērotu izmaiņas.",
"com_ui_deleted": "Dzēsts",
"com_ui_deleting_file": "Dzēšu failu...",
"com_ui_descending": "Dilstošs",
@@ -1012,7 +1012,7 @@
"com_ui_memory_would_exceed": "Nevar saglabāt - pārsniegtu tokenu limitu par {{tokens}}. Izdzēsiet esošās atmiņas, lai atbrīvotu vietu.",
"com_ui_mention": "Pieminiet galapunktu, assistentu vai iestatījumu, lai ātri uz to pārslēgtos",
"com_ui_min_tags": "Nevar noņemt vairāk vērtību, vismaz {{0}} ir nepieciešamas.",
"com_ui_minimal": "Minimāla",
"com_ui_minimal": "Minimāls",
"com_ui_misc": "Dažādi",
"com_ui_model": "Modelis",
"com_ui_model_parameters": "Modeļa Parametrus",
@@ -1035,7 +1035,7 @@
"com_ui_no_results_found": "Nav atrastu rezultātu",
"com_ui_no_terms_content": "Nav noteikumu un nosacījumu satura, ko parādīt",
"com_ui_no_valid_items": "Nav rezultātu",
"com_ui_none": "Neviens",
"com_ui_none": "Nekāds",
"com_ui_not_used": "Nav izmantots",
"com_ui_nothing_found": "Nekas nav atrasts",
"com_ui_oauth": "OAuth",

View File

@@ -834,7 +834,6 @@
"com_ui_delete_tool": "Slett verktøy",
"com_ui_delete_tool_confirm": "Er du sikker på at du vil slette dette verktøyet?",
"com_ui_delete_tool_error": "En feil oppstod ved sletting av verktøyet: {{error}}",
"com_ui_delete_tool_success": "Verktøyet ble slettet",
"com_ui_deleted": "Slettet",
"com_ui_deleting_file": "Sletter fil ...",
"com_ui_descending": "Synkende",

View File

@@ -27,6 +27,13 @@
"com_agents_file_search_disabled": "Maak eerst een Agent aan voordat je bestanden uploadt voor File Search.",
"com_agents_file_search_info": "Als deze functie is ingeschakeld, krijgt de agent informatie over de exacte bestandsnamen die hieronder staan vermeld, zodat deze relevante context uit deze bestanden kan ophalen.",
"com_agents_instructions_placeholder": "De systeeminstructies die de agent gebruikt",
"com_agents_link_copied": "Link gekopieerd",
"com_agents_link_copy_failed": "Link niet gekopieerd",
"com_agents_load_more_label": "Laad meer agenten van {{category}} categorie",
"com_agents_loading": "Aan het laden...",
"com_agents_mcp_icon_size": "Minimum formaat 128 x 128 px",
"com_agents_mcp_info": "MCP-servers toevoegen aan je agent zodat deze taken kan uitvoeren en kan communiceren met externe services",
"com_agents_mcp_name_placeholder": "Aangepast hulpmiddel",
"com_agents_missing_provider_model": "Selecteer een provider en model voordat je een agent aanmaakt.",
"com_agents_name_placeholder": "De naam van de agent",
"com_agents_no_access": "Je hebt geen toegang om deze agent te bewerken.",

View File

@@ -834,7 +834,6 @@
"com_ui_delete_tool": "Видалити інструмент",
"com_ui_delete_tool_confirm": "Ви дійсно хочете видалити цей інструмент?",
"com_ui_delete_tool_error": "Помилка під час видалення інструменту: {{error}}",
"com_ui_delete_tool_success": "Інструмент успішно видалено",
"com_ui_deleted": "Видалено",
"com_ui_deleting_file": "Видалення файлу...",
"com_ui_descending": "За спаданням",

View File

@@ -834,7 +834,6 @@
"com_ui_delete_tool": "删除工具",
"com_ui_delete_tool_confirm": "您确定要删除此工具吗?",
"com_ui_delete_tool_error": "删除工具时发生错误:{{error}}",
"com_ui_delete_tool_success": "工具删除成功",
"com_ui_deleted": "已删除",
"com_ui_deleting_file": "删除文件中...",
"com_ui_descending": "降序",

View File

@@ -49,7 +49,7 @@ export default defineConfig(({ command }) => ({
],
globIgnores: ['images/**/*', '**/*.map', 'index.html'],
maximumFileSizeToCacheInBytes: 4 * 1024 * 1024,
navigateFallbackDenylist: [/^\/oauth/, /^\/api/],
navigateFallbackDenylist: [/^\/oauth/, /^\/api/, /^\/admin\/openid/],
},
includeAssets: [],
manifest: {

View File

@@ -10,6 +10,9 @@ jest.mock('form-data', () => {
getLength: jest.fn().mockReturnValue(100),
}));
});
jest.mock('https-proxy-agent', () => ({
HttpsProxyAgent: jest.fn().mockImplementation((url) => ({ proxyUrl: url })),
}));
jest.mock('axios', () => {
const mockAxiosInstance = {
get: jest.fn().mockResolvedValue({ data: {} }),
@@ -44,6 +47,7 @@ jest.mock('~/utils/axios', () => ({
import * as fs from 'fs';
import axios from 'axios';
import { HttpsProxyAgent } from 'https-proxy-agent';
import type { Readable } from 'stream';
import type {
MistralFileUploadResponse,
@@ -1182,6 +1186,8 @@ describe('MistralOCR Service', () => {
describe('Mixed env var and hardcoded configuration', () => {
beforeEach(() => {
// Clean up any PROXY env var from previous tests
delete process.env.PROXY;
const mockReadStream: MockReadStream = {
on: jest.fn().mockImplementation(function (
this: MockReadStream,
@@ -1708,9 +1714,403 @@ describe('MistralOCR Service', () => {
});
});
describe('Proxy Configuration', () => {
const originalProxy = process.env.PROXY;
beforeEach(() => {
// Reset the HttpsProxyAgent mock to its default implementation
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
// Clear any previous axios mock calls
mockAxios.post!.mockClear();
mockAxios.get!.mockClear();
mockAxios.delete!.mockClear();
});
afterEach(() => {
if (originalProxy) {
process.env.PROXY = originalProxy;
} else {
delete process.env.PROXY;
}
// Clear mocks after each test to prevent leaking
mockAxios.post!.mockClear();
mockAxios.get!.mockClear();
mockAxios.delete!.mockClear();
});
describe('uploadDocumentToMistral with proxy', () => {
beforeEach(() => {
const mockReadStream: MockReadStream = {
on: jest.fn().mockImplementation(function (
this: MockReadStream,
event: string,
handler: () => void,
) {
if (event === 'end') {
handler();
}
return this;
}),
pipe: jest.fn().mockImplementation(function (this: MockReadStream) {
return this;
}),
pause: jest.fn(),
resume: jest.fn(),
emit: jest.fn(),
once: jest.fn(),
destroy: jest.fn(),
path: '/path/to/test.pdf',
fd: 1,
flags: 'r',
mode: 0o666,
autoClose: true,
bytesRead: 0,
closed: false,
pending: false,
};
(jest.mocked(fs).createReadStream as jest.Mock).mockReturnValue(mockReadStream);
});
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'http://proxy.example.com:8080';
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-proxy-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://proxy.example.com:8080',
}),
}),
);
});
it('should handle proxy URL with authentication', async () => {
process.env.PROXY = 'http://user:pass@proxy.example.com:8080';
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-proxy-auth-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://user:pass@proxy.example.com:8080',
}),
}),
);
});
it('should handle IPv6 proxy addresses', async () => {
process.env.PROXY = 'http://[::1]:8080';
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-proxy-ipv6-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://[::1]:8080',
}),
}),
);
});
it('should not use proxy when PROXY env var is not set', async () => {
delete process.env.PROXY;
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-no-proxy-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.not.objectContaining({
httpsAgent: expect.anything(),
}),
);
});
});
describe('performOCR with proxy', () => {
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'http://proxy.example.com:3128';
const mockResponse: { data: OCRResult } = {
data: {
model: 'mistral-ocr-latest',
pages: [
{
index: 0,
markdown: 'Proxy test content',
images: [],
dimensions: { dpi: 300, height: 1100, width: 850 },
},
],
document_annotation: '',
usage_info: {
pages_processed: 1,
doc_size_bytes: 1024,
},
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await performOCR({
apiKey: 'test-api-key',
url: 'https://document-url.com',
model: 'mistral-ocr-latest',
documentType: 'document_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://proxy.example.com:3128',
}),
}),
);
});
it('should handle malformed proxy URLs gracefully', async () => {
(HttpsProxyAgent as unknown as jest.Mock).mockImplementationOnce(() => {
throw new Error('Invalid URL');
});
process.env.PROXY = 'not-a-valid-url';
const mockResponse: { data: OCRResult } = {
data: {
model: 'mistral-ocr-latest',
pages: [
{
index: 0,
markdown: 'Test content',
images: [],
dimensions: { dpi: 300, height: 1100, width: 850 },
},
],
document_annotation: '',
usage_info: {
pages_processed: 1,
doc_size_bytes: 1024,
},
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await expect(
performOCR({
apiKey: 'test-api-key',
url: 'https://document-url.com',
}),
).rejects.toThrow('Invalid URL');
});
});
describe('Azure Mistral OCR with proxy', () => {
beforeEach(() => {
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(
Buffer.from('mock-file-content'),
);
});
it('should use proxy for Azure Mistral OCR requests', async () => {
process.env.PROXY = 'http://proxy.example.com:8080';
mockLoadAuthValues.mockResolvedValue({
OCR_API_KEY: 'azure-api-key',
OCR_BASEURL: 'https://azure.mistral.ai/v1',
});
mockAxios.post!.mockResolvedValueOnce({
data: {
model: 'mistral-ocr-latest',
pages: [
{
index: 0,
markdown: 'Azure OCR with proxy',
images: [],
dimensions: { dpi: 300, height: 1100, width: 850 },
},
],
document_annotation: '',
usage_info: {
pages_processed: 1,
doc_size_bytes: 1024,
},
},
});
const req = {
user: { id: 'user123' },
config: {
ocr: {
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-ocr-latest',
},
},
} as unknown as ServerRequest;
const file = {
path: '/tmp/upload/azure-file.pdf',
originalname: 'azure-document.pdf',
mimetype: 'application/pdf',
} as Express.Multer.File;
await uploadAzureMistralOCR({
req,
file,
loadAuthValues: mockLoadAuthValues,
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://azure.mistral.ai/v1/ocr',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://proxy.example.com:8080',
}),
}),
);
});
});
describe('getSignedUrl with proxy', () => {
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'https://secure-proxy.example.com:443';
const mockResponse: { data: MistralSignedUrlResponse } = {
data: {
url: 'https://signed-url.com',
expires_at: Date.now() + 86400000,
},
};
mockAxios.get!.mockResolvedValueOnce(mockResponse);
await getSignedUrl({
fileId: 'file-123',
apiKey: 'test-api-key',
});
expect(mockAxios.get).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'https://secure-proxy.example.com:443',
}),
}),
);
});
});
describe('deleteMistralFile with proxy', () => {
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'socks5://proxy.example.com:1080';
mockAxios.delete!.mockResolvedValueOnce({ data: {} });
await deleteMistralFile({
fileId: 'file-123',
apiKey: 'test-api-key',
});
expect(mockAxios.delete).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files/file-123',
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'socks5://proxy.example.com:1080',
}),
}),
);
});
});
});
describe('uploadAzureMistralOCR', () => {
beforeEach(() => {
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(Buffer.from('mock-file-content'));
// Reset the HttpsProxyAgent mock to its default implementation for Azure tests
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
// Clean up any PROXY env var from previous tests
delete process.env.PROXY;
// Reset axios mocks completely to clear any queued responses
mockAxios.post!.mockReset();
mockAxios.get!.mockReset();
mockAxios.delete!.mockReset();
// Re-establish default resolved values
mockAxios.post!.mockResolvedValue({ data: {} });
mockAxios.get!.mockResolvedValue({ data: {} });
mockAxios.delete!.mockResolvedValue({ data: {} });
});
it('should process OCR using Azure Mistral with base64 encoding', async () => {
@@ -1796,6 +2196,11 @@ describe('MistralOCR Service', () => {
});
describe('Mixed env var and hardcoded configuration', () => {
beforeEach(() => {
// Clean up any PROXY env var from previous tests
delete process.env.PROXY;
});
it('should preserve hardcoded baseURL when only apiKey is an env var', async () => {
// This test demonstrates the current bug
mockLoadAuthValues.mockResolvedValue({

View File

@@ -2,6 +2,7 @@ import * as fs from 'fs';
import * as path from 'path';
import FormData from 'form-data';
import { logger } from '@librechat/data-schemas';
import { HttpsProxyAgent } from 'https-proxy-agent';
import {
FileSources,
envVarRegex,
@@ -9,7 +10,7 @@ import {
extractVariableName,
} from 'librechat-data-provider';
import type { TCustomConfig } from 'librechat-data-provider';
import type { AxiosError } from 'axios';
import type { AxiosError, AxiosRequestConfig } from 'axios';
import type {
MistralFileUploadResponse,
MistralSignedUrlResponse,
@@ -77,15 +78,21 @@ export async function uploadDocumentToMistral({
const fileStream = fs.createReadStream(filePath);
form.append('file', fileStream, { filename: actualFileName });
const config: AxiosRequestConfig = {
headers: {
Authorization: `Bearer ${apiKey}`,
...form.getHeaders(),
},
maxBodyLength: Infinity,
maxContentLength: Infinity,
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.post(`${baseURL}/files`, form, {
headers: {
Authorization: `Bearer ${apiKey}`,
...form.getHeaders(),
},
maxBodyLength: Infinity,
maxContentLength: Infinity,
})
.post(`${baseURL}/files`, form, config)
.then((res) => res.data)
.catch((error) => {
throw error;
@@ -103,12 +110,18 @@ export async function getSignedUrl({
expiry?: number;
baseURL?: string;
}): Promise<MistralSignedUrlResponse> {
const config: AxiosRequestConfig = {
headers: {
Authorization: `Bearer ${apiKey}`,
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
})
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, config)
.then((res) => res.data)
.catch((error) => {
logger.error('Error fetching signed URL:', error.message);
@@ -139,6 +152,18 @@ export async function performOCR({
documentType?: 'document_url' | 'image_url';
}): Promise<OCRResult> {
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.post(
`${baseURL}/ocr`,
@@ -151,12 +176,7 @@ export async function performOCR({
[documentKey]: url,
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
},
config,
)
.then((res) => res.data)
.catch((error) => {
@@ -182,12 +202,18 @@ export async function deleteMistralFile({
apiKey: string;
baseURL?: string;
}): Promise<void> {
const config: AxiosRequestConfig = {
headers: {
Authorization: `Bearer ${apiKey}`,
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
try {
const result = await axios.delete(`${baseURL}/files/${fileId}`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
const result = await axios.delete(`${baseURL}/files/${fileId}`, config);
logger.debug(`Mistral file ${fileId} deleted successfully:`, result.data);
} catch (error) {
logger.error(`Error deleting Mistral file ${fileId}:`, error);
@@ -543,17 +569,23 @@ async function createJWT(serviceKey: GoogleServiceAccount): Promise<string> {
* Exchanges JWT for access token
*/
async function exchangeJWTForAccessToken(jwt: string): Promise<string> {
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
const response = await axios.post(
'https://oauth2.googleapis.com/token',
new URLSearchParams({
grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer',
assertion: jwt,
}),
{
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
},
config,
);
if (!response.data?.access_token) {
@@ -608,14 +640,20 @@ async function performGoogleVertexOCR({
},
});
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${accessToken}`,
Accept: 'application/json',
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.post(baseURL, requestBody, {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${accessToken}`,
Accept: 'application/json',
},
})
.post(baseURL, requestBody, config)
.then((res) => {
logger.debug('Google Vertex AI response received');
return res.data;

View File

@@ -14,6 +14,7 @@ export * from './utils';
export * from './db/utils';
/* OAuth */
export * from './oauth';
export * from './mcp/oauth/OAuthReconnectionManager';
/* Crypto */
export * from './crypto';
/* Flow */

View File

@@ -6,6 +6,7 @@ import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import type * as t from './types';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { sanitizeUrlForLogging } from './utils';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
@@ -308,7 +309,9 @@ export class MCPConnectionFactory {
metadata?: OAuthMetadata;
} | null> {
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
logger.debug(
`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
);
if (!this.flowManager || !serverUrl) {
logger.error(

View File

@@ -7,6 +7,7 @@ import type { JsonSchemaType } from '~/types';
import type * as t from '~/mcp/types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { sanitizeUrlForLogging } from '~/mcp/utils';
import { processMCPEnv } from '~/utils';
import { CONSTANTS } from '~/mcp/enum';
@@ -183,7 +184,7 @@ export class MCPServersRegistry {
const prefix = this.prefix(serverName);
const config = this.parsedConfigs[serverName];
logger.info(`${prefix} -------------------------------------------------┐`);
logger.info(`${prefix} URL: ${config.url}`);
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
logger.info(`${prefix} Tools: ${config.tools}`);

View File

@@ -1,15 +1,15 @@
import { EventEmitter } from 'events';
import { logger } from '@librechat/data-schemas';
import { fetch as undiciFetch, Agent } from 'undici';
import {
StdioClientTransport,
getDefaultEnvironment,
} from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { logger } from '@librechat/data-schemas';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type {
@@ -18,8 +18,9 @@ import type {
Response as UndiciResponse,
} from 'undici';
import type { MCPOAuthTokens } from './oauth/types';
import { mcpConfig } from './mcpConfig';
import type * as t from './types';
import { sanitizeUrlForLogging } from './utils';
import { mcpConfig } from './mcpConfig';
type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
@@ -238,7 +239,9 @@ export class MCPConnection extends EventEmitter {
}
this.url = options.url;
const url = new URL(options.url);
logger.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
logger.info(
`${this.getLogPrefix()} Creating SSE transport: ${sanitizeUrlForLogging(url)}`,
);
const abortController = new AbortController();
/** Add OAuth token to headers if available */
@@ -293,7 +296,7 @@ export class MCPConnection extends EventEmitter {
this.url = options.url;
const url = new URL(options.url);
logger.info(
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
`${this.getLogPrefix()} Creating streamable-http transport: ${sanitizeUrlForLogging(url)}`,
);
const abortController = new AbortController();
@@ -473,7 +476,9 @@ export class MCPConnection extends EventEmitter {
logger.warn(`${this.getLogPrefix()} OAuth authentication required`);
this.oauthRequired = true;
const serverUrl = this.url;
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
logger.debug(
`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
);
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
/** Promise that will resolve when OAuth is handled */

View File

@@ -0,0 +1,294 @@
import { TokenMethods } from '@librechat/data-schemas';
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
import { MCPManager } from '../MCPManager';
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
jest.mock('../MCPManager');
describe('OAuthReconnectionManager', () => {
let flowManager: jest.Mocked<FlowStateManager<null>>;
let tokenMethods: jest.Mocked<TokenMethods>;
let mockMCPManager: jest.Mocked<MCPManager>;
let reconnectionManager: OAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
// Reset singleton instance
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(OAuthReconnectionManager as any).instance = null;
// Setup mock flow manager
flowManager = {
createFlow: jest.fn(),
completeFlow: jest.fn(),
failFlow: jest.fn(),
deleteFlow: jest.fn(),
getFlow: jest.fn(),
} as unknown as jest.Mocked<FlowStateManager<null>>;
// Setup mock token methods
tokenMethods = {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteToken: jest.fn(),
} as unknown as jest.Mocked<TokenMethods>;
// Setup mock MCP Manager
mockMCPManager = {
getOAuthServers: jest.fn(),
getUserConnection: jest.fn(),
getUserConnections: jest.fn(),
disconnectUserConnection: jest.fn(),
getRawConfig: jest.fn(),
} as unknown as jest.Mocked<MCPManager>;
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
});
afterEach(() => {
jest.clearAllMocks();
});
describe('Singleton Pattern', () => {
it('should create instance successfully', async () => {
const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
expect(instance).toBeInstanceOf(OAuthReconnectionManager);
});
it('should throw error when creating instance twice', async () => {
await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
await expect(
OAuthReconnectionManager.createInstance(flowManager, tokenMethods),
).rejects.toThrow('OAuthReconnectionManager already initialized');
});
it('should throw error when getting instance before creation', () => {
expect(() => OAuthReconnectionManager.getInstance()).toThrow(
'OAuthReconnectionManager not initialized',
);
});
});
describe('isReconnecting', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true when server is actively reconnecting', () => {
const userId = 'user-123';
const serverName = 'test-server';
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
reconnectionTracker.setActive(userId, serverName);
const result = reconnectionManager.isReconnecting(userId, serverName);
expect(result).toBe(true);
});
it('should return false when server is not reconnecting', () => {
const userId = 'user-123';
const serverName = 'test-server';
const result = reconnectionManager.isReconnecting(userId, serverName);
expect(result).toBe(false);
});
});
describe('clearReconnection', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should clear both failed and active reconnection states', () => {
const userId = 'user-123';
const serverName = 'test-server';
reconnectionTracker.setFailed(userId, serverName);
reconnectionTracker.setActive(userId, serverName);
reconnectionManager.clearReconnection(userId, serverName);
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false);
expect(reconnectionTracker.isActive(userId, serverName)).toBe(false);
});
});
describe('reconnectServers', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should reconnect eligible servers', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2', 'server3']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has failed reconnection
reconnectionTracker.setFailed(userId, 'server1');
// server2: already connected
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
};
const userConnections = new Map([['server2', mockConnection]]);
mockMCPManager.getUserConnections.mockReturnValue(
userConnections as unknown as Map<string, MCPConnection>,
);
// server3: has valid token and not connected
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server3') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
} as unknown as MCPOAuthTokens;
}
return null;
});
// Mock successful reconnection
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Verify server3 was marked as active
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify reconnection was attempted for server3
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({
serverName: 'server3',
user: { id: userId },
flowManager,
tokenMethods,
forceNew: false,
connectionTimeout: 5000,
returnOnOAuth: true,
});
// Verify successful reconnection cleared the states
expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false);
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false);
});
it('should handle failed reconnection attempts', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has valid token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens);
// Mock failed connection
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify failure handling
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
it('should not reconnect servers with expired tokens', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
} as unknown as MCPOAuthTokens);
await reconnectionManager.reconnectServers(userId);
// Verify no reconnection attempt was made
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
});
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens);
// Mock connection that returns but is not connected
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(false),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify failure handling
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
});
});

View File

@@ -0,0 +1,163 @@
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { MCPOAuthTokens } from './types';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
import { FlowStateManager } from '~/flow/manager';
import { MCPManager } from '~/mcp/MCPManager';
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
export class OAuthReconnectionManager {
private static instance: OAuthReconnectionManager | null = null;
protected readonly flowManager: FlowStateManager<MCPOAuthTokens | null>;
protected readonly tokenMethods: TokenMethods;
private readonly reconnectionsTracker: OAuthReconnectionTracker;
public static getInstance(): OAuthReconnectionManager {
if (!OAuthReconnectionManager.instance) {
throw new Error('OAuthReconnectionManager not initialized');
}
return OAuthReconnectionManager.instance;
}
public static async createInstance(
flowManager: FlowStateManager<MCPOAuthTokens | null>,
tokenMethods: TokenMethods,
reconnections?: OAuthReconnectionTracker,
): Promise<OAuthReconnectionManager> {
if (OAuthReconnectionManager.instance != null) {
throw new Error('OAuthReconnectionManager already initialized');
}
const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections);
OAuthReconnectionManager.instance = manager;
return manager;
}
public constructor(
flowManager: FlowStateManager<MCPOAuthTokens | null>,
tokenMethods: TokenMethods,
reconnections?: OAuthReconnectionTracker,
) {
this.flowManager = flowManager;
this.tokenMethods = tokenMethods;
this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker();
}
public isReconnecting(userId: string, serverName: string): boolean {
return this.reconnectionsTracker.isActive(userId, serverName);
}
public async reconnectServers(userId: string) {
const mcpManager = MCPManager.getInstance();
// 1. derive the servers to reconnect
const serversToReconnect = [];
for (const serverName of mcpManager.getOAuthServers() ?? []) {
const canReconnect = await this.canReconnect(userId, serverName);
if (canReconnect) {
serversToReconnect.push(serverName);
}
}
// 2. mark the servers as reconnecting
for (const serverName of serversToReconnect) {
this.reconnectionsTracker.setActive(userId, serverName);
}
// 3. attempt to reconnect the servers
for (const serverName of serversToReconnect) {
void this.tryReconnect(userId, serverName);
}
}
public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
}
private async tryReconnect(userId: string, serverName: string) {
const mcpManager = MCPManager.getInstance();
const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`;
logger.info(`${logPrefix} Attempting reconnection`);
const config = mcpManager.getRawConfig(serverName);
const cleanupOnFailedReconnect = () => {
this.reconnectionsTracker.setFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
mcpManager.disconnectUserConnection(userId, serverName);
};
try {
// attempt to get connection (this will use existing tokens and refresh if needed)
const connection = await mcpManager.getUserConnection({
serverName,
user: { id: userId } as TUser,
flowManager: this.flowManager,
tokenMethods: this.tokenMethods,
// don't force new connection, let it reuse existing or create new as needed
forceNew: false,
// set a reasonable timeout for reconnection attempts
connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS,
// don't trigger OAuth flow during reconnection
returnOnOAuth: true,
});
if (connection && (await connection.isConnected())) {
logger.info(`${logPrefix} Successfully reconnected`);
this.clearReconnection(userId, serverName);
} else {
logger.warn(`${logPrefix} Failed to reconnect`);
await connection?.disconnect();
cleanupOnFailedReconnect();
}
} catch (error) {
logger.warn(`${logPrefix} Failed to reconnect: ${error}`);
cleanupOnFailedReconnect();
}
}
private async canReconnect(userId: string, serverName: string) {
const mcpManager = MCPManager.getInstance();
// if the server has failed reconnection, don't attempt to reconnect
if (this.reconnectionsTracker.isFailed(userId, serverName)) {
return false;
}
// if the server is already connected, don't attempt to reconnect
const existingConnections = mcpManager.getUserConnections(userId);
if (existingConnections?.has(serverName)) {
const isConnected = await existingConnections.get(serverName)?.isConnected();
if (isConnected) {
return false;
}
}
// if the server has no tokens for the user, don't attempt to reconnect
const accessToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}`,
});
if (accessToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect
const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) {
return false;
}
// …otherwise, we're good to go with the reconnect attempt
return true;
}
}

View File

@@ -0,0 +1,181 @@
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
describe('OAuthReconnectTracker', () => {
let tracker: OAuthReconnectionTracker;
const userId = 'user123';
const serverName = 'test-server';
const anotherServer = 'another-server';
beforeEach(() => {
tracker = new OAuthReconnectionTracker();
});
describe('setFailed', () => {
it('should record a failed reconnection attempt', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should track multiple servers for the same user', () => {
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, anotherServer);
expect(tracker.isFailed(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
});
it('should track different users independently', () => {
const anotherUserId = 'user456';
tracker.setFailed(userId, serverName);
tracker.setFailed(anotherUserId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
expect(tracker.isFailed(anotherUserId, serverName)).toBe(true);
});
});
describe('isFailed', () => {
it('should return false when no failed attempt is recorded', () => {
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should return true after a failed attempt is recorded', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false for a different server even after another server failed', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, anotherServer)).toBe(false);
});
});
describe('removeFailed', () => {
it('should clear a failed reconnect record', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should only clear the specific server for the user', () => {
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, anotherServer);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
});
it('should handle clearing non-existent records gracefully', () => {
expect(() => tracker.removeFailed(userId, serverName)).not.toThrow();
});
});
describe('setActive', () => {
it('should mark a server as reconnecting', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
});
it('should track multiple reconnecting servers', () => {
tracker.setActive(userId, serverName);
tracker.setActive(userId, anotherServer);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isActive(userId, anotherServer)).toBe(true);
});
});
describe('isActive', () => {
it('should return false when server is not reconnecting', () => {
expect(tracker.isActive(userId, serverName)).toBe(false);
});
it('should return true when server is marked as reconnecting', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
});
it('should handle non-existent user gracefully', () => {
expect(tracker.isActive('non-existent-user', serverName)).toBe(false);
});
});
describe('removeActive', () => {
it('should clear reconnecting state for a server', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
});
it('should only clear specific server state', () => {
tracker.setActive(userId, serverName);
tracker.setActive(userId, anotherServer);
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isActive(userId, anotherServer)).toBe(true);
});
it('should handle clearing non-existent state gracefully', () => {
expect(() => tracker.removeActive(userId, serverName)).not.toThrow();
});
});
describe('cleanup behavior', () => {
it('should clean up empty user sets for failed reconnects', () => {
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
// Record and clear another user to ensure internal cleanup
const anotherUserId = 'user456';
tracker.setFailed(anotherUserId, serverName);
// Original user should still be able to reconnect
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should clean up empty user sets for active reconnections', () => {
tracker.setActive(userId, serverName);
tracker.removeActive(userId, serverName);
// Mark another user to ensure internal cleanup
const anotherUserId = 'user456';
tracker.setActive(anotherUserId, serverName);
// Original user should not be reconnecting
expect(tracker.isActive(userId, serverName)).toBe(false);
});
});
describe('combined state management', () => {
it('should handle both failed and reconnecting states independently', () => {
// Mark as reconnecting
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Record failed attempt
tracker.setFailed(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, serverName)).toBe(true);
// Clear reconnecting state
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, serverName)).toBe(true);
// Clear failed state
tracker.removeFailed(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
});

View File

@@ -0,0 +1,46 @@
export class OAuthReconnectionTracker {
// Map of userId -> Set of serverNames that have failed reconnection
private failed: Map<string, Set<string>> = new Map();
// Map of userId -> Set of serverNames that are actively reconnecting
private active: Map<string, Set<string>> = new Map();
public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false;
}
public isActive(userId: string, serverName: string): boolean {
return this.active.get(userId)?.has(serverName) ?? false;
}
public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) {
this.failed.set(userId, new Set());
}
this.failed.get(userId)?.add(serverName);
}
public setActive(userId: string, serverName: string): void {
if (!this.active.has(userId)) {
this.active.set(userId, new Set());
}
this.active.get(userId)?.add(serverName);
}
public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.failed.delete(userId);
}
}
public removeActive(userId: string, serverName: string): void {
const userServers = this.active.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.active.delete(userId);
}
}
}

View File

@@ -17,6 +17,7 @@ import type {
MCPOAuthTokens,
OAuthMetadata,
} from './types';
import { sanitizeUrlForLogging } from '~/mcp/utils';
/** Type for the OAuth metadata from the SDK */
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
@@ -33,7 +34,9 @@ export class MCPOAuthHandler {
resourceMetadata?: OAuthProtectedResourceMetadata;
authServerUrl: URL;
}> {
logger.debug(`[MCPOAuth] discoverMetadata called with serverUrl: ${serverUrl}`);
logger.debug(
`[MCPOAuth] discoverMetadata called with serverUrl: ${sanitizeUrlForLogging(serverUrl)}`,
);
let authServerUrl = new URL(serverUrl);
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
@@ -60,11 +63,15 @@ export class MCPOAuthHandler {
}
// Discover OAuth metadata
logger.debug(`[MCPOAuth] Discovering OAuth metadata from ${authServerUrl}`);
logger.debug(
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
);
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl);
if (!rawMetadata) {
logger.error(`[MCPOAuth] Failed to discover OAuth metadata from ${authServerUrl}`);
logger.error(
`[MCPOAuth] Failed to discover OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
);
throw new Error('Failed to discover OAuth metadata');
}
@@ -88,12 +95,15 @@ export class MCPOAuthHandler {
resourceMetadata?: OAuthProtectedResourceMetadata,
redirectUri?: string,
): Promise<OAuthClientInformation> {
logger.debug(`[MCPOAuth] Starting client registration for ${serverUrl}, server metadata:`, {
grant_types_supported: metadata.grant_types_supported,
response_types_supported: metadata.response_types_supported,
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
scopes_supported: metadata.scopes_supported,
});
logger.debug(
`[MCPOAuth] Starting client registration for ${sanitizeUrlForLogging(serverUrl)}, server metadata:`,
{
grant_types_supported: metadata.grant_types_supported,
response_types_supported: metadata.response_types_supported,
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
scopes_supported: metadata.scopes_supported,
},
);
/** Client metadata based on what the server supports */
const clientMetadata = {
@@ -114,7 +124,9 @@ export class MCPOAuthHandler {
`[MCPOAuth] Server ${serverUrl} supports \`refresh_token\` grant type, adding to request`,
);
} else {
logger.debug(`[MCPOAuth] Server ${serverUrl} does not support \`refresh_token\` grant type`);
logger.debug(
`[MCPOAuth] Server ${sanitizeUrlForLogging(serverUrl)} does not support \`refresh_token\` grant type`,
);
}
clientMetadata.grant_types = requestedGrantTypes;
@@ -139,19 +151,25 @@ export class MCPOAuthHandler {
clientMetadata.scope = availableScopes.join(' ');
}
logger.debug(`[MCPOAuth] Registering client for ${serverUrl} with metadata:`, clientMetadata);
logger.debug(
`[MCPOAuth] Registering client for ${sanitizeUrlForLogging(serverUrl)} with metadata:`,
clientMetadata,
);
const clientInfo = await registerClient(serverUrl, {
metadata: metadata as unknown as SDKOAuthMetadata,
clientMetadata,
});
logger.debug(`[MCPOAuth] Client registered successfully for ${serverUrl}:`, {
client_id: clientInfo.client_id,
has_client_secret: !!clientInfo.client_secret,
grant_types: clientInfo.grant_types,
scope: clientInfo.scope,
});
logger.debug(
`[MCPOAuth] Client registered successfully for ${sanitizeUrlForLogging(serverUrl)}:`,
{
client_id: clientInfo.client_id,
has_client_secret: !!clientInfo.client_secret,
grant_types: clientInfo.grant_types,
scope: clientInfo.scope,
},
);
return clientInfo;
}
@@ -165,7 +183,9 @@ export class MCPOAuthHandler {
userId: string,
config: MCPOptions['oauth'] | undefined,
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
logger.debug(`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${serverUrl}`);
logger.debug(
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
);
const flowId = this.generateFlowId(userId, serverName);
const state = this.generateState();
@@ -226,7 +246,9 @@ export class MCPOAuthHandler {
metadata,
};
logger.debug(`[MCPOAuth] Authorization URL generated: ${authorizationUrl.toString()}`);
logger.debug(
`[MCPOAuth] Authorization URL generated: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
);
return {
authorizationUrl: authorizationUrl.toString(),
flowId,
@@ -234,10 +256,14 @@ export class MCPOAuthHandler {
};
}
logger.debug(`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${serverUrl}`);
logger.debug(
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
);
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
logger.debug(`[MCPOAuth] OAuth metadata discovered, auth server URL: ${authServerUrl}`);
logger.debug(
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
);
/** Dynamic client registration based on the discovered metadata */
const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName);
@@ -276,7 +302,9 @@ export class MCPOAuthHandler {
codeVerifier = authResult.codeVerifier;
logger.debug(`[MCPOAuth] startAuthorization completed successfully`);
logger.debug(`[MCPOAuth] Authorization URL: ${authorizationUrl.toString()}`);
logger.debug(
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
);
/** Add state parameter with flowId to the authorization URL */
authorizationUrl.searchParams.set('state', flowId);
@@ -515,7 +543,7 @@ export class MCPOAuthHandler {
body.append('client_id', metadata.clientInfo.client_id);
}
logger.debug(`[MCPOAuth] Refresh request to: ${tokenUrl}`, {
logger.debug(`[MCPOAuth] Refresh request to: ${sanitizeUrlForLogging(tokenUrl)}`, {
body: body.toString(),
headers,
});
@@ -695,7 +723,9 @@ export class MCPOAuthHandler {
}
// perform the revoke request
logger.info(`[MCPOAuth] Revoking tokens for ${serverName} via ${revokeUrl.toString()}`);
logger.info(
`[MCPOAuth] Revoking tokens for ${serverName} via ${sanitizeUrlForLogging(revokeUrl.toString())}`,
);
const response = await fetch(revokeUrl, {
method: 'POST',
body: body.toString(),

View File

@@ -31,3 +31,17 @@ export function normalizeServerName(serverName: string): string {
return normalized;
}
/**
* Sanitizes a URL by removing query parameters to prevent credential leakage in logs.
* @param url - The URL to sanitize (string or URL object)
* @returns The sanitized URL string without query parameters
*/
export function sanitizeUrlForLogging(url: string | URL): string {
try {
const urlObj = typeof url === 'string' ? new URL(url) : url;
return `${urlObj.protocol}//${urlObj.host}${urlObj.pathname}`;
} catch {
return '[invalid URL]';
}
}

View File

@@ -98,12 +98,6 @@ if (typeof window !== 'undefined') {
if (originalRequest.url?.includes('/api/auth/logout') === true) {
return Promise.reject(error);
}
if (originalRequest.url?.includes('/api/auth/refresh') === true) {
// Refresh token itself failed - redirect to login
console.log('Refresh token request failed, redirecting to login...');
window.location.href = '/login';
return Promise.reject(error);
}
if (error.response.status === 401 && !originalRequest._retry) {
console.warn('401 error, refreshing token');
@@ -124,7 +118,10 @@ if (typeof window !== 'undefined') {
isRefreshing = true;
try {
const response = await refreshToken();
const response = await refreshToken(
// Handle edge case where we get a blank screen if the initial 401 error is from a refresh token request
originalRequest.url?.includes('api/auth/refresh') === true ? true : false,
);
const token = response?.token ?? '';