Compare commits
18 Commits
chart-1.9.
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278590d0bb | ||
|
|
41a4674469 | ||
|
|
e7a9cf88ac | ||
|
|
f6925f906b | ||
|
|
e90fd1df15 | ||
|
|
a1f9f3dd39 | ||
|
|
fbe0def2fa | ||
|
|
d04da60b3b | ||
|
|
0e94d97bfb | ||
|
|
45ab4d4503 | ||
|
|
0ceef12eea | ||
|
|
6738360051 | ||
|
|
52b65492d5 | ||
|
|
7a9a99d2a0 | ||
|
|
5bfb06b417 | ||
|
|
2ce8f1f686 | ||
|
|
1a47601533 | ||
|
|
5245aeea8f |
@@ -1,6 +1,6 @@
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
@@ -26,4 +26,6 @@ module.exports = {
|
||||
createMCPManager: MCPManager.createInstance,
|
||||
getMCPManager: MCPManager.getInstance,
|
||||
getFlowStateManager,
|
||||
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
|
||||
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
|
||||
};
|
||||
|
||||
@@ -11,8 +11,9 @@ const {
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
@@ -96,14 +97,25 @@ const refreshController = async (req, res) => {
|
||||
return res.status(200).send({ token, user });
|
||||
}
|
||||
|
||||
// Find the session with the hashed refresh token
|
||||
const session = await findSession({
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
});
|
||||
/** Session with the hashed refresh token */
|
||||
const session = await findSession(
|
||||
{
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
},
|
||||
{ lean: false },
|
||||
);
|
||||
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
const token = await setAuthTokens(userId, res, session);
|
||||
|
||||
// trigger OAuth MCP server reconnection asynchronously (best effort)
|
||||
void getOAuthReconnectionManager()
|
||||
.reconnectServers(userId)
|
||||
.catch((err) => {
|
||||
logger.error('Error reconnecting OAuth MCP servers:', err);
|
||||
});
|
||||
|
||||
res.status(200).send({ token, user });
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generate2FATempToken } = require('~/server/services/twoFactorService');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const loginController = async (req, res) => {
|
||||
try {
|
||||
|
||||
50
api/server/controllers/auth/oauth.js
Normal file
50
api/server/controllers/auth/oauth.js
Normal file
@@ -0,0 +1,50 @@
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { checkBan } = require('~/server/middleware');
|
||||
|
||||
const domains = {
|
||||
client: process.env.DOMAIN_CLIENT,
|
||||
server: process.env.DOMAIN_SERVER,
|
||||
};
|
||||
|
||||
function createOAuthHandler(redirectUri = domains.client) {
|
||||
/**
|
||||
* A handler to process OAuth authentication results.
|
||||
* @type {Function}
|
||||
* @param {ServerRequest} req - Express request object.
|
||||
* @param {ServerResponse} res - Express response object.
|
||||
* @param {NextFunction} next - Express next middleware function.
|
||||
*/
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
if (res.headersSent) {
|
||||
return;
|
||||
}
|
||||
|
||||
await checkBan(req, res);
|
||||
if (req.banned) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
req.user &&
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
res.redirect(redirectUri);
|
||||
} catch (err) {
|
||||
logger.error('Error in setting authentication tokens:', err);
|
||||
next(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createOAuthHandler,
|
||||
};
|
||||
@@ -12,6 +12,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
@@ -108,6 +109,7 @@ const startServer = async () => {
|
||||
app.use('/oauth', routes.oauth);
|
||||
/* API Endpoints */
|
||||
app.use('/api/auth', routes.auth);
|
||||
app.use('/api/admin', routes.adminAuth);
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/user', routes.user);
|
||||
@@ -154,7 +156,7 @@ const startServer = async () => {
|
||||
res.send(updatedIndexHtml);
|
||||
});
|
||||
|
||||
app.listen(port, host, () => {
|
||||
app.listen(port, host, async () => {
|
||||
if (host === '0.0.0.0') {
|
||||
logger.info(
|
||||
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
|
||||
@@ -163,7 +165,9 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCPs().then(() => checkMigrations());
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const configMiddleware = require('./config/app');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireAdmin = require('./requireAdmin');
|
||||
const moderateText = require('./moderateText');
|
||||
const logHeaders = require('./logHeaders');
|
||||
const setHeaders = require('./setHeaders');
|
||||
@@ -36,6 +37,7 @@ module.exports = {
|
||||
setHeaders,
|
||||
logHeaders,
|
||||
moderateText,
|
||||
requireAdmin,
|
||||
validateModel,
|
||||
requireJwtAuth,
|
||||
checkInviteUser,
|
||||
|
||||
22
api/server/middleware/requireAdmin.js
Normal file
22
api/server/middleware/requireAdmin.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Middleware to check if authenticated user has admin role
|
||||
* Should be used AFTER authentication middleware (requireJwtAuth, requireLocalAuth, etc.)
|
||||
*/
|
||||
const requireAdmin = (req, res, next) => {
|
||||
if (!req.user) {
|
||||
logger.warn('[requireAdmin] No user found in request');
|
||||
return res.status(401).json({ message: 'Authentication required' });
|
||||
}
|
||||
|
||||
if (!req.user.role || req.user.role !== SystemRoles.ADMIN) {
|
||||
logger.debug('[requireAdmin] Access denied for non-admin user:', req.user.email);
|
||||
return res.status(403).json({ message: 'Access denied: Admin privileges required' });
|
||||
}
|
||||
|
||||
next();
|
||||
};
|
||||
|
||||
module.exports = requireAdmin;
|
||||
@@ -1,6 +1,6 @@
|
||||
const passport = require('passport');
|
||||
const cookies = require('cookie');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse
|
||||
|
||||
66
api/server/routes/admin/auth.js
Normal file
66
api/server/routes/admin/auth.js
Normal file
@@ -0,0 +1,66 @@
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const middleware = require('~/server/middleware');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
const setBalanceConfig = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post(
|
||||
'/login/local',
|
||||
middleware.logHeaders,
|
||||
middleware.loginLimiter,
|
||||
middleware.checkBan,
|
||||
middleware.requireLocalAuth,
|
||||
middleware.requireAdmin,
|
||||
setBalanceConfig,
|
||||
loginController,
|
||||
);
|
||||
|
||||
router.get('/verify', middleware.requireJwtAuth, middleware.requireAdmin, (req, res) => {
|
||||
const { password: _p, totpSecret: _t, __v, ...user } = req.user;
|
||||
user.id = user._id.toString();
|
||||
res.status(200).json({ user });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid/check', (req, res) => {
|
||||
const openidConfig = getOpenIdConfig();
|
||||
if (!openidConfig) {
|
||||
return res.status(404).json({ message: 'OpenID configuration not found' });
|
||||
}
|
||||
res.status(200).json({ message: 'OpenID check successful' });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid', (req, res, next) => {
|
||||
return passport.authenticate('openidAdmin', {
|
||||
session: false,
|
||||
state: randomState(),
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/openid/callback',
|
||||
passport.authenticate('openidAdmin', {
|
||||
failureRedirect: `${process.env.DOMAIN_CLIENT}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
middleware.requireAdmin,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(
|
||||
(process.env.ADMIN_PANEL_URL || 'http://localhost:3000') + '/auth/openid/callback',
|
||||
),
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,6 +1,7 @@
|
||||
const accessPermissions = require('./accessPermissions');
|
||||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const adminAuth = require('./admin/auth');
|
||||
const tokenizer = require('./tokenizer');
|
||||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
@@ -32,6 +33,7 @@ module.exports = {
|
||||
mcp,
|
||||
edit,
|
||||
auth,
|
||||
adminAuth,
|
||||
keys,
|
||||
user,
|
||||
tags,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
@@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
// clear any reconnection attempts
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
|
||||
|
||||
const tools = await userConnection.fetchTools();
|
||||
await updateMCPUserTools({
|
||||
userId: flowState.userId,
|
||||
|
||||
@@ -4,10 +4,9 @@ const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
@@ -26,32 +25,7 @@ const domains = {
|
||||
router.use(logHeaders);
|
||||
router.use(loginLimiter);
|
||||
|
||||
const oauthHandler = async (req, res, next) => {
|
||||
try {
|
||||
if (res.headersSent) {
|
||||
return;
|
||||
}
|
||||
|
||||
await checkBan(req, res);
|
||||
if (req.banned) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
req.user &&
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
res.redirect(domains.client);
|
||||
} catch (err) {
|
||||
logger.error('Error in setting authentication tokens:', err);
|
||||
next(err);
|
||||
}
|
||||
};
|
||||
const oauthHandler = createOAuthHandler();
|
||||
|
||||
router.get('/error', (req, res) => {
|
||||
/** A single error message is pushed by passport when authentication fails. */
|
||||
|
||||
@@ -357,23 +357,18 @@ const resetPassword = async (userId, token, password) => {
|
||||
|
||||
/**
|
||||
* Set Auth Tokens
|
||||
*
|
||||
* @param {String | ObjectId} userId
|
||||
* @param {Object} res
|
||||
* @param {String} sessionId
|
||||
* @param {ServerResponse} res
|
||||
* @param {ISession | null} [session=null]
|
||||
* @returns
|
||||
*/
|
||||
const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
const setAuthTokens = async (userId, res, _session = null) => {
|
||||
try {
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
let session;
|
||||
let session = _session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
|
||||
if (sessionId) {
|
||||
session = await findSession({ sessionId: sessionId }, { lean: false });
|
||||
if (session && session._id && session.expiration != null) {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
@@ -383,6 +378,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
httpOnly: true,
|
||||
|
||||
@@ -20,8 +20,8 @@ const {
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, getAppConfig } = require('./Config');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -538,13 +538,20 @@ async function getServerConnectionStatus(
|
||||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
// connection state overrides specific to OAuth servers
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
// check if server is actively being reconnected
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
|
||||
finalConnectionState = 'connecting';
|
||||
} else {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getOAuthReconnectionManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
@@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
let mockGetMCPManager;
|
||||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
let mockGetOAuthReconnectionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
mockGetMCPManager = require('~/config').getMCPManager;
|
||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
@@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
@@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return failed flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return active flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return no flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => null),
|
||||
@@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
});
|
||||
});
|
||||
|
||||
it('should return connecting state when OAuth server is reconnecting', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager to return true for isReconnecting
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => true),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'connecting',
|
||||
});
|
||||
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not check OAuth flow status when server is connected', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
|
||||
@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
};
|
||||
|
||||
const userId = await createUser(userData, true, false);
|
||||
const userId = await createUser(userData, true, true);
|
||||
return userId.toString();
|
||||
}
|
||||
|
||||
|
||||
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
@@ -0,0 +1,26 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Initialize OAuth reconnect manager
|
||||
*/
|
||||
async function initializeOAuthReconnectManager() {
|
||||
try {
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
const tokenMethods = {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
};
|
||||
await createOAuthReconnectionManager(flowManager, tokenMethods);
|
||||
logger.info(`OAuth reconnect manager initialized successfully.`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize OAuth reconnect manager:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = initializeOAuthReconnectManager;
|
||||
@@ -1,14 +1,14 @@
|
||||
const appleLogin = require('./appleStrategy');
|
||||
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
|
||||
const openIdJwtLogin = require('./openIdJwtStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const discordLogin = require('./discordStrategy');
|
||||
const passportLogin = require('./localStrategy');
|
||||
const googleLogin = require('./googleStrategy');
|
||||
const githubLogin = require('./githubStrategy');
|
||||
const discordLogin = require('./discordStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
const { setupSaml } = require('./samlStrategy');
|
||||
const openIdJwtLogin = require('./openIdJwtStrategy');
|
||||
const appleLogin = require('./appleStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
|
||||
module.exports = {
|
||||
appleLogin,
|
||||
|
||||
@@ -281,6 +281,221 @@ function convertToUsername(input, defaultValue = '') {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process OpenID authentication tokenset and userinfo
|
||||
* This is the core logic extracted from the passport strategy callback
|
||||
* Can be reused by both the passport strategy and proxy authentication
|
||||
*
|
||||
* @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc.
|
||||
* @param {boolean} existingUsersOnly - If true, only existing users will be processed
|
||||
* @returns {Promise<Object>} The authenticated user object with tokenset
|
||||
*/
|
||||
async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
|
||||
const claims = tokenset.claims ? tokenset.claims() : tokenset;
|
||||
const userinfo = {
|
||||
...claims,
|
||||
};
|
||||
|
||||
// Get userinfo from provider if we have access_token and haven't already
|
||||
if (tokenset.access_token) {
|
||||
const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub);
|
||||
Object.assign(userinfo, providerUserinfo);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Auth] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
|
||||
);
|
||||
throw new Error('Email domain not allowed');
|
||||
}
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
openidId: claims.sub || userinfo.sub,
|
||||
email: claims.email || userinfo.email,
|
||||
strategyName: 'openidStrategy',
|
||||
findUser,
|
||||
});
|
||||
let user = result.user;
|
||||
const error = result.error;
|
||||
|
||||
if (error) {
|
||||
throw new Error(ErrorTypes.AUTH_FAILED);
|
||||
}
|
||||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
/** Required role if configured */
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
if (requiredRole) {
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access' && tokenset.access_token) {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id' && tokenset.id_token) {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
} else if (userinfo.roles) {
|
||||
// If roles are already in userinfo, use them directly
|
||||
const roles = Array.isArray(userinfo.roles) ? userinfo.roles : [userinfo.roles];
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new Error(`You must have the "${requiredRole}" role to log in.`);
|
||||
}
|
||||
} else if (requiredRoleParameterPath) {
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[OpenID Auth] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new Error(`You must have the "${requiredRole}" role to log in.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let username = '';
|
||||
if (process.env.OPENID_USERNAME_CLAIM) {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
if (existingUsersOnly && !user) {
|
||||
throw new Error('User does not exist');
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username,
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
// Handle avatar
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
const imageUrl = userinfo.picture;
|
||||
let fileName;
|
||||
if (crypto) {
|
||||
fileName = (await hashToken(userinfo.sub)) + '.png';
|
||||
} else {
|
||||
fileName = userinfo.sub + '.png';
|
||||
}
|
||||
|
||||
const imageBuffer = await downloadImage(
|
||||
imageUrl,
|
||||
openidConfig,
|
||||
tokenset.access_token,
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
buffer: imageBuffer,
|
||||
});
|
||||
user.avatar = imagePath ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[OpenID Auth] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username}`,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return { ...user, tokenset };
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {boolean | undefined} [existingUsersOnly]
|
||||
*/
|
||||
function createOpenIDCallback(existingUsersOnly) {
|
||||
return async (tokenset, done) => {
|
||||
try {
|
||||
const user = await processOpenIDAuth(tokenset, existingUsersOnly);
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
if (err.message === 'Email domain not allowed') {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
if (err.message === ErrorTypes.AUTH_FAILED) {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
if (err.message && err.message.includes('role to log in')) {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up the OpenID strategy specifically for admin authentication.
|
||||
* @param {Configuration} openidConfig
|
||||
*/
|
||||
const setupOpenIdAdmin = (openidConfig) => {
|
||||
try {
|
||||
if (!openidConfig) {
|
||||
throw new Error('OpenID configuration not initialized');
|
||||
}
|
||||
|
||||
const openidAdminLogin = new CustomOpenIDStrategy(
|
||||
{
|
||||
config: openidConfig,
|
||||
scope: process.env.OPENID_SCOPE,
|
||||
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback',
|
||||
},
|
||||
createOpenIDCallback(true),
|
||||
);
|
||||
|
||||
passport.use('openidAdmin', openidAdminLogin);
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] setupOpenIdAdmin', err);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Sets up the OpenID strategy for authentication.
|
||||
* This function configures the OpenID client, handles proxy settings,
|
||||
@@ -318,10 +533,6 @@ async function setupOpenId() {
|
||||
},
|
||||
);
|
||||
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
const usePKCE = isEnabled(process.env.OPENID_USE_PKCE);
|
||||
logger.info(`[openidStrategy] OpenID authentication configuration`, {
|
||||
generateNonce: shouldGenerateNonce,
|
||||
reason: shouldGenerateNonce
|
||||
@@ -335,159 +546,19 @@ async function setupOpenId() {
|
||||
scope: process.env.OPENID_SCOPE,
|
||||
callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL,
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
usePKCE,
|
||||
},
|
||||
async (tokenset, done) => {
|
||||
try {
|
||||
const claims = tokenset.claims();
|
||||
const userinfo = {
|
||||
...claims,
|
||||
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
|
||||
};
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
|
||||
);
|
||||
return done(null, false, { message: 'Email domain not allowed' });
|
||||
}
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
openidId: claims.sub,
|
||||
email: claims.email,
|
||||
strategyName: 'openidStrategy',
|
||||
findUser,
|
||||
});
|
||||
let user = result.user;
|
||||
const error = result.error;
|
||||
|
||||
if (error) {
|
||||
return done(null, false, {
|
||||
message: ErrorTypes.AUTH_FAILED,
|
||||
});
|
||||
}
|
||||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
if (requiredRole) {
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access') {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id') {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
}
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!roles.includes(requiredRole)) {
|
||||
return done(null, false, {
|
||||
message: `You must have the "${requiredRole}" role to log in.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let username = '';
|
||||
if (process.env.OPENID_USERNAME_CLAIM) {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username,
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
/** @type {string | undefined} */
|
||||
const imageUrl = userinfo.picture;
|
||||
|
||||
let fileName;
|
||||
if (crypto) {
|
||||
fileName = (await hashToken(userinfo.sub)) + '.png';
|
||||
} else {
|
||||
fileName = userinfo.sub + '.png';
|
||||
}
|
||||
|
||||
const imageBuffer = await downloadImage(
|
||||
imageUrl,
|
||||
openidConfig,
|
||||
tokenset.access_token,
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
buffer: imageBuffer,
|
||||
});
|
||||
user.avatar = imagePath ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
done(null, { ...user, tokenset });
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
|
||||
},
|
||||
createOpenIDCallback(),
|
||||
);
|
||||
passport.use('openid', openidLogin);
|
||||
setupOpenIdAdmin(openidConfig);
|
||||
return openidConfig;
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy]', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @function getOpenIdConfig
|
||||
* @description Returns the OpenID client instance.
|
||||
|
||||
@@ -873,6 +873,13 @@
|
||||
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ISession
|
||||
* @typedef {import('@librechat/data-schemas').ISession} ISession
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports IBalance
|
||||
* @typedef {import('@librechat/data-schemas').IBalance} IBalance
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import { createContext, useContext } from 'react';
|
||||
|
||||
type MessageContext = {
|
||||
messageId: string;
|
||||
nextType?: string;
|
||||
partIndex?: number;
|
||||
isExpanded: boolean;
|
||||
conversationId?: string | null;
|
||||
/** Submission state for cursor display - only true for latest message when submitting */
|
||||
isSubmitting?: boolean;
|
||||
/** Whether this is the latest message in the conversation */
|
||||
isLatestMessage?: boolean;
|
||||
};
|
||||
|
||||
export const MessageContext = createContext<MessageContext>({} as MessageContext);
|
||||
|
||||
150
client/src/Providers/MessagesViewContext.tsx
Normal file
150
client/src/Providers/MessagesViewContext.tsx
Normal file
@@ -0,0 +1,150 @@
|
||||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { useAddedChatContext } from './AddedChatContext';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface MessagesViewContextValue {
|
||||
/** Core conversation data */
|
||||
conversation: ReturnType<typeof useChatContext>['conversation'];
|
||||
conversationId: string | null | undefined;
|
||||
|
||||
/** Submission and control states */
|
||||
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
|
||||
isSubmittingFamily: boolean;
|
||||
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
|
||||
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
|
||||
|
||||
/** Message operations */
|
||||
ask: ReturnType<typeof useChatContext>['ask'];
|
||||
regenerate: ReturnType<typeof useChatContext>['regenerate'];
|
||||
handleContinue: ReturnType<typeof useChatContext>['handleContinue'];
|
||||
|
||||
/** Message state management */
|
||||
index: ReturnType<typeof useChatContext>['index'];
|
||||
latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
|
||||
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
|
||||
getMessages: ReturnType<typeof useChatContext>['getMessages'];
|
||||
setMessages: ReturnType<typeof useChatContext>['setMessages'];
|
||||
}
|
||||
|
||||
const MessagesViewContext = createContext<MessagesViewContextValue | undefined>(undefined);
|
||||
|
||||
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
|
||||
const chatContext = useChatContext();
|
||||
const addedChatContext = useAddedChatContext();
|
||||
|
||||
const {
|
||||
ask,
|
||||
index,
|
||||
regenerate,
|
||||
isSubmitting: isSubmittingRoot,
|
||||
conversation,
|
||||
latestMessage,
|
||||
setAbortScroll,
|
||||
handleContinue,
|
||||
setLatestMessage,
|
||||
abortScroll,
|
||||
getMessages,
|
||||
setMessages,
|
||||
} = chatContext;
|
||||
|
||||
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
|
||||
|
||||
/** Memoize conversation-related values */
|
||||
const conversationValues = useMemo(
|
||||
() => ({
|
||||
conversation,
|
||||
conversationId: conversation?.conversationId,
|
||||
}),
|
||||
[conversation],
|
||||
);
|
||||
|
||||
/** Memoize submission states */
|
||||
const submissionStates = useMemo(
|
||||
() => ({
|
||||
isSubmitting: isSubmittingRoot,
|
||||
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
|
||||
abortScroll,
|
||||
setAbortScroll,
|
||||
}),
|
||||
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
|
||||
);
|
||||
|
||||
/** Memoize message operations (these are typically stable references) */
|
||||
const messageOperations = useMemo(
|
||||
() => ({
|
||||
ask,
|
||||
regenerate,
|
||||
getMessages,
|
||||
setMessages,
|
||||
handleContinue,
|
||||
}),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
|
||||
/** Memoize message state values */
|
||||
const messageState = useMemo(
|
||||
() => ({
|
||||
index,
|
||||
latestMessage,
|
||||
setLatestMessage,
|
||||
}),
|
||||
[index, latestMessage, setLatestMessage],
|
||||
);
|
||||
|
||||
/** Combine all values into final context value */
|
||||
const contextValue = useMemo<MessagesViewContextValue>(
|
||||
() => ({
|
||||
...conversationValues,
|
||||
...submissionStates,
|
||||
...messageOperations,
|
||||
...messageState,
|
||||
}),
|
||||
[conversationValues, submissionStates, messageOperations, messageState],
|
||||
);
|
||||
|
||||
return (
|
||||
<MessagesViewContext.Provider value={contextValue}>{children}</MessagesViewContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useMessagesViewContext() {
|
||||
const context = useContext(MessagesViewContext);
|
||||
if (!context) {
|
||||
throw new Error('useMessagesViewContext must be used within MessagesViewProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
/** Hook for components that only need conversation data */
|
||||
export function useMessagesConversation() {
|
||||
const { conversation, conversationId } = useMessagesViewContext();
|
||||
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
|
||||
}
|
||||
|
||||
/** Hook for components that only need submission states */
|
||||
export function useMessagesSubmission() {
|
||||
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
|
||||
useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
|
||||
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
|
||||
);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message operations */
|
||||
export function useMessagesOperations() {
|
||||
const { ask, regenerate, handleContinue, getMessages, setMessages } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ ask, regenerate, handleContinue, getMessages, setMessages }),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message state */
|
||||
export function useMessagesState() {
|
||||
const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ index, latestMessage, setLatestMessage }),
|
||||
[index, latestMessage, setLatestMessage],
|
||||
);
|
||||
}
|
||||
@@ -26,4 +26,5 @@ export * from './SidePanelContext';
|
||||
export * from './MCPPanelContext';
|
||||
export * from './ArtifactsContext';
|
||||
export * from './PromptGroupsContext';
|
||||
export * from './MessagesViewContext';
|
||||
export { default as BadgeRowProvider } from './BadgeRowContext';
|
||||
|
||||
@@ -26,6 +26,7 @@ type ContentPartsProps = {
|
||||
isCreatedByUser: boolean;
|
||||
isLast: boolean;
|
||||
isSubmitting: boolean;
|
||||
isLatestMessage?: boolean;
|
||||
edit?: boolean;
|
||||
enterEdit?: (cancel?: boolean) => void | null | undefined;
|
||||
siblingIdx?: number;
|
||||
@@ -45,6 +46,7 @@ const ContentParts = memo(
|
||||
isCreatedByUser,
|
||||
isLast,
|
||||
isSubmitting,
|
||||
isLatestMessage,
|
||||
edit,
|
||||
enterEdit,
|
||||
siblingIdx,
|
||||
@@ -55,6 +57,8 @@ const ContentParts = memo(
|
||||
const [isExpanded, setIsExpanded] = useState(showThinking);
|
||||
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
|
||||
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const hasReasoningParts = useMemo(() => {
|
||||
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
|
||||
const allThinkPartsHaveContent =
|
||||
@@ -134,7 +138,9 @@ const ContentParts = memo(
|
||||
})
|
||||
}
|
||||
label={
|
||||
isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts')
|
||||
effectiveIsSubmitting && isLast
|
||||
? localize('com_ui_thinking')
|
||||
: localize('com_ui_thoughts')
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
@@ -155,12 +161,14 @@ const ContentParts = memo(
|
||||
conversationId,
|
||||
partIndex: idx,
|
||||
nextType: content[idx + 1]?.type,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isLatestMessage,
|
||||
}}
|
||||
>
|
||||
<Part
|
||||
part={part}
|
||||
attachments={attachments}
|
||||
isSubmitting={isSubmitting}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
key={`part-${messageId}-${idx}`}
|
||||
isCreatedByUser={isCreatedByUser}
|
||||
isLast={idx === content.length - 1}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { TextareaAutosize, TooltipAnchor } from '@librechat/client';
|
||||
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TEditProps } from '~/common';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import Container from './Container';
|
||||
@@ -22,7 +22,8 @@ const EditMessage = ({
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const saveButtonRef = useRef<HTMLButtonElement | null>(null);
|
||||
const submitButtonRef = useRef<HTMLButtonElement | null>(null);
|
||||
const { getMessages, setMessages, conversation } = useChatContext();
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { getMessages, setMessages } = useMessagesOperations();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageContentProps, TDisplayProps } from '~/common';
|
||||
import Error from '~/components/Messages/Content/Error';
|
||||
import Thinking from '~/components/Artifacts/Thinking';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useMessageContext } from '~/Providers';
|
||||
import MarkdownLite from './MarkdownLite';
|
||||
import EditMessage from './EditMessage';
|
||||
import { useLocalize } from '~/hooks';
|
||||
@@ -70,16 +70,12 @@ export const ErrorMessage = ({
|
||||
};
|
||||
|
||||
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
|
||||
const { isSubmitting, latestMessage } = useChatContext();
|
||||
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
|
||||
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
|
||||
const showCursorState = useMemo(
|
||||
() => showCursor === true && isSubmitting,
|
||||
[showCursor, isSubmitting],
|
||||
);
|
||||
const isLatestMessage = useMemo(
|
||||
() => message.messageId === latestMessage?.messageId,
|
||||
[message.messageId, latestMessage?.messageId],
|
||||
);
|
||||
|
||||
let content: React.ReactElement;
|
||||
if (!isCreatedByUser) {
|
||||
|
||||
@@ -85,13 +85,14 @@ const Part = memo(
|
||||
|
||||
const isToolCall =
|
||||
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
|
||||
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
|
||||
if (isToolCall && toolCall.name === Tools.execute_code) {
|
||||
return (
|
||||
<ExecuteCode
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
attachments={attachments}
|
||||
isSubmitting={isSubmitting}
|
||||
output={toolCall.output ?? ''}
|
||||
initialProgress={toolCall.progress ?? 0.1}
|
||||
attachments={attachments}
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
/>
|
||||
);
|
||||
} else if (
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
|
||||
import type { Agents } from 'librechat-data-provider';
|
||||
import type { TEditProps } from '~/common';
|
||||
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
|
||||
import Container from '~/components/Chat/Messages/Content/Container';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
@@ -25,7 +25,8 @@ const EditTextPart = ({
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const { ask, getMessages, setMessages, conversation } = useChatContext();
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { ask, getMessages, setMessages } = useMessagesOperations();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
|
||||
@@ -45,26 +45,28 @@ export function useParseArgs(args?: string): ParsedArgs | null {
|
||||
}
|
||||
|
||||
export default function ExecuteCode({
|
||||
isSubmitting,
|
||||
initialProgress = 0.1,
|
||||
args,
|
||||
output = '',
|
||||
attachments,
|
||||
}: {
|
||||
initialProgress: number;
|
||||
isSubmitting: boolean;
|
||||
args?: string;
|
||||
output?: string;
|
||||
attachments?: TAttachment[];
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const showAnalysisCode = useRecoilValue(store.showCode);
|
||||
const [showCode, setShowCode] = useState(showAnalysisCode);
|
||||
const codeContentRef = useRef<HTMLDivElement>(null);
|
||||
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const hasOutput = output.length > 0;
|
||||
const outputRef = useRef<string>(output);
|
||||
const prevShowCodeRef = useRef<boolean>(showCode);
|
||||
const codeContentRef = useRef<HTMLDivElement>(null);
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const showAnalysisCode = useRecoilValue(store.showCode);
|
||||
const [showCode, setShowCode] = useState(showAnalysisCode);
|
||||
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
|
||||
|
||||
const prevShowCodeRef = useRef<boolean>(showCode);
|
||||
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
|
||||
const progress = useProgress(initialProgress);
|
||||
|
||||
@@ -136,6 +138,8 @@ export default function ExecuteCode({
|
||||
};
|
||||
}, [showCode, isAnimating]);
|
||||
|
||||
const cancelled = !isSubmitting && progress < 1;
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="relative my-2.5 flex size-5 shrink-0 items-center gap-2.5">
|
||||
@@ -143,9 +147,12 @@ export default function ExecuteCode({
|
||||
progress={progress}
|
||||
onClick={() => setShowCode((prev) => !prev)}
|
||||
inProgressText={localize('com_ui_analyzing')}
|
||||
finishedText={localize('com_ui_analyzing_finished')}
|
||||
finishedText={
|
||||
cancelled ? localize('com_ui_cancelled') : localize('com_ui_analyzing_finished')
|
||||
}
|
||||
hasInput={!!code?.length}
|
||||
isExpanded={showCode}
|
||||
error={cancelled}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
|
||||
@@ -2,7 +2,7 @@ import { memo, useMemo, ReactElement } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import MarkdownLite from '~/components/Chat/Messages/Content/MarkdownLite';
|
||||
import Markdown from '~/components/Chat/Messages/Content/Markdown';
|
||||
import { useChatContext, useMessageContext } from '~/Providers';
|
||||
import { useMessageContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -18,14 +18,9 @@ type ContentType =
|
||||
| ReactElement;
|
||||
|
||||
const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => {
|
||||
const { messageId } = useMessageContext();
|
||||
const { isSubmitting, latestMessage } = useChatContext();
|
||||
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
|
||||
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
|
||||
const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]);
|
||||
const isLatestMessage = useMemo(
|
||||
() => messageId === latestMessage?.messageId,
|
||||
[messageId, latestMessage?.messageId],
|
||||
);
|
||||
|
||||
const content: ContentType = useMemo(() => {
|
||||
if (!isCreatedByUser) {
|
||||
|
||||
@@ -21,7 +21,7 @@ type THoverButtons = {
|
||||
latestMessage: TMessage | null;
|
||||
isLast: boolean;
|
||||
index: number;
|
||||
handleFeedback: ({ feedback }: { feedback: TFeedback | undefined }) => void;
|
||||
handleFeedback?: ({ feedback }: { feedback: TFeedback | undefined }) => void;
|
||||
};
|
||||
|
||||
type HoverButtonProps = {
|
||||
@@ -238,7 +238,7 @@ const HoverButtons = ({
|
||||
/>
|
||||
|
||||
{/* Feedback Buttons */}
|
||||
{!isCreatedByUser && (
|
||||
{!isCreatedByUser && handleFeedback != null && (
|
||||
<Feedback handleFeedback={handleFeedback} feedback={message.feedback} isLast={isLast} />
|
||||
)}
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ export default function Message(props: TMessageProps) {
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6 ">
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<MessageRender {...props} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -125,6 +125,7 @@ export default function Message(props: TMessageProps) {
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isCreatedByUser={message.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
isLatestMessage={messageId === latestMessage?.messageId}
|
||||
content={message.content as Array<TMessageContentParts | undefined>}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -4,11 +4,12 @@ import { CSSTransition } from 'react-transition-group';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useScreenshot, useMessageScrolling, useLocalize } from '~/hooks';
|
||||
import ScrollToBottom from '~/components/Messages/ScrollToBottom';
|
||||
import { MessagesViewProvider } from '~/Providers';
|
||||
import MultiMessage from './MultiMessage';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
export default function MessagesView({
|
||||
function MessagesViewContent({
|
||||
messagesTree: _messagesTree,
|
||||
}: {
|
||||
messagesTree?: TMessage[] | null;
|
||||
@@ -92,3 +93,11 @@ export default function MessagesView({
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function MessagesView({ messagesTree }: { messagesTree?: TMessage[] | null }) {
|
||||
return (
|
||||
<MessagesViewProvider>
|
||||
<MessagesViewContent messagesTree={messagesTree} />
|
||||
</MessagesViewProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ export default function MultiMessage({
|
||||
useEffect(() => {
|
||||
// reset siblingIdx when the tree changes, mostly when a new message is submitting.
|
||||
setSiblingIdx(0);
|
||||
}, [messagesTree?.length]);
|
||||
}, [messagesTree?.length, setSiblingIdx]);
|
||||
|
||||
useEffect(() => {
|
||||
if (messagesTree?.length && siblingIdx >= messagesTree.length) {
|
||||
|
||||
@@ -71,6 +71,9 @@ const MessageRender = memo(
|
||||
const showCardRender = isLast && !isSubmittingFamily && isCard;
|
||||
const isLatestCard = isCard && !isSubmittingFamily && isLatestMessage;
|
||||
|
||||
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const iconData: TMessageIcon = useMemo(
|
||||
() => ({
|
||||
endpoint: msg?.endpoint ?? conversation?.endpoint,
|
||||
@@ -166,6 +169,8 @@ const MessageRender = memo(
|
||||
messageId: msg.messageId,
|
||||
conversationId: conversation?.conversationId,
|
||||
isExpanded: false,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isLatestMessage,
|
||||
}}
|
||||
>
|
||||
{msg.plugin && <Plugin plugin={msg.plugin} />}
|
||||
@@ -177,7 +182,7 @@ const MessageRender = memo(
|
||||
message={msg}
|
||||
enterEdit={enterEdit}
|
||||
error={!!(msg.error ?? false)}
|
||||
isSubmitting={isSubmitting}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
unfinished={msg.unfinished ?? false}
|
||||
isCreatedByUser={msg.isCreatedByUser ?? true}
|
||||
siblingIdx={siblingIdx ?? 0}
|
||||
@@ -186,7 +191,7 @@ const MessageRender = memo(
|
||||
</MessageContext.Provider>
|
||||
</div>
|
||||
|
||||
{hasNoChildren && (isSubmittingFamily === true || isSubmitting) ? (
|
||||
{hasNoChildren && (isSubmittingFamily === true || effectiveIsSubmitting) ? (
|
||||
<PlaceholderRow isCard={isCard} />
|
||||
) : (
|
||||
<SubRow classes="text-xs">
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useMemo, memo, type FC, useCallback } from 'react';
|
||||
import throttle from 'lodash/throttle';
|
||||
import { parseISO, isToday } from 'date-fns';
|
||||
import { Spinner, useMediaQuery } from '@librechat/client';
|
||||
import { List, AutoSizer, CellMeasurer, CellMeasurerCache } from 'react-virtualized';
|
||||
import { TConversation } from 'librechat-data-provider';
|
||||
@@ -50,27 +49,17 @@ const MemoizedConvo = memo(
|
||||
conversation,
|
||||
retainView,
|
||||
toggleNav,
|
||||
isLatestConvo,
|
||||
}: {
|
||||
conversation: TConversation;
|
||||
retainView: () => void;
|
||||
toggleNav: () => void;
|
||||
isLatestConvo: boolean;
|
||||
}) => {
|
||||
return (
|
||||
<Convo
|
||||
conversation={conversation}
|
||||
retainView={retainView}
|
||||
toggleNav={toggleNav}
|
||||
isLatestConvo={isLatestConvo}
|
||||
/>
|
||||
);
|
||||
return <Convo conversation={conversation} retainView={retainView} toggleNav={toggleNav} />;
|
||||
},
|
||||
(prevProps, nextProps) => {
|
||||
return (
|
||||
prevProps.conversation.conversationId === nextProps.conversation.conversationId &&
|
||||
prevProps.conversation.title === nextProps.conversation.title &&
|
||||
prevProps.isLatestConvo === nextProps.isLatestConvo &&
|
||||
prevProps.conversation.endpoint === nextProps.conversation.endpoint
|
||||
);
|
||||
},
|
||||
@@ -98,13 +87,6 @@ const Conversations: FC<ConversationsProps> = ({
|
||||
[filteredConversations],
|
||||
);
|
||||
|
||||
const firstTodayConvoId = useMemo(
|
||||
() =>
|
||||
filteredConversations.find((convo) => convo.updatedAt && isToday(parseISO(convo.updatedAt)))
|
||||
?.conversationId ?? undefined,
|
||||
[filteredConversations],
|
||||
);
|
||||
|
||||
const flattenedItems = useMemo(() => {
|
||||
const items: FlattenedItem[] = [];
|
||||
groupedConversations.forEach(([groupName, convos]) => {
|
||||
@@ -154,26 +136,25 @@ const Conversations: FC<ConversationsProps> = ({
|
||||
</CellMeasurer>
|
||||
);
|
||||
}
|
||||
let rendering: JSX.Element;
|
||||
if (item.type === 'header') {
|
||||
rendering = <DateLabel groupName={item.groupName} />;
|
||||
} else if (item.type === 'convo') {
|
||||
rendering = (
|
||||
<MemoizedConvo conversation={item.convo} retainView={moveToTop} toggleNav={toggleNav} />
|
||||
);
|
||||
}
|
||||
return (
|
||||
<CellMeasurer cache={cache} columnIndex={0} key={key} parent={parent} rowIndex={index}>
|
||||
{({ registerChild }) => (
|
||||
<div ref={registerChild} style={style}>
|
||||
{item.type === 'header' ? (
|
||||
<DateLabel groupName={item.groupName} />
|
||||
) : item.type === 'convo' ? (
|
||||
<MemoizedConvo
|
||||
conversation={item.convo}
|
||||
retainView={moveToTop}
|
||||
toggleNav={toggleNav}
|
||||
isLatestConvo={item.convo.conversationId === firstTodayConvoId}
|
||||
/>
|
||||
) : null}
|
||||
{rendering}
|
||||
</div>
|
||||
)}
|
||||
</CellMeasurer>
|
||||
);
|
||||
},
|
||||
[cache, flattenedItems, firstTodayConvoId, moveToTop, toggleNav],
|
||||
[cache, flattenedItems, moveToTop, toggleNav],
|
||||
);
|
||||
|
||||
const getRowHeight = useCallback(
|
||||
|
||||
@@ -11,23 +11,17 @@ import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { NotificationSeverity } from '~/common';
|
||||
import { ConvoOptions } from './ConvoOptions';
|
||||
import RenameForm from './RenameForm';
|
||||
import { cn, logger } from '~/utils';
|
||||
import ConvoLink from './ConvoLink';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
interface ConversationProps {
|
||||
conversation: TConversation;
|
||||
retainView: () => void;
|
||||
toggleNav: () => void;
|
||||
isLatestConvo: boolean;
|
||||
}
|
||||
|
||||
export default function Conversation({
|
||||
conversation,
|
||||
retainView,
|
||||
toggleNav,
|
||||
isLatestConvo,
|
||||
}: ConversationProps) {
|
||||
export default function Conversation({ conversation, retainView, toggleNav }: ConversationProps) {
|
||||
const params = useParams();
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
@@ -84,6 +78,7 @@ export default function Conversation({
|
||||
});
|
||||
setRenaming(false);
|
||||
} catch (error) {
|
||||
logger.error('Error renaming conversation', error);
|
||||
setTitleInput(title as string);
|
||||
showToast({
|
||||
message: localize('com_ui_rename_failed'),
|
||||
|
||||
@@ -173,6 +173,7 @@ const ContentRender = memo(
|
||||
isSubmitting={isSubmitting}
|
||||
searchResults={searchResults}
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isLatestMessage={isLatestMessage}
|
||||
isCreatedByUser={msg.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
content={msg.content as Array<TMessageContentParts | undefined>}
|
||||
|
||||
@@ -76,6 +76,8 @@ export default function Message(props: TMessageProps) {
|
||||
messageId,
|
||||
isExpanded: false,
|
||||
conversationId: conversation?.conversationId,
|
||||
isSubmitting: false, // Share view is always read-only
|
||||
isLatestMessage: false, // No concept of latest message in share view
|
||||
}}
|
||||
>
|
||||
{/* Legacy Plugins */}
|
||||
|
||||
@@ -4,7 +4,6 @@ import { ChevronDown } from 'lucide-react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import * as AccordionPrimitive from '@radix-ui/react-accordion';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Label,
|
||||
Checkbox,
|
||||
@@ -14,20 +13,18 @@ import {
|
||||
AccordionItem,
|
||||
CircleHelpIcon,
|
||||
OGDialogTrigger,
|
||||
useToastContext,
|
||||
AccordionContent,
|
||||
OGDialogTemplate,
|
||||
} from '@librechat/client';
|
||||
import type { AgentForm, MCPServerInfo } from '~/common';
|
||||
import { useLocalize, useMCPServerManager, useRemoveMCPTool } from '~/hooks';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useLocalize, useMCPServerManager } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function MCPTool({ serverInfo }: { serverInfo?: MCPServerInfo }) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { removeTool } = useRemoveMCPTool();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { getServerStatusIconProps, getConfigDialogProps } = useMCPServerManager();
|
||||
|
||||
@@ -56,36 +53,6 @@ export default function MCPTool({ serverInfo }: { serverInfo?: MCPServerInfo })
|
||||
setValue('tools', [...otherTools, ...newSelectedTools]);
|
||||
};
|
||||
|
||||
const removeTool = (serverName: string) => {
|
||||
if (!serverName) {
|
||||
return;
|
||||
}
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({ message: `Error while deleting the tool: ${error}`, status: 'error' });
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds);
|
||||
showToast({ message: 'Tool deleted successfully', status: 'success' });
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const selectedTools = getSelectedTools();
|
||||
const isExpanded = accordionValue === currentServerName;
|
||||
|
||||
|
||||
@@ -1,25 +1,12 @@
|
||||
import React, { useState } from 'react';
|
||||
import { CircleX } from 'lucide-react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Label,
|
||||
OGDialog,
|
||||
TrashIcon,
|
||||
useToastContext,
|
||||
OGDialogTrigger,
|
||||
OGDialogTemplate,
|
||||
} from '@librechat/client';
|
||||
import type { AgentForm } from '~/common';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { Label, OGDialog, TrashIcon, OGDialogTrigger, OGDialogTemplate } from '@librechat/client';
|
||||
import { useLocalize, useRemoveMCPTool } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function UnconfiguredMCPTool({ serverName }: { serverName?: string }) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { removeTool } = useRemoveMCPTool();
|
||||
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
@@ -28,36 +15,6 @@ export default function UnconfiguredMCPTool({ serverName }: { serverName?: strin
|
||||
return null;
|
||||
}
|
||||
|
||||
const removeTool = () => {
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({
|
||||
message: localize('com_ui_delete_tool_error', { error: String(error) }),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds);
|
||||
showToast({ message: localize('com_ui_delete_tool_success'), status: 'success' });
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog>
|
||||
<div
|
||||
@@ -116,7 +73,7 @@ export default function UnconfiguredMCPTool({ serverName }: { serverName?: strin
|
||||
</Label>
|
||||
}
|
||||
selection={{
|
||||
selectHandler: () => removeTool(),
|
||||
selectHandler: () => removeTool(serverName || ''),
|
||||
selectClasses:
|
||||
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white',
|
||||
selectText: localize('com_ui_delete'),
|
||||
|
||||
@@ -1,29 +1,18 @@
|
||||
import React, { useState } from 'react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Label,
|
||||
OGDialog,
|
||||
TrashIcon,
|
||||
OGDialogTrigger,
|
||||
useToastContext,
|
||||
OGDialogTemplate,
|
||||
} from '@librechat/client';
|
||||
import type { AgentForm, MCPServerInfo } from '~/common';
|
||||
import { Label, OGDialog, TrashIcon, OGDialogTrigger, OGDialogTemplate } from '@librechat/client';
|
||||
import type { MCPServerInfo } from '~/common';
|
||||
import { useLocalize, useMCPServerManager, useRemoveMCPTool } from '~/hooks';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useLocalize, useMCPServerManager } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function UninitializedMCPTool({ serverInfo }: { serverInfo?: MCPServerInfo }) {
|
||||
const localize = useLocalize();
|
||||
const { removeTool } = useRemoveMCPTool();
|
||||
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { initializeServer, isInitializing, getServerStatusIconProps, getConfigDialogProps } =
|
||||
useMCPServerManager();
|
||||
|
||||
@@ -31,39 +20,6 @@ export default function UninitializedMCPTool({ serverInfo }: { serverInfo?: MCPS
|
||||
return null;
|
||||
}
|
||||
|
||||
const removeTool = (serverName: string) => {
|
||||
if (!serverName) {
|
||||
return;
|
||||
}
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({
|
||||
message: localize('com_ui_delete_tool_error', { error: String(error) }),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds);
|
||||
showToast({ message: localize('com_ui_delete_tool_success'), status: 'success' });
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const serverName = serverInfo.serverName;
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
const statusIconProps = getServerStatusIconProps(serverName);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useMemo, useCallback } from 'react';
|
||||
import React, { useState, useMemo, useCallback, useEffect } from 'react';
|
||||
import { ChevronLeft, Trash2 } from 'lucide-react';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { Button, useToastContext } from '@librechat/client';
|
||||
@@ -12,6 +12,8 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import MCPPanelSkeleton from './MCPPanelSkeleton';
|
||||
|
||||
const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms
|
||||
|
||||
function MCPPanelContent() {
|
||||
const localize = useLocalize();
|
||||
const queryClient = useQueryClient();
|
||||
@@ -26,6 +28,29 @@ function MCPPanelContent() {
|
||||
null,
|
||||
);
|
||||
|
||||
// Check if any connections are in 'connecting' state
|
||||
const hasConnectingServers = useMemo(() => {
|
||||
if (!connectionStatus) {
|
||||
return false;
|
||||
}
|
||||
return Object.values(connectionStatus).some(
|
||||
(status) => status?.connectionState === 'connecting',
|
||||
);
|
||||
}, [connectionStatus]);
|
||||
|
||||
// Set up polling when servers are connecting
|
||||
useEffect(() => {
|
||||
if (!hasConnectingServers) {
|
||||
return;
|
||||
}
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
|
||||
}, POLL_FOR_CONNECTION_STATUS_INTERVAL);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [hasConnectingServers, queryClient]);
|
||||
|
||||
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
|
||||
onSuccess: async () => {
|
||||
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });
|
||||
|
||||
@@ -0,0 +1,440 @@
|
||||
import { renderHook } from '@testing-library/react';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import useAgentToolPermissions from '../useAgentToolPermissions';
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/data-provider', () => ({
|
||||
useGetAgentByIdQuery: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/Providers', () => ({
|
||||
useAgentsMapContext: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import mocked functions after mocking
|
||||
import { useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
|
||||
describe('useAgentToolPermissions', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Ephemeral Agent Scenarios', () => {
|
||||
it('should return true for all tools when agentId is null', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(null));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return true for all tools when agentId is undefined', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(undefined));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return true for all tools when agentId is empty string', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(''));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return true for all tools when agentId is EPHEMERAL_AGENT_ID', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useAgentToolPermissions(Constants.EPHEMERAL_AGENT_ID)
|
||||
);
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Regular Agent with Tools', () => {
|
||||
it('should allow file_search when agent has the tool', () => {
|
||||
const agentId = 'agent-123';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search, 'other_tool'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual([Tools.file_search, 'other_tool']);
|
||||
});
|
||||
|
||||
it('should allow execute_code when agent has the tool', () => {
|
||||
const agentId = 'agent-456';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.execute_code, 'another_tool'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toEqual([Tools.execute_code, 'another_tool']);
|
||||
});
|
||||
|
||||
it('should allow both tools when agent has both', () => {
|
||||
const agentId = 'agent-789';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search, Tools.execute_code, 'custom_tool'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code, 'custom_tool']);
|
||||
});
|
||||
|
||||
it('should disallow both tools when agent has neither', () => {
|
||||
const agentId = 'agent-no-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: ['custom_tool1', 'custom_tool2'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual(['custom_tool1', 'custom_tool2']);
|
||||
});
|
||||
|
||||
it('should handle agent with empty tools array', () => {
|
||||
const agentId = 'agent-empty-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle agent with undefined tools', () => {
|
||||
const agentId = 'agent-undefined-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Data from Query', () => {
|
||||
it('should prioritize agentData tools over selectedAgent tools', () => {
|
||||
const agentId = 'agent-with-query-data';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: ['old_tool'],
|
||||
};
|
||||
const mockAgentData = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search, Tools.execute_code],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code]);
|
||||
});
|
||||
|
||||
it('should fallback to selectedAgent tools when agentData has no tools', () => {
|
||||
const agentId = 'agent-fallback';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search],
|
||||
};
|
||||
const mockAgentData = {
|
||||
id: agentId,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual([Tools.file_search]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Not Found Scenarios', () => {
|
||||
it('should disallow all tools when agent is not found in map', () => {
|
||||
const agentId = 'non-existent-agent';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should disallow all tools when agentsMap is null', () => {
|
||||
const agentId = 'agent-with-null-map';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue(null);
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should disallow all tools when agentsMap is undefined', () => {
|
||||
const agentId = 'agent-with-undefined-map';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue(undefined);
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memoization and Performance', () => {
|
||||
it('should memoize results when inputs do not change', () => {
|
||||
const agentId = 'memoized-agent';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result, rerender } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
const firstResult = result.current;
|
||||
|
||||
// Rerender without changing inputs
|
||||
rerender();
|
||||
|
||||
const secondResult = result.current;
|
||||
|
||||
// The hook returns a new object each time, but the values should be equal
|
||||
expect(firstResult.fileSearchAllowedByAgent).toBe(secondResult.fileSearchAllowedByAgent);
|
||||
expect(firstResult.codeAllowedByAgent).toBe(secondResult.codeAllowedByAgent);
|
||||
// Tools array reference should be the same since it comes from useMemo
|
||||
expect(firstResult.tools).toBe(secondResult.tools);
|
||||
|
||||
// Verify the actual values are correct
|
||||
expect(secondResult.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(secondResult.codeAllowedByAgent).toBe(false);
|
||||
expect(secondResult.tools).toEqual([Tools.file_search]);
|
||||
});
|
||||
|
||||
it('should recompute when agentId changes', () => {
|
||||
const agentId1 = 'agent-1';
|
||||
const agentId2 = 'agent-2';
|
||||
const mockAgents = {
|
||||
[agentId1]: { id: agentId1, tools: [Tools.file_search] },
|
||||
[agentId2]: { id: agentId2, tools: [Tools.execute_code] },
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue(mockAgents);
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ agentId }) => useAgentToolPermissions(agentId),
|
||||
{ initialProps: { agentId: agentId1 } }
|
||||
);
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
|
||||
// Change agentId
|
||||
rerender({ agentId: agentId2 });
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle switching between ephemeral and regular agents', () => {
|
||||
const regularAgentId = 'regular-agent';
|
||||
const mockAgent = {
|
||||
id: regularAgentId,
|
||||
tools: [],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[regularAgentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ agentId }) => useAgentToolPermissions(agentId),
|
||||
{ initialProps: { agentId: null } }
|
||||
);
|
||||
|
||||
// Start with ephemeral agent (null)
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
|
||||
// Switch to regular agent
|
||||
rerender({ agentId: regularAgentId });
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
|
||||
// Switch back to ephemeral
|
||||
rerender({ agentId: '' });
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle agents with null tools gracefully', () => {
|
||||
const agentId = 'agent-null-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: null as any,
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle whitespace-only agentId as ephemeral', () => {
|
||||
// Note: Based on the current implementation, only empty string is treated as ephemeral
|
||||
// Whitespace-only strings would be treated as regular agent IDs
|
||||
const whitespaceId = ' ';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(whitespaceId));
|
||||
|
||||
// Whitespace ID is not considered ephemeral in current implementation
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle query loading state', () => {
|
||||
const agentId = 'loading-agent';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: true,
|
||||
error: null,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
// During loading, should return false for non-ephemeral agents
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle query error state', () => {
|
||||
const agentId = 'error-agent';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error('Failed to fetch agent'),
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
// On error, should return false for non-ephemeral agents
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useMemo } from 'react';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import { useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
|
||||
@@ -9,6 +9,10 @@ interface AgentToolPermissionsResult {
|
||||
tools: string[] | undefined;
|
||||
}
|
||||
|
||||
function isEphemeralAgent(agentId: string | null | undefined): boolean {
|
||||
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to determine whether specific tools are allowed for a given agent.
|
||||
*
|
||||
@@ -33,8 +37,8 @@ export default function useAgentToolPermissions(
|
||||
);
|
||||
|
||||
const fileSearchAllowedByAgent = useMemo(() => {
|
||||
// If no agentId, allow for ephemeral agents
|
||||
if (!agentId) return true;
|
||||
// Allow for ephemeral agents
|
||||
if (isEphemeralAgent(agentId)) return true;
|
||||
// If agentId exists but agent not found, disallow
|
||||
if (!selectedAgent) return false;
|
||||
// Check if the agent has the file_search tool
|
||||
@@ -42,8 +46,8 @@ export default function useAgentToolPermissions(
|
||||
}, [agentId, selectedAgent, tools]);
|
||||
|
||||
const codeAllowedByAgent = useMemo(() => {
|
||||
// If no agentId, allow for ephemeral agents
|
||||
if (!agentId) return true;
|
||||
// Allow for ephemeral agents
|
||||
if (isEphemeralAgent(agentId)) return true;
|
||||
// If agentId exists but agent not found, disallow
|
||||
if (!selectedAgent) return false;
|
||||
// Check if the agent has the execute_code tool
|
||||
|
||||
495
client/src/hooks/MCP/__tests__/useMCPSelect.test.tsx
Normal file
495
client/src/hooks/MCP/__tests__/useMCPSelect.test.tsx
Normal file
@@ -0,0 +1,495 @@
|
||||
import React from 'react';
|
||||
import { Provider, createStore } from 'jotai';
|
||||
import { renderHook, act, waitFor } from '@testing-library/react';
|
||||
import { RecoilRoot, useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
|
||||
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { setTimestamp } from '~/utils/timestamps';
|
||||
import { useMCPSelect } from '../useMCPSelect';
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/utils/timestamps', () => ({
|
||||
setTimestamp: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('lodash/isEqual', () => jest.fn((a, b) => JSON.stringify(a) === JSON.stringify(b)));
|
||||
|
||||
const createWrapper = () => {
|
||||
// Create a new Jotai store for each test to ensure clean state
|
||||
const store = createStore();
|
||||
|
||||
const Wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => (
|
||||
<RecoilRoot>
|
||||
<Provider store={store}>{children}</Provider>
|
||||
</RecoilRoot>
|
||||
);
|
||||
return Wrapper;
|
||||
};
|
||||
|
||||
describe('useMCPSelect', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
describe('Basic Functionality', () => {
|
||||
it('should initialize with default values', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
expect(result.current.isPinned).toBe(true); // Default value from mcpPinnedAtom is true
|
||||
expect(typeof result.current.setMCPValues).toBe('function');
|
||||
expect(typeof result.current.setIsPinned).toBe('function');
|
||||
});
|
||||
|
||||
it('should use conversationId when provided', () => {
|
||||
const conversationId = 'test-convo-123';
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
|
||||
it('should use NEW_CONVO constant when conversationId is null', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId: null }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('State Updates', () => {
|
||||
it('should update mcpValues when setMCPValues is called', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const newValues = ['value1', 'value2'];
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(newValues);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(newValues);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not update mcpValues if non-array is passed', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
// @ts-ignore - Testing invalid input
|
||||
result.current.setMCPValues('not-an-array');
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
|
||||
it('should update isPinned state', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Default is true
|
||||
expect(result.current.isPinned).toBe(true);
|
||||
|
||||
// Toggle to false
|
||||
act(() => {
|
||||
result.current.setIsPinned(false);
|
||||
});
|
||||
|
||||
expect(result.current.isPinned).toBe(false);
|
||||
|
||||
// Toggle back to true
|
||||
act(() => {
|
||||
result.current.setIsPinned(true);
|
||||
});
|
||||
|
||||
expect(result.current.isPinned).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Timestamp Management', () => {
|
||||
it('should set timestamp when mcpValues is updated with values', async () => {
|
||||
const conversationId = 'test-convo';
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const newValues = ['value1', 'value2'];
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(newValues);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const expectedKey = `${LocalStorageKeys.LAST_MCP_}${conversationId}`;
|
||||
expect(setTimestamp).toHaveBeenCalledWith(expectedKey);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not set timestamp when mcpValues is empty', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues([]);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(setTimestamp).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Race Conditions and Infinite Loops Prevention', () => {
|
||||
it('should not create infinite loop when syncing between Jotai and Recoil states', async () => {
|
||||
const { result, rerender } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
let renderCount = 0;
|
||||
const maxRenders = 10;
|
||||
|
||||
// Track renders to detect infinite loops
|
||||
const trackRender = () => {
|
||||
renderCount++;
|
||||
if (renderCount > maxRenders) {
|
||||
throw new Error('Potential infinite loop detected');
|
||||
}
|
||||
};
|
||||
|
||||
// Set initial value
|
||||
act(() => {
|
||||
trackRender();
|
||||
result.current.setMCPValues(['initial']);
|
||||
});
|
||||
|
||||
// Trigger multiple rerenders
|
||||
for (let i = 0; i < 3; i++) {
|
||||
rerender();
|
||||
trackRender();
|
||||
}
|
||||
|
||||
// Should not exceed max renders
|
||||
expect(renderCount).toBeLessThanOrEqual(maxRenders);
|
||||
});
|
||||
|
||||
it('should handle rapid consecutive updates without race conditions', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const updates = [
|
||||
['value1'],
|
||||
['value1', 'value2'],
|
||||
['value1', 'value2', 'value3'],
|
||||
['value4'],
|
||||
[],
|
||||
];
|
||||
|
||||
// Rapid fire updates
|
||||
act(() => {
|
||||
updates.forEach((update) => {
|
||||
result.current.setMCPValues(update);
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// Should settle on the last update
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
it('should maintain stable setter function reference', () => {
|
||||
const { result, rerender } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const firstSetMCPValues = result.current.setMCPValues;
|
||||
|
||||
// Trigger multiple rerenders
|
||||
rerender();
|
||||
rerender();
|
||||
rerender();
|
||||
|
||||
// Setter should remain the same reference (memoized)
|
||||
expect(result.current.setMCPValues).toBe(firstSetMCPValues);
|
||||
});
|
||||
|
||||
it('should handle switching conversation IDs without issues', async () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ conversationId }) => useMCPSelect({ conversationId }),
|
||||
{
|
||||
wrapper: createWrapper(),
|
||||
initialProps: { conversationId: 'convo1' },
|
||||
},
|
||||
);
|
||||
|
||||
// Set values for first conversation
|
||||
act(() => {
|
||||
result.current.setMCPValues(['convo1-value']);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(['convo1-value']);
|
||||
});
|
||||
|
||||
// Switch to different conversation
|
||||
rerender({ conversationId: 'convo2' });
|
||||
|
||||
// Should have different state for new conversation
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
|
||||
// Set values for second conversation
|
||||
act(() => {
|
||||
result.current.setMCPValues(['convo2-value']);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(['convo2-value']);
|
||||
});
|
||||
|
||||
// Switch back to first conversation
|
||||
rerender({ conversationId: 'convo1' });
|
||||
|
||||
// Should maintain separate state
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(['convo1-value']);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Ephemeral Agent Synchronization', () => {
|
||||
it('should sync mcpValues when ephemeralAgent is updated externally', async () => {
|
||||
// Create a shared wrapper for both hooks to share the same Recoil/Jotai context
|
||||
const wrapper = createWrapper();
|
||||
|
||||
// Create a component that uses both hooks to ensure they share state
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(
|
||||
ephemeralAgentByConvoId(Constants.NEW_CONVO),
|
||||
);
|
||||
return { mcpHook, ephemeralAgent, setEphemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
// Simulate external update to ephemeralAgent (e.g., from another component)
|
||||
const externalMcpValues = ['external-value1', 'external-value2'];
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: externalMcpValues,
|
||||
});
|
||||
});
|
||||
|
||||
// The hook should sync with the external update
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(externalMcpValues);
|
||||
});
|
||||
});
|
||||
|
||||
it('should update ephemeralAgent when mcpValues changes through hook', async () => {
|
||||
// Create a shared wrapper for both hooks
|
||||
const wrapper = createWrapper();
|
||||
|
||||
// Create a component that uses both the hook and accesses Recoil state
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(Constants.NEW_CONVO));
|
||||
return { mcpHook, ephemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
const newValues = ['hook-value1', 'hook-value2'];
|
||||
|
||||
act(() => {
|
||||
result.current.mcpHook.setMCPValues(newValues);
|
||||
});
|
||||
|
||||
// Verify both mcpValues and ephemeralAgent are updated
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(newValues);
|
||||
expect(result.current.ephemeralAgent?.mcp).toEqual(newValues);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty ephemeralAgent.mcp array correctly', async () => {
|
||||
// Create a shared wrapper
|
||||
const wrapper = createWrapper();
|
||||
|
||||
// Create a component that uses both hooks
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
|
||||
return { mcpHook, setEphemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
// Set initial values
|
||||
act(() => {
|
||||
result.current.mcpHook.setMCPValues(['initial-value']);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
|
||||
});
|
||||
|
||||
// Try to set empty array externally
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: [],
|
||||
});
|
||||
});
|
||||
|
||||
// Values should remain unchanged since empty mcp array doesn't trigger update
|
||||
// (due to the condition: ephemeralAgent?.mcp && ephemeralAgent.mcp.length > 0)
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
|
||||
});
|
||||
|
||||
it('should properly sync non-empty arrays from ephemeralAgent', async () => {
|
||||
// Additional test to ensure non-empty arrays DO sync
|
||||
const wrapper = createWrapper();
|
||||
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
|
||||
return { mcpHook, setEphemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
// Set initial values through ephemeralAgent with non-empty array
|
||||
const initialValues = ['value1', 'value2'];
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: initialValues,
|
||||
});
|
||||
});
|
||||
|
||||
// Should sync since it's non-empty
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(initialValues);
|
||||
});
|
||||
|
||||
// Update with different non-empty values
|
||||
const updatedValues = ['value3', 'value4', 'value5'];
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: updatedValues,
|
||||
});
|
||||
});
|
||||
|
||||
// Should sync the new values
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(updatedValues);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined conversationId', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId: undefined }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(['test']);
|
||||
});
|
||||
|
||||
expect(() => result.current).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle empty string conversationId', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId: '' }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle very large arrays without performance issues', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const largeArray = Array.from({ length: 1000 }, (_, i) => `value-${i}`);
|
||||
|
||||
const startTime = performance.now();
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(largeArray);
|
||||
});
|
||||
|
||||
const endTime = performance.now();
|
||||
const executionTime = endTime - startTime;
|
||||
|
||||
// Should complete within reasonable time (< 100ms)
|
||||
expect(executionTime).toBeLessThan(100);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(largeArray);
|
||||
});
|
||||
});
|
||||
|
||||
it('should cleanup properly on unmount', () => {
|
||||
const { unmount } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Should unmount without errors
|
||||
expect(() => unmount()).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memory Leak Prevention', () => {
|
||||
it('should not leak memory on repeated updates', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Perform many updates to test for memory leaks
|
||||
for (let i = 0; i < 100; i++) {
|
||||
act(() => {
|
||||
result.current.setMCPValues([`value-${i}`]);
|
||||
});
|
||||
}
|
||||
|
||||
// If we get here without crashing, memory management is likely OK
|
||||
expect(result.current.mcpValues).toEqual(['value-99']);
|
||||
});
|
||||
|
||||
it('should handle component remounting', () => {
|
||||
const { result, unmount } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(['before-unmount']);
|
||||
});
|
||||
|
||||
unmount();
|
||||
|
||||
// Remount
|
||||
const { result: newResult } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Should handle remounting gracefully
|
||||
expect(newResult.current.mcpValues).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,3 +3,4 @@ export * from './useMCPConnectionStatus';
|
||||
export * from './useMCPSelect';
|
||||
export * from './useVisibleTools';
|
||||
export { useMCPServerManager } from './useMCPServerManager';
|
||||
export { useRemoveMCPTool } from './useRemoveMCPTool';
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { useAtom } from 'jotai';
|
||||
import isEqual from 'lodash/isEqual';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { ephemeralAgentByConvoId, mcpValuesAtomFamily, mcpPinnedAtom } from '~/store';
|
||||
@@ -19,15 +20,14 @@ export function useMCPSelect({ conversationId }: { conversationId?: string | nul
|
||||
}
|
||||
}, [ephemeralAgent?.mcp, setMCPValuesRaw]);
|
||||
|
||||
// Update ephemeral agent when Jotai state changes
|
||||
useEffect(() => {
|
||||
if (mcpValues.length > 0 && JSON.stringify(mcpValues) !== JSON.stringify(ephemeralAgent?.mcp)) {
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
mcp: mcpValues,
|
||||
}));
|
||||
}
|
||||
}, [mcpValues, ephemeralAgent?.mcp, setEphemeralAgent]);
|
||||
setEphemeralAgent((prev) => {
|
||||
if (!isEqual(prev?.mcp, mcpValues)) {
|
||||
return { ...(prev ?? {}), mcp: mcpValues };
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
}, [mcpValues, setEphemeralAgent]);
|
||||
|
||||
useEffect(() => {
|
||||
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${key}`;
|
||||
|
||||
61
client/src/hooks/MCP/useRemoveMCPTool.ts
Normal file
61
client/src/hooks/MCP/useRemoveMCPTool.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import { useCallback } from 'react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useToastContext } from '@librechat/client';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import type { AgentForm } from '~/common';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
/**
|
||||
* Hook for removing MCP tools/servers from an agent
|
||||
* Provides unified logic for MCPTool, UninitializedMCPTool, and UnconfiguredMCPTool components
|
||||
*/
|
||||
export function useRemoveMCPTool() {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
|
||||
const removeTool = useCallback(
|
||||
(serverName: string) => {
|
||||
if (!serverName) {
|
||||
return;
|
||||
}
|
||||
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({
|
||||
message: localize('com_ui_delete_tool_error', { error: String(error) }),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds, { shouldDirty: true });
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_delete_tool_save_reminder'),
|
||||
status: 'warning',
|
||||
});
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
[getValues, setValue, updateUserPlugins, showToast, localize],
|
||||
);
|
||||
|
||||
return { removeTool };
|
||||
}
|
||||
@@ -2,7 +2,7 @@ import throttle from 'lodash/throttle';
|
||||
import { useEffect, useRef, useCallback, useMemo } from 'react';
|
||||
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
|
||||
import { useMessagesViewContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
|
||||
import useCopyToClipboard from './useCopyToClipboard';
|
||||
import { getTextKey, logger } from '~/utils';
|
||||
|
||||
@@ -20,9 +20,9 @@ export default function useMessageHelpers(props: TMessageProps) {
|
||||
setAbortScroll,
|
||||
handleContinue,
|
||||
setLatestMessage,
|
||||
} = useChatContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
} = useMessagesViewContext();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
|
||||
const { text, content, children, messageId = null, isCreatedByUser } = message ?? {};
|
||||
const edit = messageId === currentEditId;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useRecoilValue } from 'recoil';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useEffect, useRef, useCallback, useMemo, useState } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { useMessagesViewContext } from '~/Providers';
|
||||
import { getTextKey, logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -18,14 +18,9 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu
|
||||
latestMessage,
|
||||
setAbortScroll,
|
||||
setLatestMessage,
|
||||
isSubmitting: isSubmittingRoot,
|
||||
} = useChatContext();
|
||||
const { isSubmitting: isSubmittingAdditional } = useAddedChatContext();
|
||||
isSubmittingFamily,
|
||||
} = useMessagesViewContext();
|
||||
const latestMultiMessage = useRecoilValue(store.latestMessageFamily(index + 1));
|
||||
const isSubmittingFamily = useMemo(
|
||||
() => isSubmittingRoot || isSubmittingAdditional,
|
||||
[isSubmittingRoot, isSubmittingAdditional],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const convoId = conversation?.conversationId;
|
||||
|
||||
@@ -2,8 +2,8 @@ import { useRecoilValue } from 'recoil';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useMessagesConversation, useMessagesSubmission } from '~/Providers';
|
||||
import useScrollToRef from '~/hooks/useScrollToRef';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import store from '~/store';
|
||||
|
||||
const threshold = 0.85;
|
||||
@@ -15,8 +15,8 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
|
||||
const scrollableRef = useRef<HTMLDivElement | null>(null);
|
||||
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
||||
const [showScrollButton, setShowScrollButton] = useState(false);
|
||||
const { conversation, setAbortScroll, isSubmitting, abortScroll } = useChatContext();
|
||||
const { conversationId } = conversation ?? {};
|
||||
const { conversation, conversationId } = useMessagesConversation();
|
||||
const { setAbortScroll, isSubmitting, abortScroll } = useMessagesSubmission();
|
||||
|
||||
const timeoutIdRef = useRef<NodeJS.Timeout>();
|
||||
|
||||
|
||||
@@ -833,7 +833,6 @@
|
||||
"com_ui_delete_tool": "Werkzeug löschen",
|
||||
"com_ui_delete_tool_confirm": "Bist du sicher, dass du dieses Werkzeug löschen möchtest?",
|
||||
"com_ui_delete_tool_error": "Fehler beim Löschen des Tools: {{error}}",
|
||||
"com_ui_delete_tool_success": "Tool erfolgreich gelöscht",
|
||||
"com_ui_deleted": "Gelöscht",
|
||||
"com_ui_deleting_file": "Lösche Datei...",
|
||||
"com_ui_descending": "Absteigend",
|
||||
|
||||
@@ -834,7 +834,7 @@
|
||||
"com_ui_delete_tool": "Delete Tool",
|
||||
"com_ui_delete_tool_confirm": "Are you sure you want to delete this tool?",
|
||||
"com_ui_delete_tool_error": "Error while deleting the tool: {{error}}",
|
||||
"com_ui_delete_tool_success": "Tool deleted successfully",
|
||||
"com_ui_delete_tool_save_reminder": "Tool removed. Save the agent to apply changes.",
|
||||
"com_ui_deleted": "Deleted",
|
||||
"com_ui_deleting_file": "Deleting file...",
|
||||
"com_ui_descending": "Desc",
|
||||
|
||||
@@ -337,7 +337,7 @@
|
||||
"com_endpoint_prompt_prefix_assistants": "Papildu instrukcijas",
|
||||
"com_endpoint_prompt_prefix_assistants_placeholder": "Iestatiet papildu norādījumus vai kontekstu virs Asistenta galvenajiem norādījumiem. Ja lauks ir tukšs, tas tiek ignorēts.",
|
||||
"com_endpoint_prompt_prefix_placeholder": "Iestatiet pielāgotas instrukcijas vai kontekstu. Ja lauks ir tukšs, tas tiek ignorēts.",
|
||||
"com_endpoint_reasoning_effort": "Spriešanas piepūle",
|
||||
"com_endpoint_reasoning_effort": "Spriešanas līmenis",
|
||||
"com_endpoint_reasoning_summary": "Spriešanas kopsavilkums",
|
||||
"com_endpoint_save_as_preset": "Saglabāt kā iestatījumu",
|
||||
"com_endpoint_search": "Meklēt galapunktu pēc nosaukuma",
|
||||
@@ -834,7 +834,7 @@
|
||||
"com_ui_delete_tool": "Dzēst rīku",
|
||||
"com_ui_delete_tool_confirm": "Vai tiešām vēlaties dzēst šo rīku?",
|
||||
"com_ui_delete_tool_error": "Kļūda, dzēšot rīku: {{error}}",
|
||||
"com_ui_delete_tool_success": "Rīks veiksmīgi izdzēsts",
|
||||
"com_ui_delete_tool_save_reminder": "Rīks noņemts. Saglabājiet aģentu, lai piemērotu izmaiņas.",
|
||||
"com_ui_deleted": "Dzēsts",
|
||||
"com_ui_deleting_file": "Dzēšu failu...",
|
||||
"com_ui_descending": "Dilstošs",
|
||||
@@ -1012,7 +1012,7 @@
|
||||
"com_ui_memory_would_exceed": "Nevar saglabāt - pārsniegtu tokenu limitu par {{tokens}}. Izdzēsiet esošās atmiņas, lai atbrīvotu vietu.",
|
||||
"com_ui_mention": "Pieminiet galapunktu, assistentu vai iestatījumu, lai ātri uz to pārslēgtos",
|
||||
"com_ui_min_tags": "Nevar noņemt vairāk vērtību, vismaz {{0}} ir nepieciešamas.",
|
||||
"com_ui_minimal": "Minimāla",
|
||||
"com_ui_minimal": "Minimāls",
|
||||
"com_ui_misc": "Dažādi",
|
||||
"com_ui_model": "Modelis",
|
||||
"com_ui_model_parameters": "Modeļa Parametrus",
|
||||
@@ -1035,7 +1035,7 @@
|
||||
"com_ui_no_results_found": "Nav atrastu rezultātu",
|
||||
"com_ui_no_terms_content": "Nav noteikumu un nosacījumu satura, ko parādīt",
|
||||
"com_ui_no_valid_items": "Nav rezultātu",
|
||||
"com_ui_none": "Neviens",
|
||||
"com_ui_none": "Nekāds",
|
||||
"com_ui_not_used": "Nav izmantots",
|
||||
"com_ui_nothing_found": "Nekas nav atrasts",
|
||||
"com_ui_oauth": "OAuth",
|
||||
|
||||
@@ -834,7 +834,6 @@
|
||||
"com_ui_delete_tool": "Slett verktøy",
|
||||
"com_ui_delete_tool_confirm": "Er du sikker på at du vil slette dette verktøyet?",
|
||||
"com_ui_delete_tool_error": "En feil oppstod ved sletting av verktøyet: {{error}}",
|
||||
"com_ui_delete_tool_success": "Verktøyet ble slettet",
|
||||
"com_ui_deleted": "Slettet",
|
||||
"com_ui_deleting_file": "Sletter fil ...",
|
||||
"com_ui_descending": "Synkende",
|
||||
|
||||
@@ -27,6 +27,13 @@
|
||||
"com_agents_file_search_disabled": "Maak eerst een Agent aan voordat je bestanden uploadt voor File Search.",
|
||||
"com_agents_file_search_info": "Als deze functie is ingeschakeld, krijgt de agent informatie over de exacte bestandsnamen die hieronder staan vermeld, zodat deze relevante context uit deze bestanden kan ophalen.",
|
||||
"com_agents_instructions_placeholder": "De systeeminstructies die de agent gebruikt",
|
||||
"com_agents_link_copied": "Link gekopieerd",
|
||||
"com_agents_link_copy_failed": "Link niet gekopieerd",
|
||||
"com_agents_load_more_label": "Laad meer agenten van {{category}} categorie",
|
||||
"com_agents_loading": "Aan het laden...",
|
||||
"com_agents_mcp_icon_size": "Minimum formaat 128 x 128 px",
|
||||
"com_agents_mcp_info": "MCP-servers toevoegen aan je agent zodat deze taken kan uitvoeren en kan communiceren met externe services",
|
||||
"com_agents_mcp_name_placeholder": "Aangepast hulpmiddel",
|
||||
"com_agents_missing_provider_model": "Selecteer een provider en model voordat je een agent aanmaakt.",
|
||||
"com_agents_name_placeholder": "De naam van de agent",
|
||||
"com_agents_no_access": "Je hebt geen toegang om deze agent te bewerken.",
|
||||
|
||||
@@ -834,7 +834,6 @@
|
||||
"com_ui_delete_tool": "Видалити інструмент",
|
||||
"com_ui_delete_tool_confirm": "Ви дійсно хочете видалити цей інструмент?",
|
||||
"com_ui_delete_tool_error": "Помилка під час видалення інструменту: {{error}}",
|
||||
"com_ui_delete_tool_success": "Інструмент успішно видалено",
|
||||
"com_ui_deleted": "Видалено",
|
||||
"com_ui_deleting_file": "Видалення файлу...",
|
||||
"com_ui_descending": "За спаданням",
|
||||
|
||||
@@ -834,7 +834,6 @@
|
||||
"com_ui_delete_tool": "删除工具",
|
||||
"com_ui_delete_tool_confirm": "您确定要删除此工具吗?",
|
||||
"com_ui_delete_tool_error": "删除工具时发生错误:{{error}}",
|
||||
"com_ui_delete_tool_success": "工具删除成功",
|
||||
"com_ui_deleted": "已删除",
|
||||
"com_ui_deleting_file": "删除文件中...",
|
||||
"com_ui_descending": "降序",
|
||||
|
||||
@@ -49,7 +49,7 @@ export default defineConfig(({ command }) => ({
|
||||
],
|
||||
globIgnores: ['images/**/*', '**/*.map', 'index.html'],
|
||||
maximumFileSizeToCacheInBytes: 4 * 1024 * 1024,
|
||||
navigateFallbackDenylist: [/^\/oauth/, /^\/api/],
|
||||
navigateFallbackDenylist: [/^\/oauth/, /^\/api/, /^\/admin\/openid/],
|
||||
},
|
||||
includeAssets: [],
|
||||
manifest: {
|
||||
|
||||
@@ -10,6 +10,9 @@ jest.mock('form-data', () => {
|
||||
getLength: jest.fn().mockReturnValue(100),
|
||||
}));
|
||||
});
|
||||
jest.mock('https-proxy-agent', () => ({
|
||||
HttpsProxyAgent: jest.fn().mockImplementation((url) => ({ proxyUrl: url })),
|
||||
}));
|
||||
jest.mock('axios', () => {
|
||||
const mockAxiosInstance = {
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
@@ -44,6 +47,7 @@ jest.mock('~/utils/axios', () => ({
|
||||
|
||||
import * as fs from 'fs';
|
||||
import axios from 'axios';
|
||||
import { HttpsProxyAgent } from 'https-proxy-agent';
|
||||
import type { Readable } from 'stream';
|
||||
import type {
|
||||
MistralFileUploadResponse,
|
||||
@@ -1182,6 +1186,8 @@ describe('MistralOCR Service', () => {
|
||||
|
||||
describe('Mixed env var and hardcoded configuration', () => {
|
||||
beforeEach(() => {
|
||||
// Clean up any PROXY env var from previous tests
|
||||
delete process.env.PROXY;
|
||||
const mockReadStream: MockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (
|
||||
this: MockReadStream,
|
||||
@@ -1708,9 +1714,403 @@ describe('MistralOCR Service', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Proxy Configuration', () => {
|
||||
const originalProxy = process.env.PROXY;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset the HttpsProxyAgent mock to its default implementation
|
||||
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
|
||||
// Clear any previous axios mock calls
|
||||
mockAxios.post!.mockClear();
|
||||
mockAxios.get!.mockClear();
|
||||
mockAxios.delete!.mockClear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (originalProxy) {
|
||||
process.env.PROXY = originalProxy;
|
||||
} else {
|
||||
delete process.env.PROXY;
|
||||
}
|
||||
// Clear mocks after each test to prevent leaking
|
||||
mockAxios.post!.mockClear();
|
||||
mockAxios.get!.mockClear();
|
||||
mockAxios.delete!.mockClear();
|
||||
});
|
||||
|
||||
describe('uploadDocumentToMistral with proxy', () => {
|
||||
beforeEach(() => {
|
||||
const mockReadStream: MockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (
|
||||
this: MockReadStream,
|
||||
event: string,
|
||||
handler: () => void,
|
||||
) {
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function (this: MockReadStream) {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
path: '/path/to/test.pdf',
|
||||
fd: 1,
|
||||
flags: 'r',
|
||||
mode: 0o666,
|
||||
autoClose: true,
|
||||
bytesRead: 0,
|
||||
closed: false,
|
||||
pending: false,
|
||||
};
|
||||
|
||||
(jest.mocked(fs).createReadStream as jest.Mock).mockReturnValue(mockReadStream);
|
||||
});
|
||||
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-proxy-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://proxy.example.com:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle proxy URL with authentication', async () => {
|
||||
process.env.PROXY = 'http://user:pass@proxy.example.com:8080';
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-proxy-auth-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://user:pass@proxy.example.com:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle IPv6 proxy addresses', async () => {
|
||||
process.env.PROXY = 'http://[::1]:8080';
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-proxy-ipv6-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://[::1]:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not use proxy when PROXY env var is not set', async () => {
|
||||
delete process.env.PROXY;
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-no-proxy-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.not.objectContaining({
|
||||
httpsAgent: expect.anything(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performOCR with proxy', () => {
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'http://proxy.example.com:3128';
|
||||
|
||||
const mockResponse: { data: OCRResult } = {
|
||||
data: {
|
||||
model: 'mistral-ocr-latest',
|
||||
pages: [
|
||||
{
|
||||
index: 0,
|
||||
markdown: 'Proxy test content',
|
||||
images: [],
|
||||
dimensions: { dpi: 300, height: 1100, width: 850 },
|
||||
},
|
||||
],
|
||||
document_annotation: '',
|
||||
usage_info: {
|
||||
pages_processed: 1,
|
||||
doc_size_bytes: 1024,
|
||||
},
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
model: 'mistral-ocr-latest',
|
||||
documentType: 'document_url',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://proxy.example.com:3128',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle malformed proxy URLs gracefully', async () => {
|
||||
(HttpsProxyAgent as unknown as jest.Mock).mockImplementationOnce(() => {
|
||||
throw new Error('Invalid URL');
|
||||
});
|
||||
process.env.PROXY = 'not-a-valid-url';
|
||||
|
||||
const mockResponse: { data: OCRResult } = {
|
||||
data: {
|
||||
model: 'mistral-ocr-latest',
|
||||
pages: [
|
||||
{
|
||||
index: 0,
|
||||
markdown: 'Test content',
|
||||
images: [],
|
||||
dimensions: { dpi: 300, height: 1100, width: 850 },
|
||||
},
|
||||
],
|
||||
document_annotation: '',
|
||||
usage_info: {
|
||||
pages_processed: 1,
|
||||
doc_size_bytes: 1024,
|
||||
},
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await expect(
|
||||
performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
}),
|
||||
).rejects.toThrow('Invalid URL');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Azure Mistral OCR with proxy', () => {
|
||||
beforeEach(() => {
|
||||
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(
|
||||
Buffer.from('mock-file-content'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use proxy for Azure Mistral OCR requests', async () => {
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
|
||||
mockLoadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'azure-api-key',
|
||||
OCR_BASEURL: 'https://azure.mistral.ai/v1',
|
||||
});
|
||||
|
||||
mockAxios.post!.mockResolvedValueOnce({
|
||||
data: {
|
||||
model: 'mistral-ocr-latest',
|
||||
pages: [
|
||||
{
|
||||
index: 0,
|
||||
markdown: 'Azure OCR with proxy',
|
||||
images: [],
|
||||
dimensions: { dpi: 300, height: 1100, width: 850 },
|
||||
},
|
||||
],
|
||||
document_annotation: '',
|
||||
usage_info: {
|
||||
pages_processed: 1,
|
||||
doc_size_bytes: 1024,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
config: {
|
||||
ocr: {
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-ocr-latest',
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/azure-file.pdf',
|
||||
originalname: 'azure-document.pdf',
|
||||
mimetype: 'application/pdf',
|
||||
} as Express.Multer.File;
|
||||
|
||||
await uploadAzureMistralOCR({
|
||||
req,
|
||||
file,
|
||||
loadAuthValues: mockLoadAuthValues,
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://azure.mistral.ai/v1/ocr',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://proxy.example.com:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSignedUrl with proxy', () => {
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'https://secure-proxy.example.com:443';
|
||||
|
||||
const mockResponse: { data: MistralSignedUrlResponse } = {
|
||||
data: {
|
||||
url: 'https://signed-url.com',
|
||||
expires_at: Date.now() + 86400000,
|
||||
},
|
||||
};
|
||||
mockAxios.get!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.get).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'https://secure-proxy.example.com:443',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMistralFile with proxy', () => {
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'socks5://proxy.example.com:1080';
|
||||
|
||||
mockAxios.delete!.mockResolvedValueOnce({ data: {} });
|
||||
|
||||
await deleteMistralFile({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.delete).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123',
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'socks5://proxy.example.com:1080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('uploadAzureMistralOCR', () => {
|
||||
beforeEach(() => {
|
||||
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(Buffer.from('mock-file-content'));
|
||||
// Reset the HttpsProxyAgent mock to its default implementation for Azure tests
|
||||
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
|
||||
// Clean up any PROXY env var from previous tests
|
||||
delete process.env.PROXY;
|
||||
// Reset axios mocks completely to clear any queued responses
|
||||
mockAxios.post!.mockReset();
|
||||
mockAxios.get!.mockReset();
|
||||
mockAxios.delete!.mockReset();
|
||||
// Re-establish default resolved values
|
||||
mockAxios.post!.mockResolvedValue({ data: {} });
|
||||
mockAxios.get!.mockResolvedValue({ data: {} });
|
||||
mockAxios.delete!.mockResolvedValue({ data: {} });
|
||||
});
|
||||
|
||||
it('should process OCR using Azure Mistral with base64 encoding', async () => {
|
||||
@@ -1796,6 +2196,11 @@ describe('MistralOCR Service', () => {
|
||||
});
|
||||
|
||||
describe('Mixed env var and hardcoded configuration', () => {
|
||||
beforeEach(() => {
|
||||
// Clean up any PROXY env var from previous tests
|
||||
delete process.env.PROXY;
|
||||
});
|
||||
|
||||
it('should preserve hardcoded baseURL when only apiKey is an env var', async () => {
|
||||
// This test demonstrates the current bug
|
||||
mockLoadAuthValues.mockResolvedValue({
|
||||
|
||||
@@ -2,6 +2,7 @@ import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import FormData from 'form-data';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { HttpsProxyAgent } from 'https-proxy-agent';
|
||||
import {
|
||||
FileSources,
|
||||
envVarRegex,
|
||||
@@ -9,7 +10,7 @@ import {
|
||||
extractVariableName,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TCustomConfig } from 'librechat-data-provider';
|
||||
import type { AxiosError } from 'axios';
|
||||
import type { AxiosError, AxiosRequestConfig } from 'axios';
|
||||
import type {
|
||||
MistralFileUploadResponse,
|
||||
MistralSignedUrlResponse,
|
||||
@@ -77,15 +78,21 @@ export async function uploadDocumentToMistral({
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.post(`${baseURL}/files`, form, config)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
throw error;
|
||||
@@ -103,12 +110,18 @@ export async function getSignedUrl({
|
||||
expiry?: number;
|
||||
baseURL?: string;
|
||||
}): Promise<MistralSignedUrlResponse> {
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, config)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
@@ -139,6 +152,18 @@ export async function performOCR({
|
||||
documentType?: 'document_url' | 'image_url';
|
||||
}): Promise<OCRResult> {
|
||||
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
|
||||
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
@@ -151,12 +176,7 @@ export async function performOCR({
|
||||
[documentKey]: url,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
config,
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
@@ -182,12 +202,18 @@ export async function deleteMistralFile({
|
||||
apiKey: string;
|
||||
baseURL?: string;
|
||||
}): Promise<void> {
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await axios.delete(`${baseURL}/files/${fileId}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
});
|
||||
const result = await axios.delete(`${baseURL}/files/${fileId}`, config);
|
||||
logger.debug(`Mistral file ${fileId} deleted successfully:`, result.data);
|
||||
} catch (error) {
|
||||
logger.error(`Error deleting Mistral file ${fileId}:`, error);
|
||||
@@ -543,17 +569,23 @@ async function createJWT(serviceKey: GoogleServiceAccount): Promise<string> {
|
||||
* Exchanges JWT for access token
|
||||
*/
|
||||
async function exchangeJWTForAccessToken(jwt: string): Promise<string> {
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const response = await axios.post(
|
||||
'https://oauth2.googleapis.com/token',
|
||||
new URLSearchParams({
|
||||
grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
assertion: jwt,
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
},
|
||||
config,
|
||||
);
|
||||
|
||||
if (!response.data?.access_token) {
|
||||
@@ -608,14 +640,20 @@ async function performGoogleVertexOCR({
|
||||
},
|
||||
});
|
||||
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.post(baseURL, requestBody, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
})
|
||||
.post(baseURL, requestBody, config)
|
||||
.then((res) => {
|
||||
logger.debug('Google Vertex AI response received');
|
||||
return res.data;
|
||||
|
||||
@@ -14,6 +14,7 @@ export * from './utils';
|
||||
export * from './db/utils';
|
||||
/* OAuth */
|
||||
export * from './oauth';
|
||||
export * from './mcp/oauth/OAuthReconnectionManager';
|
||||
/* Crypto */
|
||||
export * from './crypto';
|
||||
/* Flow */
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { FlowMetadata } from '~/flow/types';
|
||||
import type * as t from './types';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
|
||||
@@ -308,7 +309,9 @@ export class MCPConnectionFactory {
|
||||
metadata?: OAuthMetadata;
|
||||
} | null> {
|
||||
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
|
||||
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
|
||||
);
|
||||
|
||||
if (!this.flowManager || !serverUrl) {
|
||||
logger.error(
|
||||
|
||||
@@ -7,6 +7,7 @@ import type { JsonSchemaType } from '~/types';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
import { CONSTANTS } from '~/mcp/enum';
|
||||
|
||||
@@ -183,7 +184,7 @@ export class MCPServersRegistry {
|
||||
const prefix = this.prefix(serverName);
|
||||
const config = this.parsedConfigs[serverName];
|
||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||
logger.info(`${prefix} URL: ${config.url}`);
|
||||
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
|
||||
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { EventEmitter } from 'events';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { fetch as undiciFetch, Agent } from 'undici';
|
||||
import {
|
||||
StdioClientTransport,
|
||||
getDefaultEnvironment,
|
||||
} from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type {
|
||||
@@ -18,8 +18,9 @@ import type {
|
||||
Response as UndiciResponse,
|
||||
} from 'undici';
|
||||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
import type * as t from './types';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
|
||||
|
||||
@@ -238,7 +239,9 @@ export class MCPConnection extends EventEmitter {
|
||||
}
|
||||
this.url = options.url;
|
||||
const url = new URL(options.url);
|
||||
logger.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Creating SSE transport: ${sanitizeUrlForLogging(url)}`,
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
|
||||
/** Add OAuth token to headers if available */
|
||||
@@ -293,7 +296,7 @@ export class MCPConnection extends EventEmitter {
|
||||
this.url = options.url;
|
||||
const url = new URL(options.url);
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
|
||||
`${this.getLogPrefix()} Creating streamable-http transport: ${sanitizeUrlForLogging(url)}`,
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
|
||||
@@ -473,7 +476,9 @@ export class MCPConnection extends EventEmitter {
|
||||
logger.warn(`${this.getLogPrefix()} OAuth authentication required`);
|
||||
this.oauthRequired = true;
|
||||
const serverUrl = this.url;
|
||||
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
|
||||
);
|
||||
|
||||
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
|
||||
/** Promise that will resolve when OAuth is handled */
|
||||
|
||||
294
packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts
Normal file
294
packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts
Normal file
@@ -0,0 +1,294 @@
|
||||
import { TokenMethods } from '@librechat/data-schemas';
|
||||
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
|
||||
import { MCPManager } from '../MCPManager';
|
||||
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('../MCPManager');
|
||||
|
||||
describe('OAuthReconnectionManager', () => {
|
||||
let flowManager: jest.Mocked<FlowStateManager<null>>;
|
||||
let tokenMethods: jest.Mocked<TokenMethods>;
|
||||
let mockMCPManager: jest.Mocked<MCPManager>;
|
||||
let reconnectionManager: OAuthReconnectionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Reset singleton instance
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(OAuthReconnectionManager as any).instance = null;
|
||||
|
||||
// Setup mock flow manager
|
||||
flowManager = {
|
||||
createFlow: jest.fn(),
|
||||
completeFlow: jest.fn(),
|
||||
failFlow: jest.fn(),
|
||||
deleteFlow: jest.fn(),
|
||||
getFlow: jest.fn(),
|
||||
} as unknown as jest.Mocked<FlowStateManager<null>>;
|
||||
|
||||
// Setup mock token methods
|
||||
tokenMethods = {
|
||||
findToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
updateToken: jest.fn(),
|
||||
deleteToken: jest.fn(),
|
||||
} as unknown as jest.Mocked<TokenMethods>;
|
||||
|
||||
// Setup mock MCP Manager
|
||||
mockMCPManager = {
|
||||
getOAuthServers: jest.fn(),
|
||||
getUserConnection: jest.fn(),
|
||||
getUserConnections: jest.fn(),
|
||||
disconnectUserConnection: jest.fn(),
|
||||
getRawConfig: jest.fn(),
|
||||
} as unknown as jest.Mocked<MCPManager>;
|
||||
|
||||
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Singleton Pattern', () => {
|
||||
it('should create instance successfully', async () => {
|
||||
const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
|
||||
expect(instance).toBeInstanceOf(OAuthReconnectionManager);
|
||||
});
|
||||
|
||||
it('should throw error when creating instance twice', async () => {
|
||||
await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
|
||||
await expect(
|
||||
OAuthReconnectionManager.createInstance(flowManager, tokenMethods),
|
||||
).rejects.toThrow('OAuthReconnectionManager already initialized');
|
||||
});
|
||||
|
||||
it('should throw error when getting instance before creation', () => {
|
||||
expect(() => OAuthReconnectionManager.getInstance()).toThrow(
|
||||
'OAuthReconnectionManager not initialized',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isReconnecting', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true when server is actively reconnecting', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'test-server';
|
||||
|
||||
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
|
||||
|
||||
reconnectionTracker.setActive(userId, serverName);
|
||||
const result = reconnectionManager.isReconnecting(userId, serverName);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false when server is not reconnecting', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'test-server';
|
||||
|
||||
const result = reconnectionManager.isReconnecting(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearReconnection', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear both failed and active reconnection states', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'test-server';
|
||||
|
||||
reconnectionTracker.setFailed(userId, serverName);
|
||||
reconnectionTracker.setActive(userId, serverName);
|
||||
|
||||
reconnectionManager.clearReconnection(userId, serverName);
|
||||
|
||||
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
|
||||
expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false);
|
||||
expect(reconnectionTracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnectServers', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should reconnect eligible servers', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has failed reconnection
|
||||
reconnectionTracker.setFailed(userId, 'server1');
|
||||
|
||||
// server2: already connected
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const userConnections = new Map([['server2', mockConnection]]);
|
||||
mockMCPManager.getUserConnections.mockReturnValue(
|
||||
userConnections as unknown as Map<string, MCPConnection>,
|
||||
);
|
||||
|
||||
// server3: has valid token and not connected
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server3') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
// Mock successful reconnection
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Verify server3 was marked as active
|
||||
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true);
|
||||
|
||||
// Wait for async tryReconnect to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify reconnection was attempted for server3
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({
|
||||
serverName: 'server3',
|
||||
user: { id: userId },
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
forceNew: false,
|
||||
connectionTimeout: 5000,
|
||||
returnOnOAuth: true,
|
||||
});
|
||||
|
||||
// Verify successful reconnection cleared the states
|
||||
expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false);
|
||||
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle failed reconnection attempts', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has valid token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
} as unknown as MCPOAuthTokens);
|
||||
|
||||
// Mock failed connection
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Wait for async tryReconnect to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify failure handling
|
||||
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
|
||||
it('should not reconnect servers with expired tokens', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
|
||||
} as unknown as MCPOAuthTokens);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Verify no reconnection attempt was made
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
} as unknown as MCPOAuthTokens);
|
||||
|
||||
// Mock connection that returns but is not connected
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(false),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Wait for async tryReconnect to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify failure handling
|
||||
expect(mockConnection.disconnect).toHaveBeenCalled();
|
||||
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
});
|
||||
});
|
||||
163
packages/api/src/mcp/oauth/OAuthReconnectionManager.ts
Normal file
163
packages/api/src/mcp/oauth/OAuthReconnectionManager.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import type { MCPOAuthTokens } from './types';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
|
||||
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||
|
||||
export class OAuthReconnectionManager {
|
||||
private static instance: OAuthReconnectionManager | null = null;
|
||||
|
||||
protected readonly flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||
protected readonly tokenMethods: TokenMethods;
|
||||
|
||||
private readonly reconnectionsTracker: OAuthReconnectionTracker;
|
||||
|
||||
public static getInstance(): OAuthReconnectionManager {
|
||||
if (!OAuthReconnectionManager.instance) {
|
||||
throw new Error('OAuthReconnectionManager not initialized');
|
||||
}
|
||||
return OAuthReconnectionManager.instance;
|
||||
}
|
||||
|
||||
public static async createInstance(
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
tokenMethods: TokenMethods,
|
||||
reconnections?: OAuthReconnectionTracker,
|
||||
): Promise<OAuthReconnectionManager> {
|
||||
if (OAuthReconnectionManager.instance != null) {
|
||||
throw new Error('OAuthReconnectionManager already initialized');
|
||||
}
|
||||
|
||||
const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections);
|
||||
OAuthReconnectionManager.instance = manager;
|
||||
|
||||
return manager;
|
||||
}
|
||||
|
||||
public constructor(
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
tokenMethods: TokenMethods,
|
||||
reconnections?: OAuthReconnectionTracker,
|
||||
) {
|
||||
this.flowManager = flowManager;
|
||||
this.tokenMethods = tokenMethods;
|
||||
this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker();
|
||||
}
|
||||
|
||||
public isReconnecting(userId: string, serverName: string): boolean {
|
||||
return this.reconnectionsTracker.isActive(userId, serverName);
|
||||
}
|
||||
|
||||
public async reconnectServers(userId: string) {
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
// 1. derive the servers to reconnect
|
||||
const serversToReconnect = [];
|
||||
for (const serverName of mcpManager.getOAuthServers() ?? []) {
|
||||
const canReconnect = await this.canReconnect(userId, serverName);
|
||||
if (canReconnect) {
|
||||
serversToReconnect.push(serverName);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. mark the servers as reconnecting
|
||||
for (const serverName of serversToReconnect) {
|
||||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
}
|
||||
|
||||
// 3. attempt to reconnect the servers
|
||||
for (const serverName of serversToReconnect) {
|
||||
void this.tryReconnect(userId, serverName);
|
||||
}
|
||||
}
|
||||
|
||||
public clearReconnection(userId: string, serverName: string) {
|
||||
this.reconnectionsTracker.removeFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
}
|
||||
|
||||
private async tryReconnect(userId: string, serverName: string) {
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`;
|
||||
|
||||
logger.info(`${logPrefix} Attempting reconnection`);
|
||||
|
||||
const config = mcpManager.getRawConfig(serverName);
|
||||
|
||||
const cleanupOnFailedReconnect = () => {
|
||||
this.reconnectionsTracker.setFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
mcpManager.disconnectUserConnection(userId, serverName);
|
||||
};
|
||||
|
||||
try {
|
||||
// attempt to get connection (this will use existing tokens and refresh if needed)
|
||||
const connection = await mcpManager.getUserConnection({
|
||||
serverName,
|
||||
user: { id: userId } as TUser,
|
||||
flowManager: this.flowManager,
|
||||
tokenMethods: this.tokenMethods,
|
||||
// don't force new connection, let it reuse existing or create new as needed
|
||||
forceNew: false,
|
||||
// set a reasonable timeout for reconnection attempts
|
||||
connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS,
|
||||
// don't trigger OAuth flow during reconnection
|
||||
returnOnOAuth: true,
|
||||
});
|
||||
|
||||
if (connection && (await connection.isConnected())) {
|
||||
logger.info(`${logPrefix} Successfully reconnected`);
|
||||
this.clearReconnection(userId, serverName);
|
||||
} else {
|
||||
logger.warn(`${logPrefix} Failed to reconnect`);
|
||||
await connection?.disconnect();
|
||||
cleanupOnFailedReconnect();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`${logPrefix} Failed to reconnect: ${error}`);
|
||||
cleanupOnFailedReconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private async canReconnect(userId: string, serverName: string) {
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
// if the server has failed reconnection, don't attempt to reconnect
|
||||
if (this.reconnectionsTracker.isFailed(userId, serverName)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the server is already connected, don't attempt to reconnect
|
||||
const existingConnections = mcpManager.getUserConnections(userId);
|
||||
if (existingConnections?.has(serverName)) {
|
||||
const isConnected = await existingConnections.get(serverName)?.isConnected();
|
||||
if (isConnected) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// if the server has no tokens for the user, don't attempt to reconnect
|
||||
const accessToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}`,
|
||||
});
|
||||
if (accessToken == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the token has expired, don't attempt to reconnect
|
||||
const now = new Date();
|
||||
if (accessToken.expiresAt && accessToken.expiresAt < now) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// …otherwise, we're good to go with the reconnect attempt
|
||||
return true;
|
||||
}
|
||||
}
|
||||
181
packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts
Normal file
181
packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts
Normal file
@@ -0,0 +1,181 @@
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
|
||||
describe('OAuthReconnectTracker', () => {
|
||||
let tracker: OAuthReconnectionTracker;
|
||||
const userId = 'user123';
|
||||
const serverName = 'test-server';
|
||||
const anotherServer = 'another-server';
|
||||
|
||||
beforeEach(() => {
|
||||
tracker = new OAuthReconnectionTracker();
|
||||
});
|
||||
|
||||
describe('setFailed', () => {
|
||||
it('should record a failed reconnection attempt', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should track multiple servers for the same user', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, anotherServer);
|
||||
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
|
||||
it('should track different users independently', () => {
|
||||
const anotherUserId = 'user456';
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(anotherUserId, serverName);
|
||||
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(anotherUserId, serverName)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFailed', () => {
|
||||
it('should return false when no failed attempt is recorded', () => {
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true after a failed attempt is recorded', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for a different server even after another server failed', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, anotherServer)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeFailed', () => {
|
||||
it('should clear a failed reconnect record', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should only clear the specific server for the user', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, anotherServer);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle clearing non-existent records gracefully', () => {
|
||||
expect(() => tracker.removeFailed(userId, serverName)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('setActive', () => {
|
||||
it('should mark a server as reconnecting', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should track multiple reconnecting servers', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
tracker.setActive(userId, anotherServer);
|
||||
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
expect(tracker.isActive(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isActive', () => {
|
||||
it('should return false when server is not reconnecting', () => {
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true when server is marked as reconnecting', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle non-existent user gracefully', () => {
|
||||
expect(tracker.isActive('non-existent-user', serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeActive', () => {
|
||||
it('should clear reconnecting state for a server', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
|
||||
tracker.removeActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should only clear specific server state', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
tracker.setActive(userId, anotherServer);
|
||||
|
||||
tracker.removeActive(userId, serverName);
|
||||
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
expect(tracker.isActive(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle clearing non-existent state gracefully', () => {
|
||||
expect(() => tracker.removeActive(userId, serverName)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('cleanup behavior', () => {
|
||||
it('should clean up empty user sets for failed reconnects', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.removeFailed(userId, serverName);
|
||||
|
||||
// Record and clear another user to ensure internal cleanup
|
||||
const anotherUserId = 'user456';
|
||||
tracker.setFailed(anotherUserId, serverName);
|
||||
|
||||
// Original user should still be able to reconnect
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should clean up empty user sets for active reconnections', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
tracker.removeActive(userId, serverName);
|
||||
|
||||
// Mark another user to ensure internal cleanup
|
||||
const anotherUserId = 'user456';
|
||||
tracker.setActive(anotherUserId, serverName);
|
||||
|
||||
// Original user should not be reconnecting
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('combined state management', () => {
|
||||
it('should handle both failed and reconnecting states independently', () => {
|
||||
// Mark as reconnecting
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Record failed attempt
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
// Clear reconnecting state
|
||||
tracker.removeActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
// Clear failed state
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
46
packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts
Normal file
46
packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
export class OAuthReconnectionTracker {
|
||||
// Map of userId -> Set of serverNames that have failed reconnection
|
||||
private failed: Map<string, Set<string>> = new Map();
|
||||
// Map of userId -> Set of serverNames that are actively reconnecting
|
||||
private active: Map<string, Set<string>> = new Map();
|
||||
|
||||
public isFailed(userId: string, serverName: string): boolean {
|
||||
return this.failed.get(userId)?.has(serverName) ?? false;
|
||||
}
|
||||
|
||||
public isActive(userId: string, serverName: string): boolean {
|
||||
return this.active.get(userId)?.has(serverName) ?? false;
|
||||
}
|
||||
|
||||
public setFailed(userId: string, serverName: string): void {
|
||||
if (!this.failed.has(userId)) {
|
||||
this.failed.set(userId, new Set());
|
||||
}
|
||||
|
||||
this.failed.get(userId)?.add(serverName);
|
||||
}
|
||||
|
||||
public setActive(userId: string, serverName: string): void {
|
||||
if (!this.active.has(userId)) {
|
||||
this.active.set(userId, new Set());
|
||||
}
|
||||
|
||||
this.active.get(userId)?.add(serverName);
|
||||
}
|
||||
|
||||
public removeFailed(userId: string, serverName: string): void {
|
||||
const userServers = this.failed.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.failed.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
public removeActive(userId: string, serverName: string): void {
|
||||
const userServers = this.active.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.active.delete(userId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
MCPOAuthTokens,
|
||||
OAuthMetadata,
|
||||
} from './types';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
|
||||
/** Type for the OAuth metadata from the SDK */
|
||||
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
||||
@@ -33,7 +34,9 @@ export class MCPOAuthHandler {
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
authServerUrl: URL;
|
||||
}> {
|
||||
logger.debug(`[MCPOAuth] discoverMetadata called with serverUrl: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] discoverMetadata called with serverUrl: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
|
||||
let authServerUrl = new URL(serverUrl);
|
||||
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
|
||||
@@ -60,11 +63,15 @@ export class MCPOAuthHandler {
|
||||
}
|
||||
|
||||
// Discover OAuth metadata
|
||||
logger.debug(`[MCPOAuth] Discovering OAuth metadata from ${authServerUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl);
|
||||
|
||||
if (!rawMetadata) {
|
||||
logger.error(`[MCPOAuth] Failed to discover OAuth metadata from ${authServerUrl}`);
|
||||
logger.error(
|
||||
`[MCPOAuth] Failed to discover OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
throw new Error('Failed to discover OAuth metadata');
|
||||
}
|
||||
|
||||
@@ -88,12 +95,15 @@ export class MCPOAuthHandler {
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata,
|
||||
redirectUri?: string,
|
||||
): Promise<OAuthClientInformation> {
|
||||
logger.debug(`[MCPOAuth] Starting client registration for ${serverUrl}, server metadata:`, {
|
||||
grant_types_supported: metadata.grant_types_supported,
|
||||
response_types_supported: metadata.response_types_supported,
|
||||
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
|
||||
scopes_supported: metadata.scopes_supported,
|
||||
});
|
||||
logger.debug(
|
||||
`[MCPOAuth] Starting client registration for ${sanitizeUrlForLogging(serverUrl)}, server metadata:`,
|
||||
{
|
||||
grant_types_supported: metadata.grant_types_supported,
|
||||
response_types_supported: metadata.response_types_supported,
|
||||
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
|
||||
scopes_supported: metadata.scopes_supported,
|
||||
},
|
||||
);
|
||||
|
||||
/** Client metadata based on what the server supports */
|
||||
const clientMetadata = {
|
||||
@@ -114,7 +124,9 @@ export class MCPOAuthHandler {
|
||||
`[MCPOAuth] Server ${serverUrl} supports \`refresh_token\` grant type, adding to request`,
|
||||
);
|
||||
} else {
|
||||
logger.debug(`[MCPOAuth] Server ${serverUrl} does not support \`refresh_token\` grant type`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Server ${sanitizeUrlForLogging(serverUrl)} does not support \`refresh_token\` grant type`,
|
||||
);
|
||||
}
|
||||
clientMetadata.grant_types = requestedGrantTypes;
|
||||
|
||||
@@ -139,19 +151,25 @@ export class MCPOAuthHandler {
|
||||
clientMetadata.scope = availableScopes.join(' ');
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Registering client for ${serverUrl} with metadata:`, clientMetadata);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Registering client for ${sanitizeUrlForLogging(serverUrl)} with metadata:`,
|
||||
clientMetadata,
|
||||
);
|
||||
|
||||
const clientInfo = await registerClient(serverUrl, {
|
||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||
clientMetadata,
|
||||
});
|
||||
|
||||
logger.debug(`[MCPOAuth] Client registered successfully for ${serverUrl}:`, {
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
grant_types: clientInfo.grant_types,
|
||||
scope: clientInfo.scope,
|
||||
});
|
||||
logger.debug(
|
||||
`[MCPOAuth] Client registered successfully for ${sanitizeUrlForLogging(serverUrl)}:`,
|
||||
{
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
grant_types: clientInfo.grant_types,
|
||||
scope: clientInfo.scope,
|
||||
},
|
||||
);
|
||||
|
||||
return clientInfo;
|
||||
}
|
||||
@@ -165,7 +183,9 @@ export class MCPOAuthHandler {
|
||||
userId: string,
|
||||
config: MCPOptions['oauth'] | undefined,
|
||||
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
||||
logger.debug(`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
|
||||
const flowId = this.generateFlowId(userId, serverName);
|
||||
const state = this.generateState();
|
||||
@@ -226,7 +246,9 @@ export class MCPOAuthHandler {
|
||||
metadata,
|
||||
};
|
||||
|
||||
logger.debug(`[MCPOAuth] Authorization URL generated: ${authorizationUrl.toString()}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Authorization URL generated: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
|
||||
);
|
||||
return {
|
||||
authorizationUrl: authorizationUrl.toString(),
|
||||
flowId,
|
||||
@@ -234,10 +256,14 @@ export class MCPOAuthHandler {
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${serverUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
|
||||
|
||||
logger.debug(`[MCPOAuth] OAuth metadata discovered, auth server URL: ${authServerUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
|
||||
/** Dynamic client registration based on the discovered metadata */
|
||||
const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName);
|
||||
@@ -276,7 +302,9 @@ export class MCPOAuthHandler {
|
||||
codeVerifier = authResult.codeVerifier;
|
||||
|
||||
logger.debug(`[MCPOAuth] startAuthorization completed successfully`);
|
||||
logger.debug(`[MCPOAuth] Authorization URL: ${authorizationUrl.toString()}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
|
||||
);
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
@@ -515,7 +543,7 @@ export class MCPOAuthHandler {
|
||||
body.append('client_id', metadata.clientInfo.client_id);
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Refresh request to: ${tokenUrl}`, {
|
||||
logger.debug(`[MCPOAuth] Refresh request to: ${sanitizeUrlForLogging(tokenUrl)}`, {
|
||||
body: body.toString(),
|
||||
headers,
|
||||
});
|
||||
@@ -695,7 +723,9 @@ export class MCPOAuthHandler {
|
||||
}
|
||||
|
||||
// perform the revoke request
|
||||
logger.info(`[MCPOAuth] Revoking tokens for ${serverName} via ${revokeUrl.toString()}`);
|
||||
logger.info(
|
||||
`[MCPOAuth] Revoking tokens for ${serverName} via ${sanitizeUrlForLogging(revokeUrl.toString())}`,
|
||||
);
|
||||
const response = await fetch(revokeUrl, {
|
||||
method: 'POST',
|
||||
body: body.toString(),
|
||||
|
||||
@@ -31,3 +31,17 @@ export function normalizeServerName(serverName: string): string {
|
||||
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a URL by removing query parameters to prevent credential leakage in logs.
|
||||
* @param url - The URL to sanitize (string or URL object)
|
||||
* @returns The sanitized URL string without query parameters
|
||||
*/
|
||||
export function sanitizeUrlForLogging(url: string | URL): string {
|
||||
try {
|
||||
const urlObj = typeof url === 'string' ? new URL(url) : url;
|
||||
return `${urlObj.protocol}//${urlObj.host}${urlObj.pathname}`;
|
||||
} catch {
|
||||
return '[invalid URL]';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,12 +98,6 @@ if (typeof window !== 'undefined') {
|
||||
if (originalRequest.url?.includes('/api/auth/logout') === true) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
if (originalRequest.url?.includes('/api/auth/refresh') === true) {
|
||||
// Refresh token itself failed - redirect to login
|
||||
console.log('Refresh token request failed, redirecting to login...');
|
||||
window.location.href = '/login';
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
if (error.response.status === 401 && !originalRequest._retry) {
|
||||
console.warn('401 error, refreshing token');
|
||||
@@ -124,7 +118,10 @@ if (typeof window !== 'undefined') {
|
||||
isRefreshing = true;
|
||||
|
||||
try {
|
||||
const response = await refreshToken();
|
||||
const response = await refreshToken(
|
||||
// Handle edge case where we get a blank screen if the initial 401 error is from a refresh token request
|
||||
originalRequest.url?.includes('api/auth/refresh') === true ? true : false,
|
||||
);
|
||||
|
||||
const token = response?.token ?? '';
|
||||
|
||||
|
||||
Reference in New Issue
Block a user