Compare commits
5 Commits
docs/azure
...
fix/avatar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9ef26ddd5 | ||
|
|
c63f2a634c | ||
|
|
39d83b705b | ||
|
|
e5a5931818 | ||
|
|
41380d9cb9 |
@@ -254,10 +254,6 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
|
||||
|
||||
# OpenAI Image Tools Customization
|
||||
#----------------
|
||||
# IMAGE_GEN_OAI_API_KEY= # Create or reuse OpenAI API key for image generation tool
|
||||
# IMAGE_GEN_OAI_BASEURL= # Custom OpenAI base URL for image generation tool
|
||||
# IMAGE_GEN_OAI_AZURE_API_VERSION= # Custom Azure OpenAI deployments
|
||||
# IMAGE_GEN_OAI_DESCRIPTION=
|
||||
# IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present
|
||||
# IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present
|
||||
# IMAGE_EDIT_OAI_DESCRIPTION=Custom description for image editing tool
|
||||
|
||||
@@ -9,7 +9,6 @@ on:
|
||||
paths:
|
||||
- 'packages/api/src/cache/**'
|
||||
- 'packages/api/src/cluster/**'
|
||||
- 'packages/api/src/mcp/**'
|
||||
- 'redis-config/**'
|
||||
- '.github/workflows/cache-integration-tests.yml'
|
||||
|
||||
@@ -78,14 +77,6 @@ jobs:
|
||||
REDIS_URI: redis://127.0.0.1:6379
|
||||
run: npm run test:cache-integration:cluster
|
||||
|
||||
- name: Run mcp integration tests
|
||||
working-directory: packages/api
|
||||
env:
|
||||
NODE_ENV: test
|
||||
USE_REDIS: true
|
||||
REDIS_URI: redis://127.0.0.1:6379
|
||||
run: npm run test:cache-integration:mcp
|
||||
|
||||
- name: Stop Redis Cluster
|
||||
if: always()
|
||||
working-directory: redis-config
|
||||
|
||||
@@ -28,7 +28,6 @@ const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
@@ -199,7 +198,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
// If auth was updated successfully, disconnect MCP sessions as they might use these credentials
|
||||
if (pluginKey.startsWith(Constants.mcp_prefix)) {
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const mcpManager = getMCPManager(user.id);
|
||||
if (mcpManager) {
|
||||
// Extract server name from pluginKey (format: "mcp_<serverName>")
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
@@ -296,11 +295,10 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||
}
|
||||
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
const serverConfig =
|
||||
(await mcpServersRegistry.getServerConfig(serverName, userId)) ??
|
||||
appConfig?.mcpServers?.[serverName];
|
||||
const oauthServers = await mcpServersRegistry.getOAuthServers();
|
||||
if (!oauthServers.has(serverName)) {
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!mcpManager.getOAuthServers().has(serverName)) {
|
||||
// this server does not use OAuth, so nothing to do here as well
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ const {
|
||||
getAppConfig,
|
||||
} = require('~/server/services/Config');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
@@ -66,7 +65,7 @@ const getMCPTools = async (req, res) => {
|
||||
|
||||
// Get server config once
|
||||
const serverConfig = appConfig.mcpConfig[serverName];
|
||||
const rawServerConfig = await mcpServersRegistry.getServerConfig(serverName, userId);
|
||||
const rawServerConfig = mcpManager.getRawConfig(serverName);
|
||||
|
||||
// Initialize server object with all server-level data
|
||||
const server = {
|
||||
|
||||
@@ -15,10 +15,6 @@ jest.mock('@librechat/api', () => ({
|
||||
storeTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
mcpServersRegistry: {
|
||||
getServerConfig: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
@@ -119,7 +115,7 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('GET /:serverName/oauth/initiate', () => {
|
||||
const { MCPOAuthHandler, mcpServersRegistry } = require('@librechat/api');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
it('should initiate OAuth flow successfully', async () => {
|
||||
@@ -132,9 +128,13 @@ describe('MCP Routes', () => {
|
||||
}),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({
|
||||
authorizationUrl: 'https://oauth.example.com/auth',
|
||||
@@ -288,7 +288,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
@@ -308,7 +307,6 @@ describe('MCP Routes', () => {
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
@@ -323,6 +321,7 @@ describe('MCP Routes', () => {
|
||||
};
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
@@ -380,7 +379,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle system-level OAuth completion', async () => {
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
@@ -400,10 +398,14 @@ describe('MCP Routes', () => {
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
@@ -415,7 +417,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle reconnection failure after OAuth', async () => {
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
@@ -435,12 +436,12 @@ describe('MCP Routes', () => {
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
@@ -460,7 +461,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should redirect to error page if token storage fails', async () => {
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
@@ -480,7 +480,6 @@ describe('MCP Routes', () => {
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
@@ -731,14 +730,12 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('POST /:serverName/reinitialize', () => {
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
it('should return 404 when server is not found in configuration', async () => {
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue(null),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue(null);
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
@@ -753,6 +750,9 @@ describe('MCP Routes', () => {
|
||||
|
||||
it('should handle OAuth requirement during reinitialize', async () => {
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {},
|
||||
}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
|
||||
@@ -763,9 +763,6 @@ describe('MCP Routes', () => {
|
||||
}),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({
|
||||
customUserVars: {},
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
@@ -791,12 +788,12 @@ describe('MCP Routes', () => {
|
||||
|
||||
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
@@ -812,12 +809,11 @@ describe('MCP Routes', () => {
|
||||
|
||||
it('should return 500 when unexpected error occurs', async () => {
|
||||
const mockMcpManager = {
|
||||
disconnectUserConnection: jest.fn(),
|
||||
getRawConfig: jest.fn().mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
}),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
@@ -850,11 +846,11 @@ describe('MCP Routes', () => {
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({ endpoint: 'http://test-server.com' });
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
@@ -895,16 +891,16 @@ describe('MCP Routes', () => {
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
endpoint: 'http://test-server.com',
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({
|
||||
endpoint: 'http://test-server.com',
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
@@ -1109,17 +1105,17 @@ describe('MCP Routes', () => {
|
||||
|
||||
describe('GET /:serverName/auth-values', () => {
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
it('should return auth value flags for server', async () => {
|
||||
const mockMcpManager = {};
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
SECRET_TOKEN: 'another-env-var',
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
SECRET_TOKEN: 'another-env-var',
|
||||
},
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
|
||||
|
||||
@@ -1139,9 +1135,10 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 404 when server is not found in configuration', async () => {
|
||||
const mockMcpManager = {};
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue(null),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue(null);
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
|
||||
@@ -1153,13 +1150,14 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle errors when checking auth values', async () => {
|
||||
const mockMcpManager = {};
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
@@ -1176,11 +1174,12 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when auth values check throws unexpected error', async () => {
|
||||
const mockMcpManager = {};
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
}),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
@@ -1190,11 +1189,12 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle customUserVars that is not an object', async () => {
|
||||
const mockMcpManager = {};
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: 'not-an-object',
|
||||
}),
|
||||
};
|
||||
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({
|
||||
customUserVars: 'not-an-object',
|
||||
});
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
@@ -1221,7 +1221,7 @@ describe('MCP Routes', () => {
|
||||
|
||||
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
||||
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
||||
const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api');
|
||||
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||
const mockTokens = {
|
||||
access_token: 'edge-access-token',
|
||||
refresh_token: 'edge-refresh-token',
|
||||
@@ -1239,7 +1239,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn(),
|
||||
@@ -1250,6 +1249,7 @@ describe('MCP Routes', () => {
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
@@ -1264,7 +1264,7 @@ describe('MCP Routes', () => {
|
||||
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api');
|
||||
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||
const mockTokens = {
|
||||
access_token: 'edge-access-token',
|
||||
refresh_token: 'edge-refresh-token',
|
||||
@@ -1290,7 +1290,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
mcpServersRegistry.getServerConfig.mockResolvedValue({});
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
@@ -1298,6 +1297,7 @@ describe('MCP Routes', () => {
|
||||
.fn()
|
||||
.mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]),
|
||||
}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ const { getAppConfig } = require('~/server/services/Config/app');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
const router = express.Router();
|
||||
const emailLoginEnabled =
|
||||
@@ -126,7 +125,7 @@ router.get('/', async function (req, res) {
|
||||
payload.minPasswordLength = minPasswordLength;
|
||||
}
|
||||
|
||||
const getMCPServers = async () => {
|
||||
const getMCPServers = () => {
|
||||
try {
|
||||
if (appConfig?.mcpConfig == null) {
|
||||
return;
|
||||
@@ -135,8 +134,9 @@ router.get('/', async function (req, res) {
|
||||
if (!mcpManager) {
|
||||
return;
|
||||
}
|
||||
const mcpServers = await mcpServersRegistry.getAllServerConfigs();
|
||||
const mcpServers = mcpManager.getAllServers();
|
||||
if (!mcpServers) return;
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
for (const serverName in mcpServers) {
|
||||
if (!payload.mcpServers) {
|
||||
payload.mcpServers = {};
|
||||
@@ -145,7 +145,7 @@ router.get('/', async function (req, res) {
|
||||
payload.mcpServers[serverName] = removeNullishValues({
|
||||
startup: serverConfig?.startup,
|
||||
chatMenu: serverConfig?.chatMenu,
|
||||
isOAuth: serverConfig.requiresOAuth,
|
||||
isOAuth: oauthServers?.has(serverName),
|
||||
customUserVars: serverConfig?.customUserVars,
|
||||
});
|
||||
}
|
||||
@@ -154,7 +154,7 @@ router.get('/', async function (req, res) {
|
||||
}
|
||||
};
|
||||
|
||||
await getMCPServers();
|
||||
getMCPServers();
|
||||
const webSearchConfig = appConfig?.webSearch;
|
||||
if (
|
||||
webSearchConfig != null &&
|
||||
|
||||
@@ -6,7 +6,6 @@ const {
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
getUserMCPAuthMap,
|
||||
mcpServersRegistry,
|
||||
} = require('@librechat/api');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
@@ -62,12 +61,11 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => {
|
||||
return res.status(400).json({ error: 'Invalid flow state' });
|
||||
}
|
||||
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, userId);
|
||||
const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
serverName,
|
||||
serverUrl,
|
||||
userId,
|
||||
oauthHeaders,
|
||||
getOAuthHeaders(serverName),
|
||||
oauthConfig,
|
||||
);
|
||||
|
||||
@@ -135,8 +133,12 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
});
|
||||
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId);
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders);
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(
|
||||
flowId,
|
||||
code,
|
||||
flowManager,
|
||||
getOAuthHeaders(serverName),
|
||||
);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
/** Persist tokens immediately so reconnection uses fresh credentials */
|
||||
@@ -354,7 +356,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
||||
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id);
|
||||
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
@@ -503,7 +505,8 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id);
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
@@ -542,8 +545,9 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
async function getOAuthHeaders(serverName, userId) {
|
||||
const serverConfig = await mcpServersRegistry.getServerConfig(serverName, userId);
|
||||
function getOAuthHeaders(serverName) {
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||
return serverConfig?.oauth_headers ?? {};
|
||||
}
|
||||
|
||||
|
||||
@@ -227,6 +227,7 @@ class STTService {
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
...(apiKey && { 'api-key': apiKey }),
|
||||
};
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { mcpServersRegistry } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
@@ -451,7 +450,7 @@ async function getMCPSetupData(userId) {
|
||||
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
|
||||
}
|
||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||
const oauthServers = await mcpServersRegistry.getOAuthServers();
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
|
||||
return {
|
||||
mcpConfig,
|
||||
|
||||
@@ -50,9 +50,6 @@ jest.mock('@librechat/api', () => ({
|
||||
sendEvent: jest.fn(),
|
||||
normalizeServerName: jest.fn((name) => name),
|
||||
convertWithResolvedRefs: jest.fn((params) => params),
|
||||
mcpServersRegistry: {
|
||||
getOAuthServers: jest.fn(() => Promise.resolve(new Set())),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
@@ -103,7 +100,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
let mockGetOAuthReconnectionManager;
|
||||
let mockMcpServersRegistry;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -112,7 +108,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||
mockMcpServersRegistry = require('@librechat/api').mcpServersRegistry;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
@@ -130,8 +125,8 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
appConnections: { getAll: jest.fn(() => new Map()) },
|
||||
getUserConnections: jest.fn(() => new Map()),
|
||||
getOAuthServers: jest.fn(() => new Set()),
|
||||
});
|
||||
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set());
|
||||
});
|
||||
|
||||
it('should successfully return MCP setup data', async () => {
|
||||
@@ -144,9 +139,9 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => mockAppConnections) },
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
getOAuthServers: jest.fn(() => mockOAuthServers),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(mockOAuthServers);
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
@@ -154,7 +149,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMcpServersRegistry.getOAuthServers).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
|
||||
|
||||
expect(result).toEqual({
|
||||
mcpConfig: mockConfig.mcpServers,
|
||||
@@ -175,9 +170,9 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const mockMCPManager = {
|
||||
appConnections: { getAll: jest.fn(() => null) },
|
||||
getUserConnections: jest.fn(() => null),
|
||||
getOAuthServers: jest.fn(() => new Set()),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set());
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ async function initializeMCPs() {
|
||||
const mcpManager = await createMCPManager(mcpServers);
|
||||
|
||||
try {
|
||||
const mcpTools = (await mcpManager.getAppToolFunctions()) || {};
|
||||
const mcpTools = mcpManager.getAppToolFunctions() || {};
|
||||
await mergeAppTools(mcpTools);
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -304,7 +304,6 @@ describe('Apple Login Strategy', () => {
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
'jane.doe@example.com',
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -5,25 +5,22 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { updateUser, createUser, getUserById } = require('~/models');
|
||||
|
||||
/**
|
||||
* Updates the avatar URL and email of an existing user. If the user's avatar URL does not include the query parameter
|
||||
* Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter
|
||||
* '?manual=true', it updates the user's avatar with the provided URL. For local file storage, it directly updates
|
||||
* the avatar URL, while for other storage types, it processes the avatar URL using the specified file strategy.
|
||||
* Also updates the email if it has changed (e.g., when a Google Workspace email is updated).
|
||||
*
|
||||
* @param {IUser} oldUser - The existing user object that needs to be updated.
|
||||
* @param {string} avatarUrl - The new avatar URL to be set for the user.
|
||||
* @param {AppConfig} appConfig - The application configuration object.
|
||||
* @param {string} [email] - Optional. The new email address to update if it has changed.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
* The function updates the user's avatar and/or email and saves the user object. It does not return any value.
|
||||
* The function updates the user's avatar and saves the user object. It does not return any value.
|
||||
*
|
||||
* @throws {Error} Throws an error if there's an issue saving the updated user object.
|
||||
*/
|
||||
const handleExistingUser = async (oldUser, avatarUrl, appConfig, email) => {
|
||||
const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
|
||||
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
|
||||
const isLocal = fileStrategy === FileSources.local;
|
||||
const updates = {};
|
||||
|
||||
let updatedAvatar = false;
|
||||
const hasManualFlag =
|
||||
@@ -42,16 +39,7 @@ const handleExistingUser = async (oldUser, avatarUrl, appConfig, email) => {
|
||||
}
|
||||
|
||||
if (updatedAvatar) {
|
||||
updates.avatar = updatedAvatar;
|
||||
}
|
||||
|
||||
/** Update email if it has changed */
|
||||
if (email && email.trim() !== oldUser.email) {
|
||||
updates.email = email.trim();
|
||||
}
|
||||
|
||||
if (Object.keys(updates).length > 0) {
|
||||
await updateUser(oldUser._id, updates);
|
||||
await updateUser(oldUser._id, { avatar: updatedAvatar });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -167,76 +167,4 @@ describe('handleExistingUser', () => {
|
||||
// This should throw an error when trying to access oldUser._id
|
||||
await expect(handleExistingUser(null, avatarUrl)).rejects.toThrow();
|
||||
});
|
||||
|
||||
it('should update email when it has changed', async () => {
|
||||
const oldUser = {
|
||||
_id: 'user123',
|
||||
email: 'old@example.com',
|
||||
avatar: 'https://example.com/avatar.png?manual=true',
|
||||
};
|
||||
const avatarUrl = 'https://example.com/avatar.png';
|
||||
const newEmail = 'new@example.com';
|
||||
|
||||
await handleExistingUser(oldUser, avatarUrl, {}, newEmail);
|
||||
|
||||
expect(updateUser).toHaveBeenCalledWith('user123', { email: 'new@example.com' });
|
||||
});
|
||||
|
||||
it('should update both avatar and email when both have changed', async () => {
|
||||
const oldUser = {
|
||||
_id: 'user123',
|
||||
email: 'old@example.com',
|
||||
avatar: null,
|
||||
};
|
||||
const avatarUrl = 'https://example.com/new-avatar.png';
|
||||
const newEmail = 'new@example.com';
|
||||
|
||||
await handleExistingUser(oldUser, avatarUrl, {}, newEmail);
|
||||
|
||||
expect(updateUser).toHaveBeenCalledWith('user123', {
|
||||
avatar: avatarUrl,
|
||||
email: 'new@example.com',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not update email when it has not changed', async () => {
|
||||
const oldUser = {
|
||||
_id: 'user123',
|
||||
email: 'same@example.com',
|
||||
avatar: 'https://example.com/avatar.png?manual=true',
|
||||
};
|
||||
const avatarUrl = 'https://example.com/avatar.png';
|
||||
const sameEmail = 'same@example.com';
|
||||
|
||||
await handleExistingUser(oldUser, avatarUrl, {}, sameEmail);
|
||||
|
||||
expect(updateUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should trim email before comparison and update', async () => {
|
||||
const oldUser = {
|
||||
_id: 'user123',
|
||||
email: 'test@example.com',
|
||||
avatar: 'https://example.com/avatar.png?manual=true',
|
||||
};
|
||||
const avatarUrl = 'https://example.com/avatar.png';
|
||||
const newEmailWithSpaces = ' newemail@example.com ';
|
||||
|
||||
await handleExistingUser(oldUser, avatarUrl, {}, newEmailWithSpaces);
|
||||
|
||||
expect(updateUser).toHaveBeenCalledWith('user123', { email: 'newemail@example.com' });
|
||||
});
|
||||
|
||||
it('should not update when email parameter is not provided', async () => {
|
||||
const oldUser = {
|
||||
_id: 'user123',
|
||||
email: 'test@example.com',
|
||||
avatar: 'https://example.com/avatar.png?manual=true',
|
||||
};
|
||||
const avatarUrl = 'https://example.com/avatar.png';
|
||||
|
||||
await handleExistingUser(oldUser, avatarUrl, {});
|
||||
|
||||
expect(updateUser).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -25,24 +25,10 @@ const socialLogin =
|
||||
return cb(error);
|
||||
}
|
||||
|
||||
const providerKey = `${provider}Id`;
|
||||
let existingUser = null;
|
||||
|
||||
/** First try to find user by provider ID (e.g., googleId, facebookId) */
|
||||
if (id && typeof id === 'string') {
|
||||
existingUser = await findUser({ [providerKey]: id });
|
||||
}
|
||||
|
||||
/** If not found by provider ID, try finding by email */
|
||||
if (!existingUser) {
|
||||
existingUser = await findUser({ email: email?.trim() });
|
||||
if (existingUser) {
|
||||
logger.warn(`[${provider}Login] User found by email: ${email} but not by ${providerKey}`);
|
||||
}
|
||||
}
|
||||
const existingUser = await findUser({ email: email.trim() });
|
||||
|
||||
if (existingUser?.provider === provider) {
|
||||
await handleExistingUser(existingUser, avatarUrl, appConfig, email);
|
||||
await handleExistingUser(existingUser, avatarUrl, appConfig);
|
||||
return cb(null, existingUser);
|
||||
} else if (existingUser) {
|
||||
logger.info(
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const socialLogin = require('./socialLogin');
|
||||
const { findUser } = require('~/models');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
const actualModule = jest.requireActual('@librechat/data-schemas');
|
||||
return {
|
||||
...actualModule,
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('./process', () => ({
|
||||
createSocialUser: jest.fn(),
|
||||
handleExistingUser: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
isEnabled: jest.fn().mockReturnValue(true),
|
||||
isEmailDomainAllowed: jest.fn().mockReturnValue(true),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('socialLogin', () => {
|
||||
const mockGetProfileDetails = ({ profile }) => ({
|
||||
email: profile.emails[0].value,
|
||||
id: profile.id,
|
||||
avatarUrl: profile.photos?.[0]?.value || null,
|
||||
username: profile.name?.givenName || 'user',
|
||||
name: `${profile.name?.givenName || ''} ${profile.name?.familyName || ''}`.trim(),
|
||||
emailVerified: profile.emails[0].verified || false,
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Finding users by provider ID', () => {
|
||||
it('should find user by provider ID (googleId) when email has changed', async () => {
|
||||
const provider = 'google';
|
||||
const googleId = 'google-user-123';
|
||||
const oldEmail = 'old@example.com';
|
||||
const newEmail = 'new@example.com';
|
||||
|
||||
const existingUser = {
|
||||
_id: 'user123',
|
||||
email: oldEmail,
|
||||
provider: 'google',
|
||||
googleId: googleId,
|
||||
};
|
||||
|
||||
/** Mock findUser to return user on first call (by googleId), null on second call */
|
||||
findUser
|
||||
.mockResolvedValueOnce(existingUser) // First call: finds by googleId
|
||||
.mockResolvedValueOnce(null); // Second call would be by email, but won't be reached
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: newEmail, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'John', familyName: 'Doe' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify it searched by googleId first */
|
||||
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
|
||||
|
||||
/** Verify it did NOT search by email (because it found user by googleId) */
|
||||
expect(findUser).toHaveBeenCalledTimes(1);
|
||||
|
||||
/** Verify handleExistingUser was called with the new email */
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(
|
||||
existingUser,
|
||||
'https://example.com/avatar.png',
|
||||
expect.any(Object),
|
||||
newEmail,
|
||||
);
|
||||
|
||||
/** Verify callback was called with success */
|
||||
expect(callback).toHaveBeenCalledWith(null, existingUser);
|
||||
});
|
||||
|
||||
it('should find user by provider ID (facebookId) when using Facebook', async () => {
|
||||
const provider = 'facebook';
|
||||
const facebookId = 'fb-user-456';
|
||||
const email = 'user@example.com';
|
||||
|
||||
const existingUser = {
|
||||
_id: 'user456',
|
||||
email: email,
|
||||
provider: 'facebook',
|
||||
facebookId: facebookId,
|
||||
};
|
||||
|
||||
findUser.mockResolvedValue(existingUser); // Always returns user
|
||||
|
||||
const mockProfile = {
|
||||
id: facebookId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/fb-avatar.png' }],
|
||||
name: { givenName: 'Jane', familyName: 'Smith' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify it searched by facebookId first */
|
||||
expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId });
|
||||
expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]);
|
||||
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(
|
||||
existingUser,
|
||||
'https://example.com/fb-avatar.png',
|
||||
expect.any(Object),
|
||||
email,
|
||||
);
|
||||
|
||||
expect(callback).toHaveBeenCalledWith(null, existingUser);
|
||||
});
|
||||
|
||||
it('should fallback to finding user by email if not found by provider ID', async () => {
|
||||
const provider = 'google';
|
||||
const googleId = 'google-user-789';
|
||||
const email = 'user@example.com';
|
||||
|
||||
const existingUser = {
|
||||
_id: 'user789',
|
||||
email: email,
|
||||
provider: 'google',
|
||||
googleId: 'old-google-id', // Different googleId (edge case)
|
||||
};
|
||||
|
||||
/** First call (by googleId) returns null, second call (by email) returns user */
|
||||
findUser
|
||||
.mockResolvedValueOnce(null) // By googleId
|
||||
.mockResolvedValueOnce(existingUser); // By email
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'Bob', familyName: 'Johnson' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify both searches happened */
|
||||
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
|
||||
expect(findUser).toHaveBeenNthCalledWith(2, { email: email });
|
||||
expect(findUser).toHaveBeenCalledTimes(2);
|
||||
|
||||
/** Verify warning log */
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
`[${provider}Login] User found by email: ${email} but not by ${provider}Id`,
|
||||
);
|
||||
|
||||
expect(handleExistingUser).toHaveBeenCalled();
|
||||
expect(callback).toHaveBeenCalledWith(null, existingUser);
|
||||
});
|
||||
|
||||
it('should create new user if not found by provider ID or email', async () => {
|
||||
const provider = 'google';
|
||||
const googleId = 'google-new-user';
|
||||
const email = 'newuser@example.com';
|
||||
|
||||
const newUser = {
|
||||
_id: 'newuser123',
|
||||
email: email,
|
||||
provider: 'google',
|
||||
googleId: googleId,
|
||||
};
|
||||
|
||||
/** Both searches return null */
|
||||
findUser.mockResolvedValue(null);
|
||||
createSocialUser.mockResolvedValue(newUser);
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'New', familyName: 'User' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify both searches happened */
|
||||
expect(findUser).toHaveBeenCalledTimes(2);
|
||||
|
||||
/** Verify createSocialUser was called */
|
||||
expect(createSocialUser).toHaveBeenCalledWith({
|
||||
email: email,
|
||||
avatarUrl: 'https://example.com/avatar.png',
|
||||
provider: provider,
|
||||
providerKey: 'googleId',
|
||||
providerId: googleId,
|
||||
username: 'New',
|
||||
name: 'New User',
|
||||
emailVerified: true,
|
||||
appConfig: expect.any(Object),
|
||||
});
|
||||
|
||||
expect(callback).toHaveBeenCalledWith(null, newUser);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should return error if user exists with different provider', async () => {
|
||||
const provider = 'google';
|
||||
const googleId = 'google-user-123';
|
||||
const email = 'user@example.com';
|
||||
|
||||
const existingUser = {
|
||||
_id: 'user123',
|
||||
email: email,
|
||||
provider: 'local', // Different provider
|
||||
};
|
||||
|
||||
findUser
|
||||
.mockResolvedValueOnce(null) // By googleId
|
||||
.mockResolvedValueOnce(existingUser); // By email
|
||||
|
||||
const mockProfile = {
|
||||
id: googleId,
|
||||
emails: [{ value: email, verified: true }],
|
||||
photos: [{ value: 'https://example.com/avatar.png' }],
|
||||
name: { givenName: 'John', familyName: 'Doe' },
|
||||
};
|
||||
|
||||
const loginFn = socialLogin(provider, mockGetProfileDetails);
|
||||
const callback = jest.fn();
|
||||
|
||||
await loginFn(null, null, null, mockProfile, callback);
|
||||
|
||||
/** Verify error callback */
|
||||
expect(callback).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
code: ErrorTypes.AUTH_FAILED,
|
||||
provider: 'local',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
`[${provider}Login] User ${email} already exists with provider local`,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -117,10 +117,8 @@ const AttachFileMenu = ({
|
||||
const items: MenuItemProps[] = [];
|
||||
|
||||
const currentProvider = provider || endpoint;
|
||||
if (
|
||||
isDocumentSupportedProvider(endpointType) ||
|
||||
isDocumentSupportedProvider(currentProvider)
|
||||
) {
|
||||
|
||||
if (isDocumentSupportedProvider(currentProvider || endpointType)) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_provider'),
|
||||
onClick: () => {
|
||||
|
||||
@@ -57,7 +57,7 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
||||
const currentProvider = provider || endpoint;
|
||||
|
||||
// Check if provider supports document upload
|
||||
if (isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider)) {
|
||||
if (isDocumentSupportedProvider(currentProvider || endpointType)) {
|
||||
const isGoogleProvider = currentProvider === EModelEndpoint.google;
|
||||
const validFileTypes = isGoogleProvider
|
||||
? files.every(
|
||||
|
||||
@@ -133,7 +133,7 @@ export default function FileRow({
|
||||
>
|
||||
{isImage ? (
|
||||
<Image
|
||||
url={file.progress === 1 ? file.filepath : (file.preview ?? file.filepath)}
|
||||
url={file.preview ?? file.filepath}
|
||||
onDelete={handleDelete}
|
||||
progress={file.progress}
|
||||
source={file.source}
|
||||
|
||||
@@ -1,602 +0,0 @@
|
||||
import React from 'react';
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom';
|
||||
import { RecoilRoot } from 'recoil';
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import AttachFileMenu from '../AttachFileMenu';
|
||||
|
||||
// Mock all the hooks
|
||||
jest.mock('~/hooks', () => ({
|
||||
useAgentToolPermissions: jest.fn(),
|
||||
useAgentCapabilities: jest.fn(),
|
||||
useGetAgentsConfig: jest.fn(),
|
||||
useFileHandling: jest.fn(),
|
||||
useLocalize: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/data-provider', () => ({
|
||||
useGetStartupConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/components/SharePoint', () => ({
|
||||
SharePointPickerDialog: jest.fn(() => null),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/client', () => {
|
||||
const React = jest.requireActual('react');
|
||||
return {
|
||||
FileUpload: React.forwardRef(({ children, handleFileChange }: any, ref: any) => (
|
||||
<div data-testid="file-upload">
|
||||
<input ref={ref} type="file" onChange={handleFileChange} data-testid="file-input" />
|
||||
{children}
|
||||
</div>
|
||||
)),
|
||||
TooltipAnchor: ({ render }: any) => render,
|
||||
DropdownPopup: ({ trigger, items, isOpen, setIsOpen }: any) => {
|
||||
const handleTriggerClick = () => {
|
||||
if (setIsOpen) {
|
||||
setIsOpen(!isOpen);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div onClick={handleTriggerClick}>{trigger}</div>
|
||||
{isOpen && (
|
||||
<div data-testid="dropdown-menu">
|
||||
{items.map((item: any, idx: number) => (
|
||||
<button key={idx} onClick={item.onClick} data-testid={`menu-item-${idx}`}>
|
||||
{item.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
AttachmentIcon: () => <span data-testid="attachment-icon">📎</span>,
|
||||
SharePointIcon: () => <span data-testid="sharepoint-icon">SP</span>,
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@ariakit/react', () => ({
|
||||
MenuButton: ({ children, onClick, disabled, ...props }: any) => (
|
||||
<button onClick={onClick} disabled={disabled} {...props}>
|
||||
{children}
|
||||
</button>
|
||||
),
|
||||
}));
|
||||
|
||||
const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions;
|
||||
const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities;
|
||||
const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig;
|
||||
const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling;
|
||||
const mockUseLocalize = jest.requireMock('~/hooks').useLocalize;
|
||||
const mockUseSharePointFileHandling = jest.requireMock(
|
||||
'~/hooks/Files/useSharePointFileHandling',
|
||||
).default;
|
||||
const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig;
|
||||
|
||||
describe('AttachFileMenu', () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
});
|
||||
|
||||
const mockHandleFileChange = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Default mock implementations
|
||||
mockUseLocalize.mockReturnValue((key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
com_ui_upload_provider: 'Upload to Provider',
|
||||
com_ui_upload_image_input: 'Upload Image',
|
||||
com_ui_upload_ocr_text: 'Upload OCR Text',
|
||||
com_ui_upload_file_search: 'Upload for File Search',
|
||||
com_ui_upload_code_files: 'Upload Code Files',
|
||||
com_sidepanel_attach_files: 'Attach Files',
|
||||
com_files_upload_sharepoint: 'Upload from SharePoint',
|
||||
};
|
||||
return translations[key] || key;
|
||||
});
|
||||
|
||||
mockUseAgentCapabilities.mockReturnValue({
|
||||
contextEnabled: false,
|
||||
fileSearchEnabled: false,
|
||||
codeEnabled: false,
|
||||
});
|
||||
|
||||
mockUseGetAgentsConfig.mockReturnValue({
|
||||
agentsConfig: {
|
||||
capabilities: {
|
||||
contextEnabled: false,
|
||||
fileSearchEnabled: false,
|
||||
codeEnabled: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
mockUseFileHandling.mockReturnValue({
|
||||
handleFileChange: mockHandleFileChange,
|
||||
});
|
||||
|
||||
mockUseSharePointFileHandling.mockReturnValue({
|
||||
handleSharePointFiles: jest.fn(),
|
||||
isProcessing: false,
|
||||
downloadProgress: 0,
|
||||
});
|
||||
|
||||
mockUseGetStartupConfig.mockReturnValue({
|
||||
data: {
|
||||
sharePointFilePickerEnabled: false,
|
||||
},
|
||||
});
|
||||
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
const renderAttachFileMenu = (props: any = {}) => {
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<RecoilRoot>
|
||||
<AttachFileMenu conversationId="test-conversation" {...props} />
|
||||
</RecoilRoot>
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
};
|
||||
|
||||
describe('Basic Rendering', () => {
|
||||
it('should render the attachment button', () => {
|
||||
renderAttachFileMenu();
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should be disabled when disabled prop is true', () => {
|
||||
renderAttachFileMenu({ disabled: true });
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).toBeDisabled();
|
||||
});
|
||||
|
||||
it('should not be disabled when disabled prop is false', () => {
|
||||
renderAttachFileMenu({ disabled: false });
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Provider Detection Fix - endpointType Priority', () => {
|
||||
it('should prioritize endpointType over currentProvider for LiteLLM gateway', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: 'litellm', // Custom gateway name NOT in documentSupportedProviders
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: 'litellm',
|
||||
endpointType: EModelEndpoint.openAI, // Backend override IS in documentSupportedProviders
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
// With the fix, should show "Upload to Provider" because endpointType is checked first
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
expect(screen.queryByText('Upload Image')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should show Upload to Provider for custom endpoints with OpenAI endpointType', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: 'my-custom-gateway',
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: 'my-custom-gateway',
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should show Upload Image when neither endpointType nor provider support documents', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: 'unsupported-provider',
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: 'unsupported-provider',
|
||||
endpointType: 'unsupported-endpoint' as any,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload Image')).toBeInTheDocument();
|
||||
expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should fallback to currentProvider when endpointType is undefined', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
endpointType: undefined,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should fallback to currentProvider when endpointType is null', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.anthropic,
|
||||
endpointType: null,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Supported Providers', () => {
|
||||
const supportedProviders = [
|
||||
{ name: 'OpenAI', endpoint: EModelEndpoint.openAI },
|
||||
{ name: 'Anthropic', endpoint: EModelEndpoint.anthropic },
|
||||
{ name: 'Google', endpoint: EModelEndpoint.google },
|
||||
{ name: 'Azure OpenAI', endpoint: EModelEndpoint.azureOpenAI },
|
||||
{ name: 'Custom', endpoint: EModelEndpoint.custom },
|
||||
];
|
||||
|
||||
supportedProviders.forEach(({ name, endpoint }) => {
|
||||
it(`should show Upload to Provider for ${name}`, () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: endpoint,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint,
|
||||
endpointType: endpoint,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Capabilities', () => {
|
||||
it('should show OCR Text option when context is enabled', () => {
|
||||
mockUseAgentCapabilities.mockReturnValue({
|
||||
contextEnabled: true,
|
||||
fileSearchEnabled: false,
|
||||
codeEnabled: false,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload OCR Text')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should show File Search option when enabled and allowed by agent', () => {
|
||||
mockUseAgentCapabilities.mockReturnValue({
|
||||
contextEnabled: false,
|
||||
fileSearchEnabled: true,
|
||||
codeEnabled: false,
|
||||
});
|
||||
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: true,
|
||||
codeAllowedByAgent: false,
|
||||
provider: undefined,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload for File Search')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should NOT show File Search when enabled but not allowed by agent', () => {
|
||||
mockUseAgentCapabilities.mockReturnValue({
|
||||
contextEnabled: false,
|
||||
fileSearchEnabled: true,
|
||||
codeEnabled: false,
|
||||
});
|
||||
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: undefined,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.queryByText('Upload for File Search')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should show Code Files option when enabled and allowed by agent', () => {
|
||||
mockUseAgentCapabilities.mockReturnValue({
|
||||
contextEnabled: false,
|
||||
fileSearchEnabled: false,
|
||||
codeEnabled: true,
|
||||
});
|
||||
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: true,
|
||||
provider: undefined,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload Code Files')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should show all options when all capabilities are enabled', () => {
|
||||
mockUseAgentCapabilities.mockReturnValue({
|
||||
contextEnabled: true,
|
||||
fileSearchEnabled: true,
|
||||
codeEnabled: true,
|
||||
});
|
||||
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: true,
|
||||
codeAllowedByAgent: true,
|
||||
provider: undefined,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
expect(screen.getByText('Upload OCR Text')).toBeInTheDocument();
|
||||
expect(screen.getByText('Upload for File Search')).toBeInTheDocument();
|
||||
expect(screen.getByText('Upload Code Files')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('SharePoint Integration', () => {
|
||||
it('should show SharePoint option when enabled', () => {
|
||||
mockUseGetStartupConfig.mockReturnValue({
|
||||
data: {
|
||||
sharePointFilePickerEnabled: true,
|
||||
},
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload from SharePoint')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should NOT show SharePoint option when disabled', () => {
|
||||
mockUseGetStartupConfig.mockReturnValue({
|
||||
data: {
|
||||
sharePointFilePickerEnabled: false,
|
||||
},
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.queryByText('Upload from SharePoint')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined endpoint and provider gracefully', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: undefined,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: undefined,
|
||||
endpointType: undefined,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).toBeInTheDocument();
|
||||
fireEvent.click(button);
|
||||
|
||||
// Should show Upload Image as fallback
|
||||
expect(screen.getByText('Upload Image')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should handle null endpoint and provider gracefully', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: null,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: null,
|
||||
endpointType: null,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should handle missing agentId gracefully', () => {
|
||||
renderAttachFileMenu({
|
||||
agentId: undefined,
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should handle empty string agentId', () => {
|
||||
renderAttachFileMenu({
|
||||
agentId: '',
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Google Provider Special Case', () => {
|
||||
it('should use google_multimodal file type for Google provider', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.google,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.google,
|
||||
endpointType: EModelEndpoint.google,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
const uploadProviderButton = screen.getByText('Upload to Provider');
|
||||
expect(uploadProviderButton).toBeInTheDocument();
|
||||
|
||||
// Click the upload to provider option
|
||||
fireEvent.click(uploadProviderButton);
|
||||
|
||||
// The file input should have been clicked (indirectly tested through the implementation)
|
||||
});
|
||||
|
||||
it('should use multimodal file type for non-Google providers', () => {
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
const uploadProviderButton = screen.getByText('Upload to Provider');
|
||||
expect(uploadProviderButton).toBeInTheDocument();
|
||||
fireEvent.click(uploadProviderButton);
|
||||
|
||||
// Implementation detail - multimodal type is used
|
||||
});
|
||||
});
|
||||
|
||||
describe('Regression Tests', () => {
|
||||
it('should not break the previous behavior for direct provider attachments', () => {
|
||||
// When using a direct supported provider (not through a gateway)
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.anthropic,
|
||||
endpointType: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should maintain correct priority when both are supported', () => {
|
||||
// Both endpointType and provider are supported, endpointType should be checked first
|
||||
mockUseAgentToolPermissions.mockReturnValue({
|
||||
fileSearchAllowedByAgent: false,
|
||||
codeAllowedByAgent: false,
|
||||
provider: EModelEndpoint.google,
|
||||
});
|
||||
|
||||
renderAttachFileMenu({
|
||||
endpoint: EModelEndpoint.google,
|
||||
endpointType: EModelEndpoint.openAI, // Different but both supported
|
||||
});
|
||||
|
||||
const button = screen.getByRole('button', { name: /attach file options/i });
|
||||
fireEvent.click(button);
|
||||
|
||||
// Should still work because endpointType (openAI) is supported
|
||||
expect(screen.getByText('Upload to Provider')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,121 +0,0 @@
|
||||
import { EModelEndpoint, isDocumentSupportedProvider } from 'librechat-data-provider';
|
||||
|
||||
describe('DragDropModal - Provider Detection', () => {
|
||||
describe('endpointType priority over currentProvider', () => {
|
||||
it('should show upload option for LiteLLM with OpenAI endpointType', () => {
|
||||
const currentProvider = 'litellm'; // NOT in documentSupportedProviders
|
||||
const endpointType = EModelEndpoint.openAI; // IS in documentSupportedProviders
|
||||
|
||||
// With fix: endpointType checked
|
||||
const withFix =
|
||||
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
|
||||
expect(withFix).toBe(true);
|
||||
|
||||
// Without fix: only currentProvider checked = false
|
||||
const withoutFix = isDocumentSupportedProvider(currentProvider || endpointType);
|
||||
expect(withoutFix).toBe(false);
|
||||
});
|
||||
|
||||
it('should show upload option for any custom gateway with OpenAI endpointType', () => {
|
||||
const currentProvider = 'my-custom-gateway';
|
||||
const endpointType = EModelEndpoint.openAI;
|
||||
|
||||
const result =
|
||||
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should fallback to currentProvider when endpointType is undefined', () => {
|
||||
const currentProvider = EModelEndpoint.openAI;
|
||||
const endpointType = undefined;
|
||||
|
||||
const result =
|
||||
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should fallback to currentProvider when endpointType is null', () => {
|
||||
const currentProvider = EModelEndpoint.anthropic;
|
||||
const endpointType = null;
|
||||
|
||||
const result =
|
||||
isDocumentSupportedProvider(endpointType as any) ||
|
||||
isDocumentSupportedProvider(currentProvider);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false when neither provider supports documents', () => {
|
||||
const currentProvider = 'unsupported-provider';
|
||||
const endpointType = 'unsupported-endpoint' as any;
|
||||
|
||||
const result =
|
||||
isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('supported providers', () => {
|
||||
const supportedProviders = [
|
||||
{ name: 'OpenAI', value: EModelEndpoint.openAI },
|
||||
{ name: 'Anthropic', value: EModelEndpoint.anthropic },
|
||||
{ name: 'Google', value: EModelEndpoint.google },
|
||||
{ name: 'Azure OpenAI', value: EModelEndpoint.azureOpenAI },
|
||||
{ name: 'Custom', value: EModelEndpoint.custom },
|
||||
];
|
||||
|
||||
supportedProviders.forEach(({ name, value }) => {
|
||||
it(`should recognize ${name} as supported`, () => {
|
||||
expect(isDocumentSupportedProvider(value)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('real-world scenarios', () => {
|
||||
it('should handle LiteLLM gateway pointing to OpenAI', () => {
|
||||
const scenario = {
|
||||
currentProvider: 'litellm',
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
};
|
||||
|
||||
expect(
|
||||
isDocumentSupportedProvider(scenario.endpointType) ||
|
||||
isDocumentSupportedProvider(scenario.currentProvider),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle direct OpenAI connection', () => {
|
||||
const scenario = {
|
||||
currentProvider: EModelEndpoint.openAI,
|
||||
endpointType: EModelEndpoint.openAI,
|
||||
};
|
||||
|
||||
expect(
|
||||
isDocumentSupportedProvider(scenario.endpointType) ||
|
||||
isDocumentSupportedProvider(scenario.currentProvider),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle unsupported custom endpoint without override', () => {
|
||||
const scenario = {
|
||||
currentProvider: 'my-unsupported-endpoint',
|
||||
endpointType: undefined,
|
||||
};
|
||||
|
||||
expect(
|
||||
isDocumentSupportedProvider(scenario.endpointType) ||
|
||||
isDocumentSupportedProvider(scenario.currentProvider),
|
||||
).toBe(false);
|
||||
});
|
||||
it('should handle agents endpoints with document supported providers', () => {
|
||||
const scenario = {
|
||||
currentProvider: EModelEndpoint.google,
|
||||
endpointType: EModelEndpoint.agents,
|
||||
};
|
||||
|
||||
expect(
|
||||
isDocumentSupportedProvider(scenario.endpointType) ||
|
||||
isDocumentSupportedProvider(scenario.currentProvider),
|
||||
).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,347 +0,0 @@
|
||||
import React from 'react';
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import type { ExtendedFile } from '~/common';
|
||||
import FileRow from '../FileRow';
|
||||
|
||||
jest.mock('~/hooks', () => ({
|
||||
useLocalize: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/data-provider', () => ({
|
||||
useDeleteFilesMutation: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/hooks/Files', () => ({
|
||||
useFileDeletion: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/utils', () => ({
|
||||
logger: {
|
||||
log: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('../Image', () => {
|
||||
return function MockImage({ url, progress, source }: any) {
|
||||
return (
|
||||
<div data-testid="mock-image">
|
||||
<span data-testid="image-url">{url}</span>
|
||||
<span data-testid="image-progress">{progress}</span>
|
||||
<span data-testid="image-source">{source}</span>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('../FileContainer', () => {
|
||||
return function MockFileContainer({ file }: any) {
|
||||
return (
|
||||
<div data-testid="mock-file-container">
|
||||
<span data-testid="file-name">{file.filename}</span>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
const mockUseLocalize = jest.requireMock('~/hooks').useLocalize;
|
||||
const mockUseDeleteFilesMutation = jest.requireMock('~/data-provider').useDeleteFilesMutation;
|
||||
const mockUseFileDeletion = jest.requireMock('~/hooks/Files').useFileDeletion;
|
||||
|
||||
describe('FileRow', () => {
|
||||
const mockSetFiles = jest.fn();
|
||||
const mockSetFilesLoading = jest.fn();
|
||||
const mockDeleteFile = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockUseLocalize.mockReturnValue((key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
com_ui_deleting_file: 'Deleting file...',
|
||||
};
|
||||
return translations[key] || key;
|
||||
});
|
||||
|
||||
mockUseDeleteFilesMutation.mockReturnValue({
|
||||
mutateAsync: jest.fn(),
|
||||
});
|
||||
|
||||
mockUseFileDeletion.mockReturnValue({
|
||||
deleteFile: mockDeleteFile,
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Creates a mock ExtendedFile with sensible defaults
|
||||
*/
|
||||
const createMockFile = (overrides: Partial<ExtendedFile> = {}): ExtendedFile => ({
|
||||
file_id: 'test-file-id',
|
||||
type: 'image/png',
|
||||
preview: 'blob:http://localhost:3080/preview-blob-url',
|
||||
filepath: '/images/user123/test-file-id__image.png',
|
||||
filename: 'test-image.png',
|
||||
progress: 1,
|
||||
size: 1024,
|
||||
source: FileSources.local,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const renderFileRow = (files: Map<string, ExtendedFile>) => {
|
||||
return render(
|
||||
<FileRow files={files} setFiles={mockSetFiles} setFilesLoading={mockSetFilesLoading} />,
|
||||
);
|
||||
};
|
||||
|
||||
describe('Image URL Selection Logic', () => {
|
||||
it('should use filepath instead of preview when progress is 1 (upload complete)', () => {
|
||||
const file = createMockFile({
|
||||
file_id: 'uploaded-file',
|
||||
preview: 'blob:http://localhost:3080/temp-preview',
|
||||
filepath: '/images/user123/uploaded-file__image.png',
|
||||
progress: 1,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const imageUrl = screen.getByTestId('image-url').textContent;
|
||||
expect(imageUrl).toBe('/images/user123/uploaded-file__image.png');
|
||||
expect(imageUrl).not.toContain('blob:');
|
||||
});
|
||||
|
||||
it('should use preview when progress is less than 1 (uploading)', () => {
|
||||
const file = createMockFile({
|
||||
file_id: 'uploading-file',
|
||||
preview: 'blob:http://localhost:3080/temp-preview',
|
||||
filepath: undefined,
|
||||
progress: 0.5,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const imageUrl = screen.getByTestId('image-url').textContent;
|
||||
expect(imageUrl).toBe('blob:http://localhost:3080/temp-preview');
|
||||
});
|
||||
|
||||
it('should fallback to filepath when preview is undefined and progress is less than 1', () => {
|
||||
const file = createMockFile({
|
||||
file_id: 'file-without-preview',
|
||||
preview: undefined,
|
||||
filepath: '/images/user123/file-without-preview__image.png',
|
||||
progress: 0.7,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const imageUrl = screen.getByTestId('image-url').textContent;
|
||||
expect(imageUrl).toBe('/images/user123/file-without-preview__image.png');
|
||||
});
|
||||
|
||||
it('should use filepath when both preview and filepath exist and progress is exactly 1', () => {
|
||||
const file = createMockFile({
|
||||
file_id: 'complete-file',
|
||||
preview: 'blob:http://localhost:3080/old-blob',
|
||||
filepath: '/images/user123/complete-file__image.png',
|
||||
progress: 1.0,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const imageUrl = screen.getByTestId('image-url').textContent;
|
||||
expect(imageUrl).toBe('/images/user123/complete-file__image.png');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Progress States', () => {
|
||||
it('should pass correct progress value during upload', () => {
|
||||
const file = createMockFile({
|
||||
progress: 0.65,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const progress = screen.getByTestId('image-progress').textContent;
|
||||
expect(progress).toBe('0.65');
|
||||
});
|
||||
|
||||
it('should pass progress value of 1 when upload is complete', () => {
|
||||
const file = createMockFile({
|
||||
progress: 1,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const progress = screen.getByTestId('image-progress').textContent;
|
||||
expect(progress).toBe('1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('File Source', () => {
|
||||
it('should pass local source to Image component', () => {
|
||||
const file = createMockFile({
|
||||
source: FileSources.local,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const source = screen.getByTestId('image-source').textContent;
|
||||
expect(source).toBe(FileSources.local);
|
||||
});
|
||||
|
||||
it('should pass openai source to Image component', () => {
|
||||
const file = createMockFile({
|
||||
source: FileSources.openai,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const source = screen.getByTestId('image-source').textContent;
|
||||
expect(source).toBe(FileSources.openai);
|
||||
});
|
||||
});
|
||||
|
||||
describe('File Type Detection', () => {
|
||||
it('should render Image component for image files', () => {
|
||||
const file = createMockFile({
|
||||
type: 'image/jpeg',
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
expect(screen.getByTestId('mock-image')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('mock-file-container')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should render FileContainer for non-image files', () => {
|
||||
const file = createMockFile({
|
||||
type: 'application/pdf',
|
||||
filename: 'document.pdf',
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
expect(screen.getByTestId('mock-file-container')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('mock-image')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Multiple Files', () => {
|
||||
it('should render multiple image files with correct URLs based on their progress', () => {
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
|
||||
const uploadingFile = createMockFile({
|
||||
file_id: 'file-1',
|
||||
preview: 'blob:http://localhost:3080/preview-1',
|
||||
filepath: undefined,
|
||||
progress: 0.3,
|
||||
});
|
||||
|
||||
const completedFile = createMockFile({
|
||||
file_id: 'file-2',
|
||||
preview: 'blob:http://localhost:3080/preview-2',
|
||||
filepath: '/images/user123/file-2__image.png',
|
||||
progress: 1,
|
||||
});
|
||||
|
||||
filesMap.set(uploadingFile.file_id, uploadingFile);
|
||||
filesMap.set(completedFile.file_id, completedFile);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const images = screen.getAllByTestId('mock-image');
|
||||
expect(images).toHaveLength(2);
|
||||
|
||||
const urls = screen.getAllByTestId('image-url').map((el) => el.textContent);
|
||||
expect(urls).toContain('blob:http://localhost:3080/preview-1');
|
||||
expect(urls).toContain('/images/user123/file-2__image.png');
|
||||
});
|
||||
|
||||
it('should deduplicate files with the same file_id', () => {
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
|
||||
const file1 = createMockFile({ file_id: 'duplicate-id' });
|
||||
const file2 = createMockFile({ file_id: 'duplicate-id' });
|
||||
|
||||
filesMap.set('key-1', file1);
|
||||
filesMap.set('key-2', file2);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const images = screen.getAllByTestId('mock-image');
|
||||
expect(images).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Empty State', () => {
|
||||
it('should render nothing when files map is empty', () => {
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
|
||||
const { container } = renderFileRow(filesMap);
|
||||
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it('should render nothing when files is undefined', () => {
|
||||
const { container } = render(
|
||||
<FileRow files={undefined} setFiles={mockSetFiles} setFilesLoading={mockSetFilesLoading} />,
|
||||
);
|
||||
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Regression: Blob URL Bug Fix', () => {
|
||||
it('should NOT use revoked blob URL after upload completes', () => {
|
||||
const file = createMockFile({
|
||||
file_id: 'regression-test',
|
||||
preview: 'blob:http://localhost:3080/d25f730c-152d-41f7-8d79-c9fa448f606b',
|
||||
filepath:
|
||||
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
|
||||
progress: 1,
|
||||
});
|
||||
|
||||
const filesMap = new Map<string, ExtendedFile>();
|
||||
filesMap.set(file.file_id, file);
|
||||
|
||||
renderFileRow(filesMap);
|
||||
|
||||
const imageUrl = screen.getByTestId('image-url').textContent;
|
||||
|
||||
expect(imageUrl).not.toContain('blob:');
|
||||
expect(imageUrl).toBe(
|
||||
'/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { memo, useState } from 'react';
|
||||
import React, { memo, useState, useMemo, useRef, useCallback, useEffect } from 'react';
|
||||
import { UserIcon, useAvatar } from '@librechat/client';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import type { IconProps } from '~/common';
|
||||
@@ -15,26 +15,49 @@ type UserAvatarProps = {
|
||||
className?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Default avatar component - memoized outside to prevent recreation on every render
|
||||
*/
|
||||
const DefaultAvatar = memo(() => (
|
||||
<div
|
||||
style={{
|
||||
backgroundColor: 'rgb(121, 137, 255)',
|
||||
width: '20px',
|
||||
height: '20px',
|
||||
boxShadow: 'rgba(240, 246, 252, 0.1) 0px 0px 0px 1px',
|
||||
}}
|
||||
className="relative flex h-9 w-9 items-center justify-center rounded-sm p-1 text-white"
|
||||
>
|
||||
<UserIcon />
|
||||
</div>
|
||||
));
|
||||
|
||||
DefaultAvatar.displayName = 'DefaultAvatar';
|
||||
|
||||
const UserAvatar = memo(({ size, user, avatarSrc, username, className }: UserAvatarProps) => {
|
||||
const [imageError, setImageError] = useState(false);
|
||||
const imageLoadedRef = useRef(false);
|
||||
|
||||
const handleImageError = () => {
|
||||
const imageSrc = useMemo(() => (user?.avatar ?? '') || avatarSrc, [user?.avatar, avatarSrc]);
|
||||
|
||||
/** Reset loaded state and error state if image source changes */
|
||||
useEffect(() => {
|
||||
imageLoadedRef.current = false;
|
||||
setImageError(false);
|
||||
}, [imageSrc]);
|
||||
|
||||
const handleImageError = useCallback(() => {
|
||||
setImageError(true);
|
||||
};
|
||||
imageLoadedRef.current = false;
|
||||
}, []);
|
||||
|
||||
const renderDefaultAvatar = () => (
|
||||
<div
|
||||
style={{
|
||||
backgroundColor: 'rgb(121, 137, 255)',
|
||||
width: '20px',
|
||||
height: '20px',
|
||||
boxShadow: 'rgba(240, 246, 252, 0.1) 0px 0px 0px 1px',
|
||||
}}
|
||||
className="relative flex h-9 w-9 items-center justify-center rounded-sm p-1 text-white"
|
||||
>
|
||||
<UserIcon />
|
||||
</div>
|
||||
);
|
||||
const handleImageLoad = useCallback(() => {
|
||||
imageLoadedRef.current = true;
|
||||
setImageError(false);
|
||||
}, []);
|
||||
|
||||
const hasAvatar = useMemo(() => imageSrc !== '', [imageSrc]);
|
||||
const showImage = useMemo(() => hasAvatar && !imageError, [hasAvatar, imageError]);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -45,14 +68,14 @@ const UserAvatar = memo(({ size, user, avatarSrc, username, className }: UserAva
|
||||
}}
|
||||
className={cn('relative flex items-center justify-center', className ?? '')}
|
||||
>
|
||||
{(!(user?.avatar ?? '') && (!(user?.username ?? '') || user?.username.trim() === '')) ||
|
||||
imageError ? (
|
||||
renderDefaultAvatar()
|
||||
{!showImage ? (
|
||||
<DefaultAvatar />
|
||||
) : (
|
||||
<img
|
||||
className="rounded-full"
|
||||
src={(user?.avatar ?? '') || avatarSrc}
|
||||
src={imageSrc}
|
||||
alt="avatar"
|
||||
onLoad={handleImageLoad}
|
||||
onError={handleImageError}
|
||||
/>
|
||||
)}
|
||||
@@ -69,8 +92,12 @@ const Icon: React.FC<IconProps> = memo((props) => {
|
||||
const avatarSrc = useAvatar(user);
|
||||
const localize = useLocalize();
|
||||
|
||||
const username = useMemo(
|
||||
() => user?.name ?? user?.username ?? localize('com_nav_user'),
|
||||
[user?.name, user?.username, localize],
|
||||
);
|
||||
|
||||
if (isCreatedByUser) {
|
||||
const username = user?.name ?? user?.username ?? localize('com_nav_user');
|
||||
return (
|
||||
<UserAvatar
|
||||
size={size}
|
||||
|
||||
@@ -787,9 +787,10 @@
|
||||
"com_ui_copy_code": "Copy code",
|
||||
"com_ui_copy_link": "Copy link",
|
||||
"com_ui_copy_stack_trace": "Copy stack trace",
|
||||
"com_ui_copy_thoughts_to_clipboard": "Copy thoughts to clipboard",
|
||||
"com_ui_copy_to_clipboard": "Copy to clipboard",
|
||||
"com_ui_copy_url_to_clipboard": "Copy URL to clipboard",
|
||||
"com_ui_copy_stack_trace": "Copy stack trace",
|
||||
"com_ui_copy_thoughts_to_clipboard": "Copy thoughts to clipboard",
|
||||
"com_ui_create": "Create",
|
||||
"com_ui_create_link": "Create link",
|
||||
"com_ui_create_memory": "Create Memory",
|
||||
@@ -1120,7 +1121,6 @@
|
||||
"com_ui_reset_var": "Reset {{0}}",
|
||||
"com_ui_reset_zoom": "Reset Zoom",
|
||||
"com_ui_resource": "resource",
|
||||
"com_ui_response": "Response",
|
||||
"com_ui_result": "Result",
|
||||
"com_ui_revoke": "Revoke",
|
||||
"com_ui_revoke_info": "Revoke all user provided credentials",
|
||||
@@ -1224,6 +1224,7 @@
|
||||
"com_ui_terms_of_service": "Terms of service",
|
||||
"com_ui_thinking": "Thinking...",
|
||||
"com_ui_thoughts": "Thoughts",
|
||||
"com_ui_response": "Response",
|
||||
"com_ui_token": "token",
|
||||
"com_ui_token_exchange_method": "Token Exchange Method",
|
||||
"com_ui_token_url": "Token URL",
|
||||
|
||||
@@ -136,16 +136,6 @@ registration:
|
||||
# apiKey: '${TTS_API_KEY}'
|
||||
# model: ''
|
||||
# voices: ['']
|
||||
# azureOpenAI:
|
||||
# # instanceName: The <NAME> part of your Azure endpoint URL
|
||||
# # For example, if your endpoint is: https://my-instance.cognitiveservices.azure.com
|
||||
# # Then instanceName should be: 'my-instance' (not the full URL)
|
||||
# instanceName: '${AZURE_TTS_INSTANCE_NAME}'
|
||||
# apiKey: '${AZURE_TTS_API_KEY}'
|
||||
# deploymentName: '${AZURE_TTS_DEPLOYMENT_NAME}'
|
||||
# apiVersion: '2024-02-01'
|
||||
# model: 'tts-1'
|
||||
# voices: ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer']
|
||||
|
||||
#
|
||||
# stt:
|
||||
@@ -153,15 +143,6 @@ registration:
|
||||
# url: ''
|
||||
# apiKey: '${STT_API_KEY}'
|
||||
# model: ''
|
||||
# azureOpenAI:
|
||||
# # instanceName: The <NAME> part of your Azure endpoint URL
|
||||
# # For example, if your endpoint is: https://my-instance.cognitiveservices.azure.com
|
||||
# # Then instanceName should be: 'my-instance' (not the full URL)
|
||||
# # Note: The code also supports full domain format for backward compatibility
|
||||
# instanceName: '${AZURE_STT_INSTANCE_NAME}'
|
||||
# apiKey: '${AZURE_STT_API_KEY}'
|
||||
# deploymentName: '${AZURE_STT_DEPLOYMENT_NAME}'
|
||||
# apiVersion: '2024-02-01'
|
||||
|
||||
# rateLimits:
|
||||
# fileUploads:
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
export default {
|
||||
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
|
||||
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
|
||||
testPathIgnorePatterns: [
|
||||
'/node_modules/',
|
||||
'/dist/',
|
||||
'\\.dev\\.ts$',
|
||||
'\\.helper\\.ts$',
|
||||
'\\.helper\\.d\\.ts$',
|
||||
],
|
||||
testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'],
|
||||
coverageReporters: ['text', 'cobertura'],
|
||||
testResultsProcessor: 'jest-junit',
|
||||
moduleNameMapper: {
|
||||
@@ -24,4 +18,4 @@ export default {
|
||||
// },
|
||||
restoreMocks: true,
|
||||
testTimeout: 15000,
|
||||
};
|
||||
};
|
||||
@@ -18,11 +18,10 @@
|
||||
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
|
||||
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
|
||||
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.\"",
|
||||
"test:cache-integration:core": "jest --testPathPattern=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
"test:cache-integration:cluster": "jest --testPathPattern=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
|
||||
"test:cache-integration:mcp": "jest --testPathPattern=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
"verify": "npm run test:ci",
|
||||
"b:clean": "bun run rimraf dist",
|
||||
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs",
|
||||
|
||||
@@ -3,7 +3,6 @@ export * from './cdn';
|
||||
/* Auth */
|
||||
export * from './auth';
|
||||
/* MCP */
|
||||
export * from './mcp/registry/MCPServersRegistry';
|
||||
export * from './mcp/MCPManager';
|
||||
export * from './mcp/connection';
|
||||
export * from './mcp/oauth';
|
||||
|
||||
@@ -9,7 +9,6 @@ import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
|
||||
/**
|
||||
* Factory for creating MCP connections with optional OAuth authentication.
|
||||
@@ -232,11 +231,14 @@ export class MCPConnectionFactory {
|
||||
/** Attempts to establish connection with timeout handling */
|
||||
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
|
||||
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
|
||||
await withTimeout(
|
||||
this.connectTo(connection),
|
||||
connectTimeout,
|
||||
`Connection timeout after ${connectTimeout}ms`,
|
||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||
connectTimeout,
|
||||
),
|
||||
);
|
||||
const connectionAttempt = this.connectTo(connection);
|
||||
await Promise.race([connectionAttempt, connectionTimeout]);
|
||||
|
||||
if (await connection.isConnected()) return;
|
||||
logger.error(`${this.logPrefix} Failed to establish connection.`);
|
||||
|
||||
@@ -5,14 +5,11 @@ import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.j
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import type { MCPOAuthTokens } from './oauth';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import type { RequestBody } from '~/types';
|
||||
import type * as t from './types';
|
||||
import { UserConnectionManager } from './UserConnectionManager';
|
||||
import { ConnectionsRepository } from './ConnectionsRepository';
|
||||
import { MCPServerInspector } from './registry/MCPServerInspector';
|
||||
import { MCPServersInitializer } from './registry/MCPServersInitializer';
|
||||
import { mcpServersRegistry as registry } from './registry/MCPServersRegistry';
|
||||
import { UserConnectionManager } from '~/mcp/UserConnectionManager';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { formatToolContent } from './parsers';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils/env';
|
||||
@@ -27,8 +24,8 @@ export class MCPManager extends UserConnectionManager {
|
||||
/** Creates and initializes the singleton MCPManager instance */
|
||||
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
|
||||
if (MCPManager.instance) throw new Error('MCPManager has already been initialized.');
|
||||
MCPManager.instance = new MCPManager();
|
||||
await MCPManager.instance.initialize(configs);
|
||||
MCPManager.instance = new MCPManager(configs);
|
||||
await MCPManager.instance.initialize();
|
||||
return MCPManager.instance;
|
||||
}
|
||||
|
||||
@@ -39,10 +36,9 @@ export class MCPManager extends UserConnectionManager {
|
||||
}
|
||||
|
||||
/** Initializes the MCPManager by setting up server registry and app connections */
|
||||
public async initialize(configs: t.MCPServers) {
|
||||
await MCPServersInitializer.initialize(configs);
|
||||
const appConfigs = await registry.sharedAppServers.getAll();
|
||||
this.appConnections = new ConnectionsRepository(appConfigs);
|
||||
public async initialize() {
|
||||
await this.serversRegistry.initialize();
|
||||
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs);
|
||||
}
|
||||
|
||||
/** Retrieves an app-level or user-specific connection based on provided arguments */
|
||||
@@ -66,18 +62,36 @@ export class MCPManager extends UserConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns all available tool functions from app-level connections */
|
||||
public async getAppToolFunctions(): Promise<t.LCAvailableTools> {
|
||||
const toolFunctions: t.LCAvailableTools = {};
|
||||
const configs = await registry.getAllServerConfigs();
|
||||
for (const config of Object.values(configs)) {
|
||||
if (config.toolFunctions != null) {
|
||||
Object.assign(toolFunctions, config.toolFunctions);
|
||||
}
|
||||
}
|
||||
return toolFunctions;
|
||||
/** Get servers that require OAuth */
|
||||
public getOAuthServers(): Set<string> {
|
||||
return this.serversRegistry.oauthServers;
|
||||
}
|
||||
|
||||
/** Get all servers */
|
||||
public getAllServers(): t.MCPServers {
|
||||
return this.serversRegistry.rawConfigs;
|
||||
}
|
||||
|
||||
/** Returns all available tool functions from app-level connections */
|
||||
public getAppToolFunctions(): t.LCAvailableTools {
|
||||
return this.serversRegistry.toolFunctions;
|
||||
}
|
||||
|
||||
/** Returns all available tool functions from all connections available to user */
|
||||
public async getAllToolFunctions(userId: string): Promise<t.LCAvailableTools | null> {
|
||||
const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions();
|
||||
const userConnections = this.getUserConnections(userId);
|
||||
if (!userConnections || userConnections.size === 0) {
|
||||
return allToolFunctions;
|
||||
}
|
||||
|
||||
for (const [serverName, connection] of userConnections.entries()) {
|
||||
const toolFunctions = await this.serversRegistry.getToolFunctions(serverName, connection);
|
||||
Object.assign(allToolFunctions, toolFunctions);
|
||||
}
|
||||
|
||||
return allToolFunctions;
|
||||
}
|
||||
/** Returns all available tool functions from all connections available to user */
|
||||
public async getServerToolFunctions(
|
||||
userId: string,
|
||||
@@ -85,7 +99,7 @@ export class MCPManager extends UserConnectionManager {
|
||||
): Promise<t.LCAvailableTools | null> {
|
||||
try {
|
||||
if (this.appConnections?.has(serverName)) {
|
||||
return MCPServerInspector.getToolFunctions(
|
||||
return this.serversRegistry.getToolFunctions(
|
||||
serverName,
|
||||
await this.appConnections.get(serverName),
|
||||
);
|
||||
@@ -99,7 +113,7 @@ export class MCPManager extends UserConnectionManager {
|
||||
return null;
|
||||
}
|
||||
|
||||
return MCPServerInspector.getToolFunctions(serverName, userConnections.get(serverName)!);
|
||||
return this.serversRegistry.getToolFunctions(serverName, userConnections.get(serverName)!);
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[getServerToolFunctions] Error getting tool functions for server ${serverName}`,
|
||||
@@ -114,14 +128,8 @@ export class MCPManager extends UserConnectionManager {
|
||||
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
|
||||
* @returns Object mapping server names to their instructions
|
||||
*/
|
||||
private async getInstructions(serverNames?: string[]): Promise<Record<string, string>> {
|
||||
const instructions: Record<string, string> = {};
|
||||
const configs = await registry.getAllServerConfigs();
|
||||
for (const [serverName, config] of Object.entries(configs)) {
|
||||
if (config.serverInstructions != null) {
|
||||
instructions[serverName] = config.serverInstructions as string;
|
||||
}
|
||||
}
|
||||
public getInstructions(serverNames?: string[]): Record<string, string> {
|
||||
const instructions = this.serversRegistry.serverInstructions;
|
||||
if (!serverNames) return instructions;
|
||||
return pick(instructions, serverNames);
|
||||
}
|
||||
@@ -131,9 +139,9 @@ export class MCPManager extends UserConnectionManager {
|
||||
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
|
||||
* @returns Formatted instructions string ready for context injection
|
||||
*/
|
||||
public async formatInstructionsForContext(serverNames?: string[]): Promise<string> {
|
||||
public formatInstructionsForContext(serverNames?: string[]): string {
|
||||
/** Instructions for specified servers or all stored instructions */
|
||||
const instructionsToInclude = await this.getInstructions(serverNames);
|
||||
const instructionsToInclude = this.getInstructions(serverNames);
|
||||
|
||||
if (Object.keys(instructionsToInclude).length === 0) {
|
||||
return '';
|
||||
@@ -217,7 +225,7 @@ Please follow these instructions when using tools from the respective MCP server
|
||||
);
|
||||
}
|
||||
|
||||
const rawConfig = (await registry.getServerConfig(serverName, userId)) as t.MCPOptions;
|
||||
const rawConfig = this.getRawConfig(serverName) as t.MCPOptions;
|
||||
const currentOptions = processMCPEnv({
|
||||
user,
|
||||
options: rawConfig,
|
||||
|
||||
230
packages/api/src/mcp/MCPServersRegistry.ts
Normal file
230
packages/api/src/mcp/MCPServersRegistry.ts
Normal file
@@ -0,0 +1,230 @@
|
||||
import mapValues from 'lodash/mapValues';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import type { JsonSchemaType } from '@librechat/data-schemas';
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import { processMCPEnv, isEnabled } from '~/utils';
|
||||
|
||||
const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000;
|
||||
|
||||
function getMCPInitTimeout(): number {
|
||||
return process.env.MCP_INIT_TIMEOUT_MS != null
|
||||
? parseInt(process.env.MCP_INIT_TIMEOUT_MS)
|
||||
: DEFAULT_MCP_INIT_TIMEOUT_MS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages MCP server configurations and metadata discovery.
|
||||
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
||||
* Determines which servers are for app-level connections.
|
||||
* Has its own connections repository. All connections are disconnected after initialization.
|
||||
*/
|
||||
export class MCPServersRegistry {
|
||||
private initialized: boolean = false;
|
||||
private connections: ConnectionsRepository;
|
||||
private initTimeoutMs: number;
|
||||
|
||||
public readonly rawConfigs: t.MCPServers;
|
||||
public readonly parsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||
|
||||
public oauthServers: Set<string> = new Set();
|
||||
public serverInstructions: Record<string, string> = {};
|
||||
public toolFunctions: t.LCAvailableTools = {};
|
||||
public appServerConfigs: t.MCPServers = {};
|
||||
|
||||
constructor(configs: t.MCPServers) {
|
||||
this.rawConfigs = configs;
|
||||
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con }));
|
||||
this.connections = new ConnectionsRepository(configs);
|
||||
this.initTimeoutMs = getMCPInitTimeout();
|
||||
}
|
||||
|
||||
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
||||
public async initialize(): Promise<void> {
|
||||
if (this.initialized) return;
|
||||
this.initialized = true;
|
||||
|
||||
const serverNames = Object.keys(this.parsedConfigs);
|
||||
|
||||
await Promise.allSettled(
|
||||
serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)),
|
||||
);
|
||||
}
|
||||
|
||||
/** Wraps server initialization with a timeout to prevent hanging */
|
||||
private async initializeServerWithTimeout(serverName: string): Promise<void> {
|
||||
let timeoutId: NodeJS.Timeout | null = null;
|
||||
|
||||
try {
|
||||
await Promise.race([
|
||||
this.initializeServer(serverName),
|
||||
new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(() => {
|
||||
reject(new Error('Server initialization timed out'));
|
||||
}, this.initTimeoutMs);
|
||||
}),
|
||||
]);
|
||||
} catch (error) {
|
||||
logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error);
|
||||
throw error;
|
||||
} finally {
|
||||
if (timeoutId != null) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Initializes a single server with all its metadata and adds it to appropriate collections */
|
||||
private async initializeServer(serverName: string): Promise<void> {
|
||||
const start = Date.now();
|
||||
|
||||
const config = this.parsedConfigs[serverName];
|
||||
|
||||
// 1. Detect OAuth requirements if not already specified
|
||||
try {
|
||||
await this.fetchOAuthRequirement(serverName);
|
||||
|
||||
if (config.startup !== false && !config.requiresOAuth) {
|
||||
await Promise.allSettled([
|
||||
this.fetchServerInstructions(serverName).catch((error) =>
|
||||
logger.warn(`${this.prefix(serverName)} Failed to fetch server instructions:`, error),
|
||||
),
|
||||
this.fetchServerCapabilities(serverName).catch((error) =>
|
||||
logger.warn(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error),
|
||||
),
|
||||
]);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error);
|
||||
}
|
||||
|
||||
// 2. Fetch tool functions for this server if a connection was established
|
||||
const getToolFunctions = async (): Promise<t.LCAvailableTools | null> => {
|
||||
try {
|
||||
const loadedConns = await this.connections.getLoaded();
|
||||
const conn = loadedConns.get(serverName);
|
||||
if (conn == null) {
|
||||
return null;
|
||||
}
|
||||
return this.getToolFunctions(serverName, conn);
|
||||
} catch (error) {
|
||||
logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
const toolFunctions = await getToolFunctions();
|
||||
|
||||
// 3. Disconnect this server's connection if it was established (fire-and-forget)
|
||||
void this.connections.disconnect(serverName);
|
||||
|
||||
// 4. Side effects
|
||||
// 4.1 Add to OAuth servers if needed
|
||||
if (config.requiresOAuth) {
|
||||
this.oauthServers.add(serverName);
|
||||
}
|
||||
// 4.2 Add server instructions if available
|
||||
if (config.serverInstructions != null) {
|
||||
this.serverInstructions[serverName] = config.serverInstructions as string;
|
||||
}
|
||||
// 4.3 Add to app server configs if eligible (startup enabled, non-OAuth servers)
|
||||
if (config.startup !== false && config.requiresOAuth === false) {
|
||||
this.appServerConfigs[serverName] = this.rawConfigs[serverName];
|
||||
}
|
||||
// 4.4 Add tool functions if available
|
||||
if (toolFunctions != null) {
|
||||
Object.assign(this.toolFunctions, toolFunctions);
|
||||
}
|
||||
|
||||
const duration = Date.now() - start;
|
||||
this.logUpdatedConfig(serverName, duration);
|
||||
}
|
||||
|
||||
/** Converts server tools to LibreChat-compatible tool functions format */
|
||||
public async getToolFunctions(
|
||||
serverName: string,
|
||||
conn: MCPConnection,
|
||||
): Promise<t.LCAvailableTools> {
|
||||
const { tools }: t.MCPToolListResponse = await conn.client.listTools();
|
||||
|
||||
const toolFunctions: t.LCAvailableTools = {};
|
||||
tools.forEach((tool) => {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
toolFunctions[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema as JsonSchemaType,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
return toolFunctions;
|
||||
}
|
||||
|
||||
/** Determines if server requires OAuth if not already specified in the config */
|
||||
private async fetchOAuthRequirement(serverName: string): Promise<boolean> {
|
||||
const config = this.parsedConfigs[serverName];
|
||||
if (config.requiresOAuth != null) return config.requiresOAuth;
|
||||
if (config.url == null) return (config.requiresOAuth = false);
|
||||
if (config.startup === false) return (config.requiresOAuth = false);
|
||||
|
||||
const result = await detectOAuthRequirement(config.url);
|
||||
config.requiresOAuth = result.requiresOAuth;
|
||||
config.oauthMetadata = result.metadata;
|
||||
return config.requiresOAuth;
|
||||
}
|
||||
|
||||
/** Retrieves server instructions from MCP server if enabled in the config */
|
||||
private async fetchServerInstructions(serverName: string): Promise<void> {
|
||||
const config = this.parsedConfigs[serverName];
|
||||
if (!config.serverInstructions) return;
|
||||
|
||||
// If it's a string that's not "true", it's a custom instruction
|
||||
if (typeof config.serverInstructions === 'string' && !isEnabled(config.serverInstructions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Fetch from server if true (boolean) or "true" (string)
|
||||
const conn = await this.connections.get(serverName);
|
||||
config.serverInstructions = conn.client.getInstructions();
|
||||
if (!config.serverInstructions) {
|
||||
logger.warn(`${this.prefix(serverName)} No server instructions available`);
|
||||
}
|
||||
}
|
||||
|
||||
/** Fetches server capabilities and available tools list */
|
||||
private async fetchServerCapabilities(serverName: string): Promise<void> {
|
||||
const config = this.parsedConfigs[serverName];
|
||||
const conn = await this.connections.get(serverName);
|
||||
const capabilities = conn.client.getServerCapabilities();
|
||||
if (!capabilities) return;
|
||||
config.capabilities = JSON.stringify(capabilities);
|
||||
if (!capabilities.tools) return;
|
||||
const tools = await conn.client.listTools();
|
||||
config.tools = tools.tools.map((tool) => tool.name).join(', ');
|
||||
}
|
||||
|
||||
// Logs server configuration summary after initialization
|
||||
private logUpdatedConfig(serverName: string, initDuration: number): void {
|
||||
const prefix = this.prefix(serverName);
|
||||
const config = this.parsedConfigs[serverName];
|
||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
|
||||
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||
logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`);
|
||||
logger.info(`${prefix} Initialized in: ${initDuration}ms`);
|
||||
logger.info(`${prefix} -------------------------------------------------┘`);
|
||||
}
|
||||
|
||||
// Returns formatted log prefix for server messages
|
||||
private prefix(serverName: string): string {
|
||||
return `[MCP][${serverName}]`;
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { mcpServersRegistry as serversRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { MCPConnection } from './connection';
|
||||
import type * as t from './types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
@@ -14,6 +14,7 @@ import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
* https://github.com/danny-avila/LibreChat/discussions/8790
|
||||
*/
|
||||
export abstract class UserConnectionManager {
|
||||
protected readonly serversRegistry: MCPServersRegistry;
|
||||
// Connections shared by all users.
|
||||
public appConnections: ConnectionsRepository | null = null;
|
||||
// Connections per userId -> serverName -> connection
|
||||
@@ -22,6 +23,15 @@ export abstract class UserConnectionManager {
|
||||
protected userLastActivity: Map<string, number> = new Map();
|
||||
protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
|
||||
|
||||
constructor(serverConfigs: t.MCPServers) {
|
||||
this.serversRegistry = new MCPServersRegistry(serverConfigs);
|
||||
}
|
||||
|
||||
/** fetches am MCP Server config from the registry */
|
||||
public getRawConfig(serverName: string): t.MCPOptions | undefined {
|
||||
return this.serversRegistry.rawConfigs[serverName];
|
||||
}
|
||||
|
||||
/** Updates the last activity timestamp for a user */
|
||||
protected updateUserLastActivity(userId: string): void {
|
||||
const now = Date.now();
|
||||
@@ -96,7 +106,7 @@ export abstract class UserConnectionManager {
|
||||
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
|
||||
}
|
||||
|
||||
const config = await serversRegistry.getServerConfig(serverName, userId);
|
||||
const config = this.serversRegistry.parsedConfigs[serverName];
|
||||
if (!config) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
|
||||
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPConnection } from '../connection';
|
||||
|
||||
@@ -17,24 +15,7 @@ jest.mock('@librechat/data-schemas', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({
|
||||
mcpServersRegistry: {
|
||||
sharedAppServers: {
|
||||
getAll: jest.fn(),
|
||||
},
|
||||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServersInitializer', () => ({
|
||||
MCPServersInitializer: {
|
||||
initialize: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServerInspector');
|
||||
jest.mock('~/mcp/MCPServersRegistry');
|
||||
jest.mock('~/mcp/ConnectionsRepository');
|
||||
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
@@ -47,13 +28,21 @@ describe('MCPManager', () => {
|
||||
// Reset MCPManager singleton state
|
||||
(MCPManager as unknown as { instance: null }).instance = null;
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Set up default mock implementations
|
||||
(MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined);
|
||||
(mcpServersRegistry.sharedAppServers.getAll as jest.Mock).mockResolvedValue({});
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({});
|
||||
});
|
||||
|
||||
function mockRegistry(
|
||||
registryConfig: Partial<MCPServersRegistry>,
|
||||
): jest.MockedClass<typeof MCPServersRegistry> {
|
||||
const mock = {
|
||||
initialize: jest.fn().mockResolvedValue(undefined),
|
||||
getToolFunctions: jest.fn().mockResolvedValue(null),
|
||||
...registryConfig,
|
||||
};
|
||||
return (MCPServersRegistry as jest.MockedClass<typeof MCPServersRegistry>).mockImplementation(
|
||||
() => mock as unknown as MCPServersRegistry,
|
||||
);
|
||||
}
|
||||
|
||||
function mockAppConnections(
|
||||
appConnectionsConfig: Partial<ConnectionsRepository>,
|
||||
): jest.MockedClass<typeof ConnectionsRepository> {
|
||||
@@ -77,229 +66,12 @@ describe('MCPManager', () => {
|
||||
};
|
||||
}
|
||||
|
||||
describe('getAppToolFunctions', () => {
|
||||
it('should return empty object when no servers have tool functions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: { type: 'stdio', command: 'test', args: [] },
|
||||
server2: { type: 'stdio', command: 'test2', args: [] },
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.getAppToolFunctions();
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should collect tool functions from multiple servers', async () => {
|
||||
const toolFunctions1 = {
|
||||
tool1_mcp_server1: {
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: 'tool1_mcp_server1',
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object' as const },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const toolFunctions2 = {
|
||||
tool2_mcp_server2: {
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: 'tool2_mcp_server2',
|
||||
description: 'Tool 2',
|
||||
parameters: { type: 'object' as const },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: {
|
||||
type: 'stdio',
|
||||
command: 'test',
|
||||
args: [],
|
||||
toolFunctions: toolFunctions1,
|
||||
},
|
||||
server2: {
|
||||
type: 'stdio',
|
||||
command: 'test2',
|
||||
args: [],
|
||||
toolFunctions: toolFunctions2,
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.getAppToolFunctions();
|
||||
|
||||
expect(result).toEqual({
|
||||
...toolFunctions1,
|
||||
...toolFunctions2,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle servers with null or undefined toolFunctions', async () => {
|
||||
const toolFunctions1 = {
|
||||
tool1_mcp_server1: {
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: 'tool1_mcp_server1',
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object' as const },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: {
|
||||
type: 'stdio',
|
||||
command: 'test',
|
||||
args: [],
|
||||
toolFunctions: toolFunctions1,
|
||||
},
|
||||
server2: {
|
||||
type: 'stdio',
|
||||
command: 'test2',
|
||||
args: [],
|
||||
toolFunctions: null,
|
||||
},
|
||||
server3: {
|
||||
type: 'stdio',
|
||||
command: 'test3',
|
||||
args: [],
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.getAppToolFunctions();
|
||||
|
||||
expect(result).toEqual(toolFunctions1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatInstructionsForContext', () => {
|
||||
it('should return empty string when no servers have instructions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: { type: 'stdio', command: 'test', args: [] },
|
||||
server2: { type: 'stdio', command: 'test2', args: [] },
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext();
|
||||
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should format instructions from multiple servers', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
serverInstructions: 'Only read/write files in allowed directories',
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext();
|
||||
|
||||
expect(result).toContain('# MCP Server Instructions');
|
||||
expect(result).toContain('## github MCP Server Instructions');
|
||||
expect(result).toContain('Use GitHub API with care');
|
||||
expect(result).toContain('## files MCP Server Instructions');
|
||||
expect(result).toContain('Only read/write files in allowed directories');
|
||||
});
|
||||
|
||||
it('should filter instructions by server names when provided', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
serverInstructions: 'Only read/write files in allowed directories',
|
||||
},
|
||||
database: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['db.js'],
|
||||
serverInstructions: 'Be careful with database operations',
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext(['github', 'database']);
|
||||
|
||||
expect(result).toContain('## github MCP Server Instructions');
|
||||
expect(result).toContain('Use GitHub API with care');
|
||||
expect(result).toContain('## database MCP Server Instructions');
|
||||
expect(result).toContain('Be careful with database operations');
|
||||
expect(result).not.toContain('files');
|
||||
expect(result).not.toContain('Only read/write files in allowed directories');
|
||||
});
|
||||
|
||||
it('should handle servers with null or undefined instructions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
serverInstructions: null,
|
||||
},
|
||||
database: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['db.js'],
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext();
|
||||
|
||||
expect(result).toContain('## github MCP Server Instructions');
|
||||
expect(result).toContain('Use GitHub API with care');
|
||||
expect(result).not.toContain('files');
|
||||
expect(result).not.toContain('database');
|
||||
});
|
||||
|
||||
it('should return empty string when filtered servers have no instructions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext(['files']);
|
||||
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getServerToolFunctions', () => {
|
||||
it('should catch and handle errors gracefully', async () => {
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => {
|
||||
throw new Error('Connection failed');
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn(() => {
|
||||
throw new Error('Connection failed');
|
||||
}),
|
||||
});
|
||||
|
||||
mockAppConnections({
|
||||
@@ -318,7 +90,9 @@ describe('MCPManager', () => {
|
||||
});
|
||||
|
||||
it('should catch synchronous errors from getUserConnections', async () => {
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn().mockResolvedValue({});
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
});
|
||||
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockReturnValue(false),
|
||||
@@ -352,9 +126,9 @@ describe('MCPManager', () => {
|
||||
},
|
||||
};
|
||||
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest
|
||||
.fn()
|
||||
.mockResolvedValue(expectedTools);
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn().mockResolvedValue(expectedTools),
|
||||
});
|
||||
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockReturnValue(true),
|
||||
@@ -371,8 +145,10 @@ describe('MCPManager', () => {
|
||||
it('should include specific server name in error messages', async () => {
|
||||
const specificServerName = 'github_mcp_server';
|
||||
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => {
|
||||
throw new Error('Server specific error');
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn(() => {
|
||||
throw new Error('Server specific error');
|
||||
}),
|
||||
});
|
||||
|
||||
mockAppConnections({
|
||||
|
||||
595
packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts
Normal file
595
packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts
Normal file
@@ -0,0 +1,595 @@
|
||||
import { join } from 'path';
|
||||
import { readFileSync } from 'fs';
|
||||
import { load as yamlLoad } from 'js-yaml';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthDetectionResult } from '~/mcp/oauth/detectOAuth';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
|
||||
// Mock external dependencies
|
||||
jest.mock('../oauth/detectOAuth');
|
||||
jest.mock('../ConnectionsRepository');
|
||||
jest.mock('../connection');
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock processMCPEnv to verify it's called and adds a processed marker
|
||||
jest.mock('~/utils', () => ({
|
||||
...jest.requireActual('~/utils'),
|
||||
processMCPEnv: jest.fn(({ options }) => ({
|
||||
...options,
|
||||
_processed: true, // Simple marker to verify processing occurred
|
||||
})),
|
||||
}));
|
||||
|
||||
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
|
||||
typeof detectOAuthRequirement
|
||||
>;
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
|
||||
describe('MCPServersRegistry - Initialize Function', () => {
|
||||
let rawConfigs: t.MCPServers;
|
||||
let expectedParsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
|
||||
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
|
||||
|
||||
beforeEach(() => {
|
||||
// Load fixtures
|
||||
const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml');
|
||||
const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml');
|
||||
|
||||
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
|
||||
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
|
||||
string,
|
||||
t.ParsedServerConfig
|
||||
>;
|
||||
|
||||
// Setup mock connections
|
||||
mockConnections = new Map();
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
const mockClient = {
|
||||
listTools: jest.fn(),
|
||||
getInstructions: jest.fn(),
|
||||
getServerCapabilities: jest.fn(),
|
||||
};
|
||||
const mockConnection = {
|
||||
client: mockClient,
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
// Setup mock responses based on expected configs
|
||||
const expectedConfig = expectedParsedConfigs[serverName];
|
||||
|
||||
// Mock listTools response
|
||||
if (expectedConfig.tools) {
|
||||
const toolNames = expectedConfig.tools.split(', ');
|
||||
const tools = toolNames.map((name: string) => ({
|
||||
name,
|
||||
description: `Description for ${name}`,
|
||||
inputSchema: {
|
||||
type: 'object' as const,
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
},
|
||||
}));
|
||||
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools });
|
||||
} else {
|
||||
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools: [] });
|
||||
}
|
||||
|
||||
// Mock getInstructions response
|
||||
if (expectedConfig.serverInstructions) {
|
||||
(mockClient.getInstructions as jest.Mock).mockReturnValue(
|
||||
expectedConfig.serverInstructions as string,
|
||||
);
|
||||
} else {
|
||||
(mockClient.getInstructions as jest.Mock).mockReturnValue(undefined);
|
||||
}
|
||||
|
||||
// Mock getServerCapabilities response
|
||||
if (expectedConfig.capabilities) {
|
||||
const capabilities = JSON.parse(expectedConfig.capabilities) as Record<string, unknown>;
|
||||
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(capabilities);
|
||||
} else {
|
||||
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(undefined);
|
||||
}
|
||||
|
||||
mockConnections.set(serverName, mockConnection);
|
||||
});
|
||||
|
||||
// Setup ConnectionsRepository mock
|
||||
mockConnectionsRepo = {
|
||||
get: jest.fn(),
|
||||
getLoaded: jest.fn(),
|
||||
disconnectAll: jest.fn(),
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<ConnectionsRepository>;
|
||||
|
||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||
const connection = mockConnections.get(serverName);
|
||||
if (!connection) {
|
||||
throw new Error(`Connection not found for server: ${serverName}`);
|
||||
}
|
||||
return Promise.resolve(connection);
|
||||
});
|
||||
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
|
||||
|
||||
(ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo);
|
||||
|
||||
// Setup OAuth detection mock with deterministic results
|
||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||
const oauthResults: Record<string, OAuthDetectionResult> = {
|
||||
'https://api.github.com/mcp': {
|
||||
requiresOAuth: true,
|
||||
method: 'protected-resource-metadata',
|
||||
metadata: {
|
||||
authorization_url: 'https://github.com/login/oauth/authorize',
|
||||
token_url: 'https://github.com/login/oauth/access_token',
|
||||
},
|
||||
},
|
||||
'https://api.disabled.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
'https://api.public.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
};
|
||||
|
||||
return Promise.resolve(
|
||||
oauthResults[url] || { requiresOAuth: false, method: 'no-metadata-found', metadata: null },
|
||||
);
|
||||
});
|
||||
|
||||
// Clear all mocks
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
delete process.env.MCP_INIT_TIMEOUT_MS;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('initialize() method', () => {
|
||||
it('should only run initialization once', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
await registry.initialize(); // Second call should not re-run
|
||||
|
||||
// Verify that connections are only requested for servers that need them
|
||||
// (servers with serverInstructions=true or all servers for capabilities)
|
||||
expect(mockConnectionsRepo.get).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set all public properties correctly after initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Verify initial state
|
||||
expect(registry.oauthServers.size).toBe(0);
|
||||
expect(registry.serverInstructions).toEqual({});
|
||||
expect(registry.toolFunctions).toEqual({});
|
||||
expect(registry.appServerConfigs).toEqual({});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Test oauthServers Set
|
||||
expect(registry.oauthServers).toEqual(
|
||||
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
|
||||
);
|
||||
|
||||
// Test serverInstructions - OAuth servers keep their original boolean value, non-OAuth fetch actual strings
|
||||
expect(registry.serverInstructions).toEqual({
|
||||
stdio_server: 'Follow these instructions for stdio server',
|
||||
oauth_server: true,
|
||||
non_oauth_server: 'Public API instructions',
|
||||
});
|
||||
|
||||
// Test appServerConfigs (startup enabled, non-OAuth servers only)
|
||||
expect(registry.appServerConfigs).toEqual({
|
||||
stdio_server: rawConfigs.stdio_server,
|
||||
websocket_server: rawConfigs.websocket_server,
|
||||
non_oauth_server: rawConfigs.non_oauth_server,
|
||||
});
|
||||
|
||||
// Test toolFunctions (only non-OAuth servers get their tools fetched during initialization)
|
||||
const expectedToolFunctions = {
|
||||
file_read_mcp_stdio_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_stdio_server',
|
||||
description: 'Description for file_read',
|
||||
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||
},
|
||||
},
|
||||
file_write_mcp_stdio_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_write_mcp_stdio_server',
|
||||
description: 'Description for file_write',
|
||||
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||
},
|
||||
},
|
||||
};
|
||||
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
|
||||
});
|
||||
|
||||
it('should handle errors gracefully and continue initialization of other servers', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Make one specific server throw an error during OAuth detection
|
||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||
if (url === 'https://api.github.com/mcp') {
|
||||
return Promise.reject(new Error('OAuth detection failed'));
|
||||
}
|
||||
// Return normal responses for other servers
|
||||
const oauthResults: Record<string, OAuthDetectionResult> = {
|
||||
'https://api.disabled.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
'https://api.public.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
};
|
||||
return Promise.resolve(
|
||||
oauthResults[url] ?? {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Should still initialize successfully for other servers
|
||||
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||
expect(registry.toolFunctions).toBeDefined();
|
||||
|
||||
// The failed server should not be in oauthServers (since it failed OAuth detection)
|
||||
expect(registry.oauthServers.has('oauth_server')).toBe(false);
|
||||
|
||||
// But other servers should still be processed successfully
|
||||
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
|
||||
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||
|
||||
// Error should be logged as a warning at the higher level
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'),
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should disconnect individual connections after each server initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify disconnect was called for each server during initialization
|
||||
// All servers attempt to connect during initialization for metadata gathering
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
|
||||
});
|
||||
|
||||
it('should log configuration updates for each startup-enabled server', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
serverNames.forEach((serverName) => {
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] URL:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] OAuth Required:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] Capabilities:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] Tools:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] Server Instructions:`),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should have parsedConfigs matching the expected fixture after initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Compare the actual parsedConfigs against the expected fixture
|
||||
expect(registry.parsedConfigs).toEqual(expectedParsedConfigs);
|
||||
});
|
||||
|
||||
it('should handle serverInstructions as string "true" correctly and fetch from server', async () => {
|
||||
// Create test config with serverInstructions as string "true"
|
||||
const testConfig: t.MCPServers = {
|
||||
test_server_string_true: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
serverInstructions: 'true', // Simulating string "true" from YAML parsing
|
||||
},
|
||||
test_server_custom_string: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
serverInstructions: 'Custom instructions here',
|
||||
},
|
||||
test_server_bool_true: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
serverInstructions: true,
|
||||
},
|
||||
};
|
||||
|
||||
const registry = new MCPServersRegistry(testConfig);
|
||||
|
||||
// Setup mock connection for servers that should fetch
|
||||
const mockClient = {
|
||||
listTools: jest.fn().mockResolvedValue({ tools: [] }),
|
||||
getInstructions: jest.fn().mockReturnValue('Fetched instructions from server'),
|
||||
getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }),
|
||||
};
|
||||
const mockConnection = {
|
||||
client: mockClient,
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
mockConnectionsRepo.get.mockResolvedValue(mockConnection);
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(
|
||||
new Map([
|
||||
['test_server_string_true', mockConnection],
|
||||
['test_server_bool_true', mockConnection],
|
||||
]),
|
||||
);
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify that string "true" was treated as fetch-from-server
|
||||
expect(registry.parsedConfigs['test_server_string_true'].serverInstructions).toBe(
|
||||
'Fetched instructions from server',
|
||||
);
|
||||
|
||||
// Verify that custom string was kept as-is
|
||||
expect(registry.parsedConfigs['test_server_custom_string'].serverInstructions).toBe(
|
||||
'Custom instructions here',
|
||||
);
|
||||
|
||||
// Verify that boolean true also fetched from server
|
||||
expect(registry.parsedConfigs['test_server_bool_true'].serverInstructions).toBe(
|
||||
'Fetched instructions from server',
|
||||
);
|
||||
|
||||
// Verify getInstructions was called for both "true" cases
|
||||
expect(mockClient.getInstructions).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use Promise.allSettled for individual server initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Spy on Promise.allSettled to verify it's being used
|
||||
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify Promise.allSettled was called with an array of server initialization promises
|
||||
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
|
||||
|
||||
// Verify it was called with the correct number of server promises
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
expect(allSettledSpy).toHaveBeenCalledWith(
|
||||
expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))),
|
||||
);
|
||||
|
||||
allSettledSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should isolate server failures and not affect other servers', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Make multiple servers fail in different ways
|
||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||
if (serverName === 'stdio_server') {
|
||||
// First server fails
|
||||
throw new Error('Connection failed for stdio_server');
|
||||
}
|
||||
if (serverName === 'websocket_server') {
|
||||
// Second server fails
|
||||
throw new Error('Connection failed for websocket_server');
|
||||
}
|
||||
// Other servers succeed
|
||||
const connection = mockConnections.get(serverName);
|
||||
if (!connection) {
|
||||
throw new Error(`Connection not found for server: ${serverName}`);
|
||||
}
|
||||
return Promise.resolve(connection);
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Despite failures, initialization should complete
|
||||
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||
expect(registry.toolFunctions).toBeDefined();
|
||||
|
||||
// Successful servers should still be processed
|
||||
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||
|
||||
// Failed servers should not crash the whole initialization
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][stdio_server] Failed to fetch server capabilities:'),
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][websocket_server] Failed to fetch server capabilities:'),
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should properly clean up connections even when some servers fail', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Track disconnect failures but suppress unhandled rejections
|
||||
const disconnectErrors: Error[] = [];
|
||||
mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => {
|
||||
if (serverName === 'stdio_server') {
|
||||
const error = new Error('Disconnect failed');
|
||||
disconnectErrors.push(error);
|
||||
return Promise.reject(error).catch(() => {}); // Suppress unhandled rejection
|
||||
}
|
||||
return Promise.resolve();
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Should still attempt to disconnect all servers during initialization
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
|
||||
expect(disconnectErrors).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should timeout individual server initialization after configured timeout', async () => {
|
||||
const timeout = 2000;
|
||||
// Create registry with a short timeout for testing
|
||||
process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`;
|
||||
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Make one server hang indefinitely during OAuth detection
|
||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||
if (url === 'https://api.github.com/mcp') {
|
||||
// Slow init
|
||||
return new Promise((res) => setTimeout(res, timeout * 2));
|
||||
}
|
||||
// Return normal responses for other servers
|
||||
return Promise.resolve({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
});
|
||||
|
||||
const start = Date.now();
|
||||
await registry.initialize();
|
||||
const duration = Date.now() - start;
|
||||
|
||||
// Should complete within reasonable time despite one server hanging
|
||||
// Allow some buffer for test execution overhead
|
||||
expect(duration).toBeLessThan(timeout * 1.5);
|
||||
|
||||
// The timeout should prevent the hanging server from blocking initialization
|
||||
// Other servers should still be processed successfully
|
||||
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
|
||||
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||
}, 10_000); // 10 second Jest timeout
|
||||
|
||||
it('should skip tool function fetching if connection was not established', async () => {
|
||||
const testConfig: t.MCPServers = {
|
||||
server_with_connection: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
},
|
||||
server_without_connection: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'failing-command',
|
||||
},
|
||||
};
|
||||
|
||||
const registry = new MCPServersRegistry(testConfig);
|
||||
|
||||
const mockClient = {
|
||||
listTools: jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'test_tool',
|
||||
description: 'Test tool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
getInstructions: jest.fn().mockReturnValue(undefined),
|
||||
getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }),
|
||||
};
|
||||
const mockConnection = {
|
||||
client: mockClient,
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||
if (serverName === 'server_with_connection') {
|
||||
return Promise.resolve(mockConnection);
|
||||
}
|
||||
throw new Error('Connection failed');
|
||||
});
|
||||
|
||||
// Mock getLoaded to return connections map - the real implementation returns all loaded connections at once
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(
|
||||
new Map([['server_with_connection', mockConnection]]),
|
||||
);
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
expect(registry.toolFunctions).toHaveProperty('test_tool_mcp_server_with_connection');
|
||||
expect(Object.keys(registry.toolFunctions)).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should handle getLoaded returning empty map gracefully', async () => {
|
||||
const testConfig: t.MCPServers = {
|
||||
test_server: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
},
|
||||
};
|
||||
|
||||
const registry = new MCPServersRegistry(testConfig);
|
||||
|
||||
mockConnectionsRepo.get.mockRejectedValue(new Error('All connections failed'));
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(new Map());
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
expect(registry.toolFunctions).toEqual({});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,67 @@
|
||||
# Expected parsed MCP server configurations after running initialize()
|
||||
# These represent the expected state of parsedConfigs after all fetch functions complete
|
||||
|
||||
oauth_server:
|
||||
_processed: true
|
||||
type: "streamable-http"
|
||||
url: "https://api.github.com/mcp"
|
||||
headers:
|
||||
Authorization: "Bearer {{GITHUB_TOKEN}}"
|
||||
serverInstructions: true
|
||||
requiresOAuth: true
|
||||
oauthMetadata:
|
||||
authorization_url: "https://github.com/login/oauth/authorize"
|
||||
token_url: "https://github.com/login/oauth/access_token"
|
||||
|
||||
oauth_predefined:
|
||||
_processed: true
|
||||
type: "sse"
|
||||
url: "https://api.example.com/sse"
|
||||
requiresOAuth: true
|
||||
oauthMetadata:
|
||||
authorization_url: "https://example.com/oauth/authorize"
|
||||
token_url: "https://example.com/oauth/token"
|
||||
|
||||
stdio_server:
|
||||
_processed: true
|
||||
command: "node"
|
||||
args: ["server.js"]
|
||||
env:
|
||||
API_KEY: "${TEST_API_KEY}"
|
||||
startup: true
|
||||
serverInstructions: "Follow these instructions for stdio server"
|
||||
requiresOAuth: false
|
||||
capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}'
|
||||
tools: "file_read, file_write"
|
||||
|
||||
websocket_server:
|
||||
_processed: true
|
||||
type: "websocket"
|
||||
url: "ws://localhost:3001/mcp"
|
||||
startup: true
|
||||
requiresOAuth: false
|
||||
oauthMetadata: null
|
||||
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||
tools: ""
|
||||
|
||||
disabled_server:
|
||||
_processed: true
|
||||
requiresOAuth: false
|
||||
type: "streamable-http"
|
||||
url: "https://api.disabled.com/mcp"
|
||||
startup: false
|
||||
|
||||
non_oauth_server:
|
||||
_processed: true
|
||||
type: "streamable-http"
|
||||
url: "https://api.public.com/mcp"
|
||||
requiresOAuth: false
|
||||
serverInstructions: "Public API instructions"
|
||||
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||
tools: ""
|
||||
|
||||
oauth_startup_enabled:
|
||||
_processed: true
|
||||
type: "sse"
|
||||
url: "https://api.oauth-startup.com/sse"
|
||||
requiresOAuth: true
|
||||
@@ -0,0 +1,53 @@
|
||||
# Raw MCP server configurations used as input to MCPServersRegistry constructor
|
||||
# These configs test different code paths in the initialization process
|
||||
|
||||
# Test OAuth detection with URL - should trigger fetchOAuthRequirement
|
||||
oauth_server:
|
||||
type: "streamable-http"
|
||||
url: "https://api.github.com/mcp"
|
||||
headers:
|
||||
Authorization: "Bearer {{GITHUB_TOKEN}}"
|
||||
serverInstructions: true
|
||||
|
||||
# Test OAuth already specified - should skip OAuth detection
|
||||
oauth_predefined:
|
||||
type: "sse"
|
||||
url: "https://api.example.com/sse"
|
||||
requiresOAuth: true
|
||||
oauthMetadata:
|
||||
authorization_url: "https://example.com/oauth/authorize"
|
||||
token_url: "https://example.com/oauth/token"
|
||||
|
||||
# Test stdio server without URL - should set requiresOAuth to false
|
||||
stdio_server:
|
||||
command: "node"
|
||||
args: ["server.js"]
|
||||
env:
|
||||
API_KEY: "${TEST_API_KEY}"
|
||||
startup: true
|
||||
serverInstructions: "Follow these instructions for stdio server"
|
||||
|
||||
# Test websocket server with capabilities but no tools
|
||||
websocket_server:
|
||||
type: "websocket"
|
||||
url: "ws://localhost:3001/mcp"
|
||||
startup: true
|
||||
|
||||
# Test server with startup disabled - should not be included in appServerConfigs
|
||||
disabled_server:
|
||||
type: "streamable-http"
|
||||
url: "https://api.disabled.com/mcp"
|
||||
startup: false
|
||||
|
||||
# Test non-OAuth server - should be included in appServerConfigs
|
||||
non_oauth_server:
|
||||
type: "streamable-http"
|
||||
url: "https://api.public.com/mcp"
|
||||
requiresOAuth: false
|
||||
serverInstructions: true
|
||||
|
||||
# Test server with OAuth but startup enabled - should not be in appServerConfigs
|
||||
oauth_startup_enabled:
|
||||
type: "sse"
|
||||
url: "https://api.oauth-startup.com/sse"
|
||||
requiresOAuth: true
|
||||
@@ -18,7 +18,6 @@ import type {
|
||||
Response as UndiciResponse,
|
||||
} from 'undici';
|
||||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import type * as t from './types';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
@@ -458,11 +457,15 @@ export class MCPConnection extends EventEmitter {
|
||||
this.setupTransportDebugHandlers();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
await withTimeout(
|
||||
await Promise.race([
|
||||
this.client.connect(this.transport),
|
||||
connectTimeout,
|
||||
`Connection timeout after ${connectTimeout}ms`,
|
||||
);
|
||||
new Promise((_resolve, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||
connectTimeout,
|
||||
),
|
||||
),
|
||||
]);
|
||||
|
||||
this.connectionState = 'connected';
|
||||
this.emit('connectionChange', 'connected');
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { TokenMethods } from '@librechat/data-schemas';
|
||||
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
|
||||
import { MCPManager } from '../MCPManager';
|
||||
import { mcpServersRegistry } from '../../mcp/registry/MCPServersRegistry';
|
||||
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
|
||||
@@ -15,12 +14,6 @@ jest.mock('@librechat/data-schemas', () => ({
|
||||
}));
|
||||
|
||||
jest.mock('../MCPManager');
|
||||
jest.mock('../../mcp/registry/MCPServersRegistry', () => ({
|
||||
mcpServersRegistry: {
|
||||
getServerConfig: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('OAuthReconnectionManager', () => {
|
||||
let flowManager: jest.Mocked<FlowStateManager<null>>;
|
||||
@@ -58,10 +51,10 @@ describe('OAuthReconnectionManager', () => {
|
||||
getUserConnection: jest.fn(),
|
||||
getUserConnections: jest.fn(),
|
||||
disconnectUserConnection: jest.fn(),
|
||||
getRawConfig: jest.fn(),
|
||||
} as unknown as jest.Mocked<MCPManager>;
|
||||
|
||||
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -159,7 +152,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
it('should reconnect eligible servers', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has failed reconnection
|
||||
reconnectionTracker.setFailed(userId, 'server1');
|
||||
@@ -193,9 +186,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({
|
||||
initTimeout: 5000,
|
||||
} as unknown as MCPOptions);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
@@ -224,7 +215,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
it('should handle failed reconnection attempts', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has valid token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
@@ -235,9 +226,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
|
||||
// Mock failed connection
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
@@ -253,7 +242,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
it('should not reconnect servers with expired tokens', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
@@ -272,7 +261,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
@@ -288,9 +277,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
@@ -372,7 +359,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
it('should not attempt to reconnect servers that have timed out during reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2']);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
@@ -427,7 +414,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
const oauthServers = new Set([serverName]);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
@@ -441,9 +428,7 @@ describe('OAuthReconnectionManager', () => {
|
||||
|
||||
// First reconnect attempt - will fail
|
||||
mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed'));
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
await jest.runAllTimersAsync();
|
||||
|
||||
@@ -5,7 +5,6 @@ import type { MCPOAuthTokens } from './types';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||
|
||||
@@ -73,7 +72,7 @@ export class OAuthReconnectionManager {
|
||||
|
||||
// 1. derive the servers to reconnect
|
||||
const serversToReconnect = [];
|
||||
for (const serverName of await mcpServersRegistry.getOAuthServers()) {
|
||||
for (const serverName of this.mcpManager.getOAuthServers()) {
|
||||
const canReconnect = await this.canReconnect(userId, serverName);
|
||||
if (canReconnect) {
|
||||
serversToReconnect.push(serverName);
|
||||
@@ -105,7 +104,7 @@ export class OAuthReconnectionManager {
|
||||
|
||||
logger.info(`${logPrefix} Attempting reconnection`);
|
||||
|
||||
const config = await mcpServersRegistry.getServerConfig(serverName, userId);
|
||||
const config = this.mcpManager.getRawConfig(serverName);
|
||||
|
||||
const cleanupOnFailedReconnect = () => {
|
||||
this.reconnectionsTracker.setFailed(userId, serverName);
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import type { JsonSchemaType } from '@librechat/data-schemas';
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { isEnabled } from '~/utils';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
|
||||
/**
|
||||
* Inspects MCP servers to discover their metadata, capabilities, and tools.
|
||||
* Connects to servers and populates configuration with OAuth requirements,
|
||||
* server instructions, capabilities, and available tools.
|
||||
*/
|
||||
export class MCPServerInspector {
|
||||
private constructor(
|
||||
private readonly serverName: string,
|
||||
private readonly config: t.ParsedServerConfig,
|
||||
private connection: MCPConnection | undefined,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Inspects a server and returns an enriched configuration with metadata.
|
||||
* Detects OAuth requirements and fetches server capabilities.
|
||||
* @param serverName - The name of the server (used for tool function naming)
|
||||
* @param rawConfig - The raw server configuration
|
||||
* @param connection - The MCP connection
|
||||
* @returns A fully processed and enriched configuration with server metadata
|
||||
*/
|
||||
public static async inspect(
|
||||
serverName: string,
|
||||
rawConfig: t.MCPOptions,
|
||||
connection?: MCPConnection,
|
||||
): Promise<t.ParsedServerConfig> {
|
||||
const start = Date.now();
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
|
||||
await inspector.inspectServer();
|
||||
inspector.config.initDuration = Date.now() - start;
|
||||
return inspector.config;
|
||||
}
|
||||
|
||||
private async inspectServer(): Promise<void> {
|
||||
await this.detectOAuth();
|
||||
|
||||
if (this.config.startup !== false && !this.config.requiresOAuth) {
|
||||
let tempConnection = false;
|
||||
if (!this.connection) {
|
||||
tempConnection = true;
|
||||
this.connection = await MCPConnectionFactory.create({
|
||||
serverName: this.serverName,
|
||||
serverConfig: this.config,
|
||||
});
|
||||
}
|
||||
|
||||
await Promise.allSettled([
|
||||
this.fetchServerInstructions(),
|
||||
this.fetchServerCapabilities(),
|
||||
this.fetchToolFunctions(),
|
||||
]);
|
||||
|
||||
if (tempConnection) await this.connection.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private async detectOAuth(): Promise<void> {
|
||||
if (this.config.requiresOAuth != null) return;
|
||||
if (this.config.url == null || this.config.startup === false) {
|
||||
this.config.requiresOAuth = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await detectOAuthRequirement(this.config.url);
|
||||
this.config.requiresOAuth = result.requiresOAuth;
|
||||
this.config.oauthMetadata = result.metadata;
|
||||
}
|
||||
|
||||
private async fetchServerInstructions(): Promise<void> {
|
||||
if (isEnabled(this.config.serverInstructions)) {
|
||||
this.config.serverInstructions = this.connection!.client.getInstructions();
|
||||
}
|
||||
}
|
||||
|
||||
private async fetchServerCapabilities(): Promise<void> {
|
||||
const capabilities = this.connection!.client.getServerCapabilities();
|
||||
this.config.capabilities = JSON.stringify(capabilities);
|
||||
const tools = await this.connection!.client.listTools();
|
||||
this.config.tools = tools.tools.map((tool) => tool.name).join(', ');
|
||||
}
|
||||
|
||||
private async fetchToolFunctions(): Promise<void> {
|
||||
this.config.toolFunctions = await MCPServerInspector.getToolFunctions(
|
||||
this.serverName,
|
||||
this.connection!,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts server tools to LibreChat-compatible tool functions format.
|
||||
* @param serverName - The name of the server
|
||||
* @param connection - The MCP connection
|
||||
* @returns Tool functions formatted for LibreChat
|
||||
*/
|
||||
public static async getToolFunctions(
|
||||
serverName: string,
|
||||
connection: MCPConnection,
|
||||
): Promise<t.LCAvailableTools> {
|
||||
const { tools }: t.MCPToolListResponse = await connection.client.listTools();
|
||||
|
||||
const toolFunctions: t.LCAvailableTools = {};
|
||||
tools.forEach((tool) => {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
toolFunctions[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema as JsonSchemaType,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
return toolFunctions;
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
import { registryStatusCache as statusCache } from './cache/RegistryStatusCache';
|
||||
import { isLeader } from '~/cluster';
|
||||
import { withTimeout } from '~/utils';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { MCPServerInspector } from './MCPServerInspector';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { mcpServersRegistry as registry } from './MCPServersRegistry';
|
||||
|
||||
const MCP_INIT_TIMEOUT_MS =
|
||||
process.env.MCP_INIT_TIMEOUT_MS != null ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) : 30_000;
|
||||
|
||||
/**
|
||||
* Handles initialization of MCP servers at application startup with distributed coordination.
|
||||
* In cluster environments, ensures only the leader node performs initialization while followers wait.
|
||||
* Connects to each configured MCP server, inspects capabilities and tools, then caches the results.
|
||||
* Categorizes servers as either shared app servers (auto-started) or shared user servers (OAuth/on-demand).
|
||||
* Uses a timeout mechanism to prevent hanging on unresponsive servers during initialization.
|
||||
*/
|
||||
export class MCPServersInitializer {
|
||||
/**
|
||||
* Initializes MCP servers with distributed leader-follower coordination.
|
||||
*
|
||||
* Design rationale:
|
||||
* - Handles leader crash scenarios: If the leader crashes during initialization, all followers
|
||||
* will independently attempt initialization after a 3-second delay. The first to become leader
|
||||
* will complete the initialization.
|
||||
* - Only the leader performs the actual initialization work (reset caches, inspect servers).
|
||||
* When complete, the leader signals completion via `statusCache`, allowing followers to proceed.
|
||||
* - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node
|
||||
* performs the expensive initialization operations.
|
||||
*/
|
||||
public static async initialize(rawConfigs: t.MCPServers): Promise<void> {
|
||||
if (await statusCache.isInitialized()) return;
|
||||
|
||||
if (await isLeader()) {
|
||||
// Leader performs initialization
|
||||
await statusCache.reset();
|
||||
await registry.reset();
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
await Promise.allSettled(
|
||||
serverNames.map((serverName) =>
|
||||
withTimeout(
|
||||
MCPServersInitializer.initializeServer(serverName, rawConfigs[serverName]),
|
||||
MCP_INIT_TIMEOUT_MS,
|
||||
`${MCPServersInitializer.prefix(serverName)} Server initialization timed out`,
|
||||
logger.error,
|
||||
),
|
||||
),
|
||||
);
|
||||
await statusCache.setInitialized(true);
|
||||
} else {
|
||||
// Followers try again after a delay if not initialized
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
await this.initialize(rawConfigs);
|
||||
}
|
||||
}
|
||||
|
||||
/** Initializes a single server with all its metadata and adds it to appropriate collections */
|
||||
private static async initializeServer(
|
||||
serverName: string,
|
||||
rawConfig: t.MCPOptions,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const config = await MCPServerInspector.inspect(serverName, rawConfig);
|
||||
|
||||
if (config.startup === false || config.requiresOAuth) {
|
||||
await registry.sharedUserServers.add(serverName, config);
|
||||
} else {
|
||||
await registry.sharedAppServers.add(serverName, config);
|
||||
}
|
||||
MCPServersInitializer.logParsedConfig(serverName, config);
|
||||
} catch (error) {
|
||||
logger.error(`${MCPServersInitializer.prefix(serverName)} Failed to initialize:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// Logs server configuration summary after initialization
|
||||
private static logParsedConfig(serverName: string, config: ParsedServerConfig): void {
|
||||
const prefix = MCPServersInitializer.prefix(serverName);
|
||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
|
||||
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||
logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`);
|
||||
logger.info(`${prefix} Initialized in: ${config.initDuration ?? 'N/A'}ms`);
|
||||
logger.info(`${prefix} -------------------------------------------------┘`);
|
||||
}
|
||||
|
||||
// Returns formatted log prefix for server messages
|
||||
private static prefix(serverName: string): string {
|
||||
return `[MCP][${serverName}]`;
|
||||
}
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
import type * as t from '~/mcp/types';
|
||||
import {
|
||||
ServerConfigsCacheFactory,
|
||||
type ServerConfigsCache,
|
||||
} from './cache/ServerConfigsCacheFactory';
|
||||
|
||||
/**
|
||||
* Central registry for managing MCP server configurations across different scopes and users.
|
||||
* Maintains three categories of server configurations:
|
||||
* - Shared App Servers: Auto-started servers available to all users (initialized at startup)
|
||||
* - Shared User Servers: User-scope servers that require OAuth or on-demand startup
|
||||
* - Private User Servers: Per-user configurations dynamically added during runtime
|
||||
*
|
||||
* Provides a unified interface for retrieving server configs with proper fallback hierarchy:
|
||||
* checks shared app servers first, then shared user servers, then private user servers.
|
||||
* Handles server lifecycle operations including adding, removing, and querying configurations.
|
||||
*/
|
||||
class MCPServersRegistry {
|
||||
public readonly sharedAppServers = ServerConfigsCacheFactory.create('App', true);
|
||||
public readonly sharedUserServers = ServerConfigsCacheFactory.create('User', true);
|
||||
private readonly privateUserServers: Map<string | undefined, ServerConfigsCache> = new Map();
|
||||
|
||||
public async addPrivateUserServer(
|
||||
userId: string,
|
||||
serverName: string,
|
||||
config: t.ParsedServerConfig,
|
||||
): Promise<void> {
|
||||
if (!this.privateUserServers.has(userId)) {
|
||||
const cache = ServerConfigsCacheFactory.create(`User(${userId})`, false);
|
||||
this.privateUserServers.set(userId, cache);
|
||||
}
|
||||
await this.privateUserServers.get(userId)!.add(serverName, config);
|
||||
}
|
||||
|
||||
public async updatePrivateUserServer(
|
||||
userId: string,
|
||||
serverName: string,
|
||||
config: t.ParsedServerConfig,
|
||||
): Promise<void> {
|
||||
const userCache = this.privateUserServers.get(userId);
|
||||
if (!userCache) throw new Error(`No private servers found for user "${userId}".`);
|
||||
await userCache.update(serverName, config);
|
||||
}
|
||||
|
||||
public async removePrivateUserServer(userId: string, serverName: string): Promise<void> {
|
||||
await this.privateUserServers.get(userId)?.remove(serverName);
|
||||
}
|
||||
|
||||
public async getServerConfig(
|
||||
serverName: string,
|
||||
userId?: string,
|
||||
): Promise<t.ParsedServerConfig | undefined> {
|
||||
const sharedAppServer = await this.sharedAppServers.get(serverName);
|
||||
if (sharedAppServer) return sharedAppServer;
|
||||
|
||||
const sharedUserServer = await this.sharedUserServers.get(serverName);
|
||||
if (sharedUserServer) return sharedUserServer;
|
||||
|
||||
const privateUserServer = await this.privateUserServers.get(userId)?.get(serverName);
|
||||
if (privateUserServer) return privateUserServer;
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
public async getAllServerConfigs(userId?: string): Promise<Record<string, t.ParsedServerConfig>> {
|
||||
return {
|
||||
...(await this.sharedAppServers.getAll()),
|
||||
...(await this.sharedUserServers.getAll()),
|
||||
...((await this.privateUserServers.get(userId)?.getAll()) ?? {}),
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: This is currently used to determine if a server requires OAuth. However, this info can
|
||||
// can be determined through config.requiresOAuth. Refactor usages and remove this method.
|
||||
public async getOAuthServers(userId?: string): Promise<Set<string>> {
|
||||
const allServers = await this.getAllServerConfigs(userId);
|
||||
const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth);
|
||||
return new Set(oauthServers.map(([name]) => name));
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
await this.sharedAppServers.reset();
|
||||
await this.sharedUserServers.reset();
|
||||
for (const cache of this.privateUserServers.values()) {
|
||||
await cache.reset();
|
||||
}
|
||||
this.privateUserServers.clear();
|
||||
}
|
||||
}
|
||||
|
||||
export const mcpServersRegistry = new MCPServersRegistry();
|
||||
@@ -1,338 +0,0 @@
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { createMockConnection } from './mcpConnectionsMock.helper';
|
||||
|
||||
// Mock external dependencies
|
||||
jest.mock('../../oauth/detectOAuth');
|
||||
jest.mock('../../MCPConnectionFactory');
|
||||
|
||||
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
|
||||
typeof detectOAuthRequirement
|
||||
>;
|
||||
|
||||
describe('MCPServerInspector', () => {
|
||||
let mockConnection: jest.Mocked<MCPConnection>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConnection = createMockConnection('test_server');
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('inspect()', () => {
|
||||
it('should process env and fetch all metadata for non-OAuth stdio server with serverInstructions=true', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: true,
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'instructions for test_server',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: {
|
||||
listFiles_mcp_test_server: expect.objectContaining({
|
||||
type: 'function',
|
||||
function: expect.objectContaining({
|
||||
name: 'listFiles_mcp_test_server',
|
||||
}),
|
||||
}),
|
||||
},
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should detect OAuth and skip capabilities fetch for streamable-http server', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: true,
|
||||
method: 'protected-resource-metadata',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
requiresOAuth: true,
|
||||
oauthMetadata: undefined,
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should skip capabilities fetch when startup=false', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
startup: false,
|
||||
};
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
startup: false,
|
||||
requiresOAuth: false,
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should keep custom serverInstructions string and not fetch from server', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'Custom instructions here',
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'Custom instructions here',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: expect.any(Object),
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle serverInstructions as string "true" and fetch from server', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'true', // String "true" from YAML
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'instructions for test_server',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: expect.any(Object),
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle predefined requiresOAuth without detection', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://api.example.com/sse',
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'sse',
|
||||
url: 'https://api.example.com/sse',
|
||||
requiresOAuth: true,
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should fetch capabilities when server has no tools', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
// Mock server with no tools
|
||||
mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] });
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: '',
|
||||
toolFunctions: {},
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should create temporary connection when no connection is provided', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: true,
|
||||
};
|
||||
|
||||
const tempMockConnection = createMockConnection('test_server');
|
||||
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(tempMockConnection);
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig);
|
||||
|
||||
// Verify factory was called to create connection
|
||||
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
|
||||
serverName: 'test_server',
|
||||
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
|
||||
});
|
||||
|
||||
// Verify temporary connection was disconnected
|
||||
expect(tempMockConnection.disconnect).toHaveBeenCalled();
|
||||
|
||||
// Verify result is correct
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'instructions for test_server',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: expect.any(Object),
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should not create temporary connection when connection is provided', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: true,
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
// Verify factory was NOT called
|
||||
expect(MCPConnectionFactory.create).not.toHaveBeenCalled();
|
||||
|
||||
// Verify provided connection was NOT disconnected
|
||||
expect(mockConnection.disconnect).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToolFunctions()', () => {
|
||||
it('should convert MCP tools to LibreChat tool functions format', async () => {
|
||||
mockConnection.client.listTools = jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'file_read',
|
||||
description: 'Read a file',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: { path: { type: 'string' } },
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'file_write',
|
||||
description: 'Write a file',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
path: { type: 'string' },
|
||||
content: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
file_read_mcp_my_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_my_server',
|
||||
description: 'Read a file',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: { path: { type: 'string' } },
|
||||
},
|
||||
},
|
||||
},
|
||||
file_write_mcp_my_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_write_mcp_my_server',
|
||||
description: 'Write a file',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
path: { type: 'string' },
|
||||
content: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty tools list', async () => {
|
||||
mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] });
|
||||
|
||||
const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection);
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,301 +0,0 @@
|
||||
import { expect } from '@playwright/test';
|
||||
import type * as t from '~/mcp/types';
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
|
||||
// Mock isLeader to always return true to avoid lock contention during parallel operations
|
||||
jest.mock('~/cluster', () => ({
|
||||
...jest.requireActual('~/cluster'),
|
||||
isLeader: jest.fn().mockResolvedValue(true),
|
||||
}));
|
||||
|
||||
describe('MCPServersInitializer Redis Integration Tests', () => {
|
||||
let MCPServersInitializer: typeof import('../MCPServersInitializer').MCPServersInitializer;
|
||||
let registry: typeof import('../MCPServersRegistry').mcpServersRegistry;
|
||||
let registryStatusCache: typeof import('../cache/RegistryStatusCache').registryStatusCache;
|
||||
let MCPServerInspector: typeof import('../MCPServerInspector').MCPServerInspector;
|
||||
let MCPConnectionFactory: typeof import('~/mcp/MCPConnectionFactory').MCPConnectionFactory;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
|
||||
|
||||
const testConfigs: t.MCPServers = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
},
|
||||
};
|
||||
|
||||
const testParsedConfigs: Record<string, t.ParsedServerConfig> = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
requiresOAuth: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
requiresOAuth: true,
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for search_tools_server',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
tools: 'search',
|
||||
toolFunctions: {
|
||||
search_mcp_search_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'search_mcp_search_tools_server',
|
||||
description: 'Search tool',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'MCPServersInitializer-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const initializerModule = await import('../MCPServersInitializer');
|
||||
const registryModule = await import('../MCPServersRegistry');
|
||||
const statusCacheModule = await import('../cache/RegistryStatusCache');
|
||||
const inspectorModule = await import('../MCPServerInspector');
|
||||
const connectionFactoryModule = await import('~/mcp/MCPConnectionFactory');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
|
||||
MCPServersInitializer = initializerModule.MCPServersInitializer;
|
||||
registry = registryModule.mcpServersRegistry;
|
||||
registryStatusCache = statusCacheModule.registryStatusCache;
|
||||
MCPServerInspector = inspectorModule.MCPServerInspector;
|
||||
MCPConnectionFactory = connectionFactoryModule.MCPConnectionFactory;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Become leader so we can perform write operations
|
||||
leaderInstance = new LeaderElection();
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Ensure we're still the leader
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
if (!isLeader) {
|
||||
throw new Error('Lost leader status before test');
|
||||
}
|
||||
|
||||
// Mock MCPServerInspector.inspect to return parsed config
|
||||
jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => {
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
// Mock MCPConnection
|
||||
const mockConnection = {
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
// Mock MCPConnectionFactory
|
||||
jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection);
|
||||
|
||||
// Reset caches before each test
|
||||
await registryStatusCache.reset();
|
||||
await registry.reset();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: clear all test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*MCPServersInitializer-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Resign as leader
|
||||
if (leaderInstance) await leaderInstance.resign();
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('initialize()', () => {
|
||||
it('should reset registry and status cache before initialization', async () => {
|
||||
// Pre-populate registry with some old servers
|
||||
await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server);
|
||||
await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server);
|
||||
|
||||
// Initialize with new configs (this should reset first)
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify old servers are gone
|
||||
expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined();
|
||||
expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined();
|
||||
|
||||
// Verify new servers are present
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip initialization if already initialized', async () => {
|
||||
// First initialization
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Clear mock calls
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Second initialization should skip due to static flag
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify inspect was not called again
|
||||
expect(MCPServerInspector.inspect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should add disabled servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
expect(disabledServer).toMatchObject({
|
||||
...testParsedConfigs.disabled_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add OAuth servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
expect(oauthServer).toMatchObject({
|
||||
...testParsedConfigs.oauth_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add enabled non-OAuth servers to sharedAppServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeDefined();
|
||||
expect(fileToolsServer).toMatchObject({
|
||||
...testParsedConfigs.file_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
expect(searchToolsServer).toMatchObject({
|
||||
...testParsedConfigs.search_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully initialize all servers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify all servers were added to appropriate registries
|
||||
expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle inspection failures gracefully', async () => {
|
||||
// Mock inspection failure for one server
|
||||
jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => {
|
||||
if (serverName === 'file_tools_server') {
|
||||
throw new Error('Inspection failed');
|
||||
}
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify other servers were still processed
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
|
||||
// Verify file_tools_server was not added (due to inspection failure)
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should set initialized status after completion', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,292 +0,0 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import * as t from '~/mcp/types';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache';
|
||||
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
|
||||
import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
// Mock external dependencies
|
||||
jest.mock('../../MCPConnectionFactory');
|
||||
jest.mock('../../connection');
|
||||
jest.mock('../../registry/MCPServerInspector');
|
||||
jest.mock('~/cluster', () => ({
|
||||
isLeader: jest.fn().mockResolvedValue(true),
|
||||
}));
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
const mockInspect = MCPServerInspector.inspect as jest.MockedFunction<
|
||||
typeof MCPServerInspector.inspect
|
||||
>;
|
||||
|
||||
describe('MCPServersInitializer', () => {
|
||||
let mockConnection: jest.Mocked<MCPConnection>;
|
||||
|
||||
const testConfigs: t.MCPServers = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
},
|
||||
};
|
||||
|
||||
const testParsedConfigs: Record<string, t.ParsedServerConfig> = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
requiresOAuth: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
requiresOAuth: true,
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for search_tools_server',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
tools: 'search',
|
||||
toolFunctions: {
|
||||
search_mcp_search_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'search_mcp_search_tools_server',
|
||||
description: 'Search tool',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Setup MCPConnection mock
|
||||
mockConnection = {
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
// Setup MCPConnectionFactory mock
|
||||
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
|
||||
|
||||
// Mock MCPServerInspector.inspect to return parsed config
|
||||
mockInspect.mockImplementation(async (serverName: string) => {
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
// Reset caches before each test
|
||||
await registryStatusCache.reset();
|
||||
await registry.sharedAppServers.reset();
|
||||
await registry.sharedUserServers.reset();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('initialize()', () => {
|
||||
it('should reset registry and status cache before initialization', async () => {
|
||||
// Pre-populate registry with some old servers
|
||||
await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server);
|
||||
await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server);
|
||||
|
||||
// Initialize with new configs (this should reset first)
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify old servers are gone
|
||||
expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined();
|
||||
expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined();
|
||||
|
||||
// Verify new servers are present
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip initialization if already initialized (Redis flag)', async () => {
|
||||
// First initialization
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Second initialization should skip due to Redis cache flag
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(mockInspect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should process all server configs through inspector', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify all configs were processed by inspector (without connection parameter)
|
||||
expect(mockInspect).toHaveBeenCalledTimes(4);
|
||||
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
|
||||
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
|
||||
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
|
||||
expect(mockInspect).toHaveBeenCalledWith(
|
||||
'search_tools_server',
|
||||
testConfigs.search_tools_server,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add disabled servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
expect(disabledServer).toMatchObject({
|
||||
...testParsedConfigs.disabled_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add OAuth servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
expect(oauthServer).toMatchObject({
|
||||
...testParsedConfigs.oauth_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add enabled non-OAuth servers to sharedAppServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeDefined();
|
||||
expect(fileToolsServer).toMatchObject({
|
||||
...testParsedConfigs.file_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
expect(searchToolsServer).toMatchObject({
|
||||
...testParsedConfigs.search_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully initialize all servers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify all servers were added to appropriate registries
|
||||
expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle inspection failures gracefully', async () => {
|
||||
// Mock inspection failure for one server
|
||||
mockInspect.mockImplementation(async (serverName: string) => {
|
||||
if (serverName === 'file_tools_server') {
|
||||
throw new Error('Inspection failed');
|
||||
}
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify other servers were still processed
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
|
||||
// Verify file_tools_server was not added (due to inspection failure)
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should log server configuration after initialization', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify logging occurred for each server
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][disabled_server]'),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('[MCP][oauth_server]'));
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][file_tools_server]'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use Promise.allSettled for parallel server initialization', async () => {
|
||||
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
|
||||
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
|
||||
expect(allSettledSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
allSettledSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should set initialized status after completion', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,227 +0,0 @@
|
||||
import { expect } from '@playwright/test';
|
||||
import type * as t from '~/mcp/types';
|
||||
|
||||
/**
|
||||
* Integration tests for MCPServersRegistry using Redis-backed cache.
|
||||
* For unit tests using in-memory cache, see MCPServersRegistry.test.ts
|
||||
*/
|
||||
describe('MCPServersRegistry Redis Integration Tests', () => {
|
||||
let registry: typeof import('../MCPServersRegistry').mcpServersRegistry;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
|
||||
|
||||
const testParsedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'MCPServersRegistry-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const registryModule = await import('../MCPServersRegistry');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
|
||||
registry = registryModule.mcpServersRegistry;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Become leader so we can perform write operations
|
||||
leaderInstance = new LeaderElection();
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: reset registry to clear all test data
|
||||
await registry.reset();
|
||||
|
||||
// Also clean up any remaining test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*MCPServersRegistry-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Resign as leader
|
||||
if (leaderInstance) await leaderInstance.resign();
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('private user servers', () => {
|
||||
it('should add and remove private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Verify server was added
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(testParsedConfig);
|
||||
|
||||
// Remove private user server
|
||||
await registry.removePrivateUserServer(userId, serverName);
|
||||
|
||||
// Verify server was removed
|
||||
const configAfterRemoval = await registry.getServerConfig(serverName, userId);
|
||||
expect(configAfterRemoval).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
await expect(
|
||||
registry.addPrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should update an existing private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
const updatedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'python',
|
||||
args: ['updated.py'],
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Update the server config
|
||||
await registry.updatePrivateUserServer(userId, serverName, updatedConfig);
|
||||
|
||||
// Verify server was updated
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(updatedConfig);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add a user cache first
|
||||
await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig);
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when updating server for non-existent user', async () => {
|
||||
const userId = 'nonexistent_user';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow('No private servers found for user "nonexistent_user".');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllServerConfigs', () => {
|
||||
it('should return correct servers based on userId', async () => {
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig);
|
||||
|
||||
// Without userId: should return only shared app + shared user servers
|
||||
const configsNoUser = await registry.getAllServerConfigs();
|
||||
expect(Object.keys(configsNoUser)).toHaveLength(2);
|
||||
expect(configsNoUser).toHaveProperty('app_server');
|
||||
expect(configsNoUser).toHaveProperty('user_server');
|
||||
|
||||
// With userId 'abc': should return shared app + shared user + abc's private servers
|
||||
const configsAbc = await registry.getAllServerConfigs('abc');
|
||||
expect(Object.keys(configsAbc)).toHaveLength(3);
|
||||
expect(configsAbc).toHaveProperty('app_server');
|
||||
expect(configsAbc).toHaveProperty('user_server');
|
||||
expect(configsAbc).toHaveProperty('abc_private_server');
|
||||
|
||||
// With userId 'xyz': should return shared app + shared user + xyz's private servers
|
||||
const configsXyz = await registry.getAllServerConfigs('xyz');
|
||||
expect(Object.keys(configsXyz)).toHaveLength(3);
|
||||
expect(configsXyz).toHaveProperty('app_server');
|
||||
expect(configsXyz).toHaveProperty('user_server');
|
||||
expect(configsXyz).toHaveProperty('xyz_private_server');
|
||||
});
|
||||
});
|
||||
|
||||
describe('reset', () => {
|
||||
it('should clear all servers from all caches (shared app, shared user, and private user)', async () => {
|
||||
const userId = 'user123';
|
||||
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig);
|
||||
|
||||
// Verify all servers are accessible before reset
|
||||
const appConfigBefore = await registry.getServerConfig('app_server');
|
||||
const userConfigBefore = await registry.getServerConfig('user_server');
|
||||
const privateConfigBefore = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsBefore = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigBefore).toEqual(testParsedConfig);
|
||||
expect(userConfigBefore).toEqual(testParsedConfig);
|
||||
expect(privateConfigBefore).toEqual(testParsedConfig);
|
||||
expect(Object.keys(allConfigsBefore)).toHaveLength(3);
|
||||
|
||||
// Reset everything
|
||||
await registry.reset();
|
||||
|
||||
// Verify all servers are cleared after reset
|
||||
const appConfigAfter = await registry.getServerConfig('app_server');
|
||||
const userConfigAfter = await registry.getServerConfig('user_server');
|
||||
const privateConfigAfter = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsAfter = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigAfter).toBeUndefined();
|
||||
expect(userConfigAfter).toBeUndefined();
|
||||
expect(privateConfigAfter).toBeUndefined();
|
||||
expect(Object.keys(allConfigsAfter)).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,175 +0,0 @@
|
||||
import * as t from '~/mcp/types';
|
||||
import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
/**
|
||||
* Unit tests for MCPServersRegistry using in-memory cache.
|
||||
* For integration tests using Redis-backed cache, see MCPServersRegistry.cache_integration.spec.ts
|
||||
*/
|
||||
describe('MCPServersRegistry', () => {
|
||||
const testParsedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
await registry.reset();
|
||||
});
|
||||
|
||||
describe('private user servers', () => {
|
||||
it('should add and remove private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Verify server was added
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(testParsedConfig);
|
||||
|
||||
// Remove private user server
|
||||
await registry.removePrivateUserServer(userId, serverName);
|
||||
|
||||
// Verify server was removed
|
||||
const configAfterRemoval = await registry.getServerConfig(serverName, userId);
|
||||
expect(configAfterRemoval).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
await expect(
|
||||
registry.addPrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should update an existing private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
const updatedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'python',
|
||||
args: ['updated.py'],
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Update the server config
|
||||
await registry.updatePrivateUserServer(userId, serverName, updatedConfig);
|
||||
|
||||
// Verify server was updated
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(updatedConfig);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add a user cache first
|
||||
await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig);
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when updating server for non-existent user', async () => {
|
||||
const userId = 'nonexistent_user';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow('No private servers found for user "nonexistent_user".');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllServerConfigs', () => {
|
||||
it('should return correct servers based on userId', async () => {
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig);
|
||||
|
||||
// Without userId: should return only shared app + shared user servers
|
||||
const configsNoUser = await registry.getAllServerConfigs();
|
||||
expect(Object.keys(configsNoUser)).toHaveLength(2);
|
||||
expect(configsNoUser).toHaveProperty('app_server');
|
||||
expect(configsNoUser).toHaveProperty('user_server');
|
||||
|
||||
// With userId 'abc': should return shared app + shared user + abc's private servers
|
||||
const configsAbc = await registry.getAllServerConfigs('abc');
|
||||
expect(Object.keys(configsAbc)).toHaveLength(3);
|
||||
expect(configsAbc).toHaveProperty('app_server');
|
||||
expect(configsAbc).toHaveProperty('user_server');
|
||||
expect(configsAbc).toHaveProperty('abc_private_server');
|
||||
|
||||
// With userId 'xyz': should return shared app + shared user + xyz's private servers
|
||||
const configsXyz = await registry.getAllServerConfigs('xyz');
|
||||
expect(Object.keys(configsXyz)).toHaveLength(3);
|
||||
expect(configsXyz).toHaveProperty('app_server');
|
||||
expect(configsXyz).toHaveProperty('user_server');
|
||||
expect(configsXyz).toHaveProperty('xyz_private_server');
|
||||
});
|
||||
});
|
||||
|
||||
describe('reset', () => {
|
||||
it('should clear all servers from all caches (shared app, shared user, and private user)', async () => {
|
||||
const userId = 'user123';
|
||||
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig);
|
||||
|
||||
// Verify all servers are accessible before reset
|
||||
const appConfigBefore = await registry.getServerConfig('app_server');
|
||||
const userConfigBefore = await registry.getServerConfig('user_server');
|
||||
const privateConfigBefore = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsBefore = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigBefore).toEqual(testParsedConfig);
|
||||
expect(userConfigBefore).toEqual(testParsedConfig);
|
||||
expect(privateConfigBefore).toEqual(testParsedConfig);
|
||||
expect(Object.keys(allConfigsBefore)).toHaveLength(3);
|
||||
|
||||
// Reset everything
|
||||
await registry.reset();
|
||||
|
||||
// Verify all servers are cleared after reset
|
||||
const appConfigAfter = await registry.getServerConfig('app_server');
|
||||
const userConfigAfter = await registry.getServerConfig('user_server');
|
||||
const privateConfigAfter = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsAfter = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigAfter).toBeUndefined();
|
||||
expect(userConfigAfter).toBeUndefined();
|
||||
expect(privateConfigAfter).toBeUndefined();
|
||||
expect(Object.keys(allConfigsAfter)).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,55 +0,0 @@
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
|
||||
/**
|
||||
* Creates a single mock MCP connection for testing.
|
||||
* The connection has a client with mocked methods that return server-specific data.
|
||||
* @param serverName - Name of the server to create mock connection for
|
||||
* @returns Mocked MCPConnection instance
|
||||
*/
|
||||
export function createMockConnection(serverName: string): jest.Mocked<MCPConnection> {
|
||||
const mockClient = {
|
||||
getInstructions: jest.fn().mockReturnValue(`instructions for ${serverName}`),
|
||||
getServerCapabilities: jest.fn().mockReturnValue({
|
||||
tools: { listChanged: true },
|
||||
resources: { listChanged: true },
|
||||
prompts: { get: `getPrompts for ${serverName}` },
|
||||
}),
|
||||
listTools: jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'listFiles',
|
||||
description: `Description for ${serverName}'s listFiles tool`,
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
||||
return {
|
||||
client: mockClient,
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates mock MCP connections for testing.
|
||||
* Each connection has a client with mocked methods that return server-specific data.
|
||||
* @param serverNames - Array of server names to create mock connections for
|
||||
* @returns Map of server names to mocked MCPConnection instances
|
||||
*/
|
||||
export function createMockConnectionsMap(
|
||||
serverNames: string[],
|
||||
): Map<string, jest.Mocked<MCPConnection>> {
|
||||
const mockConnections = new Map<string, jest.Mocked<MCPConnection>>();
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
mockConnections.set(serverName, createMockConnection(serverName));
|
||||
});
|
||||
|
||||
return mockConnections;
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
import type Keyv from 'keyv';
|
||||
import { isLeader } from '~/cluster';
|
||||
|
||||
/**
|
||||
* Base class for MCP registry caches that require distributed leader coordination.
|
||||
* Provides helper methods for leader-only operations and success validation.
|
||||
* All concrete implementations must provide their own Keyv cache instance.
|
||||
*/
|
||||
export abstract class BaseRegistryCache {
|
||||
protected readonly PREFIX = 'MCP::ServersRegistry';
|
||||
protected abstract readonly cache: Keyv;
|
||||
|
||||
protected async leaderCheck(action: string): Promise<void> {
|
||||
if (!(await isLeader())) throw new Error(`Only leader can ${action}.`);
|
||||
}
|
||||
|
||||
protected successCheck(action: string, success: boolean): true {
|
||||
if (!success) throw new Error(`Failed to ${action} in cache.`);
|
||||
return true;
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
await this.leaderCheck(`reset ${this.cache.namespace} cache`);
|
||||
await this.cache.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
import { standardCache } from '~/cache';
|
||||
import { BaseRegistryCache } from './BaseRegistryCache';
|
||||
|
||||
// Status keys
|
||||
const INITIALIZED = 'INITIALIZED';
|
||||
|
||||
/**
|
||||
* Cache for tracking MCP Servers Registry metadata and status across distributed instances.
|
||||
* Uses Redis-backed storage to coordinate state between leader and follower nodes.
|
||||
* Currently, tracks initialization status to ensure only the leader performs initialization
|
||||
* while followers wait for completion. Designed to be extended with additional registry
|
||||
* metadata as needed (e.g., last update timestamps, version info, health status).
|
||||
* This cache is only meant to be used internally by registry management components.
|
||||
*/
|
||||
class RegistryStatusCache extends BaseRegistryCache {
|
||||
protected readonly cache = standardCache(`${this.PREFIX}::Status`);
|
||||
|
||||
public async isInitialized(): Promise<boolean> {
|
||||
return (await this.get(INITIALIZED)) === true;
|
||||
}
|
||||
|
||||
public async setInitialized(value: boolean): Promise<void> {
|
||||
await this.set(INITIALIZED, value);
|
||||
}
|
||||
|
||||
private async get<T = unknown>(key: string): Promise<T | undefined> {
|
||||
return this.cache.get(key);
|
||||
}
|
||||
|
||||
private async set(key: string, value: string | number | boolean, ttl?: number): Promise<void> {
|
||||
await this.leaderCheck('set MCP Servers Registry status');
|
||||
const success = await this.cache.set(key, value, ttl);
|
||||
this.successCheck(`set status key "${key}"`, success);
|
||||
}
|
||||
}
|
||||
|
||||
export const registryStatusCache = new RegistryStatusCache();
|
||||
@@ -1,31 +0,0 @@
|
||||
import { cacheConfig } from '~/cache';
|
||||
import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory';
|
||||
import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis';
|
||||
|
||||
export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis;
|
||||
|
||||
/**
|
||||
* Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode.
|
||||
* Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config.
|
||||
* In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache.
|
||||
* In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination.
|
||||
* Provides a unified interface regardless of the underlying storage mechanism.
|
||||
*/
|
||||
export class ServerConfigsCacheFactory {
|
||||
/**
|
||||
* Create a ServerConfigsCache instance.
|
||||
* Returns Redis implementation if Redis is configured, otherwise in-memory implementation.
|
||||
*
|
||||
* @param owner - The owner of the cache (e.g., 'user', 'global') - only used for Redis namespacing
|
||||
* @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis)
|
||||
* @returns ServerConfigsCache instance
|
||||
*/
|
||||
static create(owner: string, leaderOnly: boolean): ServerConfigsCache {
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
return new ServerConfigsCacheRedis(owner, leaderOnly);
|
||||
}
|
||||
|
||||
// In-memory mode uses a simple Map - doesn't need owner/namespace
|
||||
return new ServerConfigsCacheInMemory();
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
|
||||
/**
|
||||
* In-memory implementation of MCP server configurations cache for single-instance deployments.
|
||||
* Uses a native JavaScript Map for fast, local storage without Redis dependencies.
|
||||
* Suitable for development environments or single-server production deployments.
|
||||
* Does not require leader checks or distributed coordination since data is instance-local.
|
||||
* Data is lost on server restart and not shared across multiple server instances.
|
||||
*/
|
||||
export class ServerConfigsCacheInMemory {
|
||||
private readonly cache: Map<string, ParsedServerConfig> = new Map();
|
||||
|
||||
public async add(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (this.cache.has(serverName))
|
||||
throw new Error(
|
||||
`Server "${serverName}" already exists in cache. Use update() to modify existing configs.`,
|
||||
);
|
||||
this.cache.set(serverName, config);
|
||||
}
|
||||
|
||||
public async update(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (!this.cache.has(serverName))
|
||||
throw new Error(
|
||||
`Server "${serverName}" does not exist in cache. Use add() to create new configs.`,
|
||||
);
|
||||
this.cache.set(serverName, config);
|
||||
}
|
||||
|
||||
public async remove(serverName: string): Promise<void> {
|
||||
if (!this.cache.delete(serverName)) {
|
||||
throw new Error(`Failed to remove server "${serverName}" in cache.`);
|
||||
}
|
||||
}
|
||||
|
||||
public async get(serverName: string): Promise<ParsedServerConfig | undefined> {
|
||||
return this.cache.get(serverName);
|
||||
}
|
||||
|
||||
public async getAll(): Promise<Record<string, ParsedServerConfig>> {
|
||||
return Object.fromEntries(this.cache);
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
this.cache.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
import type Keyv from 'keyv';
|
||||
import { fromPairs } from 'lodash';
|
||||
import { standardCache, keyvRedisClient } from '~/cache';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
import { BaseRegistryCache } from './BaseRegistryCache';
|
||||
|
||||
/**
|
||||
* Redis-backed implementation of MCP server configurations cache for distributed deployments.
|
||||
* Stores server configs in Redis with namespace isolation by owner (App, User, or specific user ID).
|
||||
* Enables data sharing across multiple server instances in a cluster environment.
|
||||
* Supports optional leader-only write operations to prevent race conditions during initialization.
|
||||
* Data persists across server restarts and is accessible from any instance in the cluster.
|
||||
*/
|
||||
export class ServerConfigsCacheRedis extends BaseRegistryCache {
|
||||
protected readonly cache: Keyv;
|
||||
private readonly owner: string;
|
||||
private readonly leaderOnly: boolean;
|
||||
|
||||
constructor(owner: string, leaderOnly: boolean) {
|
||||
super();
|
||||
this.owner = owner;
|
||||
this.leaderOnly = leaderOnly;
|
||||
this.cache = standardCache(`${this.PREFIX}::Servers::${owner}`);
|
||||
}
|
||||
|
||||
public async add(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (this.leaderOnly) await this.leaderCheck(`add ${this.owner} MCP servers`);
|
||||
const exists = await this.cache.has(serverName);
|
||||
if (exists)
|
||||
throw new Error(
|
||||
`Server "${serverName}" already exists in cache. Use update() to modify existing configs.`,
|
||||
);
|
||||
const success = await this.cache.set(serverName, config);
|
||||
this.successCheck(`add ${this.owner} server "${serverName}"`, success);
|
||||
}
|
||||
|
||||
public async update(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (this.leaderOnly) await this.leaderCheck(`update ${this.owner} MCP servers`);
|
||||
const exists = await this.cache.has(serverName);
|
||||
if (!exists)
|
||||
throw new Error(
|
||||
`Server "${serverName}" does not exist in cache. Use add() to create new configs.`,
|
||||
);
|
||||
const success = await this.cache.set(serverName, config);
|
||||
this.successCheck(`update ${this.owner} server "${serverName}"`, success);
|
||||
}
|
||||
|
||||
public async remove(serverName: string): Promise<void> {
|
||||
if (this.leaderOnly) await this.leaderCheck(`remove ${this.owner} MCP servers`);
|
||||
const success = await this.cache.delete(serverName);
|
||||
this.successCheck(`remove ${this.owner} server "${serverName}"`, success);
|
||||
}
|
||||
|
||||
public async get(serverName: string): Promise<ParsedServerConfig | undefined> {
|
||||
return this.cache.get(serverName);
|
||||
}
|
||||
|
||||
public async getAll(): Promise<Record<string, ParsedServerConfig>> {
|
||||
// Use Redis SCAN iterator directly (non-blocking, production-ready)
|
||||
// Note: Keyv uses a single colon ':' between namespace and key, even if GLOBAL_PREFIX_SEPARATOR is '::'
|
||||
const pattern = `*${this.cache.namespace}:*`;
|
||||
const entries: Array<[string, ParsedServerConfig]> = [];
|
||||
|
||||
// Use scanIterator from Redis client
|
||||
if (keyvRedisClient && 'scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
// Extract the actual key name (last part after final colon)
|
||||
// Full key format: "prefix::namespace:keyName"
|
||||
const lastColonIndex = key.lastIndexOf(':');
|
||||
const keyName = key.substring(lastColonIndex + 1);
|
||||
const value = await this.cache.get(keyName);
|
||||
if (value) {
|
||||
entries.push([keyName, value as ParsedServerConfig]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fromPairs(entries);
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
import { expect } from '@playwright/test';
|
||||
|
||||
describe('RegistryStatusCache Integration Tests', () => {
|
||||
let registryStatusCache: typeof import('../RegistryStatusCache').registryStatusCache;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'RegistryStatusCache-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const statusCacheModule = await import('../RegistryStatusCache');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
|
||||
registryStatusCache = statusCacheModule.registryStatusCache;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Become leader so we can perform write operations
|
||||
leaderInstance = new LeaderElection();
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: clear all test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*RegistryStatusCache-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Resign as leader
|
||||
if (leaderInstance) await leaderInstance.resign();
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('Initialization status tracking', () => {
|
||||
it('should return false for isInitialized when not set', async () => {
|
||||
const initialized = await registryStatusCache.isInitialized();
|
||||
expect(initialized).toBe(false);
|
||||
});
|
||||
|
||||
it('should set and get initialized status', async () => {
|
||||
await registryStatusCache.setInitialized(true);
|
||||
const initialized = await registryStatusCache.isInitialized();
|
||||
expect(initialized).toBe(true);
|
||||
|
||||
await registryStatusCache.setInitialized(false);
|
||||
const uninitialized = await registryStatusCache.isInitialized();
|
||||
expect(uninitialized).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,70 +0,0 @@
|
||||
import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory';
|
||||
import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory';
|
||||
import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis';
|
||||
import { cacheConfig } from '~/cache';
|
||||
|
||||
// Mock the cache implementations
|
||||
jest.mock('../ServerConfigsCacheInMemory');
|
||||
jest.mock('../ServerConfigsCacheRedis');
|
||||
|
||||
// Mock the cache config module
|
||||
jest.mock('~/cache', () => ({
|
||||
cacheConfig: {
|
||||
USE_REDIS: false,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ServerConfigsCacheFactory', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('create()', () => {
|
||||
it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Act
|
||||
const cache = ServerConfigsCacheFactory.create('TestOwner', true);
|
||||
|
||||
// Assert
|
||||
expect(cache).toBeInstanceOf(ServerConfigsCacheRedis);
|
||||
expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('TestOwner', true);
|
||||
});
|
||||
|
||||
it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = false;
|
||||
|
||||
// Act
|
||||
const cache = ServerConfigsCacheFactory.create('TestOwner', false);
|
||||
|
||||
// Assert
|
||||
expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory);
|
||||
expect(ServerConfigsCacheInMemory).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass correct parameters to ServerConfigsCacheRedis', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Act
|
||||
ServerConfigsCacheFactory.create('App', true);
|
||||
|
||||
// Assert
|
||||
expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true);
|
||||
});
|
||||
|
||||
it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = false;
|
||||
|
||||
// Act
|
||||
ServerConfigsCacheFactory.create('User', false);
|
||||
|
||||
// Assert
|
||||
// In-memory cache doesn't use owner/leaderOnly parameters
|
||||
expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,173 +0,0 @@
|
||||
import { expect } from '@playwright/test';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
|
||||
describe('ServerConfigsCacheInMemory Integration Tests', () => {
|
||||
let ServerConfigsCacheInMemory: typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory;
|
||||
let cache: InstanceType<
|
||||
typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory
|
||||
>;
|
||||
|
||||
// Test data
|
||||
const mockConfig1: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server1.js'],
|
||||
env: { TEST: 'value1' },
|
||||
};
|
||||
|
||||
const mockConfig2: ParsedServerConfig = {
|
||||
command: 'python',
|
||||
args: ['server2.py'],
|
||||
env: { TEST: 'value2' },
|
||||
};
|
||||
|
||||
const mockConfig3: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server3.js'],
|
||||
url: 'http://localhost:3000',
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Import modules
|
||||
const cacheModule = await import('../ServerConfigsCacheInMemory');
|
||||
ServerConfigsCacheInMemory = cacheModule.ServerConfigsCacheInMemory;
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a fresh instance for each test
|
||||
cache = new ServerConfigsCacheInMemory();
|
||||
});
|
||||
|
||||
describe('add and get operations', () => {
|
||||
it('should add and retrieve a server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig1);
|
||||
});
|
||||
|
||||
it('should return undefined for non-existent server', async () => {
|
||||
const result = await cache.get('non-existent');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await expect(cache.add('server1', mockConfig2)).rejects.toThrow(
|
||||
'Server "server1" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result1 = await cache.get('server1');
|
||||
const result2 = await cache.get('server2');
|
||||
const result3 = await cache.get('server3');
|
||||
|
||||
expect(result1).toEqual(mockConfig1);
|
||||
expect(result2).toEqual(mockConfig2);
|
||||
expect(result3).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAll operation', () => {
|
||||
it('should return empty object when no servers exist', async () => {
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should return all server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({
|
||||
server1: mockConfig1,
|
||||
server2: mockConfig2,
|
||||
server3: mockConfig3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.add('server3', mockConfig3);
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(3);
|
||||
expect(result.server3).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('update operation', () => {
|
||||
it('should update an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.update('server1', mockConfig2);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow(
|
||||
'Server "non-existent" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
await cache.update('server1', mockConfig3);
|
||||
const result = await cache.getAll();
|
||||
expect(result.server1).toEqual(mockConfig3);
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('remove operation', () => {
|
||||
it('should remove an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.remove('server1');
|
||||
expect(await cache.get('server1')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when removing non-existent server', async () => {
|
||||
await expect(cache.remove('non-existent')).rejects.toThrow(
|
||||
'Failed to remove server "non-existent" in cache.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove server from getAll results', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.remove('server1');
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(1);
|
||||
expect(result.server1).toBeUndefined();
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should allow re-adding a removed server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.remove('server1');
|
||||
await cache.add('server1', mockConfig3);
|
||||
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,278 +0,0 @@
|
||||
import { expect } from '@playwright/test';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
|
||||
describe('ServerConfigsCacheRedis Integration Tests', () => {
|
||||
let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let checkIsLeader: () => Promise<boolean>;
|
||||
let cache: InstanceType<typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis>;
|
||||
|
||||
// Test data
|
||||
const mockConfig1: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server1.js'],
|
||||
env: { TEST: 'value1' },
|
||||
};
|
||||
|
||||
const mockConfig2: ParsedServerConfig = {
|
||||
command: 'python',
|
||||
args: ['server2.py'],
|
||||
env: { TEST: 'value2' },
|
||||
};
|
||||
|
||||
const mockConfig3: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server3.js'],
|
||||
url: 'http://localhost:3000',
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'ServerConfigsCacheRedis-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const cacheModule = await import('../ServerConfigsCacheRedis');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
const clusterModule = await import('~/cluster');
|
||||
|
||||
ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
checkIsLeader = clusterModule.isLeader;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Clear any existing leader key to ensure clean state
|
||||
await keyvRedisClient.del(LeaderElection.LEADER_KEY);
|
||||
|
||||
// Become leader so we can perform write operations (using default election instance)
|
||||
const isLeader = await checkIsLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a fresh instance for each test with leaderOnly=true
|
||||
cache = new ServerConfigsCacheRedis('test-user', true);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: clear all test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*ServerConfigsCacheRedis-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Clear leader key to allow other tests to become leader
|
||||
if (keyvRedisClient) await keyvRedisClient.del(LeaderElection.LEADER_KEY);
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('add and get operations', () => {
|
||||
it('should add and retrieve a server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig1);
|
||||
});
|
||||
|
||||
it('should return undefined for non-existent server', async () => {
|
||||
const result = await cache.get('non-existent');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await expect(cache.add('server1', mockConfig2)).rejects.toThrow(
|
||||
'Server "server1" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result1 = await cache.get('server1');
|
||||
const result2 = await cache.get('server2');
|
||||
const result3 = await cache.get('server3');
|
||||
|
||||
expect(result1).toEqual(mockConfig1);
|
||||
expect(result2).toEqual(mockConfig2);
|
||||
expect(result3).toEqual(mockConfig3);
|
||||
});
|
||||
|
||||
it('should isolate caches by owner namespace', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await globalCache.add('server1', mockConfig2);
|
||||
|
||||
const userResult = await userCache.get('server1');
|
||||
const globalResult = await globalCache.get('server1');
|
||||
|
||||
expect(userResult).toEqual(mockConfig1);
|
||||
expect(globalResult).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAll operation', () => {
|
||||
it('should return empty object when no servers exist', async () => {
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should return all server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({
|
||||
server1: mockConfig1,
|
||||
server2: mockConfig2,
|
||||
server3: mockConfig3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.add('server3', mockConfig3);
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(3);
|
||||
expect(result.server3).toEqual(mockConfig3);
|
||||
});
|
||||
|
||||
it('should only return configs for the specific owner', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await userCache.add('server2', mockConfig2);
|
||||
await globalCache.add('server3', mockConfig3);
|
||||
|
||||
const userResult = await userCache.getAll();
|
||||
const globalResult = await globalCache.getAll();
|
||||
|
||||
expect(Object.keys(userResult).length).toBe(2);
|
||||
expect(Object.keys(globalResult).length).toBe(1);
|
||||
expect(userResult.server1).toEqual(mockConfig1);
|
||||
expect(userResult.server3).toBeUndefined();
|
||||
expect(globalResult.server3).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('update operation', () => {
|
||||
it('should update an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.update('server1', mockConfig2);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow(
|
||||
'Server "non-existent" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
await cache.update('server1', mockConfig3);
|
||||
const result = await cache.getAll();
|
||||
expect(result.server1).toEqual(mockConfig3);
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should only update in the specific owner namespace', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await globalCache.add('server1', mockConfig2);
|
||||
|
||||
await userCache.update('server1', mockConfig3);
|
||||
|
||||
expect(await userCache.get('server1')).toEqual(mockConfig3);
|
||||
expect(await globalCache.get('server1')).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('remove operation', () => {
|
||||
it('should remove an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.remove('server1');
|
||||
expect(await cache.get('server1')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when removing non-existent server', async () => {
|
||||
await expect(cache.remove('non-existent')).rejects.toThrow(
|
||||
'Failed to remove test-user server "non-existent"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove server from getAll results', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.remove('server1');
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(1);
|
||||
expect(result.server1).toBeUndefined();
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should allow re-adding a removed server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.remove('server1');
|
||||
await cache.add('server1', mockConfig3);
|
||||
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig3);
|
||||
});
|
||||
|
||||
it('should only remove from the specific owner namespace', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await globalCache.add('server1', mockConfig2);
|
||||
|
||||
await userCache.remove('server1');
|
||||
|
||||
expect(await userCache.get('server1')).toBeUndefined();
|
||||
expect(await globalCache.get('server1')).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -151,8 +151,6 @@ export type ParsedServerConfig = MCPOptions & {
|
||||
oauthMetadata?: Record<string, unknown> | null;
|
||||
capabilities?: string;
|
||||
tools?: string;
|
||||
toolFunctions?: LCAvailableTools;
|
||||
initDuration?: number;
|
||||
};
|
||||
|
||||
export interface BasicConnectionOptions {
|
||||
|
||||
@@ -25,12 +25,6 @@ export const genAzureEndpoint = ({
|
||||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName: string;
|
||||
}): string => {
|
||||
// Support both old (.openai.azure.com) and new (.cognitiveservices.azure.com) endpoint formats
|
||||
// If instanceName already includes a full domain, use it as-is
|
||||
if (azureOpenAIApiInstanceName.includes('.azure.com')) {
|
||||
return `https://${azureOpenAIApiInstanceName}/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
}
|
||||
// Legacy format for backward compatibility
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
};
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ export * from './key';
|
||||
export * from './llm';
|
||||
export * from './math';
|
||||
export * from './openid';
|
||||
export * from './promise';
|
||||
export * from './sanitizeTitle';
|
||||
export * from './tempChatRetention';
|
||||
export * from './text';
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import { withTimeout } from './promise';
|
||||
|
||||
describe('withTimeout', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllTimers();
|
||||
});
|
||||
|
||||
it('should resolve when promise completes before timeout', async () => {
|
||||
const promise = Promise.resolve('success');
|
||||
const result = await withTimeout(promise, 1000);
|
||||
expect(result).toBe('success');
|
||||
});
|
||||
|
||||
it('should reject when promise rejects before timeout', async () => {
|
||||
const promise = Promise.reject(new Error('test error'));
|
||||
await expect(withTimeout(promise, 1000)).rejects.toThrow('test error');
|
||||
});
|
||||
|
||||
it('should timeout when promise takes too long', async () => {
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
await expect(withTimeout(promise, 100, 'Custom timeout message')).rejects.toThrow(
|
||||
'Custom timeout message',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default error message when none provided', async () => {
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
await expect(withTimeout(promise, 100)).rejects.toThrow('Operation timed out after 100ms');
|
||||
});
|
||||
|
||||
it('should clear timeout when promise resolves', async () => {
|
||||
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout');
|
||||
const promise = Promise.resolve('fast');
|
||||
|
||||
await withTimeout(promise, 1000);
|
||||
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
clearTimeoutSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should clear timeout when promise rejects', async () => {
|
||||
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout');
|
||||
const promise = Promise.reject(new Error('fail'));
|
||||
|
||||
await expect(withTimeout(promise, 1000)).rejects.toThrow('fail');
|
||||
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
clearTimeoutSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle multiple concurrent timeouts', async () => {
|
||||
const promise1 = Promise.resolve('first');
|
||||
const promise2 = new Promise((resolve) => setTimeout(() => resolve('second'), 50));
|
||||
const promise3 = new Promise((resolve) => setTimeout(() => resolve('third'), 2000));
|
||||
|
||||
const [result1, result2] = await Promise.all([
|
||||
withTimeout(promise1, 1000),
|
||||
withTimeout(promise2, 1000),
|
||||
]);
|
||||
|
||||
expect(result1).toBe('first');
|
||||
expect(result2).toBe('second');
|
||||
|
||||
await expect(withTimeout(promise3, 100)).rejects.toThrow('Operation timed out after 100ms');
|
||||
});
|
||||
|
||||
it('should work with async functions', async () => {
|
||||
const asyncFunction = async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return 'async result';
|
||||
};
|
||||
|
||||
const result = await withTimeout(asyncFunction(), 1000);
|
||||
expect(result).toBe('async result');
|
||||
});
|
||||
|
||||
it('should work with any return type', async () => {
|
||||
const numberPromise = Promise.resolve(42);
|
||||
const objectPromise = Promise.resolve({ key: 'value' });
|
||||
const arrayPromise = Promise.resolve([1, 2, 3]);
|
||||
|
||||
expect(await withTimeout(numberPromise, 1000)).toBe(42);
|
||||
expect(await withTimeout(objectPromise, 1000)).toEqual({ key: 'value' });
|
||||
expect(await withTimeout(arrayPromise, 1000)).toEqual([1, 2, 3]);
|
||||
});
|
||||
|
||||
it('should call logger when timeout occurs', async () => {
|
||||
const loggerMock = jest.fn();
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
const errorMessage = 'Custom timeout with logger';
|
||||
|
||||
await expect(withTimeout(promise, 100, errorMessage, loggerMock)).rejects.toThrow(errorMessage);
|
||||
|
||||
expect(loggerMock).toHaveBeenCalledTimes(1);
|
||||
expect(loggerMock).toHaveBeenCalledWith(errorMessage, expect.any(Error));
|
||||
});
|
||||
|
||||
it('should not call logger when promise resolves', async () => {
|
||||
const loggerMock = jest.fn();
|
||||
const promise = Promise.resolve('success');
|
||||
|
||||
const result = await withTimeout(promise, 1000, 'Should not timeout', loggerMock);
|
||||
|
||||
expect(result).toBe('success');
|
||||
expect(loggerMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should work without logger parameter', async () => {
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
|
||||
await expect(withTimeout(promise, 100, 'No logger provided')).rejects.toThrow(
|
||||
'No logger provided',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,42 +0,0 @@
|
||||
/**
|
||||
* Wraps a promise with a timeout. If the promise doesn't resolve/reject within
|
||||
* the specified time, it will be rejected with a timeout error.
|
||||
*
|
||||
* @param promise - The promise to wrap with a timeout
|
||||
* @param timeoutMs - Timeout duration in milliseconds
|
||||
* @param errorMessage - Custom error message for timeout (optional)
|
||||
* @param logger - Optional logger function to log timeout errors (e.g., console.warn, logger.warn)
|
||||
* @returns Promise that resolves/rejects with the original promise or times out
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const result = await withTimeout(
|
||||
* fetchData(),
|
||||
* 5000,
|
||||
* 'Failed to fetch data within 5 seconds',
|
||||
* console.warn
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export async function withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
errorMessage?: string,
|
||||
logger?: (message: string, error: Error) => void,
|
||||
): Promise<T> {
|
||||
let timeoutId: NodeJS.Timeout;
|
||||
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(() => {
|
||||
const error = new Error(errorMessage ?? `Operation timed out after ${timeoutMs}ms`);
|
||||
if (logger) logger(error.message, error);
|
||||
reject(error);
|
||||
}, timeoutMs);
|
||||
});
|
||||
|
||||
try {
|
||||
return await Promise.race([promise, timeoutPromise]);
|
||||
} finally {
|
||||
clearTimeout(timeoutId!);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user