Compare commits

...

13 Commits

Author SHA1 Message Date
Rakshit
af4520df29 Merge branch 'main' into docs/azure-instance-name-clarification 2025-11-05 21:37:57 +05:30
constanttime
fabce6f823 docs: clarify Azure instanceName format for speech services
- Add Azure OpenAI STT/TTS examples to librechat.example.yaml
- Clarify that instanceName should be the <NAME> part only (e.g., 'my-instance')
- Not the full URL (e.g., 'my-instance.cognitiveservices.azure.com')
- Add note that full domain format is still supported for backward compatibility

Examples:
- Correct: instanceName: 'my-instance'
- Also works: instanceName: 'my-instance.cognitiveservices.azure.com'

Related to #10283
2025-11-05 21:35:45 +05:30
Danny Avila
0f4222a908 🪞 fix: Prevent Revoked Blob URLs in Uploaded Images (FileRow) (#10361) 2025-11-05 10:28:06 -05:00
Rakshit
772b706e20 🎙️ fix: Azure OpenAI Speech-to-Text 400 Bad Request Error (#10355) 2025-11-05 10:27:34 -05:00
Danny Avila
06fcf79d56 🛂 feat: Social Login by Provider ID First then Email (#10358) 2025-11-05 09:20:35 -05:00
Eduardo Cruz Guedes
c9e1127b85 🌅 docs: Add OpenAI Image Gen Env Vars (#10335) 2025-11-04 13:52:47 -05:00
Max Sanna
14e4941367 📎 fix: Document Uploads for Custom Endpoints (#10336)
* Fixed upload to provider for custom endpoints + unit tests

* fix: add support back for agents to be able to use Upload to Provider with supported providers

* ci: add test for agents endpoint still recognizing document supported providers

* chore: address ESLint suggestions

* Improved unit tests

* Linting error on unit tests fixed

---------

Co-authored-by: Dustin Healy <dustinhealy1@gmail.com>
2025-11-04 13:40:24 -05:00
Theo N. Truong
ce7e6edad8 🔄 refactor: MCP Registry System with Distributed Caching (#10191)
* refactor: Restructure MCP registry system with caching

- Split MCPServersRegistry into modular components:
  - MCPServerInspector: handles server inspection and health checks
  - MCPServersInitializer: manages server initialization logic
  - MCPServersRegistry: simplified registry coordination
- Add distributed caching layer:
  - ServerConfigsCacheRedis: Redis-backed configuration cache
  - ServerConfigsCacheInMemory: in-memory fallback cache
  - RegistryStatusCache: distributed leader election state
- Add promise utilities (withTimeout) replacing Promise.race patterns
- Add comprehensive cache integration tests for all cache implementations
- Remove unused MCPManager.getAllToolFunctions method

* fix: Update OAuth flow to include user-specific headers

* chore: Update Jest configuration to ignore additional test files

- Added patterns to ignore files ending with .helper.ts and .helper.d.ts in testPathIgnorePatterns for cleaner test runs.

* fix: oauth headers in callback

* chore: Update Jest testPathIgnorePatterns to exclude helper files

- Modified testPathIgnorePatterns in package.json to ignore files ending with .helper.ts and .helper.d.ts for cleaner test execution.

* ci: update test mocks

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-10-31 15:00:21 -04:00
github-actions[bot]
961f87cfda 🌍 i18n: Update translation.json with latest translations (#10323)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2025-10-31 14:36:32 -04:00
Danny Avila
9b4c4cafb6 🧠 refactor: Improve Reasoning Component Structure and UX (#10320)
* refactor: Reasoning components with independent toggle buttons

- Refactored ThinkingButton to remove unnecessary state and props.
- Updated ContentParts to simplify content rendering and remove hover handling.
- Improved Reasoning component to include independent toggle functionality for each THINK part.
- Adjusted styles for better layout consistency and user experience.

* refactor: isolate hover effects for Reasoning

- Updated ThinkingButton to improve hover effects and layout consistency.
- Refactored Reasoning component to include a new wrapper class for better styling.
- Adjusted icon visibility and transitions for a smoother user experience.

* fix: Prevent rendering of empty messages in Chat component

- Added a check to skip rendering if the message text is only whitespace, improving the user interface by avoiding empty containers.

* chore: Replace div with fragment in Thinking component for cleaner markup

* chore: move Thinking component to Content Parts directory

* refactor: prevent rendering of whitespace-only text in Part component only for edge cases
2025-10-31 13:05:12 -04:00
Marco Beretta
c0f1cfcaba 💡 feat: Improve Reasoning Content UI, copy-to-clipboard, and error handling (#10278)
*  feat: Refactor error handling and improve loading states in MessageContent component

*  feat: Enhance Thinking and ContentParts components with improved hover functionality and clipboard support

* fix: Adjust padding in Thinking and ContentParts components for consistent layout

*  feat: Add response label and improve message editing UI with contextual indicators

*  feat: Add isEditing prop to Feedback and Fork components for improved editing state handling

* refactor: Remove isEditing prop from Feedback and Fork components for cleaner state management

* refactor: Migrate state management from Recoil to Jotai for font size and show thinking features

* refactor: Separate ToggleSwitch into RecoilToggle and JotaiToggle components for improved clarity and state management

* refactor: Remove unnecessary comments in ToggleSwitch and MessageContent components for cleaner code

* chore: reorder import statements in Thinking.tsx

* chore: reorder import statement in EditTextPart.tsx

* chore: reorder import statement

* chore: Reorganize imports in ToggleSwitch.tsx

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-10-30 17:14:38 -04:00
Federico Ruggi
ea45d0b9c6 🏷️ fix: Add user ID to MCP tools cache keys (#10201)
* add user id to mcp tools cache key

* tests

* clean up redundant tests

* remove unused imports
2025-10-30 17:09:56 -04:00
Theo N. Truong
8f4705f683 👑 feat: Distributed Leader Election with Redis for Multi-instance Coordination (#10189)
* 🔧 refactor: Move GLOBAL_PREFIX_SEPARATOR to cacheConfig for consistency

* 👑 feat: Implement distributed leader election using Redis
2025-10-30 17:08:04 -04:00
94 changed files with 5667 additions and 1507 deletions

View File

@@ -254,6 +254,10 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
# OpenAI Image Tools Customization
#----------------
# IMAGE_GEN_OAI_API_KEY= # Create or reuse OpenAI API key for image generation tool
# IMAGE_GEN_OAI_BASEURL= # Custom OpenAI base URL for image generation tool
# IMAGE_GEN_OAI_AZURE_API_VERSION= # Custom Azure OpenAI deployments
# IMAGE_GEN_OAI_DESCRIPTION=
# IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present
# IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present
# IMAGE_EDIT_OAI_DESCRIPTION=Custom description for image editing tool
@@ -702,6 +706,16 @@ HELP_AND_FAQ_URL=https://librechat.ai
# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES)
# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES
# Leader Election Configuration (for multi-instance deployments with Redis)
# Duration in seconds that the leader lease is valid before it expires (default: 25)
# LEADER_LEASE_DURATION=25
# Interval in seconds at which the leader renews its lease (default: 10)
# LEADER_RENEW_INTERVAL=10
# Maximum number of retry attempts when renewing the lease fails (default: 3)
# LEADER_RENEW_ATTEMPTS=3
# Delay in seconds between retry attempts when renewing the lease (default: 0.5)
# LEADER_RENEW_RETRY_DELAY=0.5
#==================================================#
# Others #
#==================================================#

View File

@@ -8,12 +8,14 @@ on:
- release/*
paths:
- 'packages/api/src/cache/**'
- 'packages/api/src/cluster/**'
- 'packages/api/src/mcp/**'
- 'redis-config/**'
- '.github/workflows/cache-integration-tests.yml'
jobs:
cache_integration_tests:
name: Run Cache Integration Tests
name: Integration Tests that use actual Redis Cache
timeout-minutes: 30
runs-on: ubuntu-latest
@@ -66,7 +68,23 @@ jobs:
USE_REDIS: true
REDIS_URI: redis://127.0.0.1:6379
REDIS_CLUSTER_URI: redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
run: npm run test:cache:integration
run: npm run test:cache-integration:core
- name: Run cluster integration tests
working-directory: packages/api
env:
NODE_ENV: test
USE_REDIS: true
REDIS_URI: redis://127.0.0.1:6379
run: npm run test:cache-integration:cluster
- name: Run mcp integration tests
working-directory: packages/api
env:
NODE_ENV: test
USE_REDIS: true
REDIS_URI: redis://127.0.0.1:6379
run: npm run test:cache-integration:mcp
- name: Stop Redis Cluster
if: always()

View File

@@ -448,7 +448,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
}
if (!availableTools) {
try {
availableTools = await getMCPServerTools(serverName);
availableTools = await getMCPServerTools(safeUser.id, serverName);
} catch (error) {
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
}

View File

@@ -79,6 +79,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
const userId = req.user?.id; // note: userId cannot be undefined at runtime
if (modelSpec?.mcpServers) {
for (const mcpServer of modelSpec.mcpServers) {
mcpServers.add(mcpServer);
@@ -102,7 +103,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet
if (addedServers.has(mcpServer)) {
continue;
}
const serverTools = await getMCPServerTools(mcpServer);
const serverTools = await getMCPServerTools(userId, mcpServer);
if (!serverTools) {
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
addedServers.add(mcpServer);

View File

@@ -1931,7 +1931,7 @@ describe('models/Agent', () => {
});
// Mock getMCPServerTools to return tools for each server
getMCPServerTools.mockImplementation(async (server) => {
getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') {
return { tool1_mcp_server1: {} };
} else if (server === 'server2') {
@@ -2125,7 +2125,7 @@ describe('models/Agent', () => {
getCachedTools.mockResolvedValue(availableTools);
// Mock getMCPServerTools to return all tools for server1
getMCPServerTools.mockImplementation(async (server) => {
getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') {
return availableTools; // All 100 tools belong to server1
}
@@ -2674,7 +2674,7 @@ describe('models/Agent', () => {
});
// Mock getMCPServerTools to return only tools matching the server
getMCPServerTools.mockImplementation(async (server) => {
getMCPServerTools.mockImplementation(async (_userId, server) => {
if (server === 'server1') {
// Only return tool that correctly matches server1 format
return { tool_mcp_server1: {} };

View File

@@ -28,6 +28,7 @@ const { getMCPManager, getFlowStateManager } = require('~/config');
const { getAppConfig } = require('~/server/services/Config');
const { deleteToolCalls } = require('~/models/ToolCall');
const { getLogStores } = require('~/cache');
const { mcpServersRegistry } = require('@librechat/api');
const getUserController = async (req, res) => {
const appConfig = await getAppConfig({ role: req.user?.role });
@@ -198,7 +199,7 @@ const updateUserPluginsController = async (req, res) => {
// If auth was updated successfully, disconnect MCP sessions as they might use these credentials
if (pluginKey.startsWith(Constants.mcp_prefix)) {
try {
const mcpManager = getMCPManager(user.id);
const mcpManager = getMCPManager();
if (mcpManager) {
// Extract server name from pluginKey (format: "mcp_<serverName>")
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
@@ -295,10 +296,11 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
}
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
const mcpManager = getMCPManager(userId);
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
if (!mcpManager.getOAuthServers().has(serverName)) {
const serverConfig =
(await mcpServersRegistry.getServerConfig(serverName, userId)) ??
appConfig?.mcpServers?.[serverName];
const oauthServers = await mcpServersRegistry.getOAuthServers();
if (!oauthServers.has(serverName)) {
// this server does not use OAuth, so nothing to do here as well
return;
}

View File

@@ -10,6 +10,7 @@ const {
getAppConfig,
} = require('~/server/services/Config');
const { getMCPManager } = require('~/config');
const { mcpServersRegistry } = require('@librechat/api');
/**
* Get all MCP tools available to the user
@@ -32,7 +33,7 @@ const getMCPTools = async (req, res) => {
const mcpServers = {};
const cachePromises = configuredServers.map((serverName) =>
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })),
getMCPServerTools(userId, serverName).then((tools) => ({ serverName, tools })),
);
const cacheResults = await Promise.all(cachePromises);
@@ -52,7 +53,7 @@ const getMCPTools = async (req, res) => {
if (Object.keys(serverTools).length > 0) {
// Cache asynchronously without blocking
cacheMCPServerTools({ serverName, serverTools }).catch((err) =>
cacheMCPServerTools({ userId, serverName, serverTools }).catch((err) =>
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
);
}
@@ -65,7 +66,7 @@ const getMCPTools = async (req, res) => {
// Get server config once
const serverConfig = appConfig.mcpConfig[serverName];
const rawServerConfig = mcpManager.getRawConfig(serverName);
const rawServerConfig = await mcpServersRegistry.getServerConfig(serverName, userId);
// Initialize server object with all server-level data
const server = {

View File

@@ -15,6 +15,10 @@ jest.mock('@librechat/api', () => ({
storeTokens: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
mcpServersRegistry: {
getServerConfig: jest.fn(),
getOAuthServers: jest.fn(),
},
}));
jest.mock('@librechat/data-schemas', () => ({
@@ -47,6 +51,7 @@ jest.mock('~/models', () => ({
jest.mock('~/server/services/Config', () => ({
setCachedTools: jest.fn(),
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
loadCustomConfig: jest.fn(),
}));
@@ -114,7 +119,7 @@ describe('MCP Routes', () => {
});
describe('GET /:serverName/oauth/initiate', () => {
const { MCPOAuthHandler } = require('@librechat/api');
const { MCPOAuthHandler, mcpServersRegistry } = require('@librechat/api');
const { getLogStores } = require('~/cache');
it('should initiate OAuth flow successfully', async () => {
@@ -127,13 +132,9 @@ describe('MCP Routes', () => {
}),
};
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({}),
};
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
mcpServersRegistry.getServerConfig.mockResolvedValue({});
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
authorizationUrl: 'https://oauth.example.com/auth',
@@ -287,6 +288,7 @@ describe('MCP Routes', () => {
});
it('should handle OAuth callback successfully', async () => {
const { mcpServersRegistry } = require('@librechat/api');
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
@@ -306,6 +308,7 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
mcpServersRegistry.getServerConfig.mockResolvedValue({});
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -320,7 +323,6 @@ describe('MCP Routes', () => {
};
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
@@ -378,6 +380,7 @@ describe('MCP Routes', () => {
});
it('should handle system-level OAuth completion', async () => {
const { mcpServersRegistry } = require('@librechat/api');
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
@@ -397,14 +400,10 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
mcpServersRegistry.getServerConfig.mockResolvedValue({});
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
@@ -416,6 +415,7 @@ describe('MCP Routes', () => {
});
it('should handle reconnection failure after OAuth', async () => {
const { mcpServersRegistry } = require('@librechat/api');
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
@@ -435,12 +435,12 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
mcpServersRegistry.getServerConfig.mockResolvedValue({});
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const mockMcpManager = {
getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
@@ -460,6 +460,7 @@ describe('MCP Routes', () => {
});
it('should redirect to error page if token storage fails', async () => {
const { mcpServersRegistry } = require('@librechat/api');
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
@@ -479,6 +480,7 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
mcpServersRegistry.getServerConfig.mockResolvedValue({});
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -729,12 +731,14 @@ describe('MCP Routes', () => {
});
describe('POST /:serverName/reinitialize', () => {
const { mcpServersRegistry } = require('@librechat/api');
it('should return 404 when server is not found in configuration', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue(null),
disconnectUserConnection: jest.fn().mockResolvedValue(),
};
mcpServersRegistry.getServerConfig.mockResolvedValue(null);
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
@@ -749,9 +753,6 @@ describe('MCP Routes', () => {
it('should handle OAuth requirement during reinitialize', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {},
}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
mcpConfigs: {},
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
@@ -762,6 +763,9 @@ describe('MCP Routes', () => {
}),
};
mcpServersRegistry.getServerConfig.mockResolvedValue({
customUserVars: {},
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
@@ -787,12 +791,12 @@ describe('MCP Routes', () => {
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
mcpConfigs: {},
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
};
mcpServersRegistry.getServerConfig.mockResolvedValue({});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
@@ -808,11 +812,12 @@ describe('MCP Routes', () => {
it('should return 500 when unexpected error occurs', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockImplementation(() => {
throw new Error('Config loading failed');
}),
disconnectUserConnection: jest.fn(),
};
mcpServersRegistry.getServerConfig.mockImplementation(() => {
throw new Error('Config loading failed');
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).post('/api/mcp/test-server/reinitialize');
@@ -845,11 +850,11 @@ describe('MCP Routes', () => {
};
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
disconnectUserConnection: jest.fn().mockResolvedValue(),
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
};
mcpServersRegistry.getServerConfig.mockResolvedValue({ endpoint: 'http://test-server.com' });
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
@@ -890,16 +895,16 @@ describe('MCP Routes', () => {
};
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
endpoint: 'http://test-server.com',
customUserVars: {
API_KEY: 'some-env-var',
},
}),
disconnectUserConnection: jest.fn().mockResolvedValue(),
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
};
mcpServersRegistry.getServerConfig.mockResolvedValue({
endpoint: 'http://test-server.com',
customUserVars: {
API_KEY: 'some-env-var',
},
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
require('~/config').getFlowStateManager.mockReturnValue({});
require('~/cache').getLogStores.mockReturnValue({});
@@ -1104,17 +1109,17 @@ describe('MCP Routes', () => {
describe('GET /:serverName/auth-values', () => {
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { mcpServersRegistry } = require('@librechat/api');
it('should return auth value flags for server', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {
API_KEY: 'some-env-var',
SECRET_TOKEN: 'another-env-var',
},
}),
};
const mockMcpManager = {};
mcpServersRegistry.getServerConfig.mockResolvedValue({
customUserVars: {
API_KEY: 'some-env-var',
SECRET_TOKEN: 'another-env-var',
},
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
@@ -1134,10 +1139,9 @@ describe('MCP Routes', () => {
});
it('should return 404 when server is not found in configuration', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue(null),
};
const mockMcpManager = {};
mcpServersRegistry.getServerConfig.mockResolvedValue(null);
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
@@ -1149,14 +1153,13 @@ describe('MCP Routes', () => {
});
it('should handle errors when checking auth values', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {
API_KEY: 'some-env-var',
},
}),
};
const mockMcpManager = {};
mcpServersRegistry.getServerConfig.mockResolvedValue({
customUserVars: {
API_KEY: 'some-env-var',
},
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
@@ -1173,12 +1176,11 @@ describe('MCP Routes', () => {
});
it('should return 500 when auth values check throws unexpected error', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockImplementation(() => {
throw new Error('Config loading failed');
}),
};
const mockMcpManager = {};
mcpServersRegistry.getServerConfig.mockImplementation(() => {
throw new Error('Config loading failed');
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/auth-values');
@@ -1188,12 +1190,11 @@ describe('MCP Routes', () => {
});
it('should handle customUserVars that is not an object', async () => {
const mockMcpManager = {
getRawConfig: jest.fn().mockReturnValue({
customUserVars: 'not-an-object',
}),
};
const mockMcpManager = {};
mcpServersRegistry.getServerConfig.mockResolvedValue({
customUserVars: 'not-an-object',
});
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/auth-values');
@@ -1220,7 +1221,7 @@ describe('MCP Routes', () => {
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
@@ -1238,6 +1239,7 @@ describe('MCP Routes', () => {
});
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
mcpServersRegistry.getServerConfig.mockResolvedValue({});
const mockFlowManager = {
completeFlow: jest.fn(),
@@ -1248,7 +1250,6 @@ describe('MCP Routes', () => {
getUserConnection: jest.fn().mockResolvedValue({
fetchTools: jest.fn().mockResolvedValue([]),
}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
@@ -1263,7 +1264,7 @@ describe('MCP Routes', () => {
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
const { getCachedTools } = require('~/server/services/Config');
getCachedTools.mockResolvedValue(null);
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
@@ -1289,6 +1290,7 @@ describe('MCP Routes', () => {
});
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
mcpServersRegistry.getServerConfig.mockResolvedValue({});
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({
@@ -1296,7 +1298,6 @@ describe('MCP Routes', () => {
.fn()
.mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]),
}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);

View File

@@ -12,6 +12,7 @@ const { getAppConfig } = require('~/server/services/Config/app');
const { getProjectByName } = require('~/models/Project');
const { getMCPManager } = require('~/config');
const { getLogStores } = require('~/cache');
const { mcpServersRegistry } = require('@librechat/api');
const router = express.Router();
const emailLoginEnabled =
@@ -125,7 +126,7 @@ router.get('/', async function (req, res) {
payload.minPasswordLength = minPasswordLength;
}
const getMCPServers = () => {
const getMCPServers = async () => {
try {
if (appConfig?.mcpConfig == null) {
return;
@@ -134,9 +135,8 @@ router.get('/', async function (req, res) {
if (!mcpManager) {
return;
}
const mcpServers = mcpManager.getAllServers();
const mcpServers = await mcpServersRegistry.getAllServerConfigs();
if (!mcpServers) return;
const oauthServers = mcpManager.getOAuthServers();
for (const serverName in mcpServers) {
if (!payload.mcpServers) {
payload.mcpServers = {};
@@ -145,7 +145,7 @@ router.get('/', async function (req, res) {
payload.mcpServers[serverName] = removeNullishValues({
startup: serverConfig?.startup,
chatMenu: serverConfig?.chatMenu,
isOAuth: oauthServers?.has(serverName),
isOAuth: serverConfig.requiresOAuth,
customUserVars: serverConfig?.customUserVars,
});
}
@@ -154,7 +154,7 @@ router.get('/', async function (req, res) {
}
};
getMCPServers();
await getMCPServers();
const webSearchConfig = appConfig?.webSearch;
if (
webSearchConfig != null &&

View File

@@ -6,6 +6,7 @@ const {
MCPOAuthHandler,
MCPTokenStorage,
getUserMCPAuthMap,
mcpServersRegistry,
} = require('@librechat/api');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
@@ -61,11 +62,12 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
return res.status(400).json({ error: 'Invalid flow state' });
}
const oauthHeaders = await getOAuthHeaders(serverName, userId);
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
serverName,
serverUrl,
userId,
getOAuthHeaders(serverName),
oauthHeaders,
oauthConfig,
);
@@ -133,12 +135,8 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
});
logger.debug('[MCP OAuth] Completing OAuth flow');
const tokens = await MCPOAuthHandler.completeOAuthFlow(
flowId,
code,
flowManager,
getOAuthHeaders(serverName),
);
const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId);
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders);
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
/** Persist tokens immediately so reconnection uses fresh credentials */
@@ -205,6 +203,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const tools = await userConnection.fetchTools();
await updateMCPServerTools({
userId: flowState.userId,
serverName,
tools,
});
@@ -355,7 +354,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
const mcpManager = getMCPManager();
const serverConfig = mcpManager.getRawConfig(serverName);
const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id);
if (!serverConfig) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
@@ -504,8 +503,7 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
return res.status(401).json({ error: 'User not authenticated' });
}
const mcpManager = getMCPManager();
const serverConfig = mcpManager.getRawConfig(serverName);
const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id);
if (!serverConfig) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
@@ -544,9 +542,8 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
}
});
function getOAuthHeaders(serverName) {
const mcpManager = getMCPManager();
const serverConfig = mcpManager.getRawConfig(serverName);
async function getOAuthHeaders(serverName, userId) {
const serverConfig = await mcpServersRegistry.getServerConfig(serverName, userId);
return serverConfig?.oauth_headers ?? {};
}

View File

@@ -0,0 +1,10 @@
const { ToolCacheKeys } = require('../getCachedTools');
describe('getCachedTools - Cache Isolation Security', () => {
describe('ToolCacheKeys.MCP_SERVER', () => {
it('should generate cache keys that include userId', () => {
const key = ToolCacheKeys.MCP_SERVER('user123', 'github');
expect(key).toBe('tools:mcp:user123:github');
});
});
});

View File

@@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores');
const ToolCacheKeys = {
/** Global tools available to all users */
GLOBAL: 'tools:global',
/** MCP tools cached by server name */
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
/** MCP tools cached by user ID and server name */
MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`,
};
/**
* Retrieves available tools from cache
* @function getCachedTools
* @param {Object} options - Options for retrieving tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name to get cached tools for
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName } = options;
const { userId, serverName } = options;
// Return MCP server-specific tools if requested
if (serverName) {
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
if (serverName && userId) {
return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
}
// Default to global tools
@@ -36,17 +37,18 @@ async function getCachedTools(options = {}) {
* @function setCachedTools
* @param {Object} tools - The tools object to cache
* @param {Object} options - Options for caching tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name for server-specific tools
* @param {number} [options.ttl] - Time to live in milliseconds
* @returns {Promise<boolean>} Whether the operation was successful
*/
async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, ttl } = options;
const { userId, serverName, ttl } = options;
// Cache by MCP server if specified
if (serverName) {
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
// Cache by MCP server if specified (requires userId)
if (serverName && userId) {
return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl);
}
// Default to global cache
@@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) {
* Invalidates cached tools
* @function invalidateCachedTools
* @param {Object} options - Options for invalidating tools
* @param {string} [options.userId] - User ID for user-specific MCP tools
* @param {string} [options.serverName] - MCP server name to invalidate
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
* @returns {Promise<void>}
*/
async function invalidateCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, invalidateGlobal = false } = options;
const { userId, serverName, invalidateGlobal = false } = options;
const keysToDelete = [];
@@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) {
keysToDelete.push(ToolCacheKeys.GLOBAL);
}
if (serverName) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
if (serverName && userId) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName));
}
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
}
/**
* Gets MCP tools for a specific server from cache or merges with global tools
* Gets MCP tools for a specific server from cache
* @function getMCPServerTools
* @param {string} userId - The user ID
* @param {string} serverName - The MCP server name
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
*/
async function getMCPServerTools(serverName) {
async function getMCPServerTools(userId, serverName) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName));
if (serverTools) {
return serverTools;

View File

@@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache');
/**
* Updates MCP tools in the cache for a specific server
* @param {Object} params - Parameters for updating MCP tools
* @param {string} params.userId - User ID for user-specific caching
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPServerTools({ serverName, tools }) {
async function updateMCPServerTools({ userId, serverName, tools }) {
try {
const serverTools = {};
const mcpDelimiter = Constants.mcp_delimiter;
@@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) {
};
}
await setCachedTools(serverTools, { serverName });
await setCachedTools(serverTools, { userId, serverName });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
logger.debug(
`[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`,
);
return serverTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error);
throw error;
}
}
@@ -65,21 +68,22 @@ async function mergeAppTools(appTools) {
/**
* Caches MCP server tools (no longer merges with global)
* @param {object} params
* @param {string} params.userId - User ID for user-specific caching
* @param {string} params.serverName
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
* @returns {Promise<void>}
*/
async function cacheMCPServerTools({ serverName, serverTools }) {
async function cacheMCPServerTools({ userId, serverName, serverTools }) {
try {
const count = Object.keys(serverTools).length;
if (!count) {
return;
}
// Only cache server-specific tools, no merging with global
await setCachedTools(serverTools, { serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
await setCachedTools(serverTools, { userId, serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`);
} catch (error) {
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error);
throw error;
}
}

View File

@@ -227,7 +227,6 @@ class STTService {
}
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { 'api-key': apiKey }),
};

View File

@@ -25,6 +25,7 @@ const { findToken, createToken, updateToken } = require('~/models');
const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache');
const { mcpServersRegistry } = require('@librechat/api');
/**
* @param {object} params
@@ -450,7 +451,7 @@ async function getMCPSetupData(userId) {
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
}
const userConnections = mcpManager.getUserConnections(userId) || new Map();
const oauthServers = mcpManager.getOAuthServers();
const oauthServers = await mcpServersRegistry.getOAuthServers();
return {
mcpConfig,

View File

@@ -50,6 +50,9 @@ jest.mock('@librechat/api', () => ({
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
mcpServersRegistry: {
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
},
}));
jest.mock('librechat-data-provider', () => ({
@@ -100,6 +103,7 @@ describe('tests for the new helper functions used by the MCP connection status e
let mockGetFlowStateManager;
let mockGetLogStores;
let mockGetOAuthReconnectionManager;
let mockMcpServersRegistry;
beforeEach(() => {
jest.clearAllMocks();
@@ -108,6 +112,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
mockMcpServersRegistry = require('@librechat/api').mcpServersRegistry;
});
describe('getMCPSetupData', () => {
@@ -125,8 +130,8 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetMCPManager.mockReturnValue({
appConnections: { getAll: jest.fn(() => new Map()) },
getUserConnections: jest.fn(() => new Map()),
getOAuthServers: jest.fn(() => new Set()),
});
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set());
});
it('should successfully return MCP setup data', async () => {
@@ -139,9 +144,9 @@ describe('tests for the new helper functions used by the MCP connection status e
const mockMCPManager = {
appConnections: { getAll: jest.fn(() => mockAppConnections) },
getUserConnections: jest.fn(() => mockUserConnections),
getOAuthServers: jest.fn(() => mockOAuthServers),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(mockOAuthServers);
const result = await getMCPSetupData(mockUserId);
@@ -149,7 +154,7 @@ describe('tests for the new helper functions used by the MCP connection status e
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
expect(mockMcpServersRegistry.getOAuthServers).toHaveBeenCalled();
expect(result).toEqual({
mcpConfig: mockConfig.mcpServers,
@@ -170,9 +175,9 @@ describe('tests for the new helper functions used by the MCP connection status e
const mockMCPManager = {
appConnections: { getAll: jest.fn(() => null) },
getUserConnections: jest.fn(() => null),
getOAuthServers: jest.fn(() => new Set()),
};
mockGetMCPManager.mockReturnValue(mockMCPManager);
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set());
const result = await getMCPSetupData(mockUserId);

View File

@@ -98,6 +98,7 @@ async function reinitMCPServer({
if (connection && !oauthRequired) {
tools = await connection.fetchTools();
availableTools = await updateMCPServerTools({
userId: user.id,
serverName,
tools,
});

View File

@@ -15,7 +15,7 @@ async function initializeMCPs() {
const mcpManager = await createMCPManager(mcpServers);
try {
const mcpTools = mcpManager.getAppToolFunctions() || {};
const mcpTools = (await mcpManager.getAppToolFunctions()) || {};
await mergeAppTools(mcpTools);
logger.info(

View File

@@ -304,6 +304,7 @@ describe('Apple Login Strategy', () => {
fileStrategy: 'local',
balance: { enabled: false },
}),
'jane.doe@example.com',
);
});

View File

@@ -5,22 +5,25 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar');
const { updateUser, createUser, getUserById } = require('~/models');
/**
* Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter
* Updates the avatar URL and email of an existing user. If the user's avatar URL does not include the query parameter
* '?manual=true', it updates the user's avatar with the provided URL. For local file storage, it directly updates
* the avatar URL, while for other storage types, it processes the avatar URL using the specified file strategy.
* Also updates the email if it has changed (e.g., when a Google Workspace email is updated).
*
* @param {IUser} oldUser - The existing user object that needs to be updated.
* @param {string} avatarUrl - The new avatar URL to be set for the user.
* @param {AppConfig} appConfig - The application configuration object.
* @param {string} [email] - Optional. The new email address to update if it has changed.
*
* @returns {Promise<void>}
* The function updates the user's avatar and saves the user object. It does not return any value.
* The function updates the user's avatar and/or email and saves the user object. It does not return any value.
*
* @throws {Error} Throws an error if there's an issue saving the updated user object.
*/
const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
const handleExistingUser = async (oldUser, avatarUrl, appConfig, email) => {
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
const isLocal = fileStrategy === FileSources.local;
const updates = {};
let updatedAvatar = false;
const hasManualFlag =
@@ -39,7 +42,16 @@ const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
}
if (updatedAvatar) {
await updateUser(oldUser._id, { avatar: updatedAvatar });
updates.avatar = updatedAvatar;
}
/** Update email if it has changed */
if (email && email.trim() !== oldUser.email) {
updates.email = email.trim();
}
if (Object.keys(updates).length > 0) {
await updateUser(oldUser._id, updates);
}
};

View File

@@ -167,4 +167,76 @@ describe('handleExistingUser', () => {
// This should throw an error when trying to access oldUser._id
await expect(handleExistingUser(null, avatarUrl)).rejects.toThrow();
});
it('should update email when it has changed', async () => {
const oldUser = {
_id: 'user123',
email: 'old@example.com',
avatar: 'https://example.com/avatar.png?manual=true',
};
const avatarUrl = 'https://example.com/avatar.png';
const newEmail = 'new@example.com';
await handleExistingUser(oldUser, avatarUrl, {}, newEmail);
expect(updateUser).toHaveBeenCalledWith('user123', { email: 'new@example.com' });
});
it('should update both avatar and email when both have changed', async () => {
const oldUser = {
_id: 'user123',
email: 'old@example.com',
avatar: null,
};
const avatarUrl = 'https://example.com/new-avatar.png';
const newEmail = 'new@example.com';
await handleExistingUser(oldUser, avatarUrl, {}, newEmail);
expect(updateUser).toHaveBeenCalledWith('user123', {
avatar: avatarUrl,
email: 'new@example.com',
});
});
it('should not update email when it has not changed', async () => {
const oldUser = {
_id: 'user123',
email: 'same@example.com',
avatar: 'https://example.com/avatar.png?manual=true',
};
const avatarUrl = 'https://example.com/avatar.png';
const sameEmail = 'same@example.com';
await handleExistingUser(oldUser, avatarUrl, {}, sameEmail);
expect(updateUser).not.toHaveBeenCalled();
});
it('should trim email before comparison and update', async () => {
const oldUser = {
_id: 'user123',
email: 'test@example.com',
avatar: 'https://example.com/avatar.png?manual=true',
};
const avatarUrl = 'https://example.com/avatar.png';
const newEmailWithSpaces = ' newemail@example.com ';
await handleExistingUser(oldUser, avatarUrl, {}, newEmailWithSpaces);
expect(updateUser).toHaveBeenCalledWith('user123', { email: 'newemail@example.com' });
});
it('should not update when email parameter is not provided', async () => {
const oldUser = {
_id: 'user123',
email: 'test@example.com',
avatar: 'https://example.com/avatar.png?manual=true',
};
const avatarUrl = 'https://example.com/avatar.png';
await handleExistingUser(oldUser, avatarUrl, {});
expect(updateUser).not.toHaveBeenCalled();
});
});

View File

@@ -25,10 +25,24 @@ const socialLogin =
return cb(error);
}
const existingUser = await findUser({ email: email.trim() });
const providerKey = `${provider}Id`;
let existingUser = null;
/** First try to find user by provider ID (e.g., googleId, facebookId) */
if (id && typeof id === 'string') {
existingUser = await findUser({ [providerKey]: id });
}
/** If not found by provider ID, try finding by email */
if (!existingUser) {
existingUser = await findUser({ email: email?.trim() });
if (existingUser) {
logger.warn(`[${provider}Login] User found by email: ${email} but not by ${providerKey}`);
}
}
if (existingUser?.provider === provider) {
await handleExistingUser(existingUser, avatarUrl, appConfig);
await handleExistingUser(existingUser, avatarUrl, appConfig, email);
return cb(null, existingUser);
} else if (existingUser) {
logger.info(

View File

@@ -0,0 +1,276 @@
const { logger } = require('@librechat/data-schemas');
const { ErrorTypes } = require('librechat-data-provider');
const { createSocialUser, handleExistingUser } = require('./process');
const socialLogin = require('./socialLogin');
const { findUser } = require('~/models');
jest.mock('@librechat/data-schemas', () => {
const actualModule = jest.requireActual('@librechat/data-schemas');
return {
...actualModule,
logger: {
error: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
},
};
});
jest.mock('./process', () => ({
createSocialUser: jest.fn(),
handleExistingUser: jest.fn(),
}));
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
isEnabled: jest.fn().mockReturnValue(true),
isEmailDomainAllowed: jest.fn().mockReturnValue(true),
}));
jest.mock('~/models', () => ({
findUser: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn().mockResolvedValue({
fileStrategy: 'local',
balance: { enabled: false },
}),
}));
describe('socialLogin', () => {
const mockGetProfileDetails = ({ profile }) => ({
email: profile.emails[0].value,
id: profile.id,
avatarUrl: profile.photos?.[0]?.value || null,
username: profile.name?.givenName || 'user',
name: `${profile.name?.givenName || ''} ${profile.name?.familyName || ''}`.trim(),
emailVerified: profile.emails[0].verified || false,
});
beforeEach(() => {
jest.clearAllMocks();
});
describe('Finding users by provider ID', () => {
it('should find user by provider ID (googleId) when email has changed', async () => {
const provider = 'google';
const googleId = 'google-user-123';
const oldEmail = 'old@example.com';
const newEmail = 'new@example.com';
const existingUser = {
_id: 'user123',
email: oldEmail,
provider: 'google',
googleId: googleId,
};
/** Mock findUser to return user on first call (by googleId), null on second call */
findUser
.mockResolvedValueOnce(existingUser) // First call: finds by googleId
.mockResolvedValueOnce(null); // Second call would be by email, but won't be reached
const mockProfile = {
id: googleId,
emails: [{ value: newEmail, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'John', familyName: 'Doe' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
/** Verify it searched by googleId first */
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
/** Verify it did NOT search by email (because it found user by googleId) */
expect(findUser).toHaveBeenCalledTimes(1);
/** Verify handleExistingUser was called with the new email */
expect(handleExistingUser).toHaveBeenCalledWith(
existingUser,
'https://example.com/avatar.png',
expect.any(Object),
newEmail,
);
/** Verify callback was called with success */
expect(callback).toHaveBeenCalledWith(null, existingUser);
});
it('should find user by provider ID (facebookId) when using Facebook', async () => {
const provider = 'facebook';
const facebookId = 'fb-user-456';
const email = 'user@example.com';
const existingUser = {
_id: 'user456',
email: email,
provider: 'facebook',
facebookId: facebookId,
};
findUser.mockResolvedValue(existingUser); // Always returns user
const mockProfile = {
id: facebookId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/fb-avatar.png' }],
name: { givenName: 'Jane', familyName: 'Smith' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
/** Verify it searched by facebookId first */
expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId });
expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]);
expect(handleExistingUser).toHaveBeenCalledWith(
existingUser,
'https://example.com/fb-avatar.png',
expect.any(Object),
email,
);
expect(callback).toHaveBeenCalledWith(null, existingUser);
});
it('should fallback to finding user by email if not found by provider ID', async () => {
const provider = 'google';
const googleId = 'google-user-789';
const email = 'user@example.com';
const existingUser = {
_id: 'user789',
email: email,
provider: 'google',
googleId: 'old-google-id', // Different googleId (edge case)
};
/** First call (by googleId) returns null, second call (by email) returns user */
findUser
.mockResolvedValueOnce(null) // By googleId
.mockResolvedValueOnce(existingUser); // By email
const mockProfile = {
id: googleId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'Bob', familyName: 'Johnson' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
/** Verify both searches happened */
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
expect(findUser).toHaveBeenNthCalledWith(2, { email: email });
expect(findUser).toHaveBeenCalledTimes(2);
/** Verify warning log */
expect(logger.warn).toHaveBeenCalledWith(
`[${provider}Login] User found by email: ${email} but not by ${provider}Id`,
);
expect(handleExistingUser).toHaveBeenCalled();
expect(callback).toHaveBeenCalledWith(null, existingUser);
});
it('should create new user if not found by provider ID or email', async () => {
const provider = 'google';
const googleId = 'google-new-user';
const email = 'newuser@example.com';
const newUser = {
_id: 'newuser123',
email: email,
provider: 'google',
googleId: googleId,
};
/** Both searches return null */
findUser.mockResolvedValue(null);
createSocialUser.mockResolvedValue(newUser);
const mockProfile = {
id: googleId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'New', familyName: 'User' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
/** Verify both searches happened */
expect(findUser).toHaveBeenCalledTimes(2);
/** Verify createSocialUser was called */
expect(createSocialUser).toHaveBeenCalledWith({
email: email,
avatarUrl: 'https://example.com/avatar.png',
provider: provider,
providerKey: 'googleId',
providerId: googleId,
username: 'New',
name: 'New User',
emailVerified: true,
appConfig: expect.any(Object),
});
expect(callback).toHaveBeenCalledWith(null, newUser);
});
});
describe('Error handling', () => {
it('should return error if user exists with different provider', async () => {
const provider = 'google';
const googleId = 'google-user-123';
const email = 'user@example.com';
const existingUser = {
_id: 'user123',
email: email,
provider: 'local', // Different provider
};
findUser
.mockResolvedValueOnce(null) // By googleId
.mockResolvedValueOnce(existingUser); // By email
const mockProfile = {
id: googleId,
emails: [{ value: email, verified: true }],
photos: [{ value: 'https://example.com/avatar.png' }],
name: { givenName: 'John', familyName: 'Doe' },
};
const loginFn = socialLogin(provider, mockGetProfileDetails);
const callback = jest.fn();
await loginFn(null, null, null, mockProfile, callback);
/** Verify error callback */
expect(callback).toHaveBeenCalledWith(
expect.objectContaining({
code: ErrorTypes.AUTH_FAILED,
provider: 'local',
}),
);
expect(logger.info).toHaveBeenCalledWith(
`[${provider}Login] User ${email} already exists with provider local`,
);
});
});
});

View File

@@ -1,89 +0,0 @@
import { useState, useMemo, memo, useCallback } from 'react';
import { useRecoilValue } from 'recoil';
import { Atom, ChevronDown } from 'lucide-react';
import type { MouseEvent, FC } from 'react';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import store from '~/store';
const BUTTON_STYLES = {
base: 'group mt-3 flex w-fit items-center justify-center rounded-xl bg-surface-tertiary px-3 py-2 text-xs leading-[18px] animate-thinking-appear',
icon: 'icon-sm ml-1.5 transform-gpu text-text-primary transition-transform duration-200',
} as const;
const CONTENT_STYLES = {
wrapper: 'relative pl-3 text-text-secondary',
border:
'absolute left-0 h-[calc(100%-10px)] border-l-2 border-border-medium dark:border-border-heavy',
partBorder:
'absolute left-0 h-[calc(100%)] border-l-2 border-border-medium dark:border-border-heavy',
text: 'whitespace-pre-wrap leading-[26px]',
} as const;
export const ThinkingContent: FC<{ children: React.ReactNode; isPart?: boolean }> = memo(
({ isPart, children }) => (
<div className={CONTENT_STYLES.wrapper}>
<div className={isPart === true ? CONTENT_STYLES.partBorder : CONTENT_STYLES.border} />
<p className={CONTENT_STYLES.text}>{children}</p>
</div>
),
);
export const ThinkingButton = memo(
({
isExpanded,
onClick,
label,
}: {
isExpanded: boolean;
onClick: (e: MouseEvent<HTMLButtonElement>) => void;
label: string;
}) => (
<button type="button" onClick={onClick} className={BUTTON_STYLES.base}>
<Atom size={14} className="mr-1.5 text-text-secondary" />
{label}
<ChevronDown className={`${BUTTON_STYLES.icon} ${isExpanded ? 'rotate-180' : ''}`} />
</button>
),
);
const Thinking: React.ElementType = memo(({ children }: { children: React.ReactNode }) => {
const localize = useLocalize();
const showThinking = useRecoilValue<boolean>(store.showThinking);
const [isExpanded, setIsExpanded] = useState(showThinking);
const handleClick = useCallback((e: MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
setIsExpanded((prev) => !prev);
}, []);
const label = useMemo(() => localize('com_ui_thoughts'), [localize]);
if (children == null) {
return null;
}
return (
<>
<div className="mb-5">
<ThinkingButton isExpanded={isExpanded} onClick={handleClick} label={label} />
</div>
<div
className={cn('grid transition-all duration-300 ease-out', isExpanded && 'mb-8')}
style={{
gridTemplateRows: isExpanded ? '1fr' : '0fr',
}}
>
<div className="overflow-hidden">
<ThinkingContent isPart={true}>{children}</ThinkingContent>
</div>
</div>
</>
);
});
ThinkingButton.displayName = 'ThinkingButton';
ThinkingContent.displayName = 'ThinkingContent';
Thinking.displayName = 'Thinking';
export default memo(Thinking);

View File

@@ -117,8 +117,10 @@ const AttachFileMenu = ({
const items: MenuItemProps[] = [];
const currentProvider = provider || endpoint;
if (isDocumentSupportedProvider(currentProvider || endpointType)) {
if (
isDocumentSupportedProvider(endpointType) ||
isDocumentSupportedProvider(currentProvider)
) {
items.push({
label: localize('com_ui_upload_provider'),
onClick: () => {

View File

@@ -57,7 +57,7 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
const currentProvider = provider || endpoint;
// Check if provider supports document upload
if (isDocumentSupportedProvider(currentProvider || endpointType)) {
if (isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider)) {
const isGoogleProvider = currentProvider === EModelEndpoint.google;
const validFileTypes = isGoogleProvider
? files.every(

View File

@@ -133,7 +133,7 @@ export default function FileRow({
>
{isImage ? (
<Image
url={file.preview ?? file.filepath}
url={file.progress === 1 ? file.filepath : (file.preview ?? file.filepath)}
onDelete={handleDelete}
progress={file.progress}
source={file.source}

View File

@@ -0,0 +1,602 @@
import React from 'react';
import { render, screen, fireEvent } from '@testing-library/react';
import '@testing-library/jest-dom';
import { RecoilRoot } from 'recoil';
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { EModelEndpoint } from 'librechat-data-provider';
import AttachFileMenu from '../AttachFileMenu';
// Mock all the hooks
jest.mock('~/hooks', () => ({
useAgentToolPermissions: jest.fn(),
useAgentCapabilities: jest.fn(),
useGetAgentsConfig: jest.fn(),
useFileHandling: jest.fn(),
useLocalize: jest.fn(),
}));
jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({
__esModule: true,
default: jest.fn(),
}));
jest.mock('~/data-provider', () => ({
useGetStartupConfig: jest.fn(),
}));
jest.mock('~/components/SharePoint', () => ({
SharePointPickerDialog: jest.fn(() => null),
}));
jest.mock('@librechat/client', () => {
const React = jest.requireActual('react');
return {
FileUpload: React.forwardRef(({ children, handleFileChange }: any, ref: any) => (
<div data-testid="file-upload">
<input ref={ref} type="file" onChange={handleFileChange} data-testid="file-input" />
{children}
</div>
)),
TooltipAnchor: ({ render }: any) => render,
DropdownPopup: ({ trigger, items, isOpen, setIsOpen }: any) => {
const handleTriggerClick = () => {
if (setIsOpen) {
setIsOpen(!isOpen);
}
};
return (
<div>
<div onClick={handleTriggerClick}>{trigger}</div>
{isOpen && (
<div data-testid="dropdown-menu">
{items.map((item: any, idx: number) => (
<button key={idx} onClick={item.onClick} data-testid={`menu-item-${idx}`}>
{item.label}
</button>
))}
</div>
)}
</div>
);
},
AttachmentIcon: () => <span data-testid="attachment-icon">📎</span>,
SharePointIcon: () => <span data-testid="sharepoint-icon">SP</span>,
};
});
jest.mock('@ariakit/react', () => ({
MenuButton: ({ children, onClick, disabled, ...props }: any) => (
<button onClick={onClick} disabled={disabled} {...props}>
{children}
</button>
),
}));
const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions;
const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities;
const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig;
const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling;
const mockUseLocalize = jest.requireMock('~/hooks').useLocalize;
const mockUseSharePointFileHandling = jest.requireMock(
'~/hooks/Files/useSharePointFileHandling',
).default;
const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig;
describe('AttachFileMenu', () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
});
const mockHandleFileChange = jest.fn();
beforeEach(() => {
jest.clearAllMocks();
// Default mock implementations
mockUseLocalize.mockReturnValue((key: string) => {
const translations: Record<string, string> = {
com_ui_upload_provider: 'Upload to Provider',
com_ui_upload_image_input: 'Upload Image',
com_ui_upload_ocr_text: 'Upload OCR Text',
com_ui_upload_file_search: 'Upload for File Search',
com_ui_upload_code_files: 'Upload Code Files',
com_sidepanel_attach_files: 'Attach Files',
com_files_upload_sharepoint: 'Upload from SharePoint',
};
return translations[key] || key;
});
mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false,
fileSearchEnabled: false,
codeEnabled: false,
});
mockUseGetAgentsConfig.mockReturnValue({
agentsConfig: {
capabilities: {
contextEnabled: false,
fileSearchEnabled: false,
codeEnabled: false,
},
},
});
mockUseFileHandling.mockReturnValue({
handleFileChange: mockHandleFileChange,
});
mockUseSharePointFileHandling.mockReturnValue({
handleSharePointFiles: jest.fn(),
isProcessing: false,
downloadProgress: 0,
});
mockUseGetStartupConfig.mockReturnValue({
data: {
sharePointFilePickerEnabled: false,
},
});
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: undefined,
});
});
const renderAttachFileMenu = (props: any = {}) => {
return render(
<QueryClientProvider client={queryClient}>
<RecoilRoot>
<AttachFileMenu conversationId="test-conversation" {...props} />
</RecoilRoot>
</QueryClientProvider>,
);
};
describe('Basic Rendering', () => {
it('should render the attachment button', () => {
renderAttachFileMenu();
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
it('should be disabled when disabled prop is true', () => {
renderAttachFileMenu({ disabled: true });
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeDisabled();
});
it('should not be disabled when disabled prop is false', () => {
renderAttachFileMenu({ disabled: false });
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).not.toBeDisabled();
});
});
describe('Provider Detection Fix - endpointType Priority', () => {
it('should prioritize endpointType over currentProvider for LiteLLM gateway', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: 'litellm', // Custom gateway name NOT in documentSupportedProviders
});
renderAttachFileMenu({
endpoint: 'litellm',
endpointType: EModelEndpoint.openAI, // Backend override IS in documentSupportedProviders
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
// With the fix, should show "Upload to Provider" because endpointType is checked first
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
expect(screen.queryByText('Upload Image')).not.toBeInTheDocument();
});
it('should show Upload to Provider for custom endpoints with OpenAI endpointType', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: 'my-custom-gateway',
});
renderAttachFileMenu({
endpoint: 'my-custom-gateway',
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
it('should show Upload Image when neither endpointType nor provider support documents', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: 'unsupported-provider',
});
renderAttachFileMenu({
endpoint: 'unsupported-provider',
endpointType: 'unsupported-endpoint' as any,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload Image')).toBeInTheDocument();
expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument();
});
it('should fallback to currentProvider when endpointType is undefined', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.openAI,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.openAI,
endpointType: undefined,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
it('should fallback to currentProvider when endpointType is null', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.anthropic,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.anthropic,
endpointType: null,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
});
describe('Supported Providers', () => {
const supportedProviders = [
{ name: 'OpenAI', endpoint: EModelEndpoint.openAI },
{ name: 'Anthropic', endpoint: EModelEndpoint.anthropic },
{ name: 'Google', endpoint: EModelEndpoint.google },
{ name: 'Azure OpenAI', endpoint: EModelEndpoint.azureOpenAI },
{ name: 'Custom', endpoint: EModelEndpoint.custom },
];
supportedProviders.forEach(({ name, endpoint }) => {
it(`should show Upload to Provider for ${name}`, () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: endpoint,
});
renderAttachFileMenu({
endpoint,
endpointType: endpoint,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
});
});
describe('Agent Capabilities', () => {
it('should show OCR Text option when context is enabled', () => {
mockUseAgentCapabilities.mockReturnValue({
contextEnabled: true,
fileSearchEnabled: false,
codeEnabled: false,
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload OCR Text')).toBeInTheDocument();
});
it('should show File Search option when enabled and allowed by agent', () => {
mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false,
fileSearchEnabled: true,
codeEnabled: false,
});
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: true,
codeAllowedByAgent: false,
provider: undefined,
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload for File Search')).toBeInTheDocument();
});
it('should NOT show File Search when enabled but not allowed by agent', () => {
mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false,
fileSearchEnabled: true,
codeEnabled: false,
});
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: undefined,
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.queryByText('Upload for File Search')).not.toBeInTheDocument();
});
it('should show Code Files option when enabled and allowed by agent', () => {
mockUseAgentCapabilities.mockReturnValue({
contextEnabled: false,
fileSearchEnabled: false,
codeEnabled: true,
});
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: true,
provider: undefined,
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload Code Files')).toBeInTheDocument();
});
it('should show all options when all capabilities are enabled', () => {
mockUseAgentCapabilities.mockReturnValue({
contextEnabled: true,
fileSearchEnabled: true,
codeEnabled: true,
});
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: true,
codeAllowedByAgent: true,
provider: undefined,
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
expect(screen.getByText('Upload OCR Text')).toBeInTheDocument();
expect(screen.getByText('Upload for File Search')).toBeInTheDocument();
expect(screen.getByText('Upload Code Files')).toBeInTheDocument();
});
});
describe('SharePoint Integration', () => {
it('should show SharePoint option when enabled', () => {
mockUseGetStartupConfig.mockReturnValue({
data: {
sharePointFilePickerEnabled: true,
},
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload from SharePoint')).toBeInTheDocument();
});
it('should NOT show SharePoint option when disabled', () => {
mockUseGetStartupConfig.mockReturnValue({
data: {
sharePointFilePickerEnabled: false,
},
});
renderAttachFileMenu({
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.queryByText('Upload from SharePoint')).not.toBeInTheDocument();
});
});
describe('Edge Cases', () => {
it('should handle undefined endpoint and provider gracefully', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: undefined,
});
renderAttachFileMenu({
endpoint: undefined,
endpointType: undefined,
});
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
fireEvent.click(button);
// Should show Upload Image as fallback
expect(screen.getByText('Upload Image')).toBeInTheDocument();
});
it('should handle null endpoint and provider gracefully', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: null,
});
renderAttachFileMenu({
endpoint: null,
endpointType: null,
});
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
it('should handle missing agentId gracefully', () => {
renderAttachFileMenu({
agentId: undefined,
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
it('should handle empty string agentId', () => {
renderAttachFileMenu({
agentId: '',
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
expect(button).toBeInTheDocument();
});
});
describe('Google Provider Special Case', () => {
it('should use google_multimodal file type for Google provider', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.google,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.google,
endpointType: EModelEndpoint.google,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
const uploadProviderButton = screen.getByText('Upload to Provider');
expect(uploadProviderButton).toBeInTheDocument();
// Click the upload to provider option
fireEvent.click(uploadProviderButton);
// The file input should have been clicked (indirectly tested through the implementation)
});
it('should use multimodal file type for non-Google providers', () => {
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.openAI,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.openAI,
endpointType: EModelEndpoint.openAI,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
const uploadProviderButton = screen.getByText('Upload to Provider');
expect(uploadProviderButton).toBeInTheDocument();
fireEvent.click(uploadProviderButton);
// Implementation detail - multimodal type is used
});
});
describe('Regression Tests', () => {
it('should not break the previous behavior for direct provider attachments', () => {
// When using a direct supported provider (not through a gateway)
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.anthropic,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.anthropic,
endpointType: EModelEndpoint.anthropic,
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
it('should maintain correct priority when both are supported', () => {
// Both endpointType and provider are supported, endpointType should be checked first
mockUseAgentToolPermissions.mockReturnValue({
fileSearchAllowedByAgent: false,
codeAllowedByAgent: false,
provider: EModelEndpoint.google,
});
renderAttachFileMenu({
endpoint: EModelEndpoint.google,
endpointType: EModelEndpoint.openAI, // Different but both supported
});
const button = screen.getByRole('button', { name: /attach file options/i });
fireEvent.click(button);
// Should still work because endpointType (openAI) is supported
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
});
});
});

View File

@@ -0,0 +1,121 @@
import { EModelEndpoint, isDocumentSupportedProvider } from 'librechat-data-provider';
describe('DragDropModal - Provider Detection', () => {
describe('endpointType priority over currentProvider', () => {
it('should show upload option for LiteLLM with OpenAI endpointType', () => {
const currentProvider = 'litellm'; // NOT in documentSupportedProviders
const endpointType = EModelEndpoint.openAI; // IS in documentSupportedProviders
// With fix: endpointType checked
const withFix =
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
expect(withFix).toBe(true);
// Without fix: only currentProvider checked = false
const withoutFix = isDocumentSupportedProvider(currentProvider || endpointType);
expect(withoutFix).toBe(false);
});
it('should show upload option for any custom gateway with OpenAI endpointType', () => {
const currentProvider = 'my-custom-gateway';
const endpointType = EModelEndpoint.openAI;
const result =
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
expect(result).toBe(true);
});
it('should fallback to currentProvider when endpointType is undefined', () => {
const currentProvider = EModelEndpoint.openAI;
const endpointType = undefined;
const result =
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
expect(result).toBe(true);
});
it('should fallback to currentProvider when endpointType is null', () => {
const currentProvider = EModelEndpoint.anthropic;
const endpointType = null;
const result =
isDocumentSupportedProvider(endpointType as any) ||
isDocumentSupportedProvider(currentProvider);
expect(result).toBe(true);
});
it('should return false when neither provider supports documents', () => {
const currentProvider = 'unsupported-provider';
const endpointType = 'unsupported-endpoint' as any;
const result =
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
expect(result).toBe(false);
});
});
describe('supported providers', () => {
const supportedProviders = [
{ name: 'OpenAI', value: EModelEndpoint.openAI },
{ name: 'Anthropic', value: EModelEndpoint.anthropic },
{ name: 'Google', value: EModelEndpoint.google },
{ name: 'Azure OpenAI', value: EModelEndpoint.azureOpenAI },
{ name: 'Custom', value: EModelEndpoint.custom },
];
supportedProviders.forEach(({ name, value }) => {
it(`should recognize ${name} as supported`, () => {
expect(isDocumentSupportedProvider(value)).toBe(true);
});
});
});
describe('real-world scenarios', () => {
it('should handle LiteLLM gateway pointing to OpenAI', () => {
const scenario = {
currentProvider: 'litellm',
endpointType: EModelEndpoint.openAI,
};
expect(
isDocumentSupportedProvider(scenario.endpointType) ||
isDocumentSupportedProvider(scenario.currentProvider),
).toBe(true);
});
it('should handle direct OpenAI connection', () => {
const scenario = {
currentProvider: EModelEndpoint.openAI,
endpointType: EModelEndpoint.openAI,
};
expect(
isDocumentSupportedProvider(scenario.endpointType) ||
isDocumentSupportedProvider(scenario.currentProvider),
).toBe(true);
});
it('should handle unsupported custom endpoint without override', () => {
const scenario = {
currentProvider: 'my-unsupported-endpoint',
endpointType: undefined,
};
expect(
isDocumentSupportedProvider(scenario.endpointType) ||
isDocumentSupportedProvider(scenario.currentProvider),
).toBe(false);
});
it('should handle agents endpoints with document supported providers', () => {
const scenario = {
currentProvider: EModelEndpoint.google,
endpointType: EModelEndpoint.agents,
};
expect(
isDocumentSupportedProvider(scenario.endpointType) ||
isDocumentSupportedProvider(scenario.currentProvider),
).toBe(true);
});
});
});

View File

@@ -0,0 +1,347 @@
import React from 'react';
import { render, screen } from '@testing-library/react';
import '@testing-library/jest-dom';
import { FileSources } from 'librechat-data-provider';
import type { ExtendedFile } from '~/common';
import FileRow from '../FileRow';
jest.mock('~/hooks', () => ({
useLocalize: jest.fn(),
}));
jest.mock('~/data-provider', () => ({
useDeleteFilesMutation: jest.fn(),
}));
jest.mock('~/hooks/Files', () => ({
useFileDeletion: jest.fn(),
}));
jest.mock('~/utils', () => ({
logger: {
log: jest.fn(),
},
}));
jest.mock('../Image', () => {
return function MockImage({ url, progress, source }: any) {
return (
<div data-testid="mock-image">
<span data-testid="image-url">{url}</span>
<span data-testid="image-progress">{progress}</span>
<span data-testid="image-source">{source}</span>
</div>
);
};
});
jest.mock('../FileContainer', () => {
return function MockFileContainer({ file }: any) {
return (
<div data-testid="mock-file-container">
<span data-testid="file-name">{file.filename}</span>
</div>
);
};
});
const mockUseLocalize = jest.requireMock('~/hooks').useLocalize;
const mockUseDeleteFilesMutation = jest.requireMock('~/data-provider').useDeleteFilesMutation;
const mockUseFileDeletion = jest.requireMock('~/hooks/Files').useFileDeletion;
describe('FileRow', () => {
const mockSetFiles = jest.fn();
const mockSetFilesLoading = jest.fn();
const mockDeleteFile = jest.fn();
beforeEach(() => {
jest.clearAllMocks();
mockUseLocalize.mockReturnValue((key: string) => {
const translations: Record<string, string> = {
com_ui_deleting_file: 'Deleting file...',
};
return translations[key] || key;
});
mockUseDeleteFilesMutation.mockReturnValue({
mutateAsync: jest.fn(),
});
mockUseFileDeletion.mockReturnValue({
deleteFile: mockDeleteFile,
});
});
/**
* Creates a mock ExtendedFile with sensible defaults
*/
const createMockFile = (overrides: Partial<ExtendedFile> = {}): ExtendedFile => ({
file_id: 'test-file-id',
type: 'image/png',
preview: 'blob:http://localhost:3080/preview-blob-url',
filepath: '/images/user123/test-file-id__image.png',
filename: 'test-image.png',
progress: 1,
size: 1024,
source: FileSources.local,
...overrides,
});
const renderFileRow = (files: Map<string, ExtendedFile>) => {
return render(
<FileRow files={files} setFiles={mockSetFiles} setFilesLoading={mockSetFilesLoading} />,
);
};
describe('Image URL Selection Logic', () => {
it('should use filepath instead of preview when progress is 1 (upload complete)', () => {
const file = createMockFile({
file_id: 'uploaded-file',
preview: 'blob:http://localhost:3080/temp-preview',
filepath: '/images/user123/uploaded-file__image.png',
progress: 1,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('/images/user123/uploaded-file__image.png');
expect(imageUrl).not.toContain('blob:');
});
it('should use preview when progress is less than 1 (uploading)', () => {
const file = createMockFile({
file_id: 'uploading-file',
preview: 'blob:http://localhost:3080/temp-preview',
filepath: undefined,
progress: 0.5,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('blob:http://localhost:3080/temp-preview');
});
it('should fallback to filepath when preview is undefined and progress is less than 1', () => {
const file = createMockFile({
file_id: 'file-without-preview',
preview: undefined,
filepath: '/images/user123/file-without-preview__image.png',
progress: 0.7,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('/images/user123/file-without-preview__image.png');
});
it('should use filepath when both preview and filepath exist and progress is exactly 1', () => {
const file = createMockFile({
file_id: 'complete-file',
preview: 'blob:http://localhost:3080/old-blob',
filepath: '/images/user123/complete-file__image.png',
progress: 1.0,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).toBe('/images/user123/complete-file__image.png');
});
});
describe('Progress States', () => {
it('should pass correct progress value during upload', () => {
const file = createMockFile({
progress: 0.65,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const progress = screen.getByTestId('image-progress').textContent;
expect(progress).toBe('0.65');
});
it('should pass progress value of 1 when upload is complete', () => {
const file = createMockFile({
progress: 1,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const progress = screen.getByTestId('image-progress').textContent;
expect(progress).toBe('1');
});
});
describe('File Source', () => {
it('should pass local source to Image component', () => {
const file = createMockFile({
source: FileSources.local,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const source = screen.getByTestId('image-source').textContent;
expect(source).toBe(FileSources.local);
});
it('should pass openai source to Image component', () => {
const file = createMockFile({
source: FileSources.openai,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const source = screen.getByTestId('image-source').textContent;
expect(source).toBe(FileSources.openai);
});
});
describe('File Type Detection', () => {
it('should render Image component for image files', () => {
const file = createMockFile({
type: 'image/jpeg',
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
expect(screen.getByTestId('mock-image')).toBeInTheDocument();
expect(screen.queryByTestId('mock-file-container')).not.toBeInTheDocument();
});
it('should render FileContainer for non-image files', () => {
const file = createMockFile({
type: 'application/pdf',
filename: 'document.pdf',
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
expect(screen.getByTestId('mock-file-container')).toBeInTheDocument();
expect(screen.queryByTestId('mock-image')).not.toBeInTheDocument();
});
});
describe('Multiple Files', () => {
it('should render multiple image files with correct URLs based on their progress', () => {
const filesMap = new Map<string, ExtendedFile>();
const uploadingFile = createMockFile({
file_id: 'file-1',
preview: 'blob:http://localhost:3080/preview-1',
filepath: undefined,
progress: 0.3,
});
const completedFile = createMockFile({
file_id: 'file-2',
preview: 'blob:http://localhost:3080/preview-2',
filepath: '/images/user123/file-2__image.png',
progress: 1,
});
filesMap.set(uploadingFile.file_id, uploadingFile);
filesMap.set(completedFile.file_id, completedFile);
renderFileRow(filesMap);
const images = screen.getAllByTestId('mock-image');
expect(images).toHaveLength(2);
const urls = screen.getAllByTestId('image-url').map((el) => el.textContent);
expect(urls).toContain('blob:http://localhost:3080/preview-1');
expect(urls).toContain('/images/user123/file-2__image.png');
});
it('should deduplicate files with the same file_id', () => {
const filesMap = new Map<string, ExtendedFile>();
const file1 = createMockFile({ file_id: 'duplicate-id' });
const file2 = createMockFile({ file_id: 'duplicate-id' });
filesMap.set('key-1', file1);
filesMap.set('key-2', file2);
renderFileRow(filesMap);
const images = screen.getAllByTestId('mock-image');
expect(images).toHaveLength(1);
});
});
describe('Empty State', () => {
it('should render nothing when files map is empty', () => {
const filesMap = new Map<string, ExtendedFile>();
const { container } = renderFileRow(filesMap);
expect(container.firstChild).toBeNull();
});
it('should render nothing when files is undefined', () => {
const { container } = render(
<FileRow files={undefined} setFiles={mockSetFiles} setFilesLoading={mockSetFilesLoading} />,
);
expect(container.firstChild).toBeNull();
});
});
describe('Regression: Blob URL Bug Fix', () => {
it('should NOT use revoked blob URL after upload completes', () => {
const file = createMockFile({
file_id: 'regression-test',
preview: 'blob:http://localhost:3080/d25f730c-152d-41f7-8d79-c9fa448f606b',
filepath:
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
progress: 1,
});
const filesMap = new Map<string, ExtendedFile>();
filesMap.set(file.file_id, file);
renderFileRow(filesMap);
const imageUrl = screen.getByTestId('image-url').textContent;
expect(imageUrl).not.toContain('blob:');
expect(imageUrl).toBe(
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
);
});
});
});

View File

@@ -1,5 +1,4 @@
import { memo, useMemo, useState } from 'react';
import { useRecoilState } from 'recoil';
import { memo, useMemo } from 'react';
import { ContentTypes } from 'librechat-data-provider';
import type {
TMessageContentParts,
@@ -7,14 +6,11 @@ import type {
TAttachment,
Agents,
} from 'librechat-data-provider';
import { ThinkingButton } from '~/components/Artifacts/Thinking';
import { MessageContext, SearchContext } from '~/Providers';
import MemoryArtifacts from './MemoryArtifacts';
import Sources from '~/components/Web/Sources';
import { mapAttachments } from '~/utils/map';
import { EditTextPart } from './Parts';
import { useLocalize } from '~/hooks';
import store from '~/store';
import Part from './Part';
type ContentPartsProps = {
@@ -52,32 +48,10 @@ const ContentParts = memo(
siblingIdx,
setSiblingIdx,
}: ContentPartsProps) => {
const localize = useLocalize();
const [showThinking, setShowThinking] = useRecoilState<boolean>(store.showThinking);
const [isExpanded, setIsExpanded] = useState(showThinking);
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const hasReasoningParts = useMemo(() => {
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
const allThinkPartsHaveContent =
content?.every((part) => {
if (part?.type !== ContentTypes.THINK) {
return true;
}
if (typeof part.think === 'string') {
const cleanedContent = part.think.replace(/<\/?think>/g, '').trim();
return cleanedContent.length > 0;
}
return false;
}) ?? false;
return hasThinkPart && allThinkPartsHaveContent;
}, [content]);
if (!content) {
return null;
}
@@ -126,57 +100,40 @@ const ContentParts = memo(
<SearchContext.Provider value={{ searchResults }}>
<MemoryArtifacts attachments={attachments} />
<Sources messageId={messageId} conversationId={conversationId || undefined} />
{hasReasoningParts && (
<div className="mb-5">
<ThinkingButton
isExpanded={isExpanded}
onClick={() =>
setIsExpanded((prev) => {
const val = !prev;
setShowThinking(val);
return val;
})
}
label={
effectiveIsSubmitting && isLast
? localize('com_ui_thinking')
: localize('com_ui_thoughts')
}
/>
</div>
)}
{content
.filter((part) => part)
.map((part, idx) => {
const toolCallId =
(part?.[ContentTypes.TOOL_CALL] as Agents.ToolCall | undefined)?.id ?? '';
const attachments = attachmentMap[toolCallId];
{content.map((part, idx) => {
if (!part) {
return null;
}
return (
<MessageContext.Provider
key={`provider-${messageId}-${idx}`}
value={{
messageId,
isExpanded,
conversationId,
partIndex: idx,
nextType: content[idx + 1]?.type,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
<Part
part={part}
attachments={attachments}
isSubmitting={effectiveIsSubmitting}
key={`part-${messageId}-${idx}`}
isCreatedByUser={isCreatedByUser}
isLast={idx === content.length - 1}
showCursor={idx === content.length - 1 && isLast}
/>
</MessageContext.Provider>
);
})}
const toolCallId =
(part?.[ContentTypes.TOOL_CALL] as Agents.ToolCall | undefined)?.id ?? '';
const partAttachments = attachmentMap[toolCallId];
return (
<MessageContext.Provider
key={`provider-${messageId}-${idx}`}
value={{
messageId,
isExpanded: true,
conversationId,
partIndex: idx,
nextType: content[idx + 1]?.type,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
<Part
part={part}
attachments={partAttachments}
isSubmitting={effectiveIsSubmitting}
key={`part-${messageId}-${idx}`}
isCreatedByUser={isCreatedByUser}
isLast={idx === content.length - 1}
showCursor={idx === content.length - 1 && isLast}
/>
</MessageContext.Provider>
);
})}
</SearchContext.Provider>
</>
);

View File

@@ -151,7 +151,7 @@ const EditMessage = ({
return (
<Container message={message}>
<div className="bg-token-main-surface-primary relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border border-border-medium text-text-primary [&:has(textarea:focus)]:border-border-heavy [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]">
<div className="bg-token-main-surface-primary relative mt-2 flex w-full flex-grow flex-col overflow-hidden rounded-2xl border border-border-medium text-text-primary [&:has(textarea:focus)]:border-border-heavy [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]">
<TextareaAutosize
{...registerProps}
ref={(e) => {

View File

@@ -4,67 +4,89 @@ import { DelayedRender } from '@librechat/client';
import type { TMessage } from 'librechat-data-provider';
import type { TMessageContentProps, TDisplayProps } from '~/common';
import Error from '~/components/Messages/Content/Error';
import Thinking from '~/components/Artifacts/Thinking';
import { useMessageContext } from '~/Providers';
import MarkdownLite from './MarkdownLite';
import EditMessage from './EditMessage';
import Thinking from './Parts/Thinking';
import { useLocalize } from '~/hooks';
import Container from './Container';
import Markdown from './Markdown';
import { cn } from '~/utils';
import store from '~/store';
const ERROR_CONNECTION_TEXT = 'Error connecting to server, try refreshing the page.';
const DELAYED_ERROR_TIMEOUT = 5500;
const UNFINISHED_DELAY = 250;
const parseThinkingContent = (text: string) => {
const thinkingMatch = text.match(/:::thinking([\s\S]*?):::/);
return {
thinkingContent: thinkingMatch ? thinkingMatch[1].trim() : '',
regularContent: thinkingMatch ? text.replace(/:::thinking[\s\S]*?:::/, '').trim() : text,
};
};
const LoadingFallback = () => (
<div className="text-message mb-[0.625rem] flex min-h-[20px] flex-col items-start gap-3 overflow-visible">
<div className="markdown prose dark:prose-invert light w-full break-words dark:text-gray-100">
<div className="absolute">
<p className="submitting relative">
<span className="result-thinking" />
</p>
</div>
</div>
</div>
);
const ErrorBox = ({
children,
className = '',
}: {
children: React.ReactNode;
className?: string;
}) => (
<div
role="alert"
aria-live="assertive"
className={cn(
'rounded-xl border border-red-500/20 bg-red-500/5 px-3 py-2 text-sm text-gray-600 dark:text-gray-200',
className,
)}
>
{children}
</div>
);
const ConnectionError = ({ message }: { message?: TMessage }) => {
const localize = useLocalize();
return (
<Suspense fallback={<LoadingFallback />}>
<DelayedRender delay={DELAYED_ERROR_TIMEOUT}>
<Container message={message}>
<div className="mt-2 rounded-xl border border-red-500/20 bg-red-50/50 px-4 py-3 text-sm text-red-700 shadow-sm transition-all dark:bg-red-950/30 dark:text-red-100">
{localize('com_ui_error_connection')}
</div>
</Container>
</DelayedRender>
</Suspense>
);
};
export const ErrorMessage = ({
text,
message,
className = '',
}: Pick<TDisplayProps, 'text' | 'className'> & {
message?: TMessage;
}) => {
const localize = useLocalize();
if (text === 'Error connecting to server, try refreshing the page.') {
console.log('error message', message);
return (
<Suspense
fallback={
<div className="text-message mb-[0.625rem] flex min-h-[20px] flex-col items-start gap-3 overflow-visible">
<div className="markdown prose dark:prose-invert light w-full break-words dark:text-gray-100">
<div className="absolute">
<p className="submitting relative">
<span className="result-thinking" />
</p>
</div>
</div>
</div>
}
>
<DelayedRender delay={5500}>
<Container message={message}>
<div
className={cn(
'rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200',
className,
)}
>
{localize('com_ui_error_connection')}
</div>
</Container>
</DelayedRender>
</Suspense>
);
}: Pick<TDisplayProps, 'text' | 'className'> & { message?: TMessage }) => {
if (text === ERROR_CONNECTION_TEXT) {
return <ConnectionError message={message} />;
}
return (
<Container message={message}>
<div
role="alert"
aria-live="assertive"
className={cn(
'rounded-xl border border-red-500/20 bg-red-500/5 px-3 py-2 text-sm text-gray-600 dark:text-gray-200',
className,
)}
>
<ErrorBox className={className}>
<Error text={text} />
</div>
</ErrorBox>
</Container>
);
};
@@ -72,27 +94,29 @@ export const ErrorMessage = ({
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(
() => showCursor === true && isSubmitting,
[showCursor, isSubmitting],
);
let content: React.ReactElement;
if (!isCreatedByUser) {
content = <Markdown content={text} isLatestMessage={isLatestMessage} />;
} else if (enableUserMsgMarkdown) {
content = <MarkdownLite content={text} />;
} else {
content = <>{text}</>;
}
const content = useMemo(() => {
if (!isCreatedByUser) {
return <Markdown content={text} isLatestMessage={isLatestMessage} />;
}
if (enableUserMsgMarkdown) {
return <MarkdownLite content={text} />;
}
return <>{text}</>;
}, [isCreatedByUser, enableUserMsgMarkdown, text, isLatestMessage]);
return (
<Container message={message}>
<div
className={cn(
isSubmitting ? 'submitting' : '',
showCursorState && !!text.length ? 'result-streaming' : '',
'markdown prose message-content dark:prose-invert light w-full break-words',
isSubmitting && 'submitting',
showCursorState && text.length > 0 && 'result-streaming',
isCreatedByUser && !enableUserMsgMarkdown && 'whitespace-pre-wrap',
isCreatedByUser ? 'dark:text-gray-20' : 'dark:text-gray-100',
)}
@@ -103,7 +127,6 @@ const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplay
);
};
// Unfinished Message Component
export const UnfinishedMessage = ({ message }: { message: TMessage }) => (
<ErrorMessage
message={message}
@@ -123,21 +146,14 @@ const MessageContent = ({
const { message } = props;
const { messageId } = message;
const { thinkingContent, regularContent } = useMemo(() => {
const thinkingMatch = text.match(/:::thinking([\s\S]*?):::/);
return {
thinkingContent: thinkingMatch ? thinkingMatch[1].trim() : '',
regularContent: thinkingMatch ? text.replace(/:::thinking[\s\S]*?:::/, '').trim() : text,
};
}, [text]);
const { thinkingContent, regularContent } = useMemo(() => parseThinkingContent(text), [text]);
const showRegularCursor = useMemo(() => isLast && isSubmitting, [isLast, isSubmitting]);
const unfinishedMessage = useMemo(
() =>
!isSubmitting && unfinished ? (
<Suspense>
<DelayedRender delay={250}>
<DelayedRender delay={UNFINISHED_DELAY}>
<UnfinishedMessage message={message} />
</DelayedRender>
</Suspense>
@@ -146,8 +162,10 @@ const MessageContent = ({
);
if (error) {
return <ErrorMessage message={props.message} text={text} />;
} else if (edit) {
return <ErrorMessage message={message} text={text} />;
}
if (edit) {
return <EditMessage text={text} isSubmitting={isSubmitting} {...props} />;
}

View File

@@ -65,6 +65,10 @@ const Part = memo(
if (part.tool_call_ids != null && !text) {
return null;
}
/** Skip rendering if text is only whitespace to avoid empty Container */
if (!isLast && text.length > 0 && /^\s*$/.test(text)) {
return null;
}
return (
<Container>
<Text text={text} isCreatedByUser={isCreatedByUser} showCursor={showCursor} />
@@ -75,7 +79,7 @@ const Part = memo(
if (typeof reasoning !== 'string') {
return null;
}
return <Reasoning reasoning={reasoning} />;
return <Reasoning reasoning={reasoning} isLast={isLast ?? false} />;
} else if (part.type === ContentTypes.TOOL_CALL) {
const toolCall = part[ContentTypes.TOOL_CALL];

View File

@@ -3,6 +3,7 @@ import { useForm } from 'react-hook-form';
import { TextareaAutosize } from '@librechat/client';
import { ContentTypes } from 'librechat-data-provider';
import { useRecoilState, useRecoilValue } from 'recoil';
import { Lightbulb, MessageSquare } from 'lucide-react';
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
import type { Agents } from 'librechat-data-provider';
import type { TEditProps } from '~/common';
@@ -153,6 +154,22 @@ const EditTextPart = ({
return (
<Container message={message}>
{part.type === ContentTypes.THINK && (
<div className="mt-2 flex items-center gap-1.5 text-xs text-text-secondary">
<span className="flex gap-2 rounded-lg bg-surface-tertiary px-1.5 py-1 font-medium">
<Lightbulb className="size-3.5" />
{localize('com_ui_thoughts')}
</span>
</div>
)}
{part.type !== ContentTypes.THINK && (
<div className="mt-2 flex items-center gap-1.5 text-xs text-text-secondary">
<span className="flex gap-2 rounded-lg bg-surface-tertiary px-1.5 py-1 font-medium">
<MessageSquare className="size-3.5" />
{localize('com_ui_response')}
</span>
</div>
)}
<div className="bg-token-main-surface-primary relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border border-border-medium text-text-primary [&:has(textarea:focus)]:border-border-heavy [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]">
<TextareaAutosize
{...registerProps}

View File

@@ -1,15 +1,47 @@
import { memo, useMemo } from 'react';
import { memo, useMemo, useState, useCallback } from 'react';
import { useAtom } from 'jotai';
import type { MouseEvent } from 'react';
import { ContentTypes } from 'librechat-data-provider';
import { ThinkingContent } from '~/components/Artifacts/Thinking';
import { ThinkingContent, ThinkingButton } from './Thinking';
import { showThinkingAtom } from '~/store/showThinking';
import { useMessageContext } from '~/Providers';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
type ReasoningProps = {
reasoning: string;
isLast: boolean;
};
const Reasoning = memo(({ reasoning }: ReasoningProps) => {
const { isExpanded, nextType } = useMessageContext();
/**
* Reasoning Component (MODERN SYSTEM)
*
* Used for structured content parts with ContentTypes.THINK type.
* This handles modern message format where content is an array of typed parts.
*
* Pattern: `{ content: [{ type: "think", think: "<think>content</think>" }, ...] }`
*
* Used by:
* - ContentParts.tsx → Part.tsx for structured messages
* - Agent/Assistant responses (OpenAI Assistants, custom agents)
* - O-series models (o1, o3) with reasoning capabilities
* - Modern Claude responses with thinking blocks
*
* Key differences from legacy Thinking.tsx:
* - Works with content parts array instead of plain text
* - Strips `<think>` tags instead of `:::thinking:::` markers
* - Each THINK part has its own independent toggle button
* - Can be interleaved with other content types
*
* For legacy text-based messages, see Thinking.tsx component.
*/
const Reasoning = memo(({ reasoning, isLast }: ReasoningProps) => {
const localize = useLocalize();
const [showThinking] = useAtom(showThinkingAtom);
const [isExpanded, setIsExpanded] = useState(showThinking);
const { isSubmitting, isLatestMessage, nextType } = useMessageContext();
// Strip <think> tags from the reasoning content (modern format)
const reasoningText = useMemo(() => {
return reasoning
.replace(/^<think>\s*/, '')
@@ -17,22 +49,45 @@ const Reasoning = memo(({ reasoning }: ReasoningProps) => {
.trim();
}, [reasoning]);
const handleClick = useCallback((e: MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
setIsExpanded((prev) => !prev);
}, []);
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const label = useMemo(
() =>
effectiveIsSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts'),
[effectiveIsSubmitting, localize, isLast],
);
if (!reasoningText) {
return null;
}
return (
<div
className={cn(
'grid transition-all duration-300 ease-out',
nextType !== ContentTypes.THINK && isExpanded && 'mb-8',
)}
style={{
gridTemplateRows: isExpanded ? '1fr' : '0fr',
}}
>
<div className="overflow-hidden">
<ThinkingContent isPart={true}>{reasoningText}</ThinkingContent>
<div className="group/reasoning">
<div className="sticky top-0 z-10 mb-2 bg-surface-secondary pb-2 pt-2">
<ThinkingButton
isExpanded={isExpanded}
onClick={handleClick}
label={label}
content={reasoningText}
/>
</div>
<div
className={cn(
'grid transition-all duration-300 ease-out',
nextType !== ContentTypes.THINK && isExpanded && 'mb-4',
)}
style={{
gridTemplateRows: isExpanded ? '1fr' : '0fr',
}}
>
<div className="overflow-hidden">
<ThinkingContent>{reasoningText}</ThinkingContent>
</div>
</div>
</div>
);

View File

@@ -0,0 +1,172 @@
import { useState, useMemo, memo, useCallback } from 'react';
import { useAtomValue } from 'jotai';
import { Lightbulb, ChevronDown } from 'lucide-react';
import { Clipboard, CheckMark } from '@librechat/client';
import type { MouseEvent, FC } from 'react';
import { showThinkingAtom } from '~/store/showThinking';
import { fontSizeAtom } from '~/store/fontSize';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
/**
* ThinkingContent - Displays the actual thinking/reasoning content
* Used by both legacy text-based messages and modern content parts
*/
export const ThinkingContent: FC<{
children: React.ReactNode;
}> = memo(({ children }) => {
const fontSize = useAtomValue(fontSizeAtom);
return (
<div className="relative rounded-3xl border border-border-medium bg-surface-tertiary p-4 text-text-secondary">
<p className={cn('whitespace-pre-wrap leading-[26px]', fontSize)}>{children}</p>
</div>
);
});
/**
* ThinkingButton - Toggle button for expanding/collapsing thinking content
* Shows lightbulb icon by default, chevron on hover
* Shared between legacy Thinking component and modern ContentParts
*/
export const ThinkingButton = memo(
({
isExpanded,
onClick,
label,
content,
}: {
isExpanded: boolean;
onClick: (e: MouseEvent<HTMLButtonElement>) => void;
label: string;
content?: string;
}) => {
const localize = useLocalize();
const fontSize = useAtomValue(fontSizeAtom);
const [isCopied, setIsCopied] = useState(false);
const handleCopy = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
if (content) {
navigator.clipboard.writeText(content);
setIsCopied(true);
setTimeout(() => setIsCopied(false), 2000);
}
},
[content],
);
return (
<div className="flex w-full items-center justify-between gap-2">
<button
type="button"
onClick={onClick}
className={cn(
'group/button flex flex-1 items-center justify-start rounded-lg leading-[18px]',
fontSize,
)}
>
<span className="relative mr-1.5 inline-flex h-[18px] w-[18px] items-center justify-center">
<Lightbulb className="icon-sm absolute text-text-secondary opacity-100 transition-opacity group-hover/button:opacity-0" />
<ChevronDown
className={cn(
'icon-sm absolute transform-gpu text-text-primary opacity-0 transition-all duration-300 group-hover/button:opacity-100',
isExpanded && 'rotate-180',
)}
/>
</span>
{label}
</button>
{content && (
<button
type="button"
onClick={handleCopy}
title={
isCopied
? localize('com_ui_copied_to_clipboard')
: localize('com_ui_copy_thoughts_to_clipboard')
}
className={cn(
'rounded-lg p-1.5 text-text-secondary-alt transition-colors duration-200',
'hover:bg-surface-hover hover:text-text-primary',
'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-black dark:focus-visible:ring-white',
)}
>
{isCopied ? <CheckMark className="h-[18px] w-[18px]" /> : <Clipboard size="19" />}
</button>
)}
</div>
);
},
);
/**
* Thinking Component (LEGACY SYSTEM)
*
* Used for simple text-based messages with `:::thinking:::` markers.
* This handles the old message format where text contains embedded thinking blocks.
*
* Pattern: `:::thinking\n{content}\n:::\n{response}`
*
* Used by:
* - MessageContent.tsx for plain text messages
* - Legacy message format compatibility
* - User messages when manually adding thinking content
*
* For modern structured content (agents/assistants), see Reasoning.tsx component.
*/
const Thinking: React.ElementType = memo(({ children }: { children: React.ReactNode }) => {
const localize = useLocalize();
const showThinking = useAtomValue(showThinkingAtom);
const [isExpanded, setIsExpanded] = useState(showThinking);
const handleClick = useCallback((e: MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
setIsExpanded((prev) => !prev);
}, []);
const label = useMemo(() => localize('com_ui_thoughts'), [localize]);
// Extract text content for copy functionality
const textContent = useMemo(() => {
if (typeof children === 'string') {
return children;
}
return '';
}, [children]);
if (children == null) {
return null;
}
return (
<>
<div className="sticky top-0 z-10 mb-4 bg-surface-primary pb-2 pt-2">
<ThinkingButton
isExpanded={isExpanded}
onClick={handleClick}
label={label}
content={textContent}
/>
</div>
<div
className={cn('grid transition-all duration-300 ease-out', isExpanded && 'mb-8')}
style={{
gridTemplateRows: isExpanded ? '1fr' : '0fr',
}}
>
<div className="overflow-hidden">
<ThinkingContent>{children}</ThinkingContent>
</div>
</div>
</>
);
});
ThinkingButton.displayName = 'ThinkingButton';
ThinkingContent.displayName = 'ThinkingContent';
Thinking.displayName = 'Thinking';
export default memo(Thinking);

View File

@@ -1,4 +1,5 @@
import { memo } from 'react';
import { showThinkingAtom } from '~/store/showThinking';
import FontSizeSelector from './FontSizeSelector';
import { ForkSettings } from './ForkSettings';
import ChatDirection from './ChatDirection';
@@ -28,7 +29,7 @@ const toggleSwitchConfigs = [
key: 'centerFormOnLanding',
},
{
stateAtom: store.showThinking,
stateAtom: showThinkingAtom,
localizationKey: 'com_nav_show_thinking',
switchId: 'showThinking',
hoverCardText: undefined,

View File

@@ -1,18 +1,18 @@
import { useRecoilState } from 'recoil';
import { useAtom } from 'jotai';
import { Switch, InfoHoverCard, ESide } from '@librechat/client';
import { showThinkingAtom } from '~/store/showThinking';
import { useLocalize } from '~/hooks';
import store from '~/store';
export default function SaveDraft({
onCheckedChange,
}: {
onCheckedChange?: (value: boolean) => void;
}) {
const [showThinking, setSaveDrafts] = useRecoilState<boolean>(store.showThinking);
const [showThinking, setShowThinking] = useAtom(showThinkingAtom);
const localize = useLocalize();
const handleCheckedChange = (value: boolean) => {
setSaveDrafts(value);
setShowThinking(value);
if (onCheckedChange) {
onCheckedChange(value);
}

View File

@@ -1,3 +1,4 @@
import { WritableAtom, useAtom } from 'jotai';
import { RecoilState, useRecoilState } from 'recoil';
import { Switch, InfoHoverCard, ESide } from '@librechat/client';
import { useLocalize } from '~/hooks';
@@ -6,7 +7,7 @@ type LocalizeFn = ReturnType<typeof useLocalize>;
type LocalizeKey = Parameters<LocalizeFn>[0];
interface ToggleSwitchProps {
stateAtom: RecoilState<boolean>;
stateAtom: RecoilState<boolean> | WritableAtom<boolean, [boolean], void>;
localizationKey: LocalizeKey;
hoverCardText?: LocalizeKey;
switchId: string;
@@ -16,13 +17,18 @@ interface ToggleSwitchProps {
strongLabel?: boolean;
}
const ToggleSwitch: React.FC<ToggleSwitchProps> = ({
function isRecoilState<T>(atom: unknown): atom is RecoilState<T> {
return atom != null && typeof atom === 'object' && 'key' in atom;
}
const RecoilToggle: React.FC<
Omit<ToggleSwitchProps, 'stateAtom'> & { stateAtom: RecoilState<boolean> }
> = ({
stateAtom,
localizationKey,
hoverCardText,
switchId,
onCheckedChange,
showSwitch = true,
disabled = false,
strongLabel = false,
}) => {
@@ -36,9 +42,47 @@ const ToggleSwitch: React.FC<ToggleSwitchProps> = ({
const labelId = `${switchId}-label`;
if (!showSwitch) {
return null;
}
return (
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<div id={labelId}>
{strongLabel ? <strong>{localize(localizationKey)}</strong> : localize(localizationKey)}
</div>
{hoverCardText && <InfoHoverCard side={ESide.Bottom} text={localize(hoverCardText)} />}
</div>
<Switch
id={switchId}
checked={switchState}
onCheckedChange={handleCheckedChange}
disabled={disabled}
className="ml-4"
data-testid={switchId}
aria-labelledby={labelId}
/>
</div>
);
};
const JotaiToggle: React.FC<
Omit<ToggleSwitchProps, 'stateAtom'> & { stateAtom: WritableAtom<boolean, [boolean], void> }
> = ({
stateAtom,
localizationKey,
hoverCardText,
switchId,
onCheckedChange,
disabled = false,
strongLabel = false,
}) => {
const [switchState, setSwitchState] = useAtom(stateAtom);
const localize = useLocalize();
const handleCheckedChange = (value: boolean) => {
setSwitchState(value);
onCheckedChange?.(value);
};
const labelId = `${switchId}-label`;
return (
<div className="flex items-center justify-between">
@@ -52,13 +96,29 @@ const ToggleSwitch: React.FC<ToggleSwitchProps> = ({
id={switchId}
checked={switchState}
onCheckedChange={handleCheckedChange}
disabled={disabled}
className="ml-4"
data-testid={switchId}
aria-labelledby={labelId}
disabled={disabled}
/>
</div>
);
};
const ToggleSwitch: React.FC<ToggleSwitchProps> = (props) => {
const { stateAtom, showSwitch = true } = props;
if (!showSwitch) {
return null;
}
const isRecoil = isRecoilState(stateAtom);
if (isRecoil) {
return <RecoilToggle {...props} stateAtom={stateAtom as RecoilState<boolean>} />;
}
return <JotaiToggle {...props} stateAtom={stateAtom as WritableAtom<boolean, [boolean], void>} />;
};
export default ToggleSwitch;

View File

@@ -787,6 +787,7 @@
"com_ui_copy_code": "Copy code",
"com_ui_copy_link": "Copy link",
"com_ui_copy_stack_trace": "Copy stack trace",
"com_ui_copy_thoughts_to_clipboard": "Copy thoughts to clipboard",
"com_ui_copy_to_clipboard": "Copy to clipboard",
"com_ui_copy_url_to_clipboard": "Copy URL to clipboard",
"com_ui_create": "Create",
@@ -1119,6 +1120,7 @@
"com_ui_reset_var": "Reset {{0}}",
"com_ui_reset_zoom": "Reset Zoom",
"com_ui_resource": "resource",
"com_ui_response": "Response",
"com_ui_result": "Result",
"com_ui_revoke": "Revoke",
"com_ui_revoke_info": "Revoke all user provided credentials",

View File

@@ -1,54 +1,21 @@
import { atom } from 'jotai';
import { atomWithStorage } from 'jotai/utils';
import { applyFontSize } from '@librechat/client';
import { createStorageAtomWithEffect, initializeFromStorage } from './jotai-utils';
const DEFAULT_FONT_SIZE = 'text-base';
/**
* Base storage atom for font size
* This atom stores the user's font size preference
*/
const fontSizeStorageAtom = atomWithStorage<string>('fontSize', DEFAULT_FONT_SIZE, undefined, {
getOnInit: true,
});
/**
* Derived atom that applies font size changes to the DOM
* Read: returns the current font size
* Write: updates storage and applies the font size to the DOM
*/
export const fontSizeAtom = atom(
(get) => get(fontSizeStorageAtom),
(get, set, newValue: string) => {
set(fontSizeStorageAtom, newValue);
if (typeof window !== 'undefined' && typeof document !== 'undefined') {
applyFontSize(newValue);
}
},
export const fontSizeAtom = createStorageAtomWithEffect<string>(
'fontSize',
DEFAULT_FONT_SIZE,
applyFontSize,
);
/**
* Initialize font size on app load
* This function applies the saved font size from localStorage to the DOM
*/
export const initializeFontSize = () => {
if (typeof window === 'undefined' || typeof document === 'undefined') {
return;
}
const savedValue = localStorage.getItem('fontSize');
if (savedValue !== null) {
try {
const parsedValue = JSON.parse(savedValue);
applyFontSize(parsedValue);
} catch (error) {
console.error(
'Error parsing localStorage key "fontSize", resetting to default. Error:',
error,
);
localStorage.setItem('fontSize', JSON.stringify(DEFAULT_FONT_SIZE));
applyFontSize(DEFAULT_FONT_SIZE);
}
} else {
applyFontSize(DEFAULT_FONT_SIZE);
}
export const initializeFontSize = (): void => {
initializeFromStorage('fontSize', DEFAULT_FONT_SIZE, applyFontSize);
};

View File

@@ -0,0 +1,88 @@
import { atom } from 'jotai';
import { atomWithStorage } from 'jotai/utils';
/**
* Create a simple atom with localStorage persistence
* Uses Jotai's atomWithStorage with getOnInit for proper SSR support
*
* @param key - localStorage key
* @param defaultValue - default value if no saved value exists
* @returns Jotai atom with localStorage persistence
*/
export function createStorageAtom<T>(key: string, defaultValue: T) {
return atomWithStorage<T>(key, defaultValue, undefined, {
getOnInit: true,
});
}
/**
* Create an atom with localStorage persistence and side effects
* Useful when you need to apply changes to the DOM or trigger other actions
*
* @param key - localStorage key
* @param defaultValue - default value if no saved value exists
* @param onWrite - callback function to run when the value changes
* @returns Jotai atom with localStorage persistence and side effects
*/
export function createStorageAtomWithEffect<T>(
key: string,
defaultValue: T,
onWrite: (value: T) => void,
) {
const baseAtom = createStorageAtom(key, defaultValue);
return atom(
(get) => get(baseAtom),
(get, set, newValue: T) => {
set(baseAtom, newValue);
if (typeof window !== 'undefined') {
onWrite(newValue);
}
},
);
}
/**
* Initialize a value from localStorage and optionally apply it
* Useful for applying saved values on app startup (e.g., theme, fontSize)
*
* @param key - localStorage key
* @param defaultValue - default value if no saved value exists
* @param onInit - optional callback to run with the loaded value
* @returns The loaded value (or default if none exists)
*/
export function initializeFromStorage<T>(
key: string,
defaultValue: T,
onInit?: (value: T) => void,
): T {
if (typeof window === 'undefined' || typeof localStorage === 'undefined') {
return defaultValue;
}
try {
const savedValue = localStorage.getItem(key);
const value = savedValue ? (JSON.parse(savedValue) as T) : defaultValue;
if (onInit) {
onInit(value);
}
return value;
} catch (error) {
console.error(`Error initializing ${key} from localStorage, using default. Error:`, error);
// Reset corrupted value
try {
localStorage.setItem(key, JSON.stringify(defaultValue));
} catch (setError) {
console.error(`Error resetting corrupted ${key} in localStorage:`, setError);
}
if (onInit) {
onInit(defaultValue);
}
return defaultValue;
}
}

View File

@@ -0,0 +1,8 @@
import { createStorageAtom } from './jotai-utils';
const DEFAULT_SHOW_THINKING = false;
/**
* This atom controls whether AI reasoning/thinking content is expanded by default.
*/
export const showThinkingAtom = createStorageAtom<boolean>('showThinking', DEFAULT_SHOW_THINKING);

View File

@@ -136,6 +136,16 @@ registration:
# apiKey: '${TTS_API_KEY}'
# model: ''
# voices: ['']
# azureOpenAI:
# # instanceName: The <NAME> part of your Azure endpoint URL
# # For example, if your endpoint is: https://my-instance.cognitiveservices.azure.com
# # Then instanceName should be: 'my-instance' (not the full URL)
# instanceName: '${AZURE_TTS_INSTANCE_NAME}'
# apiKey: '${AZURE_TTS_API_KEY}'
# deploymentName: '${AZURE_TTS_DEPLOYMENT_NAME}'
# apiVersion: '2024-02-01'
# model: 'tts-1'
# voices: ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer']
#
# stt:
@@ -143,6 +153,15 @@ registration:
# url: ''
# apiKey: '${STT_API_KEY}'
# model: ''
# azureOpenAI:
# # instanceName: The <NAME> part of your Azure endpoint URL
# # For example, if your endpoint is: https://my-instance.cognitiveservices.azure.com
# # Then instanceName should be: 'my-instance' (not the full URL)
# # Note: The code also supports full domain format for backward compatibility
# instanceName: '${AZURE_STT_INSTANCE_NAME}'
# apiKey: '${AZURE_STT_API_KEY}'
# deploymentName: '${AZURE_STT_DEPLOYMENT_NAME}'
# apiVersion: '2024-02-01'
# rateLimits:
# fileUploads:

View File

@@ -1,7 +1,13 @@
export default {
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'],
testPathIgnorePatterns: [
'/node_modules/',
'/dist/',
'\\.dev\\.ts$',
'\\.helper\\.ts$',
'\\.helper\\.d\\.ts$',
],
coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit',
moduleNameMapper: {
@@ -18,4 +24,4 @@ export default {
// },
restoreMocks: true,
testTimeout: 15000,
};
};

View File

@@ -18,9 +18,11 @@
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.integration\\.\"",
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.integration\\.\"",
"test:cache:integration": "jest --testPathPattern=\"src/cache/.*\\.integration\\.spec\\.ts$\" --coverage=false",
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
"test:cache-integration:core": "jest --testPathPattern=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration:cluster": "jest --testPathPattern=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
"test:cache-integration:mcp": "jest --testPathPattern=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"verify": "npm run test:ci",
"b:clean": "bun run rimraf dist",
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs",

View File

@@ -1,11 +1,14 @@
import type { Keyv } from 'keyv';
// Mock GLOBAL_PREFIX_SEPARATOR
jest.mock('../../redisClients', () => {
const originalModule = jest.requireActual('../../redisClients');
// Mock GLOBAL_PREFIX_SEPARATOR from cacheConfig
jest.mock('../../cacheConfig', () => {
const originalModule = jest.requireActual('../../cacheConfig');
return {
...originalModule,
GLOBAL_PREFIX_SEPARATOR: '>>',
cacheConfig: {
...originalModule.cacheConfig,
GLOBAL_PREFIX_SEPARATOR: '>>',
},
};
});

View File

@@ -65,6 +65,7 @@ const cacheConfig = {
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: getRedisCA(),
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR ?? ''] || REDIS_KEY_PREFIX || '',
GLOBAL_PREFIX_SEPARATOR: '::',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),
/** Max delay between reconnection attempts in ms */

View File

@@ -14,7 +14,7 @@ import { logger } from '@librechat/data-schemas';
import session, { MemoryStore } from 'express-session';
import { RedisStore as ConnectRedis } from 'connect-redis';
import type { SendCommandFn } from 'rate-limit-redis';
import { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } from './redisClients';
import { keyvRedisClient, ioredisClient } from './redisClients';
import { cacheConfig } from './cacheConfig';
import { violationFile } from './keyvFiles';
@@ -31,7 +31,7 @@ export const standardCache = (namespace: string, ttl?: number, fallbackStore?: o
const keyvRedis = new KeyvRedis(keyvRedisClient);
const cache = new Keyv(keyvRedis, { namespace, ttl });
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
keyvRedis.keyPrefixSeparator = cacheConfig.GLOBAL_PREFIX_SEPARATOR;
cache.on('error', (err) => {
logger.error(`Cache error in namespace ${namespace}:`, err);

View File

@@ -5,8 +5,6 @@ import { createClient, createCluster } from '@keyv/redis';
import type { RedisClientType, RedisClusterType } from '@redis/client';
import { cacheConfig } from './cacheConfig';
const GLOBAL_PREFIX_SEPARATOR = '::';
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri)) || [];
const username = urls?.[0]?.username || cacheConfig.REDIS_USERNAME;
const password = urls?.[0]?.password || cacheConfig.REDIS_PASSWORD;
@@ -18,7 +16,7 @@ if (cacheConfig.USE_REDIS) {
username: username,
password: password,
tls: ca ? { ca } : undefined,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${cacheConfig.GLOBAL_PREFIX_SEPARATOR}`,
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
retryStrategy: (times: number) => {
if (
@@ -192,4 +190,4 @@ if (cacheConfig.USE_REDIS) {
});
}
export { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };
export { ioredisClient, keyvRedisClient };

View File

@@ -0,0 +1,180 @@
import { keyvRedisClient } from '~/cache/redisClients';
import { cacheConfig as cache } from '~/cache/cacheConfig';
import { clusterConfig as cluster } from './config';
import { logger } from '@librechat/data-schemas';
/**
* Distributed leader election implementation using Redis for coordination across multiple server instances.
*
* Leadership election:
* - During bootup, every server attempts to become the leader by calling isLeader()
* - Uses atomic Redis SET NX (set if not exists) to ensure only ONE server can claim leadership
* - The first server to successfully set the key becomes the leader; others become followers
* - Works with any number of servers (1 to infinite) - single server always becomes leader
*
* Leadership maintenance:
* - Leader holds a key in Redis with a 25-second lease duration
* - Leader renews this lease every 10 seconds to maintain leadership
* - If leader crashes, the lease eventually expires, and the key disappears
* - On shutdown, leader deletes its key to allow immediate re-election
* - Followers check for leadership and attempt to claim it when the key is empty
*/
export class LeaderElection {
// We can't use Keyv namespace here because we need direct Redis access for atomic operations
static readonly LEADER_KEY = `${cache.REDIS_KEY_PREFIX}${cache.GLOBAL_PREFIX_SEPARATOR}LeadingServerUUID`;
private static _instance = new LeaderElection();
readonly UUID: string = crypto.randomUUID();
private refreshTimer: NodeJS.Timeout | undefined = undefined;
// DO NOT create new instances of this class directly.
// Use the exported isLeader() function which uses a singleton instance.
constructor() {
if (LeaderElection._instance) return LeaderElection._instance;
process.on('SIGTERM', () => this.resign());
process.on('SIGINT', () => this.resign());
LeaderElection._instance = this;
}
/**
* Checks if this instance is the current leader.
* If no leader exists, waits upto 2 seconds (randomized to avoid thundering herd) then attempts self-election.
* Always returns true in non-Redis mode (single-instance deployment).
*/
public async isLeader(): Promise<boolean> {
if (!cache.USE_REDIS) return true;
try {
const currentLeader = await LeaderElection.getLeaderUUID();
// If we own the leadership lock, return true.
// However, in case the leadership refresh retries have been exhausted, something has gone wrong.
// This server is not considered the leader anymore, similar to a crash, to avoid split-brain scenario.
if (currentLeader === this.UUID) return this.refreshTimer != null;
if (currentLeader != null) return false; // someone holds leadership lock
const delay = Math.random() * 2000;
await new Promise((resolve) => setTimeout(resolve, delay));
return await this.electSelf();
} catch (error) {
logger.error('Failed to check leadership status:', error);
return false;
}
}
/**
* Steps down from leadership by stopping the refresh timer and releasing the leader key.
* Atomically deletes the leader key (only if we still own it) so another server can become leader immediately.
*/
public async resign(): Promise<void> {
if (!cache.USE_REDIS) return;
try {
this.clearRefreshTimer();
// Lua script for atomic check-and-delete (only delete if we still own it)
const script = `
if redis.call("get", KEYS[1]) == ARGV[1] then
redis.call("del", KEYS[1])
end
`;
await keyvRedisClient!.eval(script, {
keys: [LeaderElection.LEADER_KEY],
arguments: [this.UUID],
});
} catch (error) {
logger.error('Failed to release leadership lock:', error);
}
}
/**
* Gets the UUID of the current leader from Redis.
* Returns null if no leader exists or in non-Redis mode.
* Useful for testing and observability.
*/
public static async getLeaderUUID(): Promise<string | null> {
if (!cache.USE_REDIS) return null;
return await keyvRedisClient!.get(LeaderElection.LEADER_KEY);
}
/**
* Clears the refresh timer to stop leadership maintenance.
* Called when resigning or failing to refresh leadership.
* Calling this directly to simulate a crash in testing.
*/
public clearRefreshTimer(): void {
clearInterval(this.refreshTimer);
this.refreshTimer = undefined;
}
/**
* Attempts to claim leadership using atomic Redis SET NX (set if not exists).
* If successful, starts a refresh timer to maintain leadership by extending the lease duration.
* The NX flag ensures only one server can become leader even if multiple attempt simultaneously.
*/
private async electSelf(): Promise<boolean> {
try {
const result = await keyvRedisClient!.set(LeaderElection.LEADER_KEY, this.UUID, {
NX: true,
EX: cluster.LEADER_LEASE_DURATION,
});
if (result !== 'OK') return false;
this.clearRefreshTimer();
this.refreshTimer = setInterval(async () => {
await this.renewLeadership();
}, cluster.LEADER_RENEW_INTERVAL * 1000);
this.refreshTimer.unref();
return true;
} catch (error) {
logger.error('Leader election failed:', error);
return false;
}
}
/**
* Renews leadership by extending the lease duration on the leader key.
* Uses Lua script to atomically verify we still own the key before renewing (prevents race conditions).
* If we've lost leadership (key was taken by another server), stops the refresh timer.
* This is called every 10 seconds by the refresh timer.
*/
private async renewLeadership(attempts: number = 1): Promise<void> {
try {
// Lua script for atomic check-and-renew
const script = `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
`;
const result = await keyvRedisClient!.eval(script, {
keys: [LeaderElection.LEADER_KEY],
arguments: [this.UUID, cluster.LEADER_LEASE_DURATION.toString()],
});
if (result === 0) {
logger.warn('Lost leadership, clearing refresh timer');
this.clearRefreshTimer();
}
} catch (error) {
logger.error(`Failed to renew leadership (attempts No.${attempts}):`, error);
if (attempts <= cluster.LEADER_RENEW_ATTEMPTS) {
await new Promise((resolve) =>
setTimeout(resolve, cluster.LEADER_RENEW_RETRY_DELAY * 1000),
);
await this.renewLeadership(attempts + 1);
} else {
logger.error('Exceeded maximum attempts to renew leadership.');
this.clearRefreshTimer();
}
}
}
}
const defaultElection = new LeaderElection();
export const isLeader = (): Promise<boolean> => defaultElection.isLeader();

View File

@@ -0,0 +1,220 @@
import { expect } from '@playwright/test';
describe('LeaderElection with Redis', () => {
let LeaderElection: typeof import('../LeaderElection').LeaderElection;
let instances: InstanceType<typeof import('../LeaderElection').LeaderElection>[] = [];
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
let ioredisClient: Awaited<typeof import('~/cache/redisClients')>['ioredisClient'];
beforeAll(async () => {
// Set up environment variables for Redis
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = 'LeaderElection-IntegrationTest';
// Import modules after setting env vars
const leaderElectionModule = await import('../LeaderElection');
const redisClients = await import('~/cache/redisClients');
LeaderElection = leaderElectionModule.LeaderElection;
keyvRedisClient = redisClients.keyvRedisClient;
ioredisClient = redisClients.ioredisClient;
// Ensure Redis is connected
if (!keyvRedisClient) {
throw new Error('Redis client is not initialized');
}
// Wait for Redis to be ready
if (!keyvRedisClient.isOpen) {
await keyvRedisClient.connect();
}
// Increase max listeners to handle many instances in tests
process.setMaxListeners(200);
});
afterEach(async () => {
await Promise.all(instances.map((instance) => instance.resign()));
instances = [];
// Clean up: clear the leader key directly from Redis
if (keyvRedisClient) {
await keyvRedisClient.del(LeaderElection.LEADER_KEY);
}
});
afterAll(async () => {
// Close both Redis clients to prevent hanging
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
if (ioredisClient?.status === 'ready') await ioredisClient.quit();
});
describe('Test Case 1: Simulate shutdown of the leader', () => {
it('should elect a new leader after the current leader resigns', async () => {
// Create 100 instances
instances = Array.from({ length: 100 }, () => new LeaderElection());
// Call isLeader on all instances and get leadership status
const resultsWithInstances = await Promise.all(
instances.map(async (instance) => ({
instance,
isLeader: await instance.isLeader(),
})),
);
// Find leader and followers
const leaders = resultsWithInstances.filter((r) => r.isLeader);
const followers = resultsWithInstances.filter((r) => !r.isLeader);
const leader = leaders[0].instance;
const nextLeader = followers[0].instance;
// Verify only one is leader
expect(leaders.length).toBe(1);
// Verify getLeaderUUID matches the leader's UUID
expect(await LeaderElection.getLeaderUUID()).toBe(leader.UUID);
// Leader resigns
await leader.resign();
// Verify getLeaderUUID returns null after resignation
expect(await LeaderElection.getLeaderUUID()).toBeNull();
// Next instance to call isLeader should become the new leader
expect(await nextLeader.isLeader()).toBe(true);
}, 30000); // 30 second timeout for 100 instances
});
describe('Test Case 2: Simulate crash of the leader', () => {
it('should allow re-election after leader crashes (lease expires)', async () => {
// Mock config with short lease duration
const clusterConfigModule = await import('../config');
const originalConfig = { ...clusterConfigModule.clusterConfig };
// Override config values for this test
Object.assign(clusterConfigModule.clusterConfig, {
LEADER_LEASE_DURATION: 2,
LEADER_RENEW_INTERVAL: 4,
});
try {
// Create 1 instance with mocked config
const instance = new LeaderElection();
instances.push(instance);
// Become leader
expect(await instance.isLeader()).toBe(true);
// Verify leader UUID is set
expect(await LeaderElection.getLeaderUUID()).toBe(instance.UUID);
// Simulate crash by clearing refresh timer
instance.clearRefreshTimer();
// The instance no longer considers itself leader even though it still holds the key
expect(await LeaderElection.getLeaderUUID()).toBe(instance.UUID);
expect(await instance.isLeader()).toBe(false);
// Wait for lease to expire (3 seconds > 2 second lease)
await new Promise((resolve) => setTimeout(resolve, 3000));
// Verify leader UUID is null after lease expiration
expect(await LeaderElection.getLeaderUUID()).toBeNull();
} finally {
// Restore original config values
Object.assign(clusterConfigModule.clusterConfig, originalConfig);
}
}, 15000); // 15 second timeout
});
describe('Test Case 3: Stress testing', () => {
it('should ensure only one instance becomes leader even when multiple instances call electSelf() at once', async () => {
// Create 10 instances
instances = Array.from({ length: 10 }, () => new LeaderElection());
// Call electSelf on all instances in parallel
const results = await Promise.all(instances.map((instance) => instance['electSelf']()));
// Verify only one returned true
const successCount = results.filter((success) => success).length;
expect(successCount).toBe(1);
// Find the winning instance
const winnerInstance = instances.find((_, index) => results[index]);
// Verify getLeaderUUID matches the winner's UUID
expect(await LeaderElection.getLeaderUUID()).toBe(winnerInstance?.UUID);
}, 15000); // 15 second timeout
});
});
describe('LeaderElection without Redis', () => {
let LeaderElection: typeof import('../LeaderElection').LeaderElection;
let instances: InstanceType<typeof import('../LeaderElection').LeaderElection>[] = [];
beforeAll(async () => {
// Set up environment variables for non-Redis mode
process.env.USE_REDIS = 'false';
// Reset all modules to force re-evaluation with new env vars
jest.resetModules();
// Import modules after setting env vars and resetting modules
const leaderElectionModule = await import('../LeaderElection');
LeaderElection = leaderElectionModule.LeaderElection;
});
afterEach(async () => {
await Promise.all(instances.map((instance) => instance.resign()));
instances = [];
});
afterAll(() => {
// Restore environment variables
process.env.USE_REDIS = 'true';
// Reset all modules to ensure next test runs get fresh imports
jest.resetModules();
});
it('should allow all instances to be leaders when USE_REDIS is false', async () => {
// Create 10 instances
instances = Array.from({ length: 10 }, () => new LeaderElection());
// Call isLeader on all instances
const results = await Promise.all(instances.map((instance) => instance.isLeader()));
// Verify all instances report themselves as leaders
expect(results.every((isLeader) => isLeader)).toBe(true);
expect(results.filter((isLeader) => isLeader).length).toBe(10);
});
it('should return null for getLeaderUUID when USE_REDIS is false', async () => {
// Create a few instances
instances = Array.from({ length: 3 }, () => new LeaderElection());
// Call isLeader on all instances to make them "leaders"
await Promise.all(instances.map((instance) => instance.isLeader()));
// Verify getLeaderUUID returns null in non-Redis mode
expect(await LeaderElection.getLeaderUUID()).toBeNull();
});
it('should allow resign() to be called without throwing errors', async () => {
// Create multiple instances
instances = Array.from({ length: 5 }, () => new LeaderElection());
// Make them all leaders
await Promise.all(instances.map((instance) => instance.isLeader()));
// Call resign on all instances - should not throw
await expect(
Promise.all(instances.map((instance) => instance.resign())),
).resolves.not.toThrow();
// Verify they're still leaders after resigning (since there's no shared state)
const results = await Promise.all(instances.map((instance) => instance.isLeader()));
expect(results.every((isLeader) => isLeader)).toBe(true);
});
});

View File

@@ -0,0 +1,14 @@
import { math } from '~/utils';
const clusterConfig = {
/** Duration in seconds that the leader lease is valid before it expires */
LEADER_LEASE_DURATION: math(process.env.LEADER_LEASE_DURATION, 25),
/** Interval in seconds at which the leader renews its lease */
LEADER_RENEW_INTERVAL: math(process.env.LEADER_RENEW_INTERVAL, 10),
/** Maximum number of retry attempts when renewing the lease fails */
LEADER_RENEW_ATTEMPTS: math(process.env.LEADER_RENEW_ATTEMPTS, 3),
/** Delay in seconds between retry attempts when renewing the lease */
LEADER_RENEW_RETRY_DELAY: math(process.env.LEADER_RENEW_RETRY_DELAY, 0.5),
};
export { clusterConfig };

View File

@@ -0,0 +1 @@
export { isLeader } from './LeaderElection';

View File

@@ -3,6 +3,7 @@ export * from './cdn';
/* Auth */
export * from './auth';
/* MCP */
export * from './mcp/registry/MCPServersRegistry';
export * from './mcp/MCPManager';
export * from './mcp/connection';
export * from './mcp/oauth';

View File

@@ -9,6 +9,7 @@ import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { sanitizeUrlForLogging } from './utils';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
import { withTimeout } from '~/utils/promise';
/**
* Factory for creating MCP connections with optional OAuth authentication.
@@ -231,14 +232,11 @@ export class MCPConnectionFactory {
/** Attempts to establish connection with timeout handling */
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
connectTimeout,
),
await withTimeout(
this.connectTo(connection),
connectTimeout,
`Connection timeout after ${connectTimeout}ms`,
);
const connectionAttempt = this.connectTo(connection);
await Promise.race([connectionAttempt, connectionTimeout]);
if (await connection.isConnected()) return;
logger.error(`${this.logPrefix} Failed to establish connection.`);

View File

@@ -5,11 +5,14 @@ import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.j
import type { TokenMethods } from '@librechat/data-schemas';
import type { FlowStateManager } from '~/flow/manager';
import type { TUser } from 'librechat-data-provider';
import type { MCPOAuthTokens } from '~/mcp/oauth';
import type { MCPOAuthTokens } from './oauth';
import type { RequestBody } from '~/types';
import type * as t from './types';
import { UserConnectionManager } from '~/mcp/UserConnectionManager';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { UserConnectionManager } from './UserConnectionManager';
import { ConnectionsRepository } from './ConnectionsRepository';
import { MCPServerInspector } from './registry/MCPServerInspector';
import { MCPServersInitializer } from './registry/MCPServersInitializer';
import { mcpServersRegistry as registry } from './registry/MCPServersRegistry';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils/env';
@@ -24,8 +27,8 @@ export class MCPManager extends UserConnectionManager {
/** Creates and initializes the singleton MCPManager instance */
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
if (MCPManager.instance) throw new Error('MCPManager has already been initialized.');
MCPManager.instance = new MCPManager(configs);
await MCPManager.instance.initialize();
MCPManager.instance = new MCPManager();
await MCPManager.instance.initialize(configs);
return MCPManager.instance;
}
@@ -36,9 +39,10 @@ export class MCPManager extends UserConnectionManager {
}
/** Initializes the MCPManager by setting up server registry and app connections */
public async initialize() {
await this.serversRegistry.initialize();
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs);
public async initialize(configs: t.MCPServers) {
await MCPServersInitializer.initialize(configs);
const appConfigs = await registry.sharedAppServers.getAll();
this.appConnections = new ConnectionsRepository(appConfigs);
}
/** Retrieves an app-level or user-specific connection based on provided arguments */
@@ -62,36 +66,18 @@ export class MCPManager extends UserConnectionManager {
}
}
/** Get servers that require OAuth */
public getOAuthServers(): Set<string> {
return this.serversRegistry.oauthServers;
}
/** Get all servers */
public getAllServers(): t.MCPServers {
return this.serversRegistry.rawConfigs;
}
/** Returns all available tool functions from app-level connections */
public getAppToolFunctions(): t.LCAvailableTools {
return this.serversRegistry.toolFunctions;
public async getAppToolFunctions(): Promise<t.LCAvailableTools> {
const toolFunctions: t.LCAvailableTools = {};
const configs = await registry.getAllServerConfigs();
for (const config of Object.values(configs)) {
if (config.toolFunctions != null) {
Object.assign(toolFunctions, config.toolFunctions);
}
}
return toolFunctions;
}
/** Returns all available tool functions from all connections available to user */
public async getAllToolFunctions(userId: string): Promise<t.LCAvailableTools | null> {
const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions();
const userConnections = this.getUserConnections(userId);
if (!userConnections || userConnections.size === 0) {
return allToolFunctions;
}
for (const [serverName, connection] of userConnections.entries()) {
const toolFunctions = await this.serversRegistry.getToolFunctions(serverName, connection);
Object.assign(allToolFunctions, toolFunctions);
}
return allToolFunctions;
}
/** Returns all available tool functions from all connections available to user */
public async getServerToolFunctions(
userId: string,
@@ -99,7 +85,7 @@ export class MCPManager extends UserConnectionManager {
): Promise<t.LCAvailableTools | null> {
try {
if (this.appConnections?.has(serverName)) {
return this.serversRegistry.getToolFunctions(
return MCPServerInspector.getToolFunctions(
serverName,
await this.appConnections.get(serverName),
);
@@ -113,7 +99,7 @@ export class MCPManager extends UserConnectionManager {
return null;
}
return this.serversRegistry.getToolFunctions(serverName, userConnections.get(serverName)!);
return MCPServerInspector.getToolFunctions(serverName, userConnections.get(serverName)!);
} catch (error) {
logger.warn(
`[getServerToolFunctions] Error getting tool functions for server ${serverName}`,
@@ -128,8 +114,14 @@ export class MCPManager extends UserConnectionManager {
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
* @returns Object mapping server names to their instructions
*/
public getInstructions(serverNames?: string[]): Record<string, string> {
const instructions = this.serversRegistry.serverInstructions;
private async getInstructions(serverNames?: string[]): Promise<Record<string, string>> {
const instructions: Record<string, string> = {};
const configs = await registry.getAllServerConfigs();
for (const [serverName, config] of Object.entries(configs)) {
if (config.serverInstructions != null) {
instructions[serverName] = config.serverInstructions as string;
}
}
if (!serverNames) return instructions;
return pick(instructions, serverNames);
}
@@ -139,9 +131,9 @@ export class MCPManager extends UserConnectionManager {
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
* @returns Formatted instructions string ready for context injection
*/
public formatInstructionsForContext(serverNames?: string[]): string {
public async formatInstructionsForContext(serverNames?: string[]): Promise<string> {
/** Instructions for specified servers or all stored instructions */
const instructionsToInclude = this.getInstructions(serverNames);
const instructionsToInclude = await this.getInstructions(serverNames);
if (Object.keys(instructionsToInclude).length === 0) {
return '';
@@ -225,7 +217,7 @@ Please follow these instructions when using tools from the respective MCP server
);
}
const rawConfig = this.getRawConfig(serverName) as t.MCPOptions;
const rawConfig = (await registry.getServerConfig(serverName, userId)) as t.MCPOptions;
const currentOptions = processMCPEnv({
user,
options: rawConfig,

View File

@@ -1,230 +0,0 @@
import mapValues from 'lodash/mapValues';
import { logger } from '@librechat/data-schemas';
import { Constants } from 'librechat-data-provider';
import type { JsonSchemaType } from '@librechat/data-schemas';
import type { MCPConnection } from '~/mcp/connection';
import type * as t from '~/mcp/types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { sanitizeUrlForLogging } from '~/mcp/utils';
import { processMCPEnv, isEnabled } from '~/utils';
const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000;
function getMCPInitTimeout(): number {
return process.env.MCP_INIT_TIMEOUT_MS != null
? parseInt(process.env.MCP_INIT_TIMEOUT_MS)
: DEFAULT_MCP_INIT_TIMEOUT_MS;
}
/**
* Manages MCP server configurations and metadata discovery.
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
* Determines which servers are for app-level connections.
* Has its own connections repository. All connections are disconnected after initialization.
*/
export class MCPServersRegistry {
private initialized: boolean = false;
private connections: ConnectionsRepository;
private initTimeoutMs: number;
public readonly rawConfigs: t.MCPServers;
public readonly parsedConfigs: Record<string, t.ParsedServerConfig>;
public oauthServers: Set<string> = new Set();
public serverInstructions: Record<string, string> = {};
public toolFunctions: t.LCAvailableTools = {};
public appServerConfigs: t.MCPServers = {};
constructor(configs: t.MCPServers) {
this.rawConfigs = configs;
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con }));
this.connections = new ConnectionsRepository(configs);
this.initTimeoutMs = getMCPInitTimeout();
}
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
public async initialize(): Promise<void> {
if (this.initialized) return;
this.initialized = true;
const serverNames = Object.keys(this.parsedConfigs);
await Promise.allSettled(
serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)),
);
}
/** Wraps server initialization with a timeout to prevent hanging */
private async initializeServerWithTimeout(serverName: string): Promise<void> {
let timeoutId: NodeJS.Timeout | null = null;
try {
await Promise.race([
this.initializeServer(serverName),
new Promise<never>((_, reject) => {
timeoutId = setTimeout(() => {
reject(new Error('Server initialization timed out'));
}, this.initTimeoutMs);
}),
]);
} catch (error) {
logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error);
throw error;
} finally {
if (timeoutId != null) {
clearTimeout(timeoutId);
}
}
}
/** Initializes a single server with all its metadata and adds it to appropriate collections */
private async initializeServer(serverName: string): Promise<void> {
const start = Date.now();
const config = this.parsedConfigs[serverName];
// 1. Detect OAuth requirements if not already specified
try {
await this.fetchOAuthRequirement(serverName);
if (config.startup !== false && !config.requiresOAuth) {
await Promise.allSettled([
this.fetchServerInstructions(serverName).catch((error) =>
logger.warn(`${this.prefix(serverName)} Failed to fetch server instructions:`, error),
),
this.fetchServerCapabilities(serverName).catch((error) =>
logger.warn(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error),
),
]);
}
} catch (error) {
logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error);
}
// 2. Fetch tool functions for this server if a connection was established
const getToolFunctions = async (): Promise<t.LCAvailableTools | null> => {
try {
const loadedConns = await this.connections.getLoaded();
const conn = loadedConns.get(serverName);
if (conn == null) {
return null;
}
return this.getToolFunctions(serverName, conn);
} catch (error) {
logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error);
return null;
}
};
const toolFunctions = await getToolFunctions();
// 3. Disconnect this server's connection if it was established (fire-and-forget)
void this.connections.disconnect(serverName);
// 4. Side effects
// 4.1 Add to OAuth servers if needed
if (config.requiresOAuth) {
this.oauthServers.add(serverName);
}
// 4.2 Add server instructions if available
if (config.serverInstructions != null) {
this.serverInstructions[serverName] = config.serverInstructions as string;
}
// 4.3 Add to app server configs if eligible (startup enabled, non-OAuth servers)
if (config.startup !== false && config.requiresOAuth === false) {
this.appServerConfigs[serverName] = this.rawConfigs[serverName];
}
// 4.4 Add tool functions if available
if (toolFunctions != null) {
Object.assign(this.toolFunctions, toolFunctions);
}
const duration = Date.now() - start;
this.logUpdatedConfig(serverName, duration);
}
/** Converts server tools to LibreChat-compatible tool functions format */
public async getToolFunctions(
serverName: string,
conn: MCPConnection,
): Promise<t.LCAvailableTools> {
const { tools }: t.MCPToolListResponse = await conn.client.listTools();
const toolFunctions: t.LCAvailableTools = {};
tools.forEach((tool) => {
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
toolFunctions[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema as JsonSchemaType,
},
};
});
return toolFunctions;
}
/** Determines if server requires OAuth if not already specified in the config */
private async fetchOAuthRequirement(serverName: string): Promise<boolean> {
const config = this.parsedConfigs[serverName];
if (config.requiresOAuth != null) return config.requiresOAuth;
if (config.url == null) return (config.requiresOAuth = false);
if (config.startup === false) return (config.requiresOAuth = false);
const result = await detectOAuthRequirement(config.url);
config.requiresOAuth = result.requiresOAuth;
config.oauthMetadata = result.metadata;
return config.requiresOAuth;
}
/** Retrieves server instructions from MCP server if enabled in the config */
private async fetchServerInstructions(serverName: string): Promise<void> {
const config = this.parsedConfigs[serverName];
if (!config.serverInstructions) return;
// If it's a string that's not "true", it's a custom instruction
if (typeof config.serverInstructions === 'string' && !isEnabled(config.serverInstructions)) {
return;
}
// Fetch from server if true (boolean) or "true" (string)
const conn = await this.connections.get(serverName);
config.serverInstructions = conn.client.getInstructions();
if (!config.serverInstructions) {
logger.warn(`${this.prefix(serverName)} No server instructions available`);
}
}
/** Fetches server capabilities and available tools list */
private async fetchServerCapabilities(serverName: string): Promise<void> {
const config = this.parsedConfigs[serverName];
const conn = await this.connections.get(serverName);
const capabilities = conn.client.getServerCapabilities();
if (!capabilities) return;
config.capabilities = JSON.stringify(capabilities);
if (!capabilities.tools) return;
const tools = await conn.client.listTools();
config.tools = tools.tools.map((tool) => tool.name).join(', ');
}
// Logs server configuration summary after initialization
private logUpdatedConfig(serverName: string, initDuration: number): void {
const prefix = this.prefix(serverName);
const config = this.parsedConfigs[serverName];
logger.info(`${prefix} -------------------------------------------------┐`);
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
logger.info(`${prefix} Tools: ${config.tools}`);
logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`);
logger.info(`${prefix} Initialized in: ${initDuration}ms`);
logger.info(`${prefix} -------------------------------------------------┘`);
}
// Returns formatted log prefix for server messages
private prefix(serverName: string): string {
return `[MCP][${serverName}]`;
}
}

View File

@@ -1,7 +1,7 @@
import { logger } from '@librechat/data-schemas';
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { mcpServersRegistry as serversRegistry } from '~/mcp/registry/MCPServersRegistry';
import { MCPConnection } from './connection';
import type * as t from './types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
@@ -14,7 +14,6 @@ import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
* https://github.com/danny-avila/LibreChat/discussions/8790
*/
export abstract class UserConnectionManager {
protected readonly serversRegistry: MCPServersRegistry;
// Connections shared by all users.
public appConnections: ConnectionsRepository | null = null;
// Connections per userId -> serverName -> connection
@@ -23,15 +22,6 @@ export abstract class UserConnectionManager {
protected userLastActivity: Map<string, number> = new Map();
protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
constructor(serverConfigs: t.MCPServers) {
this.serversRegistry = new MCPServersRegistry(serverConfigs);
}
/** fetches am MCP Server config from the registry */
public getRawConfig(serverName: string): t.MCPOptions | undefined {
return this.serversRegistry.rawConfigs[serverName];
}
/** Updates the last activity timestamp for a user */
protected updateUserLastActivity(userId: string): void {
const now = Date.now();
@@ -106,7 +96,7 @@ export abstract class UserConnectionManager {
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
}
const config = this.serversRegistry.parsedConfigs[serverName];
const config = await serversRegistry.getServerConfig(serverName, userId);
if (!config) {
throw new McpError(
ErrorCode.InvalidRequest,

View File

@@ -1,7 +1,9 @@
import { logger } from '@librechat/data-schemas';
import type * as t from '~/mcp/types';
import { MCPManager } from '~/mcp/MCPManager';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPConnection } from '../connection';
@@ -15,7 +17,24 @@ jest.mock('@librechat/data-schemas', () => ({
},
}));
jest.mock('~/mcp/MCPServersRegistry');
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({
mcpServersRegistry: {
sharedAppServers: {
getAll: jest.fn(),
},
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
getOAuthServers: jest.fn(),
},
}));
jest.mock('~/mcp/registry/MCPServersInitializer', () => ({
MCPServersInitializer: {
initialize: jest.fn(),
},
}));
jest.mock('~/mcp/registry/MCPServerInspector');
jest.mock('~/mcp/ConnectionsRepository');
const mockLogger = logger as jest.Mocked<typeof logger>;
@@ -28,20 +47,12 @@ describe('MCPManager', () => {
// Reset MCPManager singleton state
(MCPManager as unknown as { instance: null }).instance = null;
jest.clearAllMocks();
});
function mockRegistry(
registryConfig: Partial<MCPServersRegistry>,
): jest.MockedClass<typeof MCPServersRegistry> {
const mock = {
initialize: jest.fn().mockResolvedValue(undefined),
getToolFunctions: jest.fn().mockResolvedValue(null),
...registryConfig,
};
return (MCPServersRegistry as jest.MockedClass<typeof MCPServersRegistry>).mockImplementation(
() => mock as unknown as MCPServersRegistry,
);
}
// Set up default mock implementations
(MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined);
(mcpServersRegistry.sharedAppServers.getAll as jest.Mock).mockResolvedValue({});
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({});
});
function mockAppConnections(
appConnectionsConfig: Partial<ConnectionsRepository>,
@@ -66,12 +77,229 @@ describe('MCPManager', () => {
};
}
describe('getAppToolFunctions', () => {
it('should return empty object when no servers have tool functions', async () => {
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
server1: { type: 'stdio', command: 'test', args: [] },
server2: { type: 'stdio', command: 'test2', args: [] },
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.getAppToolFunctions();
expect(result).toEqual({});
});
it('should collect tool functions from multiple servers', async () => {
const toolFunctions1 = {
tool1_mcp_server1: {
type: 'function' as const,
function: {
name: 'tool1_mcp_server1',
description: 'Tool 1',
parameters: { type: 'object' as const },
},
},
};
const toolFunctions2 = {
tool2_mcp_server2: {
type: 'function' as const,
function: {
name: 'tool2_mcp_server2',
description: 'Tool 2',
parameters: { type: 'object' as const },
},
},
};
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
server1: {
type: 'stdio',
command: 'test',
args: [],
toolFunctions: toolFunctions1,
},
server2: {
type: 'stdio',
command: 'test2',
args: [],
toolFunctions: toolFunctions2,
},
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.getAppToolFunctions();
expect(result).toEqual({
...toolFunctions1,
...toolFunctions2,
});
});
it('should handle servers with null or undefined toolFunctions', async () => {
const toolFunctions1 = {
tool1_mcp_server1: {
type: 'function' as const,
function: {
name: 'tool1_mcp_server1',
description: 'Tool 1',
parameters: { type: 'object' as const },
},
},
};
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
server1: {
type: 'stdio',
command: 'test',
args: [],
toolFunctions: toolFunctions1,
},
server2: {
type: 'stdio',
command: 'test2',
args: [],
toolFunctions: null,
},
server3: {
type: 'stdio',
command: 'test3',
args: [],
},
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.getAppToolFunctions();
expect(result).toEqual(toolFunctions1);
});
});
describe('formatInstructionsForContext', () => {
it('should return empty string when no servers have instructions', async () => {
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
server1: { type: 'stdio', command: 'test', args: [] },
server2: { type: 'stdio', command: 'test2', args: [] },
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.formatInstructionsForContext();
expect(result).toBe('');
});
it('should format instructions from multiple servers', async () => {
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
github: {
type: 'sse',
url: 'https://api.github.com',
serverInstructions: 'Use GitHub API with care',
},
files: {
type: 'stdio',
command: 'node',
args: ['files.js'],
serverInstructions: 'Only read/write files in allowed directories',
},
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.formatInstructionsForContext();
expect(result).toContain('# MCP Server Instructions');
expect(result).toContain('## github MCP Server Instructions');
expect(result).toContain('Use GitHub API with care');
expect(result).toContain('## files MCP Server Instructions');
expect(result).toContain('Only read/write files in allowed directories');
});
it('should filter instructions by server names when provided', async () => {
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
github: {
type: 'sse',
url: 'https://api.github.com',
serverInstructions: 'Use GitHub API with care',
},
files: {
type: 'stdio',
command: 'node',
args: ['files.js'],
serverInstructions: 'Only read/write files in allowed directories',
},
database: {
type: 'stdio',
command: 'node',
args: ['db.js'],
serverInstructions: 'Be careful with database operations',
},
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.formatInstructionsForContext(['github', 'database']);
expect(result).toContain('## github MCP Server Instructions');
expect(result).toContain('Use GitHub API with care');
expect(result).toContain('## database MCP Server Instructions');
expect(result).toContain('Be careful with database operations');
expect(result).not.toContain('files');
expect(result).not.toContain('Only read/write files in allowed directories');
});
it('should handle servers with null or undefined instructions', async () => {
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
github: {
type: 'sse',
url: 'https://api.github.com',
serverInstructions: 'Use GitHub API with care',
},
files: {
type: 'stdio',
command: 'node',
args: ['files.js'],
serverInstructions: null,
},
database: {
type: 'stdio',
command: 'node',
args: ['db.js'],
},
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.formatInstructionsForContext();
expect(result).toContain('## github MCP Server Instructions');
expect(result).toContain('Use GitHub API with care');
expect(result).not.toContain('files');
expect(result).not.toContain('database');
});
it('should return empty string when filtered servers have no instructions', async () => {
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
github: {
type: 'sse',
url: 'https://api.github.com',
serverInstructions: 'Use GitHub API with care',
},
files: {
type: 'stdio',
command: 'node',
args: ['files.js'],
},
});
const manager = await MCPManager.createInstance(newMCPServersConfig());
const result = await manager.formatInstructionsForContext(['files']);
expect(result).toBe('');
});
});
describe('getServerToolFunctions', () => {
it('should catch and handle errors gracefully', async () => {
mockRegistry({
getToolFunctions: jest.fn(() => {
throw new Error('Connection failed');
}),
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => {
throw new Error('Connection failed');
});
mockAppConnections({
@@ -90,9 +318,7 @@ describe('MCPManager', () => {
});
it('should catch synchronous errors from getUserConnections', async () => {
mockRegistry({
getToolFunctions: jest.fn().mockResolvedValue({}),
});
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn().mockResolvedValue({});
mockAppConnections({
has: jest.fn().mockReturnValue(false),
@@ -126,9 +352,9 @@ describe('MCPManager', () => {
},
};
mockRegistry({
getToolFunctions: jest.fn().mockResolvedValue(expectedTools),
});
(MCPServerInspector.getToolFunctions as jest.Mock) = jest
.fn()
.mockResolvedValue(expectedTools);
mockAppConnections({
has: jest.fn().mockReturnValue(true),
@@ -145,10 +371,8 @@ describe('MCPManager', () => {
it('should include specific server name in error messages', async () => {
const specificServerName = 'github_mcp_server';
mockRegistry({
getToolFunctions: jest.fn(() => {
throw new Error('Server specific error');
}),
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => {
throw new Error('Server specific error');
});
mockAppConnections({

View File

@@ -1,595 +0,0 @@
import { join } from 'path';
import { readFileSync } from 'fs';
import { load as yamlLoad } from 'js-yaml';
import { logger } from '@librechat/data-schemas';
import type { OAuthDetectionResult } from '~/mcp/oauth/detectOAuth';
import type * as t from '~/mcp/types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { MCPConnection } from '~/mcp/connection';
// Mock external dependencies
jest.mock('../oauth/detectOAuth');
jest.mock('../ConnectionsRepository');
jest.mock('../connection');
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
// Mock processMCPEnv to verify it's called and adds a processed marker
jest.mock('~/utils', () => ({
...jest.requireActual('~/utils'),
processMCPEnv: jest.fn(({ options }) => ({
...options,
_processed: true, // Simple marker to verify processing occurred
})),
}));
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
typeof detectOAuthRequirement
>;
const mockLogger = logger as jest.Mocked<typeof logger>;
describe('MCPServersRegistry - Initialize Function', () => {
let rawConfigs: t.MCPServers;
let expectedParsedConfigs: Record<string, t.ParsedServerConfig>;
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
beforeEach(() => {
// Load fixtures
const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml');
const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml');
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
string,
t.ParsedServerConfig
>;
// Setup mock connections
mockConnections = new Map();
const serverNames = Object.keys(rawConfigs);
serverNames.forEach((serverName) => {
const mockClient = {
listTools: jest.fn(),
getInstructions: jest.fn(),
getServerCapabilities: jest.fn(),
};
const mockConnection = {
client: mockClient,
} as unknown as jest.Mocked<MCPConnection>;
// Setup mock responses based on expected configs
const expectedConfig = expectedParsedConfigs[serverName];
// Mock listTools response
if (expectedConfig.tools) {
const toolNames = expectedConfig.tools.split(', ');
const tools = toolNames.map((name: string) => ({
name,
description: `Description for ${name}`,
inputSchema: {
type: 'object' as const,
properties: {
input: { type: 'string' },
},
},
}));
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools });
} else {
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools: [] });
}
// Mock getInstructions response
if (expectedConfig.serverInstructions) {
(mockClient.getInstructions as jest.Mock).mockReturnValue(
expectedConfig.serverInstructions as string,
);
} else {
(mockClient.getInstructions as jest.Mock).mockReturnValue(undefined);
}
// Mock getServerCapabilities response
if (expectedConfig.capabilities) {
const capabilities = JSON.parse(expectedConfig.capabilities) as Record<string, unknown>;
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(capabilities);
} else {
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(undefined);
}
mockConnections.set(serverName, mockConnection);
});
// Setup ConnectionsRepository mock
mockConnectionsRepo = {
get: jest.fn(),
getLoaded: jest.fn(),
disconnectAll: jest.fn(),
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<ConnectionsRepository>;
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
const connection = mockConnections.get(serverName);
if (!connection) {
throw new Error(`Connection not found for server: ${serverName}`);
}
return Promise.resolve(connection);
});
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
(ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo);
// Setup OAuth detection mock with deterministic results
mockDetectOAuthRequirement.mockImplementation((url: string) => {
const oauthResults: Record<string, OAuthDetectionResult> = {
'https://api.github.com/mcp': {
requiresOAuth: true,
method: 'protected-resource-metadata',
metadata: {
authorization_url: 'https://github.com/login/oauth/authorize',
token_url: 'https://github.com/login/oauth/access_token',
},
},
'https://api.disabled.com/mcp': {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
},
'https://api.public.com/mcp': {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
},
};
return Promise.resolve(
oauthResults[url] || { requiresOAuth: false, method: 'no-metadata-found', metadata: null },
);
});
// Clear all mocks
jest.clearAllMocks();
});
afterEach(() => {
delete process.env.MCP_INIT_TIMEOUT_MS;
jest.clearAllMocks();
});
describe('initialize() method', () => {
it('should only run initialization once', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
await registry.initialize(); // Second call should not re-run
// Verify that connections are only requested for servers that need them
// (servers with serverInstructions=true or all servers for capabilities)
expect(mockConnectionsRepo.get).toHaveBeenCalled();
});
it('should set all public properties correctly after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Verify initial state
expect(registry.oauthServers.size).toBe(0);
expect(registry.serverInstructions).toEqual({});
expect(registry.toolFunctions).toEqual({});
expect(registry.appServerConfigs).toEqual({});
await registry.initialize();
// Test oauthServers Set
expect(registry.oauthServers).toEqual(
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
);
// Test serverInstructions - OAuth servers keep their original boolean value, non-OAuth fetch actual strings
expect(registry.serverInstructions).toEqual({
stdio_server: 'Follow these instructions for stdio server',
oauth_server: true,
non_oauth_server: 'Public API instructions',
});
// Test appServerConfigs (startup enabled, non-OAuth servers only)
expect(registry.appServerConfigs).toEqual({
stdio_server: rawConfigs.stdio_server,
websocket_server: rawConfigs.websocket_server,
non_oauth_server: rawConfigs.non_oauth_server,
});
// Test toolFunctions (only non-OAuth servers get their tools fetched during initialization)
const expectedToolFunctions = {
file_read_mcp_stdio_server: {
type: 'function',
function: {
name: 'file_read_mcp_stdio_server',
description: 'Description for file_read',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
file_write_mcp_stdio_server: {
type: 'function',
function: {
name: 'file_write_mcp_stdio_server',
description: 'Description for file_write',
parameters: { type: 'object', properties: { input: { type: 'string' } } },
},
},
};
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
});
it('should handle errors gracefully and continue initialization of other servers', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Make one specific server throw an error during OAuth detection
mockDetectOAuthRequirement.mockImplementation((url: string) => {
if (url === 'https://api.github.com/mcp') {
return Promise.reject(new Error('OAuth detection failed'));
}
// Return normal responses for other servers
const oauthResults: Record<string, OAuthDetectionResult> = {
'https://api.disabled.com/mcp': {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
},
'https://api.public.com/mcp': {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
},
};
return Promise.resolve(
oauthResults[url] ?? {
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
},
);
});
await registry.initialize();
// Should still initialize successfully for other servers
expect(registry.oauthServers).toBeInstanceOf(Set);
expect(registry.toolFunctions).toBeDefined();
// The failed server should not be in oauthServers (since it failed OAuth detection)
expect(registry.oauthServers.has('oauth_server')).toBe(false);
// But other servers should still be processed successfully
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
// Error should be logged as a warning at the higher level
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'),
expect.any(Error),
);
});
it('should disconnect individual connections after each server initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
// Verify disconnect was called for each server during initialization
// All servers attempt to connect during initialization for metadata gathering
const serverNames = Object.keys(rawConfigs);
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
});
it('should log configuration updates for each startup-enabled server', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
const serverNames = Object.keys(rawConfigs);
serverNames.forEach((serverName) => {
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] URL:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] OAuth Required:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Capabilities:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Tools:`),
);
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining(`[MCP][${serverName}] Server Instructions:`),
);
});
});
it('should have parsedConfigs matching the expected fixture after initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
await registry.initialize();
// Compare the actual parsedConfigs against the expected fixture
expect(registry.parsedConfigs).toEqual(expectedParsedConfigs);
});
it('should handle serverInstructions as string "true" correctly and fetch from server', async () => {
// Create test config with serverInstructions as string "true"
const testConfig: t.MCPServers = {
test_server_string_true: {
type: 'stdio',
args: [],
command: 'test-command',
serverInstructions: 'true', // Simulating string "true" from YAML parsing
},
test_server_custom_string: {
type: 'stdio',
args: [],
command: 'test-command',
serverInstructions: 'Custom instructions here',
},
test_server_bool_true: {
type: 'stdio',
args: [],
command: 'test-command',
serverInstructions: true,
},
};
const registry = new MCPServersRegistry(testConfig);
// Setup mock connection for servers that should fetch
const mockClient = {
listTools: jest.fn().mockResolvedValue({ tools: [] }),
getInstructions: jest.fn().mockReturnValue('Fetched instructions from server'),
getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }),
};
const mockConnection = {
client: mockClient,
} as unknown as jest.Mocked<MCPConnection>;
mockConnectionsRepo.get.mockResolvedValue(mockConnection);
mockConnectionsRepo.getLoaded.mockResolvedValue(
new Map([
['test_server_string_true', mockConnection],
['test_server_bool_true', mockConnection],
]),
);
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
});
await registry.initialize();
// Verify that string "true" was treated as fetch-from-server
expect(registry.parsedConfigs['test_server_string_true'].serverInstructions).toBe(
'Fetched instructions from server',
);
// Verify that custom string was kept as-is
expect(registry.parsedConfigs['test_server_custom_string'].serverInstructions).toBe(
'Custom instructions here',
);
// Verify that boolean true also fetched from server
expect(registry.parsedConfigs['test_server_bool_true'].serverInstructions).toBe(
'Fetched instructions from server',
);
// Verify getInstructions was called for both "true" cases
expect(mockClient.getInstructions).toHaveBeenCalledTimes(2);
});
it('should use Promise.allSettled for individual server initialization', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Spy on Promise.allSettled to verify it's being used
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
await registry.initialize();
// Verify Promise.allSettled was called with an array of server initialization promises
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
// Verify it was called with the correct number of server promises
const serverNames = Object.keys(rawConfigs);
expect(allSettledSpy).toHaveBeenCalledWith(
expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))),
);
allSettledSpy.mockRestore();
});
it('should isolate server failures and not affect other servers', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Make multiple servers fail in different ways
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
if (serverName === 'stdio_server') {
// First server fails
throw new Error('Connection failed for stdio_server');
}
if (serverName === 'websocket_server') {
// Second server fails
throw new Error('Connection failed for websocket_server');
}
// Other servers succeed
const connection = mockConnections.get(serverName);
if (!connection) {
throw new Error(`Connection not found for server: ${serverName}`);
}
return Promise.resolve(connection);
});
await registry.initialize();
// Despite failures, initialization should complete
expect(registry.oauthServers).toBeInstanceOf(Set);
expect(registry.toolFunctions).toBeDefined();
// Successful servers should still be processed
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
// Failed servers should not crash the whole initialization
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('[MCP][stdio_server] Failed to fetch server capabilities:'),
expect.any(Error),
);
expect(mockLogger.warn).toHaveBeenCalledWith(
expect.stringContaining('[MCP][websocket_server] Failed to fetch server capabilities:'),
expect.any(Error),
);
});
it('should properly clean up connections even when some servers fail', async () => {
const registry = new MCPServersRegistry(rawConfigs);
// Track disconnect failures but suppress unhandled rejections
const disconnectErrors: Error[] = [];
mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => {
if (serverName === 'stdio_server') {
const error = new Error('Disconnect failed');
disconnectErrors.push(error);
return Promise.reject(error).catch(() => {}); // Suppress unhandled rejection
}
return Promise.resolve();
});
await registry.initialize();
// Should still attempt to disconnect all servers during initialization
const serverNames = Object.keys(rawConfigs);
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
expect(disconnectErrors).toHaveLength(1);
});
it('should timeout individual server initialization after configured timeout', async () => {
const timeout = 2000;
// Create registry with a short timeout for testing
process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`;
const registry = new MCPServersRegistry(rawConfigs);
// Make one server hang indefinitely during OAuth detection
mockDetectOAuthRequirement.mockImplementation((url: string) => {
if (url === 'https://api.github.com/mcp') {
// Slow init
return new Promise((res) => setTimeout(res, timeout * 2));
}
// Return normal responses for other servers
return Promise.resolve({
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
});
});
const start = Date.now();
await registry.initialize();
const duration = Date.now() - start;
// Should complete within reasonable time despite one server hanging
// Allow some buffer for test execution overhead
expect(duration).toBeLessThan(timeout * 1.5);
// The timeout should prevent the hanging server from blocking initialization
// Other servers should still be processed successfully
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
}, 10_000); // 10 second Jest timeout
it('should skip tool function fetching if connection was not established', async () => {
const testConfig: t.MCPServers = {
server_with_connection: {
type: 'stdio',
args: [],
command: 'test-command',
},
server_without_connection: {
type: 'stdio',
args: [],
command: 'failing-command',
},
};
const registry = new MCPServersRegistry(testConfig);
const mockClient = {
listTools: jest.fn().mockResolvedValue({
tools: [
{
name: 'test_tool',
description: 'Test tool',
inputSchema: { type: 'object', properties: {} },
},
],
}),
getInstructions: jest.fn().mockReturnValue(undefined),
getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }),
};
const mockConnection = {
client: mockClient,
} as unknown as jest.Mocked<MCPConnection>;
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
if (serverName === 'server_with_connection') {
return Promise.resolve(mockConnection);
}
throw new Error('Connection failed');
});
// Mock getLoaded to return connections map - the real implementation returns all loaded connections at once
mockConnectionsRepo.getLoaded.mockResolvedValue(
new Map([['server_with_connection', mockConnection]]),
);
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
});
await registry.initialize();
expect(registry.toolFunctions).toHaveProperty('test_tool_mcp_server_with_connection');
expect(Object.keys(registry.toolFunctions)).toHaveLength(1);
});
it('should handle getLoaded returning empty map gracefully', async () => {
const testConfig: t.MCPServers = {
test_server: {
type: 'stdio',
args: [],
command: 'test-command',
},
};
const registry = new MCPServersRegistry(testConfig);
mockConnectionsRepo.get.mockRejectedValue(new Error('All connections failed'));
mockConnectionsRepo.getLoaded.mockResolvedValue(new Map());
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
metadata: null,
});
await registry.initialize();
expect(registry.toolFunctions).toEqual({});
});
});
});

View File

@@ -1,67 +0,0 @@
# Expected parsed MCP server configurations after running initialize()
# These represent the expected state of parsedConfigs after all fetch functions complete
oauth_server:
_processed: true
type: "streamable-http"
url: "https://api.github.com/mcp"
headers:
Authorization: "Bearer {{GITHUB_TOKEN}}"
serverInstructions: true
requiresOAuth: true
oauthMetadata:
authorization_url: "https://github.com/login/oauth/authorize"
token_url: "https://github.com/login/oauth/access_token"
oauth_predefined:
_processed: true
type: "sse"
url: "https://api.example.com/sse"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://example.com/oauth/authorize"
token_url: "https://example.com/oauth/token"
stdio_server:
_processed: true
command: "node"
args: ["server.js"]
env:
API_KEY: "${TEST_API_KEY}"
startup: true
serverInstructions: "Follow these instructions for stdio server"
requiresOAuth: false
capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}'
tools: "file_read, file_write"
websocket_server:
_processed: true
type: "websocket"
url: "ws://localhost:3001/mcp"
startup: true
requiresOAuth: false
oauthMetadata: null
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
disabled_server:
_processed: true
requiresOAuth: false
type: "streamable-http"
url: "https://api.disabled.com/mcp"
startup: false
non_oauth_server:
_processed: true
type: "streamable-http"
url: "https://api.public.com/mcp"
requiresOAuth: false
serverInstructions: "Public API instructions"
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
tools: ""
oauth_startup_enabled:
_processed: true
type: "sse"
url: "https://api.oauth-startup.com/sse"
requiresOAuth: true

View File

@@ -1,53 +0,0 @@
# Raw MCP server configurations used as input to MCPServersRegistry constructor
# These configs test different code paths in the initialization process
# Test OAuth detection with URL - should trigger fetchOAuthRequirement
oauth_server:
type: "streamable-http"
url: "https://api.github.com/mcp"
headers:
Authorization: "Bearer {{GITHUB_TOKEN}}"
serverInstructions: true
# Test OAuth already specified - should skip OAuth detection
oauth_predefined:
type: "sse"
url: "https://api.example.com/sse"
requiresOAuth: true
oauthMetadata:
authorization_url: "https://example.com/oauth/authorize"
token_url: "https://example.com/oauth/token"
# Test stdio server without URL - should set requiresOAuth to false
stdio_server:
command: "node"
args: ["server.js"]
env:
API_KEY: "${TEST_API_KEY}"
startup: true
serverInstructions: "Follow these instructions for stdio server"
# Test websocket server with capabilities but no tools
websocket_server:
type: "websocket"
url: "ws://localhost:3001/mcp"
startup: true
# Test server with startup disabled - should not be included in appServerConfigs
disabled_server:
type: "streamable-http"
url: "https://api.disabled.com/mcp"
startup: false
# Test non-OAuth server - should be included in appServerConfigs
non_oauth_server:
type: "streamable-http"
url: "https://api.public.com/mcp"
requiresOAuth: false
serverInstructions: true
# Test server with OAuth but startup enabled - should not be in appServerConfigs
oauth_startup_enabled:
type: "sse"
url: "https://api.oauth-startup.com/sse"
requiresOAuth: true

View File

@@ -18,6 +18,7 @@ import type {
Response as UndiciResponse,
} from 'undici';
import type { MCPOAuthTokens } from './oauth/types';
import { withTimeout } from '~/utils/promise';
import type * as t from './types';
import { sanitizeUrlForLogging } from './utils';
import { mcpConfig } from './mcpConfig';
@@ -457,15 +458,11 @@ export class MCPConnection extends EventEmitter {
this.setupTransportDebugHandlers();
const connectTimeout = this.options.initTimeout ?? 120000;
await Promise.race([
await withTimeout(
this.client.connect(this.transport),
new Promise((_resolve, reject) =>
setTimeout(
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
connectTimeout,
),
),
]);
connectTimeout,
`Connection timeout after ${connectTimeout}ms`,
);
this.connectionState = 'connected';
this.emit('connectionChange', 'connected');

View File

@@ -1,6 +1,7 @@
import { TokenMethods } from '@librechat/data-schemas';
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
import { MCPManager } from '../MCPManager';
import { mcpServersRegistry } from '../../mcp/registry/MCPServersRegistry';
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
@@ -14,6 +15,12 @@ jest.mock('@librechat/data-schemas', () => ({
}));
jest.mock('../MCPManager');
jest.mock('../../mcp/registry/MCPServersRegistry', () => ({
mcpServersRegistry: {
getServerConfig: jest.fn(),
getOAuthServers: jest.fn(),
},
}));
describe('OAuthReconnectionManager', () => {
let flowManager: jest.Mocked<FlowStateManager<null>>;
@@ -51,10 +58,10 @@ describe('OAuthReconnectionManager', () => {
getUserConnection: jest.fn(),
getUserConnections: jest.fn(),
disconnectUserConnection: jest.fn(),
getRawConfig: jest.fn(),
} as unknown as jest.Mocked<MCPManager>;
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({});
});
afterEach(() => {
@@ -152,7 +159,7 @@ describe('OAuthReconnectionManager', () => {
it('should reconnect eligible servers', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2', 'server3']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has failed reconnection
reconnectionTracker.setFailed(userId, 'server1');
@@ -186,7 +193,9 @@ describe('OAuthReconnectionManager', () => {
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({
initTimeout: 5000,
} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
@@ -215,7 +224,7 @@ describe('OAuthReconnectionManager', () => {
it('should handle failed reconnection attempts', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has valid token
tokenMethods.findToken.mockResolvedValue({
@@ -226,7 +235,9 @@ describe('OAuthReconnectionManager', () => {
// Mock failed connection
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
@@ -242,7 +253,7 @@ describe('OAuthReconnectionManager', () => {
it('should not reconnect servers with expired tokens', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
@@ -261,7 +272,7 @@ describe('OAuthReconnectionManager', () => {
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
tokenMethods.findToken.mockResolvedValue({
userId,
@@ -277,7 +288,9 @@ describe('OAuthReconnectionManager', () => {
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
@@ -359,7 +372,7 @@ describe('OAuthReconnectionManager', () => {
it('should not attempt to reconnect servers that have timed out during reconnection', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
const now = Date.now();
jest.setSystemTime(now);
@@ -414,7 +427,7 @@ describe('OAuthReconnectionManager', () => {
const userId = 'user-123';
const serverName = 'server1';
const oauthServers = new Set([serverName]);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
const now = Date.now();
jest.setSystemTime(now);
@@ -428,7 +441,9 @@ describe('OAuthReconnectionManager', () => {
// First reconnect attempt - will fail
mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed'));
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
await jest.runAllTimersAsync();

View File

@@ -5,6 +5,7 @@ import type { MCPOAuthTokens } from './types';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
import { FlowStateManager } from '~/flow/manager';
import { MCPManager } from '~/mcp/MCPManager';
import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry';
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
@@ -72,7 +73,7 @@ export class OAuthReconnectionManager {
// 1. derive the servers to reconnect
const serversToReconnect = [];
for (const serverName of this.mcpManager.getOAuthServers()) {
for (const serverName of await mcpServersRegistry.getOAuthServers()) {
const canReconnect = await this.canReconnect(userId, serverName);
if (canReconnect) {
serversToReconnect.push(serverName);
@@ -104,7 +105,7 @@ export class OAuthReconnectionManager {
logger.info(`${logPrefix} Attempting reconnection`);
const config = this.mcpManager.getRawConfig(serverName);
const config = await mcpServersRegistry.getServerConfig(serverName, userId);
const cleanupOnFailedReconnect = () => {
this.reconnectionsTracker.setFailed(userId, serverName);

View File

@@ -0,0 +1,123 @@
import { Constants } from 'librechat-data-provider';
import type { JsonSchemaType } from '@librechat/data-schemas';
import type { MCPConnection } from '~/mcp/connection';
import type * as t from '~/mcp/types';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { isEnabled } from '~/utils';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
/**
* Inspects MCP servers to discover their metadata, capabilities, and tools.
* Connects to servers and populates configuration with OAuth requirements,
* server instructions, capabilities, and available tools.
*/
export class MCPServerInspector {
private constructor(
private readonly serverName: string,
private readonly config: t.ParsedServerConfig,
private connection: MCPConnection | undefined,
) {}
/**
* Inspects a server and returns an enriched configuration with metadata.
* Detects OAuth requirements and fetches server capabilities.
* @param serverName - The name of the server (used for tool function naming)
* @param rawConfig - The raw server configuration
* @param connection - The MCP connection
* @returns A fully processed and enriched configuration with server metadata
*/
public static async inspect(
serverName: string,
rawConfig: t.MCPOptions,
connection?: MCPConnection,
): Promise<t.ParsedServerConfig> {
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
await inspector.inspectServer();
inspector.config.initDuration = Date.now() - start;
return inspector.config;
}
private async inspectServer(): Promise<void> {
await this.detectOAuth();
if (this.config.startup !== false && !this.config.requiresOAuth) {
let tempConnection = false;
if (!this.connection) {
tempConnection = true;
this.connection = await MCPConnectionFactory.create({
serverName: this.serverName,
serverConfig: this.config,
});
}
await Promise.allSettled([
this.fetchServerInstructions(),
this.fetchServerCapabilities(),
this.fetchToolFunctions(),
]);
if (tempConnection) await this.connection.disconnect();
}
}
private async detectOAuth(): Promise<void> {
if (this.config.requiresOAuth != null) return;
if (this.config.url == null || this.config.startup === false) {
this.config.requiresOAuth = false;
return;
}
const result = await detectOAuthRequirement(this.config.url);
this.config.requiresOAuth = result.requiresOAuth;
this.config.oauthMetadata = result.metadata;
}
private async fetchServerInstructions(): Promise<void> {
if (isEnabled(this.config.serverInstructions)) {
this.config.serverInstructions = this.connection!.client.getInstructions();
}
}
private async fetchServerCapabilities(): Promise<void> {
const capabilities = this.connection!.client.getServerCapabilities();
this.config.capabilities = JSON.stringify(capabilities);
const tools = await this.connection!.client.listTools();
this.config.tools = tools.tools.map((tool) => tool.name).join(', ');
}
private async fetchToolFunctions(): Promise<void> {
this.config.toolFunctions = await MCPServerInspector.getToolFunctions(
this.serverName,
this.connection!,
);
}
/**
* Converts server tools to LibreChat-compatible tool functions format.
* @param serverName - The name of the server
* @param connection - The MCP connection
* @returns Tool functions formatted for LibreChat
*/
public static async getToolFunctions(
serverName: string,
connection: MCPConnection,
): Promise<t.LCAvailableTools> {
const { tools }: t.MCPToolListResponse = await connection.client.listTools();
const toolFunctions: t.LCAvailableTools = {};
tools.forEach((tool) => {
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
toolFunctions[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema as JsonSchemaType,
},
};
});
return toolFunctions;
}
}

View File

@@ -0,0 +1,96 @@
import { registryStatusCache as statusCache } from './cache/RegistryStatusCache';
import { isLeader } from '~/cluster';
import { withTimeout } from '~/utils';
import { logger } from '@librechat/data-schemas';
import { MCPServerInspector } from './MCPServerInspector';
import { ParsedServerConfig } from '~/mcp/types';
import { sanitizeUrlForLogging } from '~/mcp/utils';
import type * as t from '~/mcp/types';
import { mcpServersRegistry as registry } from './MCPServersRegistry';
const MCP_INIT_TIMEOUT_MS =
process.env.MCP_INIT_TIMEOUT_MS != null ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) : 30_000;
/**
* Handles initialization of MCP servers at application startup with distributed coordination.
* In cluster environments, ensures only the leader node performs initialization while followers wait.
* Connects to each configured MCP server, inspects capabilities and tools, then caches the results.
* Categorizes servers as either shared app servers (auto-started) or shared user servers (OAuth/on-demand).
* Uses a timeout mechanism to prevent hanging on unresponsive servers during initialization.
*/
export class MCPServersInitializer {
/**
* Initializes MCP servers with distributed leader-follower coordination.
*
* Design rationale:
* - Handles leader crash scenarios: If the leader crashes during initialization, all followers
* will independently attempt initialization after a 3-second delay. The first to become leader
* will complete the initialization.
* - Only the leader performs the actual initialization work (reset caches, inspect servers).
* When complete, the leader signals completion via `statusCache`, allowing followers to proceed.
* - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node
* performs the expensive initialization operations.
*/
public static async initialize(rawConfigs: t.MCPServers): Promise<void> {
if (await statusCache.isInitialized()) return;
if (await isLeader()) {
// Leader performs initialization
await statusCache.reset();
await registry.reset();
const serverNames = Object.keys(rawConfigs);
await Promise.allSettled(
serverNames.map((serverName) =>
withTimeout(
MCPServersInitializer.initializeServer(serverName, rawConfigs[serverName]),
MCP_INIT_TIMEOUT_MS,
`${MCPServersInitializer.prefix(serverName)} Server initialization timed out`,
logger.error,
),
),
);
await statusCache.setInitialized(true);
} else {
// Followers try again after a delay if not initialized
await new Promise((resolve) => setTimeout(resolve, 3000));
await this.initialize(rawConfigs);
}
}
/** Initializes a single server with all its metadata and adds it to appropriate collections */
private static async initializeServer(
serverName: string,
rawConfig: t.MCPOptions,
): Promise<void> {
try {
const config = await MCPServerInspector.inspect(serverName, rawConfig);
if (config.startup === false || config.requiresOAuth) {
await registry.sharedUserServers.add(serverName, config);
} else {
await registry.sharedAppServers.add(serverName, config);
}
MCPServersInitializer.logParsedConfig(serverName, config);
} catch (error) {
logger.error(`${MCPServersInitializer.prefix(serverName)} Failed to initialize:`, error);
}
}
// Logs server configuration summary after initialization
private static logParsedConfig(serverName: string, config: ParsedServerConfig): void {
const prefix = MCPServersInitializer.prefix(serverName);
logger.info(`${prefix} -------------------------------------------------┐`);
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
logger.info(`${prefix} Tools: ${config.tools}`);
logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`);
logger.info(`${prefix} Initialized in: ${config.initDuration ?? 'N/A'}ms`);
logger.info(`${prefix} -------------------------------------------------┘`);
}
// Returns formatted log prefix for server messages
private static prefix(serverName: string): string {
return `[MCP][${serverName}]`;
}
}

View File

@@ -0,0 +1,91 @@
import type * as t from '~/mcp/types';
import {
ServerConfigsCacheFactory,
type ServerConfigsCache,
} from './cache/ServerConfigsCacheFactory';
/**
* Central registry for managing MCP server configurations across different scopes and users.
* Maintains three categories of server configurations:
* - Shared App Servers: Auto-started servers available to all users (initialized at startup)
* - Shared User Servers: User-scope servers that require OAuth or on-demand startup
* - Private User Servers: Per-user configurations dynamically added during runtime
*
* Provides a unified interface for retrieving server configs with proper fallback hierarchy:
* checks shared app servers first, then shared user servers, then private user servers.
* Handles server lifecycle operations including adding, removing, and querying configurations.
*/
class MCPServersRegistry {
public readonly sharedAppServers = ServerConfigsCacheFactory.create('App', true);
public readonly sharedUserServers = ServerConfigsCacheFactory.create('User', true);
private readonly privateUserServers: Map<string | undefined, ServerConfigsCache> = new Map();
public async addPrivateUserServer(
userId: string,
serverName: string,
config: t.ParsedServerConfig,
): Promise<void> {
if (!this.privateUserServers.has(userId)) {
const cache = ServerConfigsCacheFactory.create(`User(${userId})`, false);
this.privateUserServers.set(userId, cache);
}
await this.privateUserServers.get(userId)!.add(serverName, config);
}
public async updatePrivateUserServer(
userId: string,
serverName: string,
config: t.ParsedServerConfig,
): Promise<void> {
const userCache = this.privateUserServers.get(userId);
if (!userCache) throw new Error(`No private servers found for user "${userId}".`);
await userCache.update(serverName, config);
}
public async removePrivateUserServer(userId: string, serverName: string): Promise<void> {
await this.privateUserServers.get(userId)?.remove(serverName);
}
public async getServerConfig(
serverName: string,
userId?: string,
): Promise<t.ParsedServerConfig | undefined> {
const sharedAppServer = await this.sharedAppServers.get(serverName);
if (sharedAppServer) return sharedAppServer;
const sharedUserServer = await this.sharedUserServers.get(serverName);
if (sharedUserServer) return sharedUserServer;
const privateUserServer = await this.privateUserServers.get(userId)?.get(serverName);
if (privateUserServer) return privateUserServer;
return undefined;
}
public async getAllServerConfigs(userId?: string): Promise<Record<string, t.ParsedServerConfig>> {
return {
...(await this.sharedAppServers.getAll()),
...(await this.sharedUserServers.getAll()),
...((await this.privateUserServers.get(userId)?.getAll()) ?? {}),
};
}
// TODO: This is currently used to determine if a server requires OAuth. However, this info can
// can be determined through config.requiresOAuth. Refactor usages and remove this method.
public async getOAuthServers(userId?: string): Promise<Set<string>> {
const allServers = await this.getAllServerConfigs(userId);
const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth);
return new Set(oauthServers.map(([name]) => name));
}
public async reset(): Promise<void> {
await this.sharedAppServers.reset();
await this.sharedUserServers.reset();
for (const cache of this.privateUserServers.values()) {
await cache.reset();
}
this.privateUserServers.clear();
}
}
export const mcpServersRegistry = new MCPServersRegistry();

View File

@@ -0,0 +1,338 @@
import type { MCPConnection } from '~/mcp/connection';
import type * as t from '~/mcp/types';
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { createMockConnection } from './mcpConnectionsMock.helper';
// Mock external dependencies
jest.mock('../../oauth/detectOAuth');
jest.mock('../../MCPConnectionFactory');
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
typeof detectOAuthRequirement
>;
describe('MCPServerInspector', () => {
let mockConnection: jest.Mocked<MCPConnection>;
beforeEach(() => {
mockConnection = createMockConnection('test_server');
jest.clearAllMocks();
});
describe('inspect()', () => {
it('should process env and fetch all metadata for non-OAuth stdio server with serverInstructions=true', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: true,
};
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
});
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: 'instructions for test_server',
requiresOAuth: false,
capabilities:
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
tools: 'listFiles',
toolFunctions: {
listFiles_mcp_test_server: expect.objectContaining({
type: 'function',
function: expect.objectContaining({
name: 'listFiles_mcp_test_server',
}),
}),
},
initDuration: expect.any(Number),
});
});
it('should detect OAuth and skip capabilities fetch for streamable-http server', async () => {
const rawConfig: t.MCPOptions = {
type: 'streamable-http',
url: 'https://api.example.com/mcp',
};
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: true,
method: 'protected-resource-metadata',
});
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'streamable-http',
url: 'https://api.example.com/mcp',
requiresOAuth: true,
oauthMetadata: undefined,
initDuration: expect.any(Number),
});
});
it('should skip capabilities fetch when startup=false', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
startup: false,
};
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'stdio',
command: 'node',
args: ['server.js'],
startup: false,
requiresOAuth: false,
initDuration: expect.any(Number),
});
});
it('should keep custom serverInstructions string and not fetch from server', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: 'Custom instructions here',
};
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
});
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: 'Custom instructions here',
requiresOAuth: false,
capabilities:
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
tools: 'listFiles',
toolFunctions: expect.any(Object),
initDuration: expect.any(Number),
});
});
it('should handle serverInstructions as string "true" and fetch from server', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: 'true', // String "true" from YAML
};
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
});
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: 'instructions for test_server',
requiresOAuth: false,
capabilities:
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
tools: 'listFiles',
toolFunctions: expect.any(Object),
initDuration: expect.any(Number),
});
});
it('should handle predefined requiresOAuth without detection', async () => {
const rawConfig: t.MCPOptions = {
type: 'sse',
url: 'https://api.example.com/sse',
requiresOAuth: true,
};
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'sse',
url: 'https://api.example.com/sse',
requiresOAuth: true,
initDuration: expect.any(Number),
});
});
it('should fetch capabilities when server has no tools', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
};
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
});
// Mock server with no tools
mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] });
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
expect(result).toEqual({
type: 'stdio',
command: 'node',
args: ['server.js'],
requiresOAuth: false,
capabilities:
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
tools: '',
toolFunctions: {},
initDuration: expect.any(Number),
});
});
it('should create temporary connection when no connection is provided', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: true,
};
const tempMockConnection = createMockConnection('test_server');
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(tempMockConnection);
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
});
const result = await MCPServerInspector.inspect('test_server', rawConfig);
// Verify factory was called to create connection
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
serverName: 'test_server',
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
});
// Verify temporary connection was disconnected
expect(tempMockConnection.disconnect).toHaveBeenCalled();
// Verify result is correct
expect(result).toEqual({
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: 'instructions for test_server',
requiresOAuth: false,
capabilities:
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
tools: 'listFiles',
toolFunctions: expect.any(Object),
initDuration: expect.any(Number),
});
});
it('should not create temporary connection when connection is provided', async () => {
const rawConfig: t.MCPOptions = {
type: 'stdio',
command: 'node',
args: ['server.js'],
serverInstructions: true,
};
mockDetectOAuthRequirement.mockResolvedValue({
requiresOAuth: false,
method: 'no-metadata-found',
});
await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
// Verify factory was NOT called
expect(MCPConnectionFactory.create).not.toHaveBeenCalled();
// Verify provided connection was NOT disconnected
expect(mockConnection.disconnect).not.toHaveBeenCalled();
});
});
describe('getToolFunctions()', () => {
it('should convert MCP tools to LibreChat tool functions format', async () => {
mockConnection.client.listTools = jest.fn().mockResolvedValue({
tools: [
{
name: 'file_read',
description: 'Read a file',
inputSchema: {
type: 'object',
properties: { path: { type: 'string' } },
},
},
{
name: 'file_write',
description: 'Write a file',
inputSchema: {
type: 'object',
properties: {
path: { type: 'string' },
content: { type: 'string' },
},
},
},
],
});
const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection);
expect(result).toEqual({
file_read_mcp_my_server: {
type: 'function',
function: {
name: 'file_read_mcp_my_server',
description: 'Read a file',
parameters: {
type: 'object',
properties: { path: { type: 'string' } },
},
},
},
file_write_mcp_my_server: {
type: 'function',
function: {
name: 'file_write_mcp_my_server',
description: 'Write a file',
parameters: {
type: 'object',
properties: {
path: { type: 'string' },
content: { type: 'string' },
},
},
},
},
});
});
it('should handle empty tools list', async () => {
mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] });
const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection);
expect(result).toEqual({});
});
});
});

View File

@@ -0,0 +1,301 @@
import { expect } from '@playwright/test';
import type * as t from '~/mcp/types';
import type { MCPConnection } from '~/mcp/connection';
// Mock isLeader to always return true to avoid lock contention during parallel operations
jest.mock('~/cluster', () => ({
...jest.requireActual('~/cluster'),
isLeader: jest.fn().mockResolvedValue(true),
}));
describe('MCPServersInitializer Redis Integration Tests', () => {
let MCPServersInitializer: typeof import('../MCPServersInitializer').MCPServersInitializer;
let registry: typeof import('../MCPServersRegistry').mcpServersRegistry;
let registryStatusCache: typeof import('../cache/RegistryStatusCache').registryStatusCache;
let MCPServerInspector: typeof import('../MCPServerInspector').MCPServerInspector;
let MCPConnectionFactory: typeof import('~/mcp/MCPConnectionFactory').MCPConnectionFactory;
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
const testConfigs: t.MCPServers = {
disabled_server: {
type: 'stdio',
command: 'node',
args: ['disabled.js'],
startup: false,
},
oauth_server: {
type: 'streamable-http',
url: 'https://api.example.com/mcp',
},
file_tools_server: {
type: 'stdio',
command: 'node',
args: ['tools.js'],
},
search_tools_server: {
type: 'stdio',
command: 'node',
args: ['instructions.js'],
},
};
const testParsedConfigs: Record<string, t.ParsedServerConfig> = {
disabled_server: {
type: 'stdio',
command: 'node',
args: ['disabled.js'],
startup: false,
requiresOAuth: false,
},
oauth_server: {
type: 'streamable-http',
url: 'https://api.example.com/mcp',
requiresOAuth: true,
},
file_tools_server: {
type: 'stdio',
command: 'node',
args: ['tools.js'],
requiresOAuth: false,
serverInstructions: 'Instructions for file_tools_server',
tools: 'file_read, file_write',
capabilities: '{"tools":{"listChanged":true}}',
toolFunctions: {
file_read_mcp_file_tools_server: {
type: 'function',
function: {
name: 'file_read_mcp_file_tools_server',
description: 'Read a file',
parameters: { type: 'object' },
},
},
},
},
search_tools_server: {
type: 'stdio',
command: 'node',
args: ['instructions.js'],
requiresOAuth: false,
serverInstructions: 'Instructions for search_tools_server',
capabilities: '{"tools":{"listChanged":true}}',
tools: 'search',
toolFunctions: {
search_mcp_search_tools_server: {
type: 'function',
function: {
name: 'search_mcp_search_tools_server',
description: 'Search tool',
parameters: { type: 'object' },
},
},
},
},
};
beforeAll(async () => {
// Set up environment variables for Redis (only if not already set)
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX =
process.env.REDIS_KEY_PREFIX ?? 'MCPServersInitializer-IntegrationTest';
// Import modules after setting env vars
const initializerModule = await import('../MCPServersInitializer');
const registryModule = await import('../MCPServersRegistry');
const statusCacheModule = await import('../cache/RegistryStatusCache');
const inspectorModule = await import('../MCPServerInspector');
const connectionFactoryModule = await import('~/mcp/MCPConnectionFactory');
const redisClients = await import('~/cache/redisClients');
const leaderElectionModule = await import('~/cluster/LeaderElection');
MCPServersInitializer = initializerModule.MCPServersInitializer;
registry = registryModule.mcpServersRegistry;
registryStatusCache = statusCacheModule.registryStatusCache;
MCPServerInspector = inspectorModule.MCPServerInspector;
MCPConnectionFactory = connectionFactoryModule.MCPConnectionFactory;
keyvRedisClient = redisClients.keyvRedisClient;
LeaderElection = leaderElectionModule.LeaderElection;
// Ensure Redis is connected
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
// Wait for Redis to be ready
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
// Become leader so we can perform write operations
leaderInstance = new LeaderElection();
const isLeader = await leaderInstance.isLeader();
expect(isLeader).toBe(true);
});
beforeEach(async () => {
// Ensure we're still the leader
const isLeader = await leaderInstance.isLeader();
if (!isLeader) {
throw new Error('Lost leader status before test');
}
// Mock MCPServerInspector.inspect to return parsed config
jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => {
return {
...testParsedConfigs[serverName],
_processedByInspector: true,
} as unknown as t.ParsedServerConfig;
});
// Mock MCPConnection
const mockConnection = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
// Mock MCPConnectionFactory
jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection);
// Reset caches before each test
await registryStatusCache.reset();
await registry.reset();
});
afterEach(async () => {
// Clean up: clear all test keys from Redis
if (keyvRedisClient) {
const pattern = '*MCPServersInitializer-IntegrationTest*';
if ('scanIterator' in keyvRedisClient) {
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
await keyvRedisClient.del(key);
}
}
}
jest.restoreAllMocks();
});
afterAll(async () => {
// Resign as leader
if (leaderInstance) await leaderInstance.resign();
// Close Redis connection
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
});
describe('initialize()', () => {
it('should reset registry and status cache before initialization', async () => {
// Pre-populate registry with some old servers
await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server);
await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server);
// Initialize with new configs (this should reset first)
await MCPServersInitializer.initialize(testConfigs);
// Verify old servers are gone
expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined();
expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined();
// Verify new servers are present
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
expect(await registryStatusCache.isInitialized()).toBe(true);
});
it('should skip initialization if already initialized', async () => {
// First initialization
await MCPServersInitializer.initialize(testConfigs);
// Clear mock calls
jest.clearAllMocks();
// Second initialization should skip due to static flag
await MCPServersInitializer.initialize(testConfigs);
// Verify inspect was not called again
expect(MCPServerInspector.inspect).not.toHaveBeenCalled();
});
it('should add disabled servers to sharedUserServers', async () => {
await MCPServersInitializer.initialize(testConfigs);
const disabledServer = await registry.sharedUserServers.get('disabled_server');
expect(disabledServer).toBeDefined();
expect(disabledServer).toMatchObject({
...testParsedConfigs.disabled_server,
_processedByInspector: true,
});
});
it('should add OAuth servers to sharedUserServers', async () => {
await MCPServersInitializer.initialize(testConfigs);
const oauthServer = await registry.sharedUserServers.get('oauth_server');
expect(oauthServer).toBeDefined();
expect(oauthServer).toMatchObject({
...testParsedConfigs.oauth_server,
_processedByInspector: true,
});
});
it('should add enabled non-OAuth servers to sharedAppServers', async () => {
await MCPServersInitializer.initialize(testConfigs);
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
expect(fileToolsServer).toBeDefined();
expect(fileToolsServer).toMatchObject({
...testParsedConfigs.file_tools_server,
_processedByInspector: true,
});
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
expect(searchToolsServer).toBeDefined();
expect(searchToolsServer).toMatchObject({
...testParsedConfigs.search_tools_server,
_processedByInspector: true,
});
});
it('should successfully initialize all servers', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify all servers were added to appropriate registries
expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined();
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined();
});
it('should handle inspection failures gracefully', async () => {
// Mock inspection failure for one server
jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => {
if (serverName === 'file_tools_server') {
throw new Error('Inspection failed');
}
return {
...testParsedConfigs[serverName],
_processedByInspector: true,
} as unknown as t.ParsedServerConfig;
});
await MCPServersInitializer.initialize(testConfigs);
// Verify other servers were still processed
const disabledServer = await registry.sharedUserServers.get('disabled_server');
expect(disabledServer).toBeDefined();
const oauthServer = await registry.sharedUserServers.get('oauth_server');
expect(oauthServer).toBeDefined();
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
expect(searchToolsServer).toBeDefined();
// Verify file_tools_server was not added (due to inspection failure)
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
expect(fileToolsServer).toBeUndefined();
});
it('should set initialized status after completion', async () => {
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
});
});
});

View File

@@ -0,0 +1,292 @@
import { logger } from '@librechat/data-schemas';
import * as t from '~/mcp/types';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
import { MCPConnection } from '~/mcp/connection';
import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache';
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry';
// Mock external dependencies
jest.mock('../../MCPConnectionFactory');
jest.mock('../../connection');
jest.mock('../../registry/MCPServerInspector');
jest.mock('~/cluster', () => ({
isLeader: jest.fn().mockResolvedValue(true),
}));
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
const mockLogger = logger as jest.Mocked<typeof logger>;
const mockInspect = MCPServerInspector.inspect as jest.MockedFunction<
typeof MCPServerInspector.inspect
>;
describe('MCPServersInitializer', () => {
let mockConnection: jest.Mocked<MCPConnection>;
const testConfigs: t.MCPServers = {
disabled_server: {
type: 'stdio',
command: 'node',
args: ['disabled.js'],
startup: false,
},
oauth_server: {
type: 'streamable-http',
url: 'https://api.example.com/mcp',
},
file_tools_server: {
type: 'stdio',
command: 'node',
args: ['tools.js'],
},
search_tools_server: {
type: 'stdio',
command: 'node',
args: ['instructions.js'],
},
};
const testParsedConfigs: Record<string, t.ParsedServerConfig> = {
disabled_server: {
type: 'stdio',
command: 'node',
args: ['disabled.js'],
startup: false,
requiresOAuth: false,
},
oauth_server: {
type: 'streamable-http',
url: 'https://api.example.com/mcp',
requiresOAuth: true,
},
file_tools_server: {
type: 'stdio',
command: 'node',
args: ['tools.js'],
requiresOAuth: false,
serverInstructions: 'Instructions for file_tools_server',
tools: 'file_read, file_write',
capabilities: '{"tools":{"listChanged":true}}',
toolFunctions: {
file_read_mcp_file_tools_server: {
type: 'function',
function: {
name: 'file_read_mcp_file_tools_server',
description: 'Read a file',
parameters: { type: 'object' },
},
},
},
},
search_tools_server: {
type: 'stdio',
command: 'node',
args: ['instructions.js'],
requiresOAuth: false,
serverInstructions: 'Instructions for search_tools_server',
capabilities: '{"tools":{"listChanged":true}}',
tools: 'search',
toolFunctions: {
search_mcp_search_tools_server: {
type: 'function',
function: {
name: 'search_mcp_search_tools_server',
description: 'Search tool',
parameters: { type: 'object' },
},
},
},
},
};
beforeEach(async () => {
// Setup MCPConnection mock
mockConnection = {
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
// Setup MCPConnectionFactory mock
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
// Mock MCPServerInspector.inspect to return parsed config
mockInspect.mockImplementation(async (serverName: string) => {
return {
...testParsedConfigs[serverName],
_processedByInspector: true,
} as unknown as t.ParsedServerConfig;
});
// Reset caches before each test
await registryStatusCache.reset();
await registry.sharedAppServers.reset();
await registry.sharedUserServers.reset();
jest.clearAllMocks();
});
afterEach(() => {
jest.restoreAllMocks();
});
describe('initialize()', () => {
it('should reset registry and status cache before initialization', async () => {
// Pre-populate registry with some old servers
await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server);
await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server);
// Initialize with new configs (this should reset first)
await MCPServersInitializer.initialize(testConfigs);
// Verify old servers are gone
expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined();
expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined();
// Verify new servers are present
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
expect(await registryStatusCache.isInitialized()).toBe(true);
});
it('should skip initialization if already initialized (Redis flag)', async () => {
// First initialization
await MCPServersInitializer.initialize(testConfigs);
jest.clearAllMocks();
// Second initialization should skip due to Redis cache flag
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).not.toHaveBeenCalled();
});
it('should process all server configs through inspector', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify all configs were processed by inspector (without connection parameter)
expect(mockInspect).toHaveBeenCalledTimes(4);
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
expect(mockInspect).toHaveBeenCalledWith(
'search_tools_server',
testConfigs.search_tools_server,
);
});
it('should add disabled servers to sharedUserServers', async () => {
await MCPServersInitializer.initialize(testConfigs);
const disabledServer = await registry.sharedUserServers.get('disabled_server');
expect(disabledServer).toBeDefined();
expect(disabledServer).toMatchObject({
...testParsedConfigs.disabled_server,
_processedByInspector: true,
});
});
it('should add OAuth servers to sharedUserServers', async () => {
await MCPServersInitializer.initialize(testConfigs);
const oauthServer = await registry.sharedUserServers.get('oauth_server');
expect(oauthServer).toBeDefined();
expect(oauthServer).toMatchObject({
...testParsedConfigs.oauth_server,
_processedByInspector: true,
});
});
it('should add enabled non-OAuth servers to sharedAppServers', async () => {
await MCPServersInitializer.initialize(testConfigs);
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
expect(fileToolsServer).toBeDefined();
expect(fileToolsServer).toMatchObject({
...testParsedConfigs.file_tools_server,
_processedByInspector: true,
});
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
expect(searchToolsServer).toBeDefined();
expect(searchToolsServer).toMatchObject({
...testParsedConfigs.search_tools_server,
_processedByInspector: true,
});
});
it('should successfully initialize all servers', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify all servers were added to appropriate registries
expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined();
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined();
});
it('should handle inspection failures gracefully', async () => {
// Mock inspection failure for one server
mockInspect.mockImplementation(async (serverName: string) => {
if (serverName === 'file_tools_server') {
throw new Error('Inspection failed');
}
return {
...testParsedConfigs[serverName],
_processedByInspector: true,
} as unknown as t.ParsedServerConfig;
});
await MCPServersInitializer.initialize(testConfigs);
// Verify other servers were still processed
const disabledServer = await registry.sharedUserServers.get('disabled_server');
expect(disabledServer).toBeDefined();
const oauthServer = await registry.sharedUserServers.get('oauth_server');
expect(oauthServer).toBeDefined();
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
expect(searchToolsServer).toBeDefined();
// Verify file_tools_server was not added (due to inspection failure)
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
expect(fileToolsServer).toBeUndefined();
});
it('should log server configuration after initialization', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify logging occurred for each server
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('[MCP][disabled_server]'),
);
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('[MCP][oauth_server]'));
expect(mockLogger.info).toHaveBeenCalledWith(
expect.stringContaining('[MCP][file_tools_server]'),
);
});
it('should use Promise.allSettled for parallel server initialization', async () => {
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
await MCPServersInitializer.initialize(testConfigs);
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
expect(allSettledSpy).toHaveBeenCalledTimes(1);
allSettledSpy.mockRestore();
});
it('should set initialized status after completion', async () => {
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
});
});
});

View File

@@ -0,0 +1,227 @@
import { expect } from '@playwright/test';
import type * as t from '~/mcp/types';
/**
* Integration tests for MCPServersRegistry using Redis-backed cache.
* For unit tests using in-memory cache, see MCPServersRegistry.test.ts
*/
describe('MCPServersRegistry Redis Integration Tests', () => {
let registry: typeof import('../MCPServersRegistry').mcpServersRegistry;
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
const testParsedConfig: t.ParsedServerConfig = {
type: 'stdio',
command: 'node',
args: ['tools.js'],
requiresOAuth: false,
serverInstructions: 'Instructions for file_tools_server',
tools: 'file_read, file_write',
capabilities: '{"tools":{"listChanged":true}}',
toolFunctions: {
file_read_mcp_file_tools_server: {
type: 'function',
function: {
name: 'file_read_mcp_file_tools_server',
description: 'Read a file',
parameters: { type: 'object' },
},
},
},
};
beforeAll(async () => {
// Set up environment variables for Redis (only if not already set)
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX =
process.env.REDIS_KEY_PREFIX ?? 'MCPServersRegistry-IntegrationTest';
// Import modules after setting env vars
const registryModule = await import('../MCPServersRegistry');
const redisClients = await import('~/cache/redisClients');
const leaderElectionModule = await import('~/cluster/LeaderElection');
registry = registryModule.mcpServersRegistry;
keyvRedisClient = redisClients.keyvRedisClient;
LeaderElection = leaderElectionModule.LeaderElection;
// Ensure Redis is connected
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
// Wait for Redis to be ready
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
// Become leader so we can perform write operations
leaderInstance = new LeaderElection();
const isLeader = await leaderInstance.isLeader();
expect(isLeader).toBe(true);
});
afterEach(async () => {
// Clean up: reset registry to clear all test data
await registry.reset();
// Also clean up any remaining test keys from Redis
if (keyvRedisClient) {
const pattern = '*MCPServersRegistry-IntegrationTest*';
if ('scanIterator' in keyvRedisClient) {
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
await keyvRedisClient.del(key);
}
}
}
});
afterAll(async () => {
// Resign as leader
if (leaderInstance) await leaderInstance.resign();
// Close Redis connection
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
});
describe('private user servers', () => {
it('should add and remove private user server', async () => {
const userId = 'user123';
const serverName = 'private_server';
// Add private user server
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
// Verify server was added
const retrievedConfig = await registry.getServerConfig(serverName, userId);
expect(retrievedConfig).toEqual(testParsedConfig);
// Remove private user server
await registry.removePrivateUserServer(userId, serverName);
// Verify server was removed
const configAfterRemoval = await registry.getServerConfig(serverName, userId);
expect(configAfterRemoval).toBeUndefined();
});
it('should throw error when adding duplicate private user server', async () => {
const userId = 'user123';
const serverName = 'private_server';
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
await expect(
registry.addPrivateUserServer(userId, serverName, testParsedConfig),
).rejects.toThrow(
'Server "private_server" already exists in cache. Use update() to modify existing configs.',
);
});
it('should update an existing private user server', async () => {
const userId = 'user123';
const serverName = 'private_server';
const updatedConfig: t.ParsedServerConfig = {
type: 'stdio',
command: 'python',
args: ['updated.py'],
requiresOAuth: true,
};
// Add private user server
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
// Update the server config
await registry.updatePrivateUserServer(userId, serverName, updatedConfig);
// Verify server was updated
const retrievedConfig = await registry.getServerConfig(serverName, userId);
expect(retrievedConfig).toEqual(updatedConfig);
});
it('should throw error when updating non-existent server', async () => {
const userId = 'user123';
const serverName = 'private_server';
// Add a user cache first
await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig);
await expect(
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
).rejects.toThrow(
'Server "private_server" does not exist in cache. Use add() to create new configs.',
);
});
it('should throw error when updating server for non-existent user', async () => {
const userId = 'nonexistent_user';
const serverName = 'private_server';
await expect(
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
).rejects.toThrow('No private servers found for user "nonexistent_user".');
});
});
describe('getAllServerConfigs', () => {
it('should return correct servers based on userId', async () => {
// Add servers to all three caches
await registry.sharedAppServers.add('app_server', testParsedConfig);
await registry.sharedUserServers.add('user_server', testParsedConfig);
await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig);
await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig);
// Without userId: should return only shared app + shared user servers
const configsNoUser = await registry.getAllServerConfigs();
expect(Object.keys(configsNoUser)).toHaveLength(2);
expect(configsNoUser).toHaveProperty('app_server');
expect(configsNoUser).toHaveProperty('user_server');
// With userId 'abc': should return shared app + shared user + abc's private servers
const configsAbc = await registry.getAllServerConfigs('abc');
expect(Object.keys(configsAbc)).toHaveLength(3);
expect(configsAbc).toHaveProperty('app_server');
expect(configsAbc).toHaveProperty('user_server');
expect(configsAbc).toHaveProperty('abc_private_server');
// With userId 'xyz': should return shared app + shared user + xyz's private servers
const configsXyz = await registry.getAllServerConfigs('xyz');
expect(Object.keys(configsXyz)).toHaveLength(3);
expect(configsXyz).toHaveProperty('app_server');
expect(configsXyz).toHaveProperty('user_server');
expect(configsXyz).toHaveProperty('xyz_private_server');
});
});
describe('reset', () => {
it('should clear all servers from all caches (shared app, shared user, and private user)', async () => {
const userId = 'user123';
// Add servers to all three caches
await registry.sharedAppServers.add('app_server', testParsedConfig);
await registry.sharedUserServers.add('user_server', testParsedConfig);
await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig);
// Verify all servers are accessible before reset
const appConfigBefore = await registry.getServerConfig('app_server');
const userConfigBefore = await registry.getServerConfig('user_server');
const privateConfigBefore = await registry.getServerConfig('private_server', userId);
const allConfigsBefore = await registry.getAllServerConfigs(userId);
expect(appConfigBefore).toEqual(testParsedConfig);
expect(userConfigBefore).toEqual(testParsedConfig);
expect(privateConfigBefore).toEqual(testParsedConfig);
expect(Object.keys(allConfigsBefore)).toHaveLength(3);
// Reset everything
await registry.reset();
// Verify all servers are cleared after reset
const appConfigAfter = await registry.getServerConfig('app_server');
const userConfigAfter = await registry.getServerConfig('user_server');
const privateConfigAfter = await registry.getServerConfig('private_server', userId);
const allConfigsAfter = await registry.getAllServerConfigs(userId);
expect(appConfigAfter).toBeUndefined();
expect(userConfigAfter).toBeUndefined();
expect(privateConfigAfter).toBeUndefined();
expect(Object.keys(allConfigsAfter)).toHaveLength(0);
});
});
});

View File

@@ -0,0 +1,175 @@
import * as t from '~/mcp/types';
import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry';
/**
* Unit tests for MCPServersRegistry using in-memory cache.
* For integration tests using Redis-backed cache, see MCPServersRegistry.cache_integration.spec.ts
*/
describe('MCPServersRegistry', () => {
const testParsedConfig: t.ParsedServerConfig = {
type: 'stdio',
command: 'node',
args: ['tools.js'],
requiresOAuth: false,
serverInstructions: 'Instructions for file_tools_server',
tools: 'file_read, file_write',
capabilities: '{"tools":{"listChanged":true}}',
toolFunctions: {
file_read_mcp_file_tools_server: {
type: 'function',
function: {
name: 'file_read_mcp_file_tools_server',
description: 'Read a file',
parameters: { type: 'object' },
},
},
},
};
beforeEach(async () => {
await registry.reset();
});
describe('private user servers', () => {
it('should add and remove private user server', async () => {
const userId = 'user123';
const serverName = 'private_server';
// Add private user server
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
// Verify server was added
const retrievedConfig = await registry.getServerConfig(serverName, userId);
expect(retrievedConfig).toEqual(testParsedConfig);
// Remove private user server
await registry.removePrivateUserServer(userId, serverName);
// Verify server was removed
const configAfterRemoval = await registry.getServerConfig(serverName, userId);
expect(configAfterRemoval).toBeUndefined();
});
it('should throw error when adding duplicate private user server', async () => {
const userId = 'user123';
const serverName = 'private_server';
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
await expect(
registry.addPrivateUserServer(userId, serverName, testParsedConfig),
).rejects.toThrow(
'Server "private_server" already exists in cache. Use update() to modify existing configs.',
);
});
it('should update an existing private user server', async () => {
const userId = 'user123';
const serverName = 'private_server';
const updatedConfig: t.ParsedServerConfig = {
type: 'stdio',
command: 'python',
args: ['updated.py'],
requiresOAuth: true,
};
// Add private user server
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
// Update the server config
await registry.updatePrivateUserServer(userId, serverName, updatedConfig);
// Verify server was updated
const retrievedConfig = await registry.getServerConfig(serverName, userId);
expect(retrievedConfig).toEqual(updatedConfig);
});
it('should throw error when updating non-existent server', async () => {
const userId = 'user123';
const serverName = 'private_server';
// Add a user cache first
await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig);
await expect(
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
).rejects.toThrow(
'Server "private_server" does not exist in cache. Use add() to create new configs.',
);
});
it('should throw error when updating server for non-existent user', async () => {
const userId = 'nonexistent_user';
const serverName = 'private_server';
await expect(
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
).rejects.toThrow('No private servers found for user "nonexistent_user".');
});
});
describe('getAllServerConfigs', () => {
it('should return correct servers based on userId', async () => {
// Add servers to all three caches
await registry.sharedAppServers.add('app_server', testParsedConfig);
await registry.sharedUserServers.add('user_server', testParsedConfig);
await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig);
await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig);
// Without userId: should return only shared app + shared user servers
const configsNoUser = await registry.getAllServerConfigs();
expect(Object.keys(configsNoUser)).toHaveLength(2);
expect(configsNoUser).toHaveProperty('app_server');
expect(configsNoUser).toHaveProperty('user_server');
// With userId 'abc': should return shared app + shared user + abc's private servers
const configsAbc = await registry.getAllServerConfigs('abc');
expect(Object.keys(configsAbc)).toHaveLength(3);
expect(configsAbc).toHaveProperty('app_server');
expect(configsAbc).toHaveProperty('user_server');
expect(configsAbc).toHaveProperty('abc_private_server');
// With userId 'xyz': should return shared app + shared user + xyz's private servers
const configsXyz = await registry.getAllServerConfigs('xyz');
expect(Object.keys(configsXyz)).toHaveLength(3);
expect(configsXyz).toHaveProperty('app_server');
expect(configsXyz).toHaveProperty('user_server');
expect(configsXyz).toHaveProperty('xyz_private_server');
});
});
describe('reset', () => {
it('should clear all servers from all caches (shared app, shared user, and private user)', async () => {
const userId = 'user123';
// Add servers to all three caches
await registry.sharedAppServers.add('app_server', testParsedConfig);
await registry.sharedUserServers.add('user_server', testParsedConfig);
await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig);
// Verify all servers are accessible before reset
const appConfigBefore = await registry.getServerConfig('app_server');
const userConfigBefore = await registry.getServerConfig('user_server');
const privateConfigBefore = await registry.getServerConfig('private_server', userId);
const allConfigsBefore = await registry.getAllServerConfigs(userId);
expect(appConfigBefore).toEqual(testParsedConfig);
expect(userConfigBefore).toEqual(testParsedConfig);
expect(privateConfigBefore).toEqual(testParsedConfig);
expect(Object.keys(allConfigsBefore)).toHaveLength(3);
// Reset everything
await registry.reset();
// Verify all servers are cleared after reset
const appConfigAfter = await registry.getServerConfig('app_server');
const userConfigAfter = await registry.getServerConfig('user_server');
const privateConfigAfter = await registry.getServerConfig('private_server', userId);
const allConfigsAfter = await registry.getAllServerConfigs(userId);
expect(appConfigAfter).toBeUndefined();
expect(userConfigAfter).toBeUndefined();
expect(privateConfigAfter).toBeUndefined();
expect(Object.keys(allConfigsAfter)).toHaveLength(0);
});
});
});

View File

@@ -0,0 +1,55 @@
import type { MCPConnection } from '~/mcp/connection';
/**
* Creates a single mock MCP connection for testing.
* The connection has a client with mocked methods that return server-specific data.
* @param serverName - Name of the server to create mock connection for
* @returns Mocked MCPConnection instance
*/
export function createMockConnection(serverName: string): jest.Mocked<MCPConnection> {
const mockClient = {
getInstructions: jest.fn().mockReturnValue(`instructions for ${serverName}`),
getServerCapabilities: jest.fn().mockReturnValue({
tools: { listChanged: true },
resources: { listChanged: true },
prompts: { get: `getPrompts for ${serverName}` },
}),
listTools: jest.fn().mockResolvedValue({
tools: [
{
name: 'listFiles',
description: `Description for ${serverName}'s listFiles tool`,
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' },
},
},
},
],
}),
};
return {
client: mockClient,
disconnect: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<MCPConnection>;
}
/**
* Creates mock MCP connections for testing.
* Each connection has a client with mocked methods that return server-specific data.
* @param serverNames - Array of server names to create mock connections for
* @returns Map of server names to mocked MCPConnection instances
*/
export function createMockConnectionsMap(
serverNames: string[],
): Map<string, jest.Mocked<MCPConnection>> {
const mockConnections = new Map<string, jest.Mocked<MCPConnection>>();
serverNames.forEach((serverName) => {
mockConnections.set(serverName, createMockConnection(serverName));
});
return mockConnections;
}

View File

@@ -0,0 +1,26 @@
import type Keyv from 'keyv';
import { isLeader } from '~/cluster';
/**
* Base class for MCP registry caches that require distributed leader coordination.
* Provides helper methods for leader-only operations and success validation.
* All concrete implementations must provide their own Keyv cache instance.
*/
export abstract class BaseRegistryCache {
protected readonly PREFIX = 'MCP::ServersRegistry';
protected abstract readonly cache: Keyv;
protected async leaderCheck(action: string): Promise<void> {
if (!(await isLeader())) throw new Error(`Only leader can ${action}.`);
}
protected successCheck(action: string, success: boolean): true {
if (!success) throw new Error(`Failed to ${action} in cache.`);
return true;
}
public async reset(): Promise<void> {
await this.leaderCheck(`reset ${this.cache.namespace} cache`);
await this.cache.clear();
}
}

View File

@@ -0,0 +1,37 @@
import { standardCache } from '~/cache';
import { BaseRegistryCache } from './BaseRegistryCache';
// Status keys
const INITIALIZED = 'INITIALIZED';
/**
* Cache for tracking MCP Servers Registry metadata and status across distributed instances.
* Uses Redis-backed storage to coordinate state between leader and follower nodes.
* Currently, tracks initialization status to ensure only the leader performs initialization
* while followers wait for completion. Designed to be extended with additional registry
* metadata as needed (e.g., last update timestamps, version info, health status).
* This cache is only meant to be used internally by registry management components.
*/
class RegistryStatusCache extends BaseRegistryCache {
protected readonly cache = standardCache(`${this.PREFIX}::Status`);
public async isInitialized(): Promise<boolean> {
return (await this.get(INITIALIZED)) === true;
}
public async setInitialized(value: boolean): Promise<void> {
await this.set(INITIALIZED, value);
}
private async get<T = unknown>(key: string): Promise<T | undefined> {
return this.cache.get(key);
}
private async set(key: string, value: string | number | boolean, ttl?: number): Promise<void> {
await this.leaderCheck('set MCP Servers Registry status');
const success = await this.cache.set(key, value, ttl);
this.successCheck(`set status key "${key}"`, success);
}
}
export const registryStatusCache = new RegistryStatusCache();

View File

@@ -0,0 +1,31 @@
import { cacheConfig } from '~/cache';
import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory';
import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis';
export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis;
/**
* Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode.
* Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config.
* In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache.
* In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination.
* Provides a unified interface regardless of the underlying storage mechanism.
*/
export class ServerConfigsCacheFactory {
/**
* Create a ServerConfigsCache instance.
* Returns Redis implementation if Redis is configured, otherwise in-memory implementation.
*
* @param owner - The owner of the cache (e.g., 'user', 'global') - only used for Redis namespacing
* @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis)
* @returns ServerConfigsCache instance
*/
static create(owner: string, leaderOnly: boolean): ServerConfigsCache {
if (cacheConfig.USE_REDIS) {
return new ServerConfigsCacheRedis(owner, leaderOnly);
}
// In-memory mode uses a simple Map - doesn't need owner/namespace
return new ServerConfigsCacheInMemory();
}
}

View File

@@ -0,0 +1,46 @@
import { ParsedServerConfig } from '~/mcp/types';
/**
* In-memory implementation of MCP server configurations cache for single-instance deployments.
* Uses a native JavaScript Map for fast, local storage without Redis dependencies.
* Suitable for development environments or single-server production deployments.
* Does not require leader checks or distributed coordination since data is instance-local.
* Data is lost on server restart and not shared across multiple server instances.
*/
export class ServerConfigsCacheInMemory {
private readonly cache: Map<string, ParsedServerConfig> = new Map();
public async add(serverName: string, config: ParsedServerConfig): Promise<void> {
if (this.cache.has(serverName))
throw new Error(
`Server "${serverName}" already exists in cache. Use update() to modify existing configs.`,
);
this.cache.set(serverName, config);
}
public async update(serverName: string, config: ParsedServerConfig): Promise<void> {
if (!this.cache.has(serverName))
throw new Error(
`Server "${serverName}" does not exist in cache. Use add() to create new configs.`,
);
this.cache.set(serverName, config);
}
public async remove(serverName: string): Promise<void> {
if (!this.cache.delete(serverName)) {
throw new Error(`Failed to remove server "${serverName}" in cache.`);
}
}
public async get(serverName: string): Promise<ParsedServerConfig | undefined> {
return this.cache.get(serverName);
}
public async getAll(): Promise<Record<string, ParsedServerConfig>> {
return Object.fromEntries(this.cache);
}
public async reset(): Promise<void> {
this.cache.clear();
}
}

View File

@@ -0,0 +1,80 @@
import type Keyv from 'keyv';
import { fromPairs } from 'lodash';
import { standardCache, keyvRedisClient } from '~/cache';
import { ParsedServerConfig } from '~/mcp/types';
import { BaseRegistryCache } from './BaseRegistryCache';
/**
* Redis-backed implementation of MCP server configurations cache for distributed deployments.
* Stores server configs in Redis with namespace isolation by owner (App, User, or specific user ID).
* Enables data sharing across multiple server instances in a cluster environment.
* Supports optional leader-only write operations to prevent race conditions during initialization.
* Data persists across server restarts and is accessible from any instance in the cluster.
*/
export class ServerConfigsCacheRedis extends BaseRegistryCache {
protected readonly cache: Keyv;
private readonly owner: string;
private readonly leaderOnly: boolean;
constructor(owner: string, leaderOnly: boolean) {
super();
this.owner = owner;
this.leaderOnly = leaderOnly;
this.cache = standardCache(`${this.PREFIX}::Servers::${owner}`);
}
public async add(serverName: string, config: ParsedServerConfig): Promise<void> {
if (this.leaderOnly) await this.leaderCheck(`add ${this.owner} MCP servers`);
const exists = await this.cache.has(serverName);
if (exists)
throw new Error(
`Server "${serverName}" already exists in cache. Use update() to modify existing configs.`,
);
const success = await this.cache.set(serverName, config);
this.successCheck(`add ${this.owner} server "${serverName}"`, success);
}
public async update(serverName: string, config: ParsedServerConfig): Promise<void> {
if (this.leaderOnly) await this.leaderCheck(`update ${this.owner} MCP servers`);
const exists = await this.cache.has(serverName);
if (!exists)
throw new Error(
`Server "${serverName}" does not exist in cache. Use add() to create new configs.`,
);
const success = await this.cache.set(serverName, config);
this.successCheck(`update ${this.owner} server "${serverName}"`, success);
}
public async remove(serverName: string): Promise<void> {
if (this.leaderOnly) await this.leaderCheck(`remove ${this.owner} MCP servers`);
const success = await this.cache.delete(serverName);
this.successCheck(`remove ${this.owner} server "${serverName}"`, success);
}
public async get(serverName: string): Promise<ParsedServerConfig | undefined> {
return this.cache.get(serverName);
}
public async getAll(): Promise<Record<string, ParsedServerConfig>> {
// Use Redis SCAN iterator directly (non-blocking, production-ready)
// Note: Keyv uses a single colon ':' between namespace and key, even if GLOBAL_PREFIX_SEPARATOR is '::'
const pattern = `*${this.cache.namespace}:*`;
const entries: Array<[string, ParsedServerConfig]> = [];
// Use scanIterator from Redis client
if (keyvRedisClient && 'scanIterator' in keyvRedisClient) {
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
// Extract the actual key name (last part after final colon)
// Full key format: "prefix::namespace:keyName"
const lastColonIndex = key.lastIndexOf(':');
const keyName = key.substring(lastColonIndex + 1);
const value = await this.cache.get(keyName);
if (value) {
entries.push([keyName, value as ParsedServerConfig]);
}
}
}
return fromPairs(entries);
}
}

View File

@@ -0,0 +1,73 @@
import { expect } from '@playwright/test';
describe('RegistryStatusCache Integration Tests', () => {
let registryStatusCache: typeof import('../RegistryStatusCache').registryStatusCache;
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
beforeAll(async () => {
// Set up environment variables for Redis (only if not already set)
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX =
process.env.REDIS_KEY_PREFIX ?? 'RegistryStatusCache-IntegrationTest';
// Import modules after setting env vars
const statusCacheModule = await import('../RegistryStatusCache');
const redisClients = await import('~/cache/redisClients');
const leaderElectionModule = await import('~/cluster/LeaderElection');
registryStatusCache = statusCacheModule.registryStatusCache;
keyvRedisClient = redisClients.keyvRedisClient;
LeaderElection = leaderElectionModule.LeaderElection;
// Ensure Redis is connected
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
// Wait for Redis to be ready
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
// Become leader so we can perform write operations
leaderInstance = new LeaderElection();
const isLeader = await leaderInstance.isLeader();
expect(isLeader).toBe(true);
});
afterEach(async () => {
// Clean up: clear all test keys from Redis
if (keyvRedisClient) {
const pattern = '*RegistryStatusCache-IntegrationTest*';
if ('scanIterator' in keyvRedisClient) {
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
await keyvRedisClient.del(key);
}
}
}
});
afterAll(async () => {
// Resign as leader
if (leaderInstance) await leaderInstance.resign();
// Close Redis connection
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
});
describe('Initialization status tracking', () => {
it('should return false for isInitialized when not set', async () => {
const initialized = await registryStatusCache.isInitialized();
expect(initialized).toBe(false);
});
it('should set and get initialized status', async () => {
await registryStatusCache.setInitialized(true);
const initialized = await registryStatusCache.isInitialized();
expect(initialized).toBe(true);
await registryStatusCache.setInitialized(false);
const uninitialized = await registryStatusCache.isInitialized();
expect(uninitialized).toBe(false);
});
});
});

View File

@@ -0,0 +1,70 @@
import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory';
import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory';
import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis';
import { cacheConfig } from '~/cache';
// Mock the cache implementations
jest.mock('../ServerConfigsCacheInMemory');
jest.mock('../ServerConfigsCacheRedis');
// Mock the cache config module
jest.mock('~/cache', () => ({
cacheConfig: {
USE_REDIS: false,
},
}));
describe('ServerConfigsCacheFactory', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('create()', () => {
it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => {
// Arrange
cacheConfig.USE_REDIS = true;
// Act
const cache = ServerConfigsCacheFactory.create('TestOwner', true);
// Assert
expect(cache).toBeInstanceOf(ServerConfigsCacheRedis);
expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('TestOwner', true);
});
it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => {
// Arrange
cacheConfig.USE_REDIS = false;
// Act
const cache = ServerConfigsCacheFactory.create('TestOwner', false);
// Assert
expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory);
expect(ServerConfigsCacheInMemory).toHaveBeenCalled();
});
it('should pass correct parameters to ServerConfigsCacheRedis', () => {
// Arrange
cacheConfig.USE_REDIS = true;
// Act
ServerConfigsCacheFactory.create('App', true);
// Assert
expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true);
});
it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => {
// Arrange
cacheConfig.USE_REDIS = false;
// Act
ServerConfigsCacheFactory.create('User', false);
// Assert
// In-memory cache doesn't use owner/leaderOnly parameters
expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith();
});
});
});

View File

@@ -0,0 +1,173 @@
import { expect } from '@playwright/test';
import { ParsedServerConfig } from '~/mcp/types';
describe('ServerConfigsCacheInMemory Integration Tests', () => {
let ServerConfigsCacheInMemory: typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory;
let cache: InstanceType<
typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory
>;
// Test data
const mockConfig1: ParsedServerConfig = {
command: 'node',
args: ['server1.js'],
env: { TEST: 'value1' },
};
const mockConfig2: ParsedServerConfig = {
command: 'python',
args: ['server2.py'],
env: { TEST: 'value2' },
};
const mockConfig3: ParsedServerConfig = {
command: 'node',
args: ['server3.js'],
url: 'http://localhost:3000',
requiresOAuth: true,
};
beforeAll(async () => {
// Import modules
const cacheModule = await import('../ServerConfigsCacheInMemory');
ServerConfigsCacheInMemory = cacheModule.ServerConfigsCacheInMemory;
});
beforeEach(() => {
// Create a fresh instance for each test
cache = new ServerConfigsCacheInMemory();
});
describe('add and get operations', () => {
it('should add and retrieve a server config', async () => {
await cache.add('server1', mockConfig1);
const result = await cache.get('server1');
expect(result).toEqual(mockConfig1);
});
it('should return undefined for non-existent server', async () => {
const result = await cache.get('non-existent');
expect(result).toBeUndefined();
});
it('should throw error when adding duplicate server', async () => {
await cache.add('server1', mockConfig1);
await expect(cache.add('server1', mockConfig2)).rejects.toThrow(
'Server "server1" already exists in cache. Use update() to modify existing configs.',
);
});
it('should handle multiple server configs', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
await cache.add('server3', mockConfig3);
const result1 = await cache.get('server1');
const result2 = await cache.get('server2');
const result3 = await cache.get('server3');
expect(result1).toEqual(mockConfig1);
expect(result2).toEqual(mockConfig2);
expect(result3).toEqual(mockConfig3);
});
});
describe('getAll operation', () => {
it('should return empty object when no servers exist', async () => {
const result = await cache.getAll();
expect(result).toEqual({});
});
it('should return all server configs', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
await cache.add('server3', mockConfig3);
const result = await cache.getAll();
expect(result).toEqual({
server1: mockConfig1,
server2: mockConfig2,
server3: mockConfig3,
});
});
it('should reflect updates in getAll', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
let result = await cache.getAll();
expect(Object.keys(result).length).toBe(2);
await cache.add('server3', mockConfig3);
result = await cache.getAll();
expect(Object.keys(result).length).toBe(3);
expect(result.server3).toEqual(mockConfig3);
});
});
describe('update operation', () => {
it('should update an existing server config', async () => {
await cache.add('server1', mockConfig1);
expect(await cache.get('server1')).toEqual(mockConfig1);
await cache.update('server1', mockConfig2);
const result = await cache.get('server1');
expect(result).toEqual(mockConfig2);
});
it('should throw error when updating non-existent server', async () => {
await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow(
'Server "non-existent" does not exist in cache. Use add() to create new configs.',
);
});
it('should reflect updates in getAll', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
await cache.update('server1', mockConfig3);
const result = await cache.getAll();
expect(result.server1).toEqual(mockConfig3);
expect(result.server2).toEqual(mockConfig2);
});
});
describe('remove operation', () => {
it('should remove an existing server config', async () => {
await cache.add('server1', mockConfig1);
expect(await cache.get('server1')).toEqual(mockConfig1);
await cache.remove('server1');
expect(await cache.get('server1')).toBeUndefined();
});
it('should throw error when removing non-existent server', async () => {
await expect(cache.remove('non-existent')).rejects.toThrow(
'Failed to remove server "non-existent" in cache.',
);
});
it('should remove server from getAll results', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
let result = await cache.getAll();
expect(Object.keys(result).length).toBe(2);
await cache.remove('server1');
result = await cache.getAll();
expect(Object.keys(result).length).toBe(1);
expect(result.server1).toBeUndefined();
expect(result.server2).toEqual(mockConfig2);
});
it('should allow re-adding a removed server', async () => {
await cache.add('server1', mockConfig1);
await cache.remove('server1');
await cache.add('server1', mockConfig3);
const result = await cache.get('server1');
expect(result).toEqual(mockConfig3);
});
});
});

View File

@@ -0,0 +1,278 @@
import { expect } from '@playwright/test';
import { ParsedServerConfig } from '~/mcp/types';
describe('ServerConfigsCacheRedis Integration Tests', () => {
let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis;
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
let checkIsLeader: () => Promise<boolean>;
let cache: InstanceType<typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis>;
// Test data
const mockConfig1: ParsedServerConfig = {
command: 'node',
args: ['server1.js'],
env: { TEST: 'value1' },
};
const mockConfig2: ParsedServerConfig = {
command: 'python',
args: ['server2.py'],
env: { TEST: 'value2' },
};
const mockConfig3: ParsedServerConfig = {
command: 'node',
args: ['server3.js'],
url: 'http://localhost:3000',
requiresOAuth: true,
};
beforeAll(async () => {
// Set up environment variables for Redis (only if not already set)
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX =
process.env.REDIS_KEY_PREFIX ?? 'ServerConfigsCacheRedis-IntegrationTest';
// Import modules after setting env vars
const cacheModule = await import('../ServerConfigsCacheRedis');
const redisClients = await import('~/cache/redisClients');
const leaderElectionModule = await import('~/cluster/LeaderElection');
const clusterModule = await import('~/cluster');
ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis;
keyvRedisClient = redisClients.keyvRedisClient;
LeaderElection = leaderElectionModule.LeaderElection;
checkIsLeader = clusterModule.isLeader;
// Ensure Redis is connected
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
// Wait for Redis to be ready
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
// Clear any existing leader key to ensure clean state
await keyvRedisClient.del(LeaderElection.LEADER_KEY);
// Become leader so we can perform write operations (using default election instance)
const isLeader = await checkIsLeader();
expect(isLeader).toBe(true);
});
beforeEach(() => {
// Create a fresh instance for each test with leaderOnly=true
cache = new ServerConfigsCacheRedis('test-user', true);
});
afterEach(async () => {
// Clean up: clear all test keys from Redis
if (keyvRedisClient) {
const pattern = '*ServerConfigsCacheRedis-IntegrationTest*';
if ('scanIterator' in keyvRedisClient) {
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
await keyvRedisClient.del(key);
}
}
}
});
afterAll(async () => {
// Clear leader key to allow other tests to become leader
if (keyvRedisClient) await keyvRedisClient.del(LeaderElection.LEADER_KEY);
// Close Redis connection
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
});
describe('add and get operations', () => {
it('should add and retrieve a server config', async () => {
await cache.add('server1', mockConfig1);
const result = await cache.get('server1');
expect(result).toEqual(mockConfig1);
});
it('should return undefined for non-existent server', async () => {
const result = await cache.get('non-existent');
expect(result).toBeUndefined();
});
it('should throw error when adding duplicate server', async () => {
await cache.add('server1', mockConfig1);
await expect(cache.add('server1', mockConfig2)).rejects.toThrow(
'Server "server1" already exists in cache. Use update() to modify existing configs.',
);
});
it('should handle multiple server configs', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
await cache.add('server3', mockConfig3);
const result1 = await cache.get('server1');
const result2 = await cache.get('server2');
const result3 = await cache.get('server3');
expect(result1).toEqual(mockConfig1);
expect(result2).toEqual(mockConfig2);
expect(result3).toEqual(mockConfig3);
});
it('should isolate caches by owner namespace', async () => {
const userCache = new ServerConfigsCacheRedis('user1', true);
const globalCache = new ServerConfigsCacheRedis('global', true);
await userCache.add('server1', mockConfig1);
await globalCache.add('server1', mockConfig2);
const userResult = await userCache.get('server1');
const globalResult = await globalCache.get('server1');
expect(userResult).toEqual(mockConfig1);
expect(globalResult).toEqual(mockConfig2);
});
});
describe('getAll operation', () => {
it('should return empty object when no servers exist', async () => {
const result = await cache.getAll();
expect(result).toEqual({});
});
it('should return all server configs', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
await cache.add('server3', mockConfig3);
const result = await cache.getAll();
expect(result).toEqual({
server1: mockConfig1,
server2: mockConfig2,
server3: mockConfig3,
});
});
it('should reflect updates in getAll', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
let result = await cache.getAll();
expect(Object.keys(result).length).toBe(2);
await cache.add('server3', mockConfig3);
result = await cache.getAll();
expect(Object.keys(result).length).toBe(3);
expect(result.server3).toEqual(mockConfig3);
});
it('should only return configs for the specific owner', async () => {
const userCache = new ServerConfigsCacheRedis('user1', true);
const globalCache = new ServerConfigsCacheRedis('global', true);
await userCache.add('server1', mockConfig1);
await userCache.add('server2', mockConfig2);
await globalCache.add('server3', mockConfig3);
const userResult = await userCache.getAll();
const globalResult = await globalCache.getAll();
expect(Object.keys(userResult).length).toBe(2);
expect(Object.keys(globalResult).length).toBe(1);
expect(userResult.server1).toEqual(mockConfig1);
expect(userResult.server3).toBeUndefined();
expect(globalResult.server3).toEqual(mockConfig3);
});
});
describe('update operation', () => {
it('should update an existing server config', async () => {
await cache.add('server1', mockConfig1);
expect(await cache.get('server1')).toEqual(mockConfig1);
await cache.update('server1', mockConfig2);
const result = await cache.get('server1');
expect(result).toEqual(mockConfig2);
});
it('should throw error when updating non-existent server', async () => {
await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow(
'Server "non-existent" does not exist in cache. Use add() to create new configs.',
);
});
it('should reflect updates in getAll', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
await cache.update('server1', mockConfig3);
const result = await cache.getAll();
expect(result.server1).toEqual(mockConfig3);
expect(result.server2).toEqual(mockConfig2);
});
it('should only update in the specific owner namespace', async () => {
const userCache = new ServerConfigsCacheRedis('user1', true);
const globalCache = new ServerConfigsCacheRedis('global', true);
await userCache.add('server1', mockConfig1);
await globalCache.add('server1', mockConfig2);
await userCache.update('server1', mockConfig3);
expect(await userCache.get('server1')).toEqual(mockConfig3);
expect(await globalCache.get('server1')).toEqual(mockConfig2);
});
});
describe('remove operation', () => {
it('should remove an existing server config', async () => {
await cache.add('server1', mockConfig1);
expect(await cache.get('server1')).toEqual(mockConfig1);
await cache.remove('server1');
expect(await cache.get('server1')).toBeUndefined();
});
it('should throw error when removing non-existent server', async () => {
await expect(cache.remove('non-existent')).rejects.toThrow(
'Failed to remove test-user server "non-existent"',
);
});
it('should remove server from getAll results', async () => {
await cache.add('server1', mockConfig1);
await cache.add('server2', mockConfig2);
let result = await cache.getAll();
expect(Object.keys(result).length).toBe(2);
await cache.remove('server1');
result = await cache.getAll();
expect(Object.keys(result).length).toBe(1);
expect(result.server1).toBeUndefined();
expect(result.server2).toEqual(mockConfig2);
});
it('should allow re-adding a removed server', async () => {
await cache.add('server1', mockConfig1);
await cache.remove('server1');
await cache.add('server1', mockConfig3);
const result = await cache.get('server1');
expect(result).toEqual(mockConfig3);
});
it('should only remove from the specific owner namespace', async () => {
const userCache = new ServerConfigsCacheRedis('user1', true);
const globalCache = new ServerConfigsCacheRedis('global', true);
await userCache.add('server1', mockConfig1);
await globalCache.add('server1', mockConfig2);
await userCache.remove('server1');
expect(await userCache.get('server1')).toBeUndefined();
expect(await globalCache.get('server1')).toEqual(mockConfig2);
});
});
});

View File

@@ -151,6 +151,8 @@ export type ParsedServerConfig = MCPOptions & {
oauthMetadata?: Record<string, unknown> | null;
capabilities?: string;
tools?: string;
toolFunctions?: LCAvailableTools;
initDuration?: number;
};
export interface BasicConnectionOptions {

View File

@@ -25,6 +25,12 @@ export const genAzureEndpoint = ({
azureOpenAIApiInstanceName: string;
azureOpenAIApiDeploymentName: string;
}): string => {
// Support both old (.openai.azure.com) and new (.cognitiveservices.azure.com) endpoint formats
// If instanceName already includes a full domain, use it as-is
if (azureOpenAIApiInstanceName.includes('.azure.com')) {
return `https://${azureOpenAIApiInstanceName}/openai/deployments/${azureOpenAIApiDeploymentName}`;
}
// Legacy format for backward compatibility
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
};

View File

@@ -10,6 +10,7 @@ export * from './key';
export * from './llm';
export * from './math';
export * from './openid';
export * from './promise';
export * from './sanitizeTitle';
export * from './tempChatRetention';
export * from './text';

View File

@@ -0,0 +1,115 @@
import { withTimeout } from './promise';
describe('withTimeout', () => {
beforeEach(() => {
jest.clearAllTimers();
});
it('should resolve when promise completes before timeout', async () => {
const promise = Promise.resolve('success');
const result = await withTimeout(promise, 1000);
expect(result).toBe('success');
});
it('should reject when promise rejects before timeout', async () => {
const promise = Promise.reject(new Error('test error'));
await expect(withTimeout(promise, 1000)).rejects.toThrow('test error');
});
it('should timeout when promise takes too long', async () => {
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
await expect(withTimeout(promise, 100, 'Custom timeout message')).rejects.toThrow(
'Custom timeout message',
);
});
it('should use default error message when none provided', async () => {
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
await expect(withTimeout(promise, 100)).rejects.toThrow('Operation timed out after 100ms');
});
it('should clear timeout when promise resolves', async () => {
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout');
const promise = Promise.resolve('fast');
await withTimeout(promise, 1000);
expect(clearTimeoutSpy).toHaveBeenCalled();
clearTimeoutSpy.mockRestore();
});
it('should clear timeout when promise rejects', async () => {
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout');
const promise = Promise.reject(new Error('fail'));
await expect(withTimeout(promise, 1000)).rejects.toThrow('fail');
expect(clearTimeoutSpy).toHaveBeenCalled();
clearTimeoutSpy.mockRestore();
});
it('should handle multiple concurrent timeouts', async () => {
const promise1 = Promise.resolve('first');
const promise2 = new Promise((resolve) => setTimeout(() => resolve('second'), 50));
const promise3 = new Promise((resolve) => setTimeout(() => resolve('third'), 2000));
const [result1, result2] = await Promise.all([
withTimeout(promise1, 1000),
withTimeout(promise2, 1000),
]);
expect(result1).toBe('first');
expect(result2).toBe('second');
await expect(withTimeout(promise3, 100)).rejects.toThrow('Operation timed out after 100ms');
});
it('should work with async functions', async () => {
const asyncFunction = async () => {
await new Promise((resolve) => setTimeout(resolve, 10));
return 'async result';
};
const result = await withTimeout(asyncFunction(), 1000);
expect(result).toBe('async result');
});
it('should work with any return type', async () => {
const numberPromise = Promise.resolve(42);
const objectPromise = Promise.resolve({ key: 'value' });
const arrayPromise = Promise.resolve([1, 2, 3]);
expect(await withTimeout(numberPromise, 1000)).toBe(42);
expect(await withTimeout(objectPromise, 1000)).toEqual({ key: 'value' });
expect(await withTimeout(arrayPromise, 1000)).toEqual([1, 2, 3]);
});
it('should call logger when timeout occurs', async () => {
const loggerMock = jest.fn();
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
const errorMessage = 'Custom timeout with logger';
await expect(withTimeout(promise, 100, errorMessage, loggerMock)).rejects.toThrow(errorMessage);
expect(loggerMock).toHaveBeenCalledTimes(1);
expect(loggerMock).toHaveBeenCalledWith(errorMessage, expect.any(Error));
});
it('should not call logger when promise resolves', async () => {
const loggerMock = jest.fn();
const promise = Promise.resolve('success');
const result = await withTimeout(promise, 1000, 'Should not timeout', loggerMock);
expect(result).toBe('success');
expect(loggerMock).not.toHaveBeenCalled();
});
it('should work without logger parameter', async () => {
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
await expect(withTimeout(promise, 100, 'No logger provided')).rejects.toThrow(
'No logger provided',
);
});
});

View File

@@ -0,0 +1,42 @@
/**
* Wraps a promise with a timeout. If the promise doesn't resolve/reject within
* the specified time, it will be rejected with a timeout error.
*
* @param promise - The promise to wrap with a timeout
* @param timeoutMs - Timeout duration in milliseconds
* @param errorMessage - Custom error message for timeout (optional)
* @param logger - Optional logger function to log timeout errors (e.g., console.warn, logger.warn)
* @returns Promise that resolves/rejects with the original promise or times out
*
* @example
* ```typescript
* const result = await withTimeout(
* fetchData(),
* 5000,
* 'Failed to fetch data within 5 seconds',
* console.warn
* );
* ```
*/
export async function withTimeout<T>(
promise: Promise<T>,
timeoutMs: number,
errorMessage?: string,
logger?: (message: string, error: Error) => void,
): Promise<T> {
let timeoutId: NodeJS.Timeout;
const timeoutPromise = new Promise<never>((_, reject) => {
timeoutId = setTimeout(() => {
const error = new Error(errorMessage ?? `Operation timed out after ${timeoutMs}ms`);
if (logger) logger(error.message, error);
reject(error);
}, timeoutMs);
});
try {
return await Promise.race([promise, timeoutPromise]);
} finally {
clearTimeout(timeoutId!);
}
}