Compare commits

..

1 Commits

Author SHA1 Message Date
Danny Avila
6900fa73c6 fix: Add useBackToNewChat Hook for Back Navigation Handling
- Introduced a new hook `useBackToNewChat` to manage state when navigating back to the `/c/new` route.
- The hook listens for browser back/forward navigation events and resets the conversation state appropriately.
- Integrated the hook into `ChatRoute.tsx` to ensure proper functionality during navigation.
- Added documentation detailing the problem, solution, and implementation of the new hook.
2025-08-27 19:09:00 -04:00
441 changed files with 6271 additions and 29982 deletions

View File

@@ -163,10 +163,10 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite,gemini-2.0-flash,gemini-2.0-flash-lite
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
# Vertex AI
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
@@ -690,8 +690,8 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS_PING_INTERVAL=300
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES)
# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES
# Comma-separated list of CacheKeys (e.g., STATIC_CONFIG,ROLES,MESSAGES)
# FORCED_IN_MEMORY_CACHE_NAMESPACES=STATIC_CONFIG,ROLES
#==================================================#
# Others #

View File

@@ -1,4 +1,4 @@
# v0.8.0
# v0.8.0-rc3
# Base node image
FROM node:20-alpine AS node
@@ -30,7 +30,7 @@ RUN \
# Allow mounting of these files, which have no default
touch .env ; \
# Create directories for the volumes to inherit the correct permissions
mkdir -p /app/client/public/images /app/api/logs /app/uploads ; \
mkdir -p /app/client/public/images /app/api/logs ; \
npm config set fetch-retry-maxtimeout 600000 ; \
npm config set fetch-retries 5 ; \
npm config set fetch-retry-mintimeout 15000 ; \
@@ -44,6 +44,8 @@ RUN \
npm prune --production; \
npm cache clean --force
RUN mkdir -p /app/client/public/images /app/api/logs
# Node API setup
EXPOSE 3080
ENV HOST=0.0.0.0

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi
# v0.8.0
# v0.8.0-rc3
# Base for all builds
FROM node:20-alpine AS base-min

View File

@@ -75,7 +75,6 @@
- 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context
- Combines search providers, content scrapers, and result rerankers for optimal results
- **Customizable Jina Reranking**: Configure custom Jina API URLs for reranking services
- **[Learn More →](https://www.librechat.ai/docs/features/web_search)**
- 🪄 **Generative UI with Code Artifacts**:

View File

@@ -10,17 +10,7 @@ const {
validateVisionModel,
} = require('librechat-data-provider');
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
const {
Tokenizer,
createFetch,
matchModelName,
getClaudeHeaders,
getModelMaxTokens,
configureReasoning,
checkPromptCacheSupport,
getModelMaxOutputTokens,
createStreamEventHandlers,
} = require('@librechat/api');
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
const {
truncateText,
formatMessage,
@@ -29,6 +19,12 @@ const {
parseParamFromPrompt,
createContextHandlers,
} = require('./prompts');
const {
getClaudeHeaders,
configureReasoning,
checkPromptCacheSupport,
} = require('~/server/services/Endpoints/anthropic/helpers');
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { sleep } = require('~/server/utils');

View File

@@ -1,5 +1,4 @@
const { google } = require('googleapis');
const { getModelMaxTokens } = require('@librechat/api');
const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
@@ -22,6 +21,7 @@ const {
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const {

View File

@@ -7,9 +7,7 @@ const {
createFetch,
resolveHeaders,
constructAzureURL,
getModelMaxTokens,
genAzureChatCompletion,
getModelMaxOutputTokens,
createStreamEventHandlers,
} = require('@librechat/api');
const {
@@ -33,13 +31,13 @@ const {
titleInstruction,
createContextHandlers,
} = require('./prompts');
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { extractBaseURL } = require('~/utils');
const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient');
const { createLLM } = require('./llm');

View File

@@ -1,5 +1,5 @@
const { getModelMaxTokens } = require('@librechat/api');
const BaseClient = require('../BaseClient');
const { getModelMaxTokens } = require('../../../utils');
class FakeClient extends BaseClient {
constructor(apiKey, options = {}) {

View File

@@ -68,19 +68,18 @@ const primeFiles = async (options) => {
/**
*
* @param {Object} options
* @param {string} options.userId
* @param {ServerRequest} options.req
* @param {Array<{ file_id: string; filename: string }>} options.files
* @param {string} [options.entity_id]
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
* @returns
*/
const createFileSearchTool = async ({ userId, files, entity_id, fileCitations = false }) => {
const createFileSearchTool = async ({ req, files, entity_id }) => {
return tool(
async ({ query }) => {
if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.';
}
const jwtToken = generateShortLivedToken(userId);
const jwtToken = generateShortLivedToken(req.user.id);
if (!jwtToken) {
return 'There was an error authenticating the file search request.';
}
@@ -143,9 +142,9 @@ const createFileSearchTool = async ({ userId, files, entity_id, fileCitations =
const formattedString = formattedResults
.map(
(result, index) =>
`File: ${result.filename}${
fileCitations ? `\nAnchor: \\ue202turn0file${index} (${result.filename})` : ''
}\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${result.content}\n`,
`File: ${result.filename}\nAnchor: \\ue202turn0file${index} (${result.filename})\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${
result.content
}\n`,
)
.join('\n---\n');
@@ -159,14 +158,12 @@ const createFileSearchTool = async ({ userId, files, entity_id, fileCitations =
pageRelevance: result.page ? { [result.page]: 1.0 - result.distance } : {},
}));
return [formattedString, { [Tools.file_search]: { sources, fileCitations } }];
return [formattedString, { [Tools.file_search]: { sources } }];
},
{
name: Tools.file_search,
responseFormat: 'content_and_artifact',
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.${
fileCitations
? `
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.
**CITE FILE SEARCH RESULTS:**
Use anchor markers immediately after statements derived from file content. Reference the filename in your text:
@@ -174,9 +171,7 @@ Use anchor markers immediately after statements derived from file content. Refer
- Page reference: "According to report.docx... \\ue202turn0file1"
- Multi-file: "Multiple sources confirm... \\ue200\\ue202turn0file0\\ue202turn0file1\\ue201"
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`
: ''
}`,
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`,
schema: z.object({
query: z
.string()

View File

@@ -1,21 +1,9 @@
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const {
checkAccess,
createSafeUser,
mcpToolPattern,
loadWebSearchAuth,
} = require('@librechat/api');
const {
Tools,
Constants,
Permissions,
EToolResources,
PermissionTypes,
replaceSpecialVars,
} = require('librechat-data-provider');
const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
const {
availableTools,
manifestToolMap,
@@ -38,8 +26,7 @@ const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSe
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getMCPServerTools } = require('~/server/services/Config');
const { getRoleByName } = require('~/models/Role');
const { getCachedTools } = require('~/server/services/Config');
/**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
@@ -255,6 +242,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */
const toolContextMap = {};
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
const requestedMCPTools = {};
for (const tool of tools) {
@@ -293,29 +281,7 @@ const loadTools = async ({
if (toolContext) {
toolContextMap[tool] = toolContext;
}
/** @type {boolean | undefined} Check if user has FILE_CITATIONS permission */
let fileCitations;
if (fileCitations == null && options.req?.user != null) {
try {
fileCitations = await checkAccess({
user: options.req.user,
permissionType: PermissionTypes.FILE_CITATIONS,
permissions: [Permissions.USE],
getRoleByName,
});
} catch (error) {
logger.error('[handleTools] FILE_CITATIONS permission check failed:', error);
fileCitations = false;
}
}
return createFileSearchTool({
userId: user,
files,
entity_id: agent?.id,
fileCitations,
});
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
};
continue;
} else if (tool === Tools.web_search) {
@@ -344,34 +310,36 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
});
};
continue;
} else if (tool && mcpToolPattern.test(tool)) {
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
if (toolName === Constants.mcp_server) {
/** Placeholder used for UI purposes */
continue;
}
if (serverName && options.req?.config?.mcpConfig?.[serverName] == null) {
logger.warn(
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
);
continue;
}
if (toolName === Constants.mcp_all) {
requestedMCPTools[serverName] = [
{
type: 'all',
const currentMCPGenerator = async (index) =>
createMCPTools({
req: options.req,
res: options.res,
index,
serverName,
},
];
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
});
requestedMCPTools[serverName] = [currentMCPGenerator];
continue;
}
const currentMCPGenerator = async (index) =>
createMCPTool({
index,
req: options.req,
res: options.res,
toolKey: tool,
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
});
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
requestedMCPTools[serverName].push({
type: 'single',
toolKey: tool,
serverName,
});
requestedMCPTools[serverName].push(currentMCPGenerator);
continue;
}
@@ -414,65 +382,24 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
const mcpToolPromises = [];
/** MCP server tools are initialized sequentially by server */
let index = -1;
const failedMCPServers = new Set();
const safeUser = createSafeUser(options.req?.user);
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
index++;
/** @type {LCAvailableTools} */
let availableTools;
for (const config of toolConfigs) {
for (const generator of generators) {
try {
if (failedMCPServers.has(serverName)) {
continue;
}
const mcpParams = {
index,
signal,
user: safeUser,
userMCPAuthMap,
res: options.res,
model: agent?.model ?? model,
serverName: config.serverName,
provider: agent?.provider ?? endpoint,
};
if (config.type === 'all' && toolConfigs.length === 1) {
/** Handle async loading for single 'all' tool config */
if (generator && generators.length === 1) {
mcpToolPromises.push(
createMCPTools(mcpParams).catch((error) => {
generator(index).catch((error) => {
logger.error(`Error loading ${serverName} tools:`, error);
return null;
}),
);
continue;
}
if (!availableTools) {
try {
availableTools = await getMCPServerTools(serverName);
} catch (error) {
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
}
}
/** Handle synchronous loading */
const mcpTool =
config.type === 'all'
? await createMCPTools(mcpParams)
: await createMCPTool({
...mcpParams,
availableTools,
toolKey: config.toolKey,
});
const mcpTool = await generator(index);
if (Array.isArray(mcpTool)) {
loadedTools.push(...mcpTool);
} else if (mcpTool) {
loadedTools.push(mcpTool);
} else {
failedMCPServers.add(serverName);
logger.warn(
`MCP tool creation failed for "${config.toolKey}", server may be unavailable or unauthenticated.`,
);
}
} catch (error) {
logger.error(`Error loading MCP tool for server ${serverName}:`, error);

View File

@@ -1,5 +1,4 @@
const fs = require('fs');
const { logger } = require('@librechat/data-schemas');
const { math, isEnabled } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
@@ -35,35 +34,13 @@ if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {
}
}
/** Helper function to safely read Redis CA certificate from file
* @returns {string|null} The contents of the CA certificate file, or null if not set or on error
*/
const getRedisCA = () => {
const caPath = process.env.REDIS_CA;
if (!caPath) {
return null;
}
try {
if (fs.existsSync(caPath)) {
return fs.readFileSync(caPath, 'utf8');
} else {
logger.warn(`Redis CA certificate file not found: ${caPath}`);
return null;
}
} catch (error) {
logger.error(`Failed to read Redis CA certificate file '${caPath}':`, error);
return null;
}
};
const cacheConfig = {
FORCED_IN_MEMORY_CACHE_NAMESPACES,
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: getRedisCA(),
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),

View File

@@ -157,11 +157,12 @@ describe('cacheConfig', () => {
describe('FORCED_IN_MEMORY_CACHE_NAMESPACES validation', () => {
test('should parse comma-separated cache keys correctly', () => {
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, MESSAGES ';
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, STATIC_CONFIG ,MESSAGES ';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([
'ROLES',
'STATIC_CONFIG',
'MESSAGES',
]);
});

View File

@@ -31,8 +31,8 @@ const namespaces = {
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
[CacheKeys.APP_CONFIG]: standardCache(CacheKeys.APP_CONFIG),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),

View File

@@ -1,6 +1,6 @@
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
const logger = require('./winston');
global.EventSource = EventSource;
@@ -26,6 +26,4 @@ module.exports = {
createMCPManager: MCPManager.createInstance,
getMCPManager: MCPManager.getInstance,
getFlowStateManager,
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
};

View File

@@ -1,8 +1,10 @@
const mongoose = require('mongoose');
const { MeiliSearch } = require('meilisearch');
const { logger } = require('@librechat/data-schemas');
const { FlowStateManager } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { isEnabled, FlowStateManager } = require('@librechat/api');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const Conversation = mongoose.models.Conversation;
@@ -29,81 +31,6 @@ class MeiliSearchClient {
}
}
/**
* Ensures indexes have proper filterable attributes configured and checks if documents have user field
* @param {MeiliSearch} client - MeiliSearch client instance
* @returns {Promise<boolean>} - true if configuration was updated or re-sync is needed
*/
async function ensureFilterableAttributes(client) {
try {
// Check and update messages index
try {
const messagesIndex = client.index('messages');
const settings = await messagesIndex.getSettings();
if (!settings.filterableAttributes || !settings.filterableAttributes.includes('user')) {
logger.info('[indexSync] Configuring messages index to filter by user...');
await messagesIndex.updateSettings({
filterableAttributes: ['user'],
});
logger.info('[indexSync] Messages index configured for user filtering');
logger.info('[indexSync] Index configuration updated. Full re-sync will be triggered.');
return true;
}
// Check if existing documents have user field indexed
try {
const searchResult = await messagesIndex.search('', { limit: 1 });
if (searchResult.hits.length > 0 && !searchResult.hits[0].user) {
logger.info('[indexSync] Existing messages missing user field, re-sync needed');
return true;
}
} catch (searchError) {
logger.debug('[indexSync] Could not check message documents:', searchError.message);
}
} catch (error) {
if (error.code !== 'index_not_found') {
logger.warn('[indexSync] Could not check/update messages index settings:', error.message);
}
}
// Check and update conversations index
try {
const convosIndex = client.index('convos');
const settings = await convosIndex.getSettings();
if (!settings.filterableAttributes || !settings.filterableAttributes.includes('user')) {
logger.info('[indexSync] Configuring convos index to filter by user...');
await convosIndex.updateSettings({
filterableAttributes: ['user'],
});
logger.info('[indexSync] Convos index configured for user filtering');
logger.info('[indexSync] Index configuration updated. Full re-sync will be triggered.');
return true;
}
// Check if existing documents have user field indexed
try {
const searchResult = await convosIndex.search('', { limit: 1 });
if (searchResult.hits.length > 0 && !searchResult.hits[0].user) {
logger.info('[indexSync] Existing conversations missing user field, re-sync needed');
return true;
}
} catch (searchError) {
logger.debug('[indexSync] Could not check conversation documents:', searchError.message);
}
} catch (error) {
if (error.code !== 'index_not_found') {
logger.warn('[indexSync] Could not check/update convos index settings:', error.message);
}
}
} catch (error) {
logger.error('[indexSync] Error ensuring filterable attributes:', error);
}
return false;
}
/**
* Performs the actual sync operations for messages and conversations
*/
@@ -120,27 +47,12 @@ async function performSync() {
return { messagesSync: false, convosSync: false };
}
/** Ensures indexes have proper filterable attributes configured */
const configUpdated = await ensureFilterableAttributes(client);
let messagesSync = false;
let convosSync = false;
// If configuration was just updated or documents are missing user field, force a full re-sync
if (configUpdated) {
logger.info('[indexSync] Forcing full re-sync to ensure user field is properly indexed...');
// Reset sync flags to force full re-sync
await Message.collection.updateMany({ _meiliIndex: true }, { $set: { _meiliIndex: false } });
await Conversation.collection.updateMany(
{ _meiliIndex: true },
{ $set: { _meiliIndex: false } },
);
}
// Check if we need to sync messages
const messageProgress = await Message.getSyncProgress();
if (!messageProgress.isComplete || configUpdated) {
if (!messageProgress.isComplete) {
logger.info(
`[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`,
);
@@ -167,7 +79,7 @@ async function performSync() {
// Check if we need to sync conversations
const convoProgress = await Conversation.getSyncProgress();
if (!convoProgress.isComplete || configUpdated) {
if (!convoProgress.isComplete) {
logger.info(
`[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`,
);

View File

@@ -11,7 +11,7 @@ const {
getProjectByName,
} = require('./Project');
const { removeAllPermissions } = require('~/server/services/PermissionService');
const { getMCPServerTools } = require('~/server/services/Config');
const { getCachedTools } = require('~/server/services/Config');
const { getActions } = require('./Action');
const { Agent } = require('~/db/models');
@@ -49,14 +49,6 @@ const createAgent = async (agentData) => {
*/
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
/**
* Get multiple agent documents based on the provided search parameters.
*
* @param {Object} searchParameter - The search parameters to find agents.
* @returns {Promise<Agent[]>} Array of agent documents as plain objects.
*/
const getAgents = async (searchParameter) => await Agent.find(searchParameter).lean();
/**
* Load an agent based on the provided ID
*
@@ -69,6 +61,8 @@ const getAgents = async (searchParameter) => await Agent.find(searchParameter).l
*/
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
const { model, ...model_parameters } = _m;
/** @type {Record<string, FunctionTool>} */
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
@@ -86,18 +80,22 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
const addedServers = new Set();
if (mcpServers.size > 0) {
for (const toolName of Object.keys(availableTools)) {
if (!toolName.includes(mcp_delimiter)) {
continue;
}
const mcpServer = toolName.split(mcp_delimiter)?.[1];
if (mcpServer && mcpServers.has(mcpServer)) {
addedServers.add(mcpServer);
tools.push(toolName);
}
}
for (const mcpServer of mcpServers) {
if (addedServers.has(mcpServer)) {
continue;
}
const serverTools = await getMCPServerTools(mcpServer);
if (!serverTools) {
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
addedServers.add(mcpServer);
continue;
}
tools.push(...Object.keys(serverTools));
addedServers.add(mcpServer);
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
}
}
@@ -837,7 +835,6 @@ const countPromotedAgents = async () => {
module.exports = {
getAgent,
getAgents,
loadAgent,
createAgent,
updateAgent,

View File

@@ -8,7 +8,6 @@ process.env.CREDS_IV = '0123456789abcdef';
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
}));
const mongoose = require('mongoose');
@@ -31,7 +30,7 @@ const {
generateActionMetadataHash,
} = require('./Agent');
const permissionService = require('~/server/services/PermissionService');
const { getCachedTools, getMCPServerTools } = require('~/server/services/Config');
const { getCachedTools } = require('~/server/services/Config');
const { AclEntry } = require('~/db/models');
/**
@@ -1930,16 +1929,6 @@ describe('models/Agent', () => {
another_tool: {},
});
// Mock getMCPServerTools to return tools for each server
getMCPServerTools.mockImplementation(async (server) => {
if (server === 'server1') {
return { tool1_mcp_server1: {} };
} else if (server === 'server2') {
return { tool2_mcp_server2: {} };
}
return null;
});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -2124,14 +2113,6 @@ describe('models/Agent', () => {
getCachedTools.mockResolvedValue(availableTools);
// Mock getMCPServerTools to return all tools for server1
getMCPServerTools.mockImplementation(async (server) => {
if (server === 'server1') {
return availableTools; // All 100 tools belong to server1
}
return null;
});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -2673,17 +2654,6 @@ describe('models/Agent', () => {
tool_mcp_server2: {}, // Different server
});
// Mock getMCPServerTools to return only tools matching the server
getMCPServerTools.mockImplementation(async (server) => {
if (server === 'server1') {
// Only return tool that correctly matches server1 format
return { tool_mcp_server1: {} };
} else if (server === 'server2') {
return { tool_mcp_server2: {} };
}
return null;
});
const mockReq = {
user: { id: 'user123' },
body: {

View File

@@ -174,7 +174,7 @@ module.exports = {
if (search) {
try {
const meiliResults = await Conversation.meiliSearch(search, { filter: `user = "${user}"` });
const meiliResults = await Conversation.meiliSearch(search);
const matchingIds = Array.isArray(meiliResults.hits)
? meiliResults.hits.map((result) => result.conversationId)
: [];

View File

@@ -239,46 +239,10 @@ const updateTagsForConversation = async (user, conversationId, tags) => {
}
};
/**
* Increments tag counts for existing tags only.
* @param {string} user - The user ID.
* @param {string[]} tags - Array of tag names to increment
* @returns {Promise<void>}
*/
const bulkIncrementTagCounts = async (user, tags) => {
if (!tags || tags.length === 0) {
return;
}
try {
const uniqueTags = [...new Set(tags.filter(Boolean))];
if (uniqueTags.length === 0) {
return;
}
const bulkOps = uniqueTags.map((tag) => ({
updateOne: {
filter: { user, tag },
update: { $inc: { count: 1 } },
},
}));
const result = await ConversationTag.bulkWrite(bulkOps);
if (result && result.modifiedCount > 0) {
logger.debug(
`user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`,
);
}
} catch (error) {
logger.error('[bulkIncrementTagCounts] Error incrementing tag counts', error);
}
};
module.exports = {
getConversationTags,
createConversationTag,
updateConversationTag,
deleteConversationTag,
bulkIncrementTagCounts,
updateTagsForConversation,
};

View File

@@ -42,7 +42,7 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => {
$or: [],
};
if (toolResourceSet.has(EToolResources.context)) {
if (toolResourceSet.has(EToolResources.ocr)) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
}
if (toolResourceSet.has(EToolResources.file_search)) {

View File

@@ -211,7 +211,7 @@ describe('File Access Control', () => {
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when user only has VIEW permission and needs access for deletion', async () => {
it('should deny access when user only has VIEW permission', async () => {
const userId = new mongoose.Types.ObjectId();
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
@@ -263,71 +263,12 @@ describe('File Access Control', () => {
role: SystemRoles.USER,
fileIds,
agentId,
isDelete: true,
});
// Should have no access to any files when only VIEW permission
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should grant access when user has VIEW permission', async () => {
const userId = new mongoose.Types.ObjectId();
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create users
await User.create({
_id: userId,
email: 'user@example.com',
emailVerified: true,
provider: 'local',
});
await User.create({
_id: authorId,
email: 'author@example.com',
emailVerified: true,
provider: 'local',
});
// Create agent with files
const agent = await createAgent({
id: agentId,
name: 'View-Only Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Grant only VIEW permission to user on the agent
await grantPermission({
principalType: PrincipalType.USER,
principalId: userId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
// Check access for files
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
const accessMap = await hasAccessToFilesViaAgent({
userId: userId,
role: SystemRoles.USER,
fileIds,
agentId,
});
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
});
});
describe('getFiles with agent access control', () => {

View File

@@ -269,7 +269,7 @@ async function getListPromptGroupsByAccess({
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
// Add cursor condition
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
if (after) {
try {
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
const { updatedAt, _id } = cursor;

View File

@@ -189,15 +189,11 @@ async function createAutoRefillTransaction(txData) {
* @param {txData} _txData - Transaction data.
*/
async function createTransaction(_txData) {
const { balance, transactions, ...txData } = _txData;
const { balance, ...txData } = _txData;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
if (transactions?.enabled === false) {
return;
}
const transaction = new Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;
calculateTokenValue(transaction);
@@ -226,11 +222,7 @@ async function createTransaction(_txData) {
* @param {txData} _txData - Transaction data.
*/
async function createStructuredTransaction(_txData) {
const { balance, transactions, ...txData } = _txData;
if (transactions?.enabled === false) {
return;
}
const { balance, ...txData } = _txData;
const transaction = new Transaction({
...txData,
endpointTokenConfig: txData.endpointTokenConfig,

View File

@@ -1,9 +1,10 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { createTransaction, createStructuredTransaction } = require('./Transaction');
const { Balance, Transaction } = require('~/db/models');
const { createTransaction } = require('./Transaction');
const { Balance } = require('~/db/models');
let mongoServer;
beforeAll(async () => {
@@ -379,188 +380,3 @@ describe('NaN Handling Tests', () => {
expect(balance.tokenCredits).toBe(initialBalance);
});
});
describe('Transactions Config Tests', () => {
test('createTransaction should not save when transactions.enabled is false', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
transactions: { enabled: false },
};
// Act
const result = await createTransaction(txData);
// Assert: No transaction should be created
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(0);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
test('createTransaction should save when transactions.enabled is true', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
transactions: { enabled: true },
balance: { enabled: true },
};
// Act
const result = await createTransaction(txData);
// Assert: Transaction should be created
expect(result).toBeDefined();
expect(result.balance).toBeLessThan(initialBalance);
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
expect(transactions[0].rawAmount).toBe(-100);
});
test('createTransaction should save when balance.enabled is true even if transactions config is missing', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
balance: { enabled: true },
// No transactions config provided
};
// Act
const result = await createTransaction(txData);
// Assert: Transaction should be created (backward compatibility)
expect(result).toBeDefined();
expect(result.balance).toBeLessThan(initialBalance);
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
});
test('createTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
transactions: { enabled: true },
balance: { enabled: false },
};
// Act
const result = await createTransaction(txData);
// Assert: Transaction should be created but balance unchanged
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
expect(transactions[0].rawAmount).toBe(-100);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
test('createStructuredTransaction should not save when transactions.enabled is false', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'claude-3-5-sonnet';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'message',
tokenType: 'prompt',
inputTokens: -10,
writeTokens: -100,
readTokens: -5,
transactions: { enabled: false },
};
// Act
const result = await createStructuredTransaction(txData);
// Assert: No transaction should be created
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(0);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
test('createStructuredTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'claude-3-5-sonnet';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'message',
tokenType: 'prompt',
inputTokens: -10,
writeTokens: -100,
readTokens: -5,
transactions: { enabled: true },
balance: { enabled: false },
};
// Act
const result = await createStructuredTransaction(txData);
// Assert: Transaction should be created but balance unchanged
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
expect(transactions[0].inputTokens).toBe(-10);
expect(transactions[0].writeTokens).toBe(-100);
expect(transactions[0].readTokens).toBe(-5);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
});

View File

@@ -1,4 +1,4 @@
const { matchModelName } = require('@librechat/api');
const { matchModelName } = require('../utils/tokens');
const defaultRate = 6;
/**
@@ -111,8 +111,8 @@ const tokenValues = Object.assign(
'claude-': { prompt: 0.8, completion: 2.4 },
'command-r-plus': { prompt: 3, completion: 15 },
'command-r': { prompt: 0.5, completion: 1.5 },
'deepseek-reasoner': { prompt: 0.28, completion: 0.42 },
deepseek: { prompt: 0.28, completion: 0.42 },
'deepseek-reasoner': { prompt: 0.55, completion: 2.19 },
deepseek: { prompt: 0.14, completion: 0.28 },
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
@@ -124,8 +124,7 @@ const tokenValues = Object.assign(
'gemini-2.0-flash': { prompt: 0.1, completion: 0.4 },
'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing
'gemini-2.5-pro': { prompt: 1.25, completion: 10 },
'gemini-2.5-flash': { prompt: 0.3, completion: 2.5 },
'gemini-2.5-flash-lite': { prompt: 0.075, completion: 0.4 },
'gemini-2.5-flash': { prompt: 0.15, completion: 3.5 },
'gemini-2.5': { prompt: 0, completion: 0 }, // Free for a period of time
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },

View File

@@ -571,9 +571,6 @@ describe('getCacheMultiplier', () => {
describe('Google Model Tests', () => {
const googleModels = [
'gemini-2.5-pro',
'gemini-2.5-flash',
'gemini-2.5-flash-lite',
'gemini-2.5-pro-preview-05-06',
'gemini-2.5-flash-preview-04-17',
'gemini-2.5-exp',
@@ -614,9 +611,6 @@ describe('Google Model Tests', () => {
it('should map to the correct model keys', () => {
const expected = {
'gemini-2.5-pro': 'gemini-2.5-pro',
'gemini-2.5-flash': 'gemini-2.5-flash',
'gemini-2.5-flash-lite': 'gemini-2.5-flash-lite',
'gemini-2.5-pro-preview-05-06': 'gemini-2.5-pro',
'gemini-2.5-flash-preview-04-17': 'gemini-2.5-flash',
'gemini-2.5-exp': 'gemini-2.5',

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "v0.8.0",
"version": "v0.8.0-rc3",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -49,14 +49,14 @@
"@langchain/google-vertexai": "^0.2.13",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.82",
"@librechat/agents": "^2.4.76",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@microsoft/microsoft-graph-client": "^3.0.7",
"@modelcontextprotocol/sdk": "^1.17.1",
"@node-saml/passport-saml": "^5.1.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.12.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
"connect-redis": "^8.1.0",

View File

@@ -1,8 +1,8 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const openIdClient = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { isEnabled, findOpenIDUser } = require('@librechat/api');
const {
requestPasswordReset,
setOpenIDAuthTokens,
@@ -11,9 +11,8 @@ const {
registerUser,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const registrationController = async (req, res) => {
try {
@@ -72,17 +71,11 @@ const refreshController = async (req, res) => {
const openIdConfig = getOpenIdConfig();
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
const claims = tokenset.claims();
const { user, error } = await findOpenIDUser({
findUser,
email: claims.email,
openidId: claims.sub,
idOnTheSource: claims.oid,
strategyName: 'refreshController',
});
if (error || !user) {
const user = await findUser({ email: claims.email });
if (!user) {
return res.status(401).redirect('/login');
}
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
const token = setOpenIDAuthTokens(tokenset, res);
return res.status(200).send({ token, user });
} catch (error) {
logger.error('[refreshController] OpenID token refresh error', error);
@@ -103,25 +96,14 @@ const refreshController = async (req, res) => {
return res.status(200).send({ token, user });
}
/** Session with the hashed refresh token */
const session = await findSession(
{
userId: userId,
refreshToken: refreshToken,
},
{ lean: false },
);
// Find the session with the hashed refresh token
const session = await findSession({
userId: userId,
refreshToken: refreshToken,
});
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session);
// trigger OAuth MCP server reconnection asynchronously (best effort)
void getOAuthReconnectionManager()
.reconnectServers(userId)
.catch((err) => {
logger.error('Error reconnecting OAuth MCP servers:', err);
});
const token = await setAuthTokens(userId, res, session._id);
res.status(200).send({ token, user });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)
@@ -132,7 +114,7 @@ const refreshController = async (req, res) => {
res.status(401).send('Refresh token expired or not found for this user');
}
} catch (err) {
logger.error(`[refreshController] Invalid refresh token:`, err);
logger.error(`[refreshController] Refresh token: ${refreshToken}`, err);
res.status(403).send('Invalid refresh token');
}
};

View File

@@ -1,9 +1,16 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api');
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { CacheKeys, Constants } = require('librechat-data-provider');
const {
getToolkitKey,
checkPluginAuth,
filterUniquePlugins,
convertMCPToolToPlugin,
convertMCPToolsToPlugins,
} = require('@librechat/api');
const { getCachedTools, setCachedTools, mergeUserTools } = require('~/server/services/Config');
const { availableTools, toolkits } = require('~/app/clients/tools');
const { getAppConfig } = require('~/server/services/Config');
const { getMCPManager } = require('~/config');
const { getLogStores } = require('~/cache');
const getAvailablePluginsController = async (req, res) => {
@@ -65,27 +72,54 @@ const getAvailableTools = async (req, res) => {
}
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const cachedUserTools = await getCachedTools({ userId });
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
const mcpManager = getMCPManager();
const userPlugins =
cachedUserTools != null
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
: undefined;
// Return early if we have cached tools
if (cachedToolsArray != null) {
res.status(200).json(cachedToolsArray);
if (cachedToolsArray != null && userPlugins != null) {
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
res.status(200).json(dedupedTools);
return;
}
/** @type {Record<string, FunctionTool> | null} Get tool definitions to filter which tools are actually available */
let toolDefinitions = await getCachedTools();
if (toolDefinitions == null && appConfig?.availableTools != null) {
logger.warn('[getAvailableTools] Tool cache was empty, re-initializing from app config');
await setCachedTools(appConfig.availableTools);
toolDefinitions = appConfig.availableTools;
}
let toolDefinitions = await getCachedTools({ includeGlobal: true });
let prelimCachedTools;
/** @type {import('@librechat/api').LCManifestTool[]} */
let pluginManifest = availableTools;
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
if (appConfig?.mcpConfig != null) {
try {
const mcpTools = await mcpManager.getAllToolFunctions(userId);
prelimCachedTools = prelimCachedTools ?? {};
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
const plugin = convertMCPToolToPlugin({
toolKey,
toolData,
mcpManager,
});
if (plugin) {
pluginManifest.push(plugin);
}
prelimCachedTools[toolKey] = toolData;
}
await mergeUserTools({ userId, cachedUserTools, userTools: prelimCachedTools });
} catch (error) {
logger.error(
'[getAvailableTools] Error loading MCP Tools, servers may still be initializing:',
error,
);
}
} else if (prelimCachedTools != null) {
await setCachedTools(prelimCachedTools, { isGlobal: true });
}
/** @type {TPlugin[]} Deduplicate and authenticate plugins */
const uniquePlugins = filterUniquePlugins(pluginManifest);
const authenticatedPlugins = uniquePlugins.map((plugin) => {
@@ -96,13 +130,13 @@ const getAvailableTools = async (req, res) => {
}
});
/** Filter plugins based on availability */
/** Filter plugins based on availability and add MCP-specific auth config */
const toolsOutput = [];
for (const plugin of authenticatedPlugins) {
const isToolDefined = toolDefinitions?.[plugin.pluginKey] !== undefined;
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
const isToolkit =
plugin.toolkit === true &&
Object.keys(toolDefinitions ?? {}).some(
Object.keys(toolDefinitions).some(
(key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey,
);
@@ -110,13 +144,39 @@ const getAvailableTools = async (req, res) => {
continue;
}
toolsOutput.push(plugin);
const toolToAdd = { ...plugin };
if (plugin.pluginKey.includes(Constants.mcp_delimiter)) {
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
const serverName = parts[parts.length - 1];
const serverConfig = appConfig?.mcpConfig?.[serverName];
if (serverConfig?.customUserVars) {
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length === 0) {
toolToAdd.authConfig = [];
toolToAdd.authenticated = true;
} else {
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(
([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}),
);
toolToAdd.authenticated = false;
}
}
}
toolsOutput.push(toolToAdd);
}
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
res.status(200).json(finalTools);
const dedupedTools = filterUniquePlugins([...(userPlugins ?? []), ...finalTools]);
res.status(200).json(dedupedTools);
} catch (error) {
logger.error('[getAvailableTools]', error);
res.status(500).json({ message: error.message });

View File

@@ -1,3 +1,4 @@
const { Constants } = require('librechat-data-provider');
const { getCachedTools, getAppConfig } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
@@ -16,10 +17,18 @@ jest.mock('~/server/services/Config', () => ({
includedTools: [],
}),
setCachedTools: jest.fn(),
mergeUserTools: jest.fn(),
}));
// loadAndFormatTools mock removed - no longer used in PluginController
// getMCPManager mock removed - no longer used in PluginController
jest.mock('~/config', () => ({
getMCPManager: jest.fn(() => ({
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
})),
getFlowStateManager: jest.fn(),
}));
jest.mock('~/app/clients/tools', () => ({
availableTools: [],
@@ -150,6 +159,43 @@ describe('PluginController', () => {
});
describe('getAvailableTools', () => {
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
const mockUserTools = {
[`tool1${Constants.mcp_delimiter}server1`]: {
type: 'function',
function: {
name: `tool1${Constants.mcp_delimiter}server1`,
description: 'Tool 1',
parameters: { type: 'object', properties: {} },
},
},
};
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(mockUserTools);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Mock second call to return tool definitions (includeGlobal: true)
getCachedTools.mockResolvedValueOnce(mockUserTools);
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(responseData).toBeDefined();
expect(Array.isArray(responseData)).toBe(true);
expect(responseData.length).toBeGreaterThan(0);
const convertedTool = responseData.find(
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}server1`,
);
expect(convertedTool).toBeDefined();
// The real convertMCPToolsToPlugins extracts the name from the delimiter
expect(convertedTool.name).toBe('tool1');
});
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
const mockUserTools = {
'user-tool': {
@@ -174,6 +220,9 @@ describe('PluginController', () => {
paths: { structuredTools: '/mock/path' },
};
// Mock second call to return tool definitions
getCachedTools.mockResolvedValueOnce(mockUserTools);
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
@@ -196,7 +245,14 @@ describe('PluginController', () => {
require('~/app/clients/tools').availableTools.push(mockPlugin);
mockCache.get.mockResolvedValue(null);
// getCachedTools returns the tool definitions
// First call returns null for user tools
getCachedTools.mockResolvedValueOnce(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce({
tool1: {
type: 'function',
@@ -207,10 +263,6 @@ describe('PluginController', () => {
},
},
});
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
await getAvailableTools(mockReq, mockRes);
@@ -241,7 +293,14 @@ describe('PluginController', () => {
});
mockCache.get.mockResolvedValue(null);
// getCachedTools returns the tool definitions
// First call returns null for user tools
getCachedTools.mockResolvedValueOnce(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce({
toolkit1_function: {
type: 'function',
@@ -252,10 +311,6 @@ describe('PluginController', () => {
},
},
});
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
await getAvailableTools(mockReq, mockRes);
@@ -267,7 +322,126 @@ describe('PluginController', () => {
});
});
describe('plugin.icon behavior', () => {
const callGetAvailableToolsWithMCPServer = async (serverConfig) => {
mockCache.get.mockResolvedValue(null);
const functionTools = {
[`test-tool${Constants.mcp_delimiter}test-server`]: {
type: 'function',
function: {
name: `test-tool${Constants.mcp_delimiter}test-server`,
description: 'A test tool',
parameters: { type: 'object', properties: {} },
},
},
};
// Mock the MCP manager to return tools and server config
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue(functionTools),
getRawConfig: jest.fn().mockReturnValue(serverConfig),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// First call returns empty user tools
getCachedTools.mockResolvedValueOnce({});
// Mock getAppConfig to return the mcpConfig
mockReq.config = {
mcpConfig: {
'test-server': serverConfig,
},
};
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce(functionTools);
await getAvailableTools(mockReq, mockRes);
const responseData = mockRes.json.mock.calls[0][0];
return responseData.find(
(tool) => tool.pluginKey === `test-tool${Constants.mcp_delimiter}test-server`,
);
};
it('should set plugin.icon when iconPath is defined', async () => {
const serverConfig = {
iconPath: '/path/to/icon.png',
};
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
expect(testTool.icon).toBe('/path/to/icon.png');
});
it('should set plugin.icon to undefined when iconPath is not defined', async () => {
const serverConfig = {};
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
expect(testTool.icon).toBeUndefined();
});
});
describe('helper function integration', () => {
it('should properly handle MCP tools with custom user variables', async () => {
const appConfig = {
mcpConfig: {
'test-server': {
customUserVars: {
API_KEY: { title: 'API Key', description: 'Your API key' },
},
},
},
};
// Mock MCP tools returned by getAllToolFunctions
const mcpToolFunctions = {
[`tool1${Constants.mcp_delimiter}test-server`]: {
type: 'function',
function: {
name: `tool1${Constants.mcp_delimiter}test-server`,
description: 'Tool 1',
parameters: {},
},
},
};
// Mock the MCP manager to return tools
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue(mcpToolFunctions),
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {
API_KEY: { title: 'API Key', description: 'Your API key' },
},
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
mockCache.get.mockResolvedValue(null);
mockReq.config = appConfig;
// First call returns user tools (empty in this case)
getCachedTools.mockResolvedValueOnce({});
// Second call (with includeGlobal: true) returns tool definitions including our MCP tool
getCachedTools.mockResolvedValueOnce(mcpToolFunctions);
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(Array.isArray(responseData)).toBe(true);
// Find the MCP tool in the response
const mcpTool = responseData.find(
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
);
// The actual implementation adds authConfig and sets authenticated to false when customUserVars exist
expect(mcpTool).toBeDefined();
expect(mcpTool.authConfig).toEqual([
{ authField: 'API_KEY', label: 'API Key', description: 'Your API key' },
]);
expect(mcpTool.authenticated).toBe(false);
});
it('should handle error cases gracefully', async () => {
mockCache.get.mockRejectedValue(new Error('Cache error'));
@@ -289,13 +463,23 @@ describe('PluginController', () => {
it('should handle null cachedTools and cachedUserTools', async () => {
mockCache.get.mockResolvedValue(null);
// getCachedTools returns empty object instead of null
getCachedTools.mockResolvedValueOnce({});
// First call returns null for user tools
getCachedTools.mockResolvedValueOnce(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Mock MCP manager to return no tools
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// Second call (with includeGlobal: true) returns empty object instead of null
getCachedTools.mockResolvedValueOnce({});
await getAvailableTools(mockReq, mockRes);
// Should handle null values gracefully
@@ -310,9 +494,9 @@ describe('PluginController', () => {
paths: { structuredTools: '/mock/path' },
};
// Mock getCachedTools to return undefined
// Mock getCachedTools to return undefined for both calls
getCachedTools.mockReset();
getCachedTools.mockResolvedValueOnce(undefined);
getCachedTools.mockResolvedValueOnce(undefined).mockResolvedValueOnce(undefined);
await getAvailableTools(mockReq, mockRes);
@@ -321,6 +505,42 @@ describe('PluginController', () => {
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
const userTools = {
[`user-tool${Constants.mcp_delimiter}server1`]: {
type: 'function',
function: {
name: `user-tool${Constants.mcp_delimiter}server1`,
description: 'User tool',
parameters: {},
},
},
};
mockCache.get.mockResolvedValue(cachedTools);
getCachedTools.mockResolvedValueOnce(userTools);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// The controller expects a second call to getCachedTools
getCachedTools.mockResolvedValueOnce({
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },
[`user-tool${Constants.mcp_delimiter}server1`]:
userTools[`user-tool${Constants.mcp_delimiter}server1`],
});
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
// Should have both cached and user tools
expect(responseData.length).toBeGreaterThanOrEqual(2);
});
it('should handle empty toolDefinitions object', async () => {
mockCache.get.mockResolvedValue(null);
// Reset getCachedTools to ensure clean state
@@ -331,12 +551,76 @@ describe('PluginController', () => {
// Ensure no plugins are available
require('~/app/clients/tools').availableTools.length = 0;
// Reset MCP manager to default state
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
await getAvailableTools(mockReq, mockRes);
// With empty tool definitions, no tools should be in the final output
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should handle MCP tools without customUserVars', async () => {
const appConfig = {
mcpConfig: {
'test-server': {
// No customUserVars defined
},
},
};
const mockUserTools = {
[`tool1${Constants.mcp_delimiter}test-server`]: {
type: 'function',
function: {
name: `tool1${Constants.mcp_delimiter}test-server`,
description: 'Tool 1',
parameters: { type: 'object', properties: {} },
},
},
};
// Mock the MCP manager to return the tools
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue(mockUserTools),
getRawConfig: jest.fn().mockReturnValue({
// No customUserVars defined
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
mockCache.get.mockResolvedValue(null);
mockReq.config = appConfig;
// First call returns empty user tools
getCachedTools.mockResolvedValueOnce({});
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce(mockUserTools);
// Ensure no plugins in availableTools for clean test
require('~/app/clients/tools').availableTools.length = 0;
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(Array.isArray(responseData)).toBe(true);
expect(responseData.length).toBeGreaterThan(0);
const mcpTool = responseData.find(
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
);
expect(mcpTool).toBeDefined();
expect(mcpTool.authenticated).toBe(true);
// The actual implementation sets authConfig to empty array when no customUserVars
expect(mcpTool.authConfig).toEqual([]);
});
it('should handle undefined filteredTools and includedTools', async () => {
mockReq.config = {};
mockCache.get.mockResolvedValue(null);
@@ -365,129 +649,20 @@ describe('PluginController', () => {
require('~/app/clients/tools').availableTools.push(mockToolkit);
mockCache.get.mockResolvedValue(null);
// getCachedTools returns empty object to avoid null reference error
// First call returns empty object
getCachedTools.mockResolvedValueOnce({});
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Second call (with includeGlobal: true) returns empty object to avoid null reference error
getCachedTools.mockResolvedValueOnce({});
await getAvailableTools(mockReq, mockRes);
// Should handle null toolDefinitions gracefully
expect(mockRes.status).toHaveBeenCalledWith(200);
});
it('should handle undefined toolDefinitions when checking isToolDefined (traversaal_search bug)', async () => {
// This test reproduces the bug where toolDefinitions is undefined
// and accessing toolDefinitions[plugin.pluginKey] causes a TypeError
const mockPlugin = {
name: 'Traversaal Search',
pluginKey: 'traversaal_search',
description: 'Search plugin',
};
// Add the plugin to availableTools
require('~/app/clients/tools').availableTools.push(mockPlugin);
mockCache.get.mockResolvedValue(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// CRITICAL: getCachedTools returns undefined
// This is what causes the bug when trying to access toolDefinitions[plugin.pluginKey]
getCachedTools.mockResolvedValueOnce(undefined);
// This should not throw an error with the optional chaining fix
await getAvailableTools(mockReq, mockRes);
// Should handle undefined toolDefinitions gracefully and return empty array
expect(mockRes.status).toHaveBeenCalledWith(200);
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should re-initialize tools from appConfig when cache returns null', async () => {
// Setup: Initial state with tools in appConfig
const mockAppTools = {
tool1: {
type: 'function',
function: {
name: 'tool1',
description: 'Tool 1',
parameters: {},
},
},
tool2: {
type: 'function',
function: {
name: 'tool2',
description: 'Tool 2',
parameters: {},
},
},
};
// Add matching plugins to availableTools
require('~/app/clients/tools').availableTools.push(
{ name: 'Tool 1', pluginKey: 'tool1', description: 'Tool 1' },
{ name: 'Tool 2', pluginKey: 'tool2', description: 'Tool 2' },
);
// Simulate cache cleared state (returns null)
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
mockReq.config = {
filteredTools: [],
includedTools: [],
availableTools: mockAppTools,
};
// Mock setCachedTools to verify it's called to re-initialize
const { setCachedTools } = require('~/server/services/Config');
await getAvailableTools(mockReq, mockRes);
// Should have re-initialized the cache with tools from appConfig
expect(setCachedTools).toHaveBeenCalledWith(mockAppTools);
// Should still return tools successfully
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(responseData).toHaveLength(2);
expect(responseData.find((t) => t.pluginKey === 'tool1')).toBeDefined();
expect(responseData.find((t) => t.pluginKey === 'tool2')).toBeDefined();
});
it('should handle cache clear without appConfig.availableTools gracefully', async () => {
// Setup: appConfig without availableTools
getAppConfig.mockResolvedValue({
filteredTools: [],
includedTools: [],
// No availableTools property
});
// Clear availableTools array
require('~/app/clients/tools').availableTools.length = 0;
// Cache returns null (cleared state)
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
mockReq.config = {
filteredTools: [],
includedTools: [],
// No availableTools
};
await getAvailableTools(mockReq, mockRes);
// Should handle gracefully without crashing
expect(mockRes.status).toHaveBeenCalledWith(200);
expect(mockRes.json).toHaveBeenCalledWith([]);
});
});
});

View File

@@ -1,34 +1,26 @@
const { logger } = require('@librechat/data-schemas');
const { Tools, CacheKeys, Constants, FileSources } = require('librechat-data-provider');
const {
webSearchKeys,
MCPOAuthHandler,
MCPTokenStorage,
normalizeHttpError,
extractWebSearchEnvVars,
} = require('@librechat/api');
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
const {
getFiles,
findToken,
updateUser,
deleteFiles,
deleteConvos,
deletePresets,
deleteMessages,
deleteUserById,
deleteAllSharedLinks,
deleteAllUserSessions,
} = require('~/models');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { Tools, Constants, FileSources } = require('librechat-data-provider');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User, Token } = require('~/db/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { Transaction, Balance, User } = require('~/db/models');
const { getAppConfig } = require('~/server/services/Config');
const { deleteToolCalls } = require('~/models/ToolCall');
const { getLogStores } = require('~/cache');
const { deleteAllSharedLinks } = require('~/models');
const { getMCPManager } = require('~/config');
const getUserController = async (req, res) => {
const appConfig = await getAppConfig({ role: req.user?.role });
@@ -170,15 +162,6 @@ const updateUserPluginsController = async (req, res) => {
);
({ status, message } = normalizeHttpError(authService));
}
try {
// if the MCP server uses OAuth, perform a full cleanup and token revocation
await maybeUninstallOAuthMCP(user.id, pluginKey, appConfig);
} catch (error) {
logger.error(
`[updateUserPluginsController] Error uninstalling OAuth MCP for ${pluginKey}:`,
error,
);
}
} else {
// This handles:
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
@@ -204,7 +187,7 @@ const updateUserPluginsController = async (req, res) => {
// Extract server name from pluginKey (format: "mcp_<serverName>")
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
logger.info(
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
);
await mcpManager.disconnectUserConnection(user.id, serverName);
}
@@ -286,94 +269,6 @@ const resendVerificationController = async (req, res) => {
}
};
/**
* OAuth MCP specific uninstall logic
*/
const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
if (!pluginKey.startsWith(Constants.mcp_prefix)) {
// this is not an MCP server, so nothing to do here
return;
}
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
const mcpManager = getMCPManager(userId);
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
if (!mcpManager.getOAuthServers().has(serverName)) {
// this server does not use OAuth, so nothing to do here as well
return;
}
// 1. get client info used for revocation (client id, secret)
const clientTokenData = await MCPTokenStorage.getClientInfoAndMetadata({
userId,
serverName,
findToken,
});
if (clientTokenData == null) {
return;
}
const { clientInfo, clientMetadata } = clientTokenData;
// 2. get decrypted tokens before deletion
const tokens = await MCPTokenStorage.getTokens({
userId,
serverName,
findToken,
});
// 3. revoke OAuth tokens at the provider
const revocationEndpoint =
serverConfig.oauth?.revocation_endpoint ?? clientMetadata.revocation_endpoint;
const revocationEndpointAuthMethodsSupported =
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
clientMetadata.revocation_endpoint_auth_methods_supported;
if (tokens?.access_token) {
try {
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
});
} catch (error) {
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
}
}
if (tokens?.refresh_token) {
try {
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
});
} catch (error) {
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
}
}
// 4. delete tokens from the DB after revocation attempts
await MCPTokenStorage.deleteUserTokens({
userId,
serverName,
deleteToken: async (filter) => {
await Token.deleteOne(filter);
},
});
// 5. clear the flow state for the OAuth tokens
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
await flowManager.deleteFlow(flowId, 'mcp_oauth');
};
module.exports = {
getUserController,
getTermsStatusController,

View File

@@ -1,342 +0,0 @@
const { Tools } = require('librechat-data-provider');
// Mock all dependencies before requiring the module
jest.mock('nanoid', () => ({
nanoid: jest.fn(() => 'mock-id'),
}));
jest.mock('@librechat/api', () => ({
sendEvent: jest.fn(),
}));
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
},
}));
jest.mock('@librechat/agents', () => ({
EnvVar: { CODE_API_KEY: 'CODE_API_KEY' },
Providers: { GOOGLE: 'google' },
GraphEvents: {},
getMessageId: jest.fn(),
ToolEndHandler: jest.fn(),
handleToolCalls: jest.fn(),
ChatModelStreamHandler: jest.fn(),
}));
jest.mock('~/server/services/Files/Citations', () => ({
processFileCitations: jest.fn(),
}));
jest.mock('~/server/services/Files/Code/process', () => ({
processCodeOutput: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
saveBase64Image: jest.fn(),
}));
describe('createToolEndCallback', () => {
let req, res, artifactPromises, createToolEndCallback;
let logger;
beforeEach(() => {
jest.clearAllMocks();
// Get the mocked logger
logger = require('@librechat/data-schemas').logger;
// Now require the module after all mocks are set up
const callbacks = require('../callbacks');
createToolEndCallback = callbacks.createToolEndCallback;
req = {
user: { id: 'user123' },
};
res = {
headersSent: false,
write: jest.fn(),
};
artifactPromises = [];
});
describe('ui_resources artifact handling', () => {
it('should process ui_resources artifact and return attachment when headers not sent', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'button', label: 'Click me' },
1: { type: 'input', placeholder: 'Enter text' },
},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
// Wait for all promises to resolve
const results = await Promise.all(artifactPromises);
// When headers are not sent, it returns attachment without writing
expect(res.write).not.toHaveBeenCalled();
const attachment = results[0];
expect(attachment).toEqual({
type: Tools.ui_resources,
messageId: 'run456',
toolCallId: 'tool123',
conversationId: 'thread789',
[Tools.ui_resources]: {
0: { type: 'button', label: 'Click me' },
1: { type: 'input', placeholder: 'Enter text' },
},
});
});
it('should write to response when headers are already sent', async () => {
res.headersSent = true;
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'carousel', items: [] },
},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(res.write).toHaveBeenCalled();
expect(results[0]).toEqual({
type: Tools.ui_resources,
messageId: 'run456',
toolCallId: 'tool123',
conversationId: 'thread789',
[Tools.ui_resources]: {
0: { type: 'carousel', items: [] },
},
});
});
it('should handle errors when processing ui_resources', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
// Mock res.write to throw an error
res.headersSent = true;
res.write.mockImplementation(() => {
throw new Error('Write failed');
});
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'test' },
},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(logger.error).toHaveBeenCalledWith(
'Error processing artifact content:',
expect.any(Error),
);
expect(results[0]).toBeNull();
});
it('should handle multiple artifacts including ui_resources', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'chart', data: [] },
},
},
[Tools.web_search]: {
results: ['result1', 'result2'],
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
// Both ui_resources and web_search should be processed
expect(artifactPromises).toHaveLength(2);
expect(results).toHaveLength(2);
// Check ui_resources attachment
const uiResourceAttachment = results.find((r) => r?.type === Tools.ui_resources);
expect(uiResourceAttachment).toBeTruthy();
expect(uiResourceAttachment[Tools.ui_resources]).toEqual({
0: { type: 'chart', data: [] },
});
// Check web_search attachment
const webSearchAttachment = results.find((r) => r?.type === Tools.web_search);
expect(webSearchAttachment).toBeTruthy();
expect(webSearchAttachment[Tools.web_search]).toEqual({
results: ['result1', 'result2'],
});
});
it('should not process artifacts when output has no artifacts', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
content: 'Some regular content',
// No artifact property
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
expect(artifactPromises).toHaveLength(0);
expect(res.write).not.toHaveBeenCalled();
});
});
describe('edge cases', () => {
it('should handle empty ui_resources data object', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(results[0]).toEqual({
type: Tools.ui_resources,
messageId: 'run456',
toolCallId: 'tool123',
conversationId: 'thread789',
[Tools.ui_resources]: {},
});
});
it('should handle ui_resources with complex nested data', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const complexData = {
0: {
type: 'form',
fields: [
{ name: 'field1', type: 'text', required: true },
{ name: 'field2', type: 'select', options: ['a', 'b', 'c'] },
],
nested: {
deep: {
value: 123,
array: [1, 2, 3],
},
},
},
};
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: complexData,
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(results[0][Tools.ui_resources]).toEqual(complexData);
});
it('should handle when output is undefined', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output: undefined }, metadata);
expect(artifactPromises).toHaveLength(0);
expect(res.write).not.toHaveBeenCalled();
});
it('should handle when data parameter is undefined', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback(undefined, metadata);
expect(artifactPromises).toHaveLength(0);
expect(res.write).not.toHaveBeenCalled();
});
});
});

View File

@@ -158,7 +158,7 @@ describe('duplicateAgent', () => {
});
});
it('should convert `tool_resources.ocr` to `tool_resources.context`', async () => {
it('should handle tool_resources.ocr correctly', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
@@ -178,7 +178,7 @@ describe('duplicateAgent', () => {
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
tool_resources: {
context: { enabled: true, config: 'test' },
ocr: { enabled: true, config: 'test' },
},
}),
);

View File

@@ -265,30 +265,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
);
}
// TODO: a lot of duplicated code in createToolEndCallback
// we should refactor this to use a helper function in a follow-up PR
if (output.artifact[Tools.ui_resources]) {
artifactPromises.push(
(async () => {
const attachment = {
type: Tools.ui_resources,
messageId: metadata.run_id,
toolCallId: output.tool_call_id,
conversationId: metadata.thread_id,
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
};
if (!res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
return attachment;
})().catch((error) => {
logger.error('Error processing artifact content:', error);
return null;
}),
);
}
if (output.artifact[Tools.web_search]) {
artifactPromises.push(
(async () => {

View File

@@ -7,12 +7,9 @@ const {
createRun,
Tokenizer,
checkAccess,
logAxiosError,
resolveHeaders,
getBalanceConfig,
memoryInstructions,
formatContentStrings,
getTransactionsConfig,
createMemoryProcessor,
} = require('@librechat/api');
const {
@@ -89,10 +86,11 @@ function createTokenCounter(encoding) {
}
function logToolError(graph, error, toolId) {
logAxiosError({
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
error,
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
});
toolId,
);
}
class AgentClient extends BaseClient {
@@ -624,13 +622,11 @@ class AgentClient extends BaseClient {
* @param {string} [params.model]
* @param {string} [params.context='message']
* @param {AppConfig['balance']} [params.balance]
* @param {AppConfig['transactions']} [params.transactions]
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/
async recordCollectedUsage({
model,
balance,
transactions,
context = 'message',
collectedUsage = this.collectedUsage,
}) {
@@ -656,7 +652,6 @@ class AgentClient extends BaseClient {
const txMetadata = {
context,
balance,
transactions,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
@@ -872,10 +867,11 @@ class AgentClient extends BaseClient {
if (agent.useLegacyContent === true) {
messages = formatContentStrings(messages);
}
const defaultHeaders =
agent.model_parameters?.clientOptions?.defaultHeaders ??
agent.model_parameters?.configuration?.defaultHeaders;
if (defaultHeaders?.['anthropic-beta']?.includes('prompt-caching')) {
if (
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
'prompt-caching',
)
) {
messages = addCacheControl(messages);
}
@@ -883,16 +879,6 @@ class AgentClient extends BaseClient {
memoryPromise = this.runMemory(messages);
}
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
* non-custom endpoints, needs consideration of varying provider header configs.
*/
if (agent.model_parameters?.configuration?.defaultHeaders != null) {
agent.model_parameters.configuration.defaultHeaders = resolveHeaders({
headers: agent.model_parameters.configuration.defaultHeaders,
body: config.configurable.requestBody,
});
}
run = await createRun({
agent,
req: this.options.req,
@@ -1054,12 +1040,7 @@ class AgentClient extends BaseClient {
}
const balanceConfig = getBalanceConfig(appConfig);
const transactionsConfig = getTransactionsConfig(appConfig);
await this.recordCollectedUsage({
context: 'message',
balance: balanceConfig,
transactions: transactionsConfig,
});
await this.recordCollectedUsage({ context: 'message', balance: balanceConfig });
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
@@ -1121,13 +1102,6 @@ class AgentClient extends BaseClient {
);
}
if (endpointConfig?.titleConvo === false) {
logger.debug(
`[api/server/controllers/agents/client.js #titleConvo] Title generation disabled for endpoint "${endpoint}"`,
);
return;
}
if (endpointConfig?.titleEndpoint && endpointConfig.titleEndpoint !== endpoint) {
try {
titleProviderConfig = getProviderConfig({
@@ -1137,7 +1111,7 @@ class AgentClient extends BaseClient {
endpoint = endpointConfig.titleEndpoint;
} catch (error) {
logger.warn(
`[api/server/controllers/agents/client.js #titleConvo] Error getting title endpoint config for "${endpointConfig.titleEndpoint}", falling back to default`,
`[api/server/controllers/agents/client.js #titleConvo] Error getting title endpoint config for ${endpointConfig.titleEndpoint}, falling back to default`,
error,
);
// Fall back to original provider config
@@ -1207,20 +1181,6 @@ class AgentClient extends BaseClient {
clientOptions.json = true;
}
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
* non-custom endpoints, needs consideration of varying provider header configs.
*/
if (clientOptions?.configuration?.defaultHeaders != null) {
clientOptions.configuration.defaultHeaders = resolveHeaders({
headers: clientOptions.configuration.defaultHeaders,
body: {
messageId: this.responseMessageId,
conversationId: this.conversationId,
parentMessageId: this.parentMessageId,
},
});
}
try {
const titleResult = await this.run.generateTitle({
provider,
@@ -1260,13 +1220,11 @@ class AgentClient extends BaseClient {
});
const balanceConfig = getBalanceConfig(appConfig);
const transactionsConfig = getTransactionsConfig(appConfig);
await this.recordCollectedUsage({
collectedUsage,
context: 'title',
model: clientOptions.model,
balance: balanceConfig,
transactions: transactionsConfig,
}).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',

View File

@@ -237,9 +237,6 @@ describe('AgentClient - titleConvo', () => {
balance: {
enabled: false,
},
transactions: {
enabled: true,
},
});
});
@@ -263,125 +260,6 @@ describe('AgentClient - titleConvo', () => {
expect(result).toBeUndefined();
});
it('should skip title generation when titleConvo is set to false', async () => {
// Set titleConvo to false in endpoint config
mockReq.config = {
endpoints: {
[EModelEndpoint.openAI]: {
titleConvo: false,
titleModel: 'gpt-3.5-turbo',
titlePrompt: 'Custom title prompt',
titleMethod: 'structured',
titlePromptTemplate: 'Template: {{content}}',
},
},
};
const text = 'Test conversation text';
const abortController = new AbortController();
const result = await client.titleConvo({ text, abortController });
// Should return undefined without generating title
expect(result).toBeUndefined();
// generateTitle should NOT have been called
expect(mockRun.generateTitle).not.toHaveBeenCalled();
// recordCollectedUsage should NOT have been called
expect(client.recordCollectedUsage).not.toHaveBeenCalled();
});
it('should skip title generation when titleConvo is false in all config', async () => {
// Set titleConvo to false in "all" config
mockReq.config = {
endpoints: {
all: {
titleConvo: false,
titleModel: 'gpt-4o-mini',
titlePrompt: 'All config title prompt',
titleMethod: 'completion',
titlePromptTemplate: 'All config template',
},
},
};
const text = 'Test conversation text';
const abortController = new AbortController();
const result = await client.titleConvo({ text, abortController });
// Should return undefined without generating title
expect(result).toBeUndefined();
// generateTitle should NOT have been called
expect(mockRun.generateTitle).not.toHaveBeenCalled();
// recordCollectedUsage should NOT have been called
expect(client.recordCollectedUsage).not.toHaveBeenCalled();
});
it('should skip title generation when titleConvo is false for custom endpoint scenario', async () => {
// This test validates the behavior when customEndpointConfig (retrieved via
// getProviderConfig for custom endpoints) has titleConvo: false.
//
// The code path is:
// 1. endpoints?.all is checked (undefined in this test)
// 2. endpoints?.[endpoint] is checked (our test config)
// 3. Would fall back to titleProviderConfig.customEndpointConfig (for real custom endpoints)
//
// We simulate a custom endpoint scenario using a dynamically named endpoint config
// Create a unique endpoint name that represents a custom endpoint
const customEndpointName = 'customEndpoint';
// Configure the endpoint to have titleConvo: false
// This simulates what would be in customEndpointConfig for a real custom endpoint
mockReq.config = {
endpoints: {
// No 'all' config - so it will check endpoints[endpoint]
// This config represents what customEndpointConfig would contain
[customEndpointName]: {
titleConvo: false,
titleModel: 'custom-model-v1',
titlePrompt: 'Custom endpoint title prompt',
titleMethod: 'completion',
titlePromptTemplate: 'Custom template: {{content}}',
baseURL: 'https://api.custom-llm.com/v1',
apiKey: 'test-custom-key',
// Additional custom endpoint properties
models: {
default: ['custom-model-v1', 'custom-model-v2'],
},
},
},
};
// Set up agent to use our custom endpoint
// Use openAI as base but override with custom endpoint name for this test
mockAgent.endpoint = EModelEndpoint.openAI;
mockAgent.provider = EModelEndpoint.openAI;
// Override the endpoint in the config to point to our custom config
mockReq.config.endpoints[EModelEndpoint.openAI] =
mockReq.config.endpoints[customEndpointName];
delete mockReq.config.endpoints[customEndpointName];
const text = 'Test custom endpoint conversation';
const abortController = new AbortController();
const result = await client.titleConvo({ text, abortController });
// Should return undefined without generating title because titleConvo is false
expect(result).toBeUndefined();
// generateTitle should NOT have been called
expect(mockRun.generateTitle).not.toHaveBeenCalled();
// recordCollectedUsage should NOT have been called
expect(client.recordCollectedUsage).not.toHaveBeenCalled();
});
it('should pass titleEndpoint configuration to generateTitle', async () => {
// Mock the API key just for this test
const originalApiKey = process.env.ANTHROPIC_API_KEY;

View File

@@ -2,15 +2,9 @@ const { z } = require('zod');
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const {
agentCreateSchema,
agentUpdateSchema,
mergeAgentOcrConversion,
convertOcrToContextInPlace,
} = require('@librechat/api');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const {
Tools,
Constants,
SystemRoles,
FileSources,
ResourceType,
@@ -71,13 +65,13 @@ const createAgentHandler = async (req, res) => {
agentData.author = userId;
agentData.tools = [];
const availableTools = await getCachedTools();
const availableTools = await getCachedTools({ includeGlobal: true });
for (const tool of tools) {
if (availableTools[tool]) {
agentData.tools.push(tool);
} else if (systemTools[tool]) {
agentData.tools.push(tool);
} else if (tool.includes(Constants.mcp_delimiter)) {
}
if (systemTools[tool]) {
agentData.tools.push(tool);
}
}
@@ -203,32 +197,19 @@ const getAgentHandler = async (req, res, expandProperties = false) => {
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @param {AgentUpdateParams} req.body - The Agent update parameters.
* @returns {Promise<Agent>} 200 - success response - application/json
* @returns {Agent} 200 - success response - application/json
*/
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const validatedData = agentUpdateSchema.parse(req.body);
const { _id, ...updateData } = removeNullishValues(validatedData);
// Convert OCR to context in incoming updateData
convertOcrToContextInPlace(updateData);
const existingAgent = await getAgent({ id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
// Convert legacy OCR tool resource to context format in existing agent
const ocrConversion = mergeAgentOcrConversion(existingAgent, updateData);
if (ocrConversion.tool_resources) {
updateData.tool_resources = ocrConversion.tool_resources;
}
if (ocrConversion.tools) {
updateData.tools = ocrConversion.tools;
}
let updatedAgent =
Object.keys(updateData).length > 0
? await updateAgent({ id }, updateData, {
@@ -273,7 +254,7 @@ const updateAgentHandler = async (req, res) => {
* @param {object} req - Express Request
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @returns {Promise<Agent>} 201 - success response - application/json
* @returns {Agent} 201 - success response - application/json
*/
const duplicateAgentHandler = async (req, res) => {
const { id } = req.params;
@@ -306,19 +287,9 @@ const duplicateAgentHandler = async (req, res) => {
hour12: false,
})})`;
if (_tool_resources?.[EToolResources.context]) {
cloneData.tool_resources = {
[EToolResources.context]: _tool_resources[EToolResources.context],
};
}
if (_tool_resources?.[EToolResources.ocr]) {
cloneData.tool_resources = {
/** Legacy conversion from `ocr` to `context` */
[EToolResources.context]: {
...(_tool_resources[EToolResources.context] ?? {}),
..._tool_resources[EToolResources.ocr],
},
[EToolResources.ocr]: _tool_resources[EToolResources.ocr],
};
}
@@ -410,7 +381,7 @@ const duplicateAgentHandler = async (req, res) => {
* @param {object} req - Express Request
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @returns {Promise<Agent>} 200 - success response - application/json
* @returns {Agent} 200 - success response - application/json
*/
const deleteAgentHandler = async (req, res) => {
try {
@@ -512,7 +483,7 @@ const getListAgentsHandler = async (req, res) => {
* @param {Express.Multer.File} req.file - The avatar image file.
* @param {object} req.body - Request body
* @param {string} [req.body.avatar] - Optional avatar for the agent's avatar.
* @returns {Promise<void>} 200 - success response - application/json
* @returns {Object} 200 - success response - application/json
*/
const uploadAgentAvatarHandler = async (req, res) => {
try {

View File

@@ -512,7 +512,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
mockReq.params.id = existingAgentId;
mockReq.body = {
tool_resources: {
/** Legacy conversion from `ocr` to `context` */
ocr: {
file_ids: ['ocr1', 'ocr2'],
},
@@ -532,8 +531,7 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tool_resources).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeUndefined();
expect(updatedAgent.tool_resources.context).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeDefined();
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
});

View File

@@ -1,7 +1,7 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
const { sendEvent, getBalanceConfig } = require('@librechat/api');
const {
Time,
Constants,
@@ -34,6 +34,7 @@ const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
/**

View File

@@ -1,7 +1,7 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
const { sendEvent, getBalanceConfig } = require('@librechat/api');
const {
Time,
Constants,
@@ -31,6 +31,7 @@ const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
/**

View File

@@ -31,7 +31,7 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
@@ -136,7 +136,7 @@ const patchAssistant = async (req, res) => {
...updateData
} = req.body;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
updateData.tools = (updateData.tools ?? [])
.map((tool) => {

View File

@@ -28,7 +28,7 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
@@ -125,7 +125,7 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
let hasFileSearch = false;
for (const tool of updateData.tools ?? []) {
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool;
if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {

View File

@@ -1,126 +0,0 @@
/**
* MCP Tools Controller
* Handles MCP-specific tool endpoints, decoupled from regular LibreChat tools
*/
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
cacheMCPServerTools,
getMCPServerTools,
getAppConfig,
} = require('~/server/services/Config');
const { getMCPManager } = require('~/config');
/**
* Get all MCP tools available to the user
*/
const getMCPTools = async (req, res) => {
try {
const userId = req.user?.id;
if (!userId) {
logger.warn('[getMCPTools] User ID not found in request');
return res.status(401).json({ message: 'Unauthorized' });
}
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
if (!appConfig?.mcpConfig) {
return res.status(200).json({ servers: {} });
}
const mcpManager = getMCPManager();
const configuredServers = Object.keys(appConfig.mcpConfig);
const mcpServers = {};
const cachePromises = configuredServers.map((serverName) =>
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })),
);
const cacheResults = await Promise.all(cachePromises);
const serverToolsMap = new Map();
for (const { serverName, tools } of cacheResults) {
if (tools) {
serverToolsMap.set(serverName, tools);
continue;
}
const serverTools = await mcpManager.getServerToolFunctions(userId, serverName);
if (!serverTools) {
logger.debug(`[getMCPTools] No tools found for server ${serverName}`);
continue;
}
serverToolsMap.set(serverName, serverTools);
if (Object.keys(serverTools).length > 0) {
// Cache asynchronously without blocking
cacheMCPServerTools({ serverName, serverTools }).catch((err) =>
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
);
}
}
// Process each configured server
for (const serverName of configuredServers) {
try {
const serverTools = serverToolsMap.get(serverName);
// Get server config once
const serverConfig = appConfig.mcpConfig[serverName];
const rawServerConfig = mcpManager.getRawConfig(serverName);
// Initialize server object with all server-level data
const server = {
name: serverName,
icon: rawServerConfig?.iconPath || '',
authenticated: true,
authConfig: [],
tools: [],
};
// Set authentication config once for the server
if (serverConfig?.customUserVars) {
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length > 0) {
server.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
server.authenticated = false;
}
}
// Process tools efficiently - no need for convertMCPToolToPlugin
if (serverTools) {
for (const [toolKey, toolData] of Object.entries(serverTools)) {
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
continue;
}
const toolName = toolKey.split(Constants.mcp_delimiter)[0];
server.tools.push({
name: toolName,
pluginKey: toolKey,
description: toolData.function.description || '',
});
}
}
// Only add server if it has tools or is configured
if (server.tools.length > 0 || serverConfig) {
mcpServers[serverName] = server;
}
} catch (error) {
logger.error(`[getMCPTools] Error loading tools for server ${serverName}:`, error);
}
}
res.status(200).json({ servers: mcpServers });
} catch (error) {
logger.error('[getMCPTools]', error);
res.status(500).json({ message: error.message });
}
};
module.exports = {
getMCPTools,
};

View File

@@ -12,8 +12,7 @@ const { logger } = require('@librechat/data-schemas');
const mongoSanitize = require('express-mongo-sanitize');
const { isEnabled, ErrorController } = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
const createValidateImageRequest = require('./middleware/validateImageRequest');
const validateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const { updateInterfacePermissions } = require('~/models/interface');
const { checkMigrations } = require('./services/start/migration');
@@ -127,7 +126,7 @@ const startServer = async () => {
app.use('/api/config', routes.config);
app.use('/api/assistants', routes.assistants);
app.use('/api/files', await routes.files.initialize());
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
app.use('/images/', validateImageRequest, routes.staticRoute);
app.use('/api/share', routes.share);
app.use('/api/roles', routes.roles);
app.use('/api/agents', routes.agents);
@@ -155,7 +154,7 @@ const startServer = async () => {
res.send(updatedIndexHtml);
});
app.listen(port, host, async () => {
app.listen(port, host, () => {
if (host === '0.0.0.0') {
logger.info(
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
@@ -164,9 +163,7 @@ const startServer = async () => {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
await initializeMCPs();
await initializeOAuthReconnectManager();
await checkMigrations();
initializeMCPs().then(() => checkMigrations());
});
};

View File

@@ -1,7 +1,7 @@
const { logger } = require('@librechat/data-schemas');
const { PermissionBits, hasPermissions, ResourceType } = require('librechat-data-provider');
const { getEffectivePermissions } = require('~/server/services/PermissionService');
const { getAgents } = require('~/models/Agent');
const { getAgent } = require('~/models/Agent');
const { getFiles } = require('~/models/File');
/**
@@ -10,12 +10,11 @@ const { getFiles } = require('~/models/File');
*/
const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
try {
/** Agents that have this file in their tool_resources */
const agentsWithFile = await getAgents({
// Find agents that have this file in their tool_resources
const agentsWithFile = await getAgent({
$or: [
{ 'tool_resources.execute_code.file_ids': fileId },
{ 'tool_resources.file_search.file_ids': fileId },
{ 'tool_resources.context.file_ids': fileId },
{ 'tool_resources.execute_code.file_ids': fileId },
{ 'tool_resources.ocr.file_ids': fileId },
],
});
@@ -25,7 +24,7 @@ const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
}
// Check if user has access to any of these agents
for (const agent of agentsWithFile) {
for (const agent of Array.isArray(agentsWithFile) ? agentsWithFile : [agentsWithFile]) {
// Check if user is the agent author
if (agent.author && agent.author.toString() === userId) {
logger.debug(`[fileAccess] User is author of agent ${agent.id}`);
@@ -84,6 +83,7 @@ const fileAccess = async (req, res, next) => {
});
}
// Get the file
const [file] = await getFiles({ file_id: fileId });
if (!file) {
return res.status(404).json({
@@ -92,18 +92,20 @@ const fileAccess = async (req, res, next) => {
});
}
// Check if user owns the file
if (file.user && file.user.toString() === userId) {
req.fileAccess = { file };
return next();
}
/** Agent-based access (file inherits agent permissions) */
// Check agent-based access (file inherits agent permissions)
const hasAgentAccess = await checkAgentBasedFileAccess({ userId, role: userRole, fileId });
if (hasAgentAccess) {
req.fileAccess = { file };
return next();
}
// No access
logger.warn(`[fileAccess] User ${userId} denied access to file ${fileId}`);
return res.status(403).json({
error: 'Forbidden',

View File

@@ -1,483 +0,0 @@
const mongoose = require('mongoose');
const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-provider');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { fileAccess } = require('./fileAccess');
const { User, Role, AclEntry } = require('~/db/models');
const { createAgent } = require('~/models/Agent');
const { createFile } = require('~/models/File');
describe('fileAccess middleware', () => {
let mongoServer;
let req, res, next;
let testUser, otherUser, thirdUser;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
// Create test role
await Role.create({
name: 'test-role',
permissions: {
AGENTS: {
USE: true,
CREATE: true,
SHARED_GLOBAL: false,
},
},
});
// Create test users
testUser = await User.create({
email: 'test@example.com',
name: 'Test User',
username: 'testuser',
role: 'test-role',
});
otherUser = await User.create({
email: 'other@example.com',
name: 'Other User',
username: 'otheruser',
role: 'test-role',
});
thirdUser = await User.create({
email: 'third@example.com',
name: 'Third User',
username: 'thirduser',
role: 'test-role',
});
// Setup request/response objects
req = {
user: { id: testUser._id.toString(), role: testUser.role },
params: {},
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
next = jest.fn();
jest.clearAllMocks();
});
describe('basic file access', () => {
test('should allow access when user owns the file', async () => {
// Create a file owned by testUser
await createFile({
user: testUser._id.toString(),
file_id: 'file_owned_by_user',
filepath: '/test/file.txt',
filename: 'file.txt',
type: 'text/plain',
size: 100,
});
req.params.file_id = 'file_owned_by_user';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
expect(req.fileAccess.file).toBeDefined();
expect(res.status).not.toHaveBeenCalled();
});
test('should deny access when user does not own the file and no agent access', async () => {
// Create a file owned by otherUser
await createFile({
user: otherUser._id.toString(),
file_id: 'file_owned_by_other',
filepath: '/test/file.txt',
filename: 'file.txt',
type: 'text/plain',
size: 100,
});
req.params.file_id = 'file_owned_by_other';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({
error: 'Forbidden',
message: 'Insufficient permissions to access this file',
});
});
test('should return 404 when file does not exist', async () => {
req.params.file_id = 'non_existent_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Not Found',
message: 'File not found',
});
});
test('should return 400 when file_id is missing', async () => {
// Don't set file_id in params
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
expect(res.json).toHaveBeenCalledWith({
error: 'Bad Request',
message: 'file_id is required',
});
});
test('should return 401 when user is not authenticated', async () => {
req.user = null;
req.params.file_id = 'some_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({
error: 'Unauthorized',
message: 'Authentication required',
});
});
});
describe('agent-based file access', () => {
beforeEach(async () => {
// Create a file owned by otherUser (not testUser)
await createFile({
user: otherUser._id.toString(),
file_id: 'shared_file_via_agent',
filepath: '/test/shared.txt',
filename: 'shared.txt',
type: 'text/plain',
size: 100,
});
});
test('should allow access when user is author of agent with file', async () => {
// Create agent owned by testUser with the file
await createAgent({
id: `agent_${Date.now()}`,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: {
file_search: {
file_ids: ['shared_file_via_agent'],
},
},
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
expect(req.fileAccess.file).toBeDefined();
});
test('should allow access when user has VIEW permission on agent with file', async () => {
// Create agent owned by otherUser
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Shared Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
execute_code: {
file_ids: ['shared_file_via_agent'],
},
},
});
// Grant VIEW permission to testUser
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
permBits: 1, // VIEW permission
grantedBy: otherUser._id,
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
});
test('should check file in ocr tool_resources', async () => {
await createAgent({
id: `agent_ocr_${Date.now()}`,
name: 'OCR Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: {
ocr: {
file_ids: ['shared_file_via_agent'],
},
},
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
});
test('should deny access when user has no permission on agent with file', async () => {
// Create agent owned by otherUser without granting permission to testUser
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
file_search: {
file_ids: ['shared_file_via_agent'],
},
},
});
// Create ACL entry for otherUser only (owner)
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
permBits: 15, // All permissions
grantedBy: otherUser._id,
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
});
describe('multiple agents with same file', () => {
/**
* This test suite verifies that when multiple agents have the same file,
* all agents are checked for permissions, not just the first one found.
* This ensures users can access files through any agent they have permission for.
*/
test('should check ALL agents with file, not just first one', async () => {
// Create a file owned by someone else
await createFile({
user: otherUser._id.toString(),
file_id: 'multi_agent_file',
filepath: '/test/multi.txt',
filename: 'multi.txt',
type: 'text/plain',
size: 100,
});
// Create first agent (owned by otherUser, no access for testUser)
const agent1 = await createAgent({
id: 'agent_no_access',
name: 'No Access Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
file_search: {
file_ids: ['multi_agent_file'],
},
},
});
// Create ACL for agent1 - only otherUser has access
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent1._id,
permBits: 15,
grantedBy: otherUser._id,
});
// Create second agent (owned by thirdUser, but testUser has VIEW access)
const agent2 = await createAgent({
id: 'agent_with_access',
name: 'Accessible Agent',
provider: 'openai',
model: 'gpt-4',
author: thirdUser._id,
tool_resources: {
file_search: {
file_ids: ['multi_agent_file'],
},
},
});
// Grant testUser VIEW access to agent2
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent2._id,
permBits: 1, // VIEW permission
grantedBy: thirdUser._id,
});
req.params.file_id = 'multi_agent_file';
await fileAccess(req, res, next);
/**
* Should succeed because testUser has access to agent2,
* even though they don't have access to agent1.
* The fix ensures all agents are checked, not just the first one.
*/
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
expect(res.status).not.toHaveBeenCalled();
});
test('should find file in any agent tool_resources type', async () => {
// Create a file
await createFile({
user: otherUser._id.toString(),
file_id: 'multi_tool_file',
filepath: '/test/tool.txt',
filename: 'tool.txt',
type: 'text/plain',
size: 100,
});
// Agent 1: file in file_search (no access for testUser)
await createAgent({
id: 'agent_file_search',
name: 'File Search Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
file_search: {
file_ids: ['multi_tool_file'],
},
},
});
// Agent 2: same file in execute_code (testUser has access)
await createAgent({
id: 'agent_execute_code',
name: 'Execute Code Agent',
provider: 'openai',
model: 'gpt-4',
author: thirdUser._id,
tool_resources: {
execute_code: {
file_ids: ['multi_tool_file'],
},
},
});
// Agent 3: same file in ocr (testUser also has access)
await createAgent({
id: 'agent_ocr',
name: 'OCR Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id, // testUser owns this one
tool_resources: {
ocr: {
file_ids: ['multi_tool_file'],
},
},
});
req.params.file_id = 'multi_tool_file';
await fileAccess(req, res, next);
/**
* Should succeed because testUser owns agent3,
* even if other agents with the file are found first.
*/
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
});
});
describe('edge cases', () => {
test('should handle agent with empty tool_resources', async () => {
await createFile({
user: otherUser._id.toString(),
file_id: 'orphan_file',
filepath: '/test/orphan.txt',
filename: 'orphan.txt',
type: 'text/plain',
size: 100,
});
// Create agent with no files in tool_resources
await createAgent({
id: `agent_empty_${Date.now()}`,
name: 'Empty Resources Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: {},
});
req.params.file_id = 'orphan_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
test('should handle agent with null tool_resources', async () => {
await createFile({
user: otherUser._id.toString(),
file_id: 'another_orphan_file',
filepath: '/test/orphan2.txt',
filename: 'orphan2.txt',
type: 'text/plain',
size: 100,
});
// Create agent with null tool_resources
await createAgent({
id: `agent_null_${Date.now()}`,
name: 'Null Resources Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: null,
});
req.params.file_id = 'another_orphan_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
});
});

View File

@@ -1,5 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { isEmailDomainAllowed } = require('@librechat/api');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { getAppConfig } = require('~/server/services/Config');
/**
@@ -11,25 +11,18 @@ const { getAppConfig } = require('~/server/services/Config');
* @param {Object} res - Express response object.
* @param {Function} next - Next middleware function.
*
* @returns {Promise<void>} - Calls next middleware if the domain's email is allowed, otherwise redirects to login
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the domain's email is allowed
*/
const checkDomainAllowed = async (req, res, next) => {
try {
const email = req?.user?.email;
const appConfig = await getAppConfig({
role: req?.user?.role,
});
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
res.redirect('/login');
return;
}
next();
} catch (error) {
logger.error('[checkDomainAllowed] Error checking domain:', error);
res.redirect('/login');
const checkDomainAllowed = async (req, res, next = () => {}) => {
const email = req?.user?.email;
const appConfig = await getAppConfig({
role: req?.user?.role,
});
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
return res.redirect('/login');
} else {
return next();
}
};

View File

@@ -1,5 +1,6 @@
const validatePasswordReset = require('./validatePasswordReset');
const validateRegistration = require('./validateRegistration');
const validateImageRequest = require('./validateImageRequest');
const buildEndpointOption = require('./buildEndpointOption');
const validateMessageReq = require('./validateMessageReq');
const checkDomainAllowed = require('./checkDomainAllowed');
@@ -49,5 +50,6 @@ module.exports = {
validateMessageReq,
buildEndpointOption,
validateRegistration,
validateImageRequest,
validatePasswordReset,
};

View File

@@ -1,14 +1,14 @@
const jwt = require('jsonwebtoken');
const { isEnabled } = require('@librechat/api');
const createValidateImageRequest = require('~/server/middleware/validateImageRequest');
const validateImageRequest = require('~/server/middleware/validateImageRequest');
jest.mock('@librechat/api', () => ({
isEnabled: jest.fn(),
jest.mock('~/server/services/Config/app', () => ({
getAppConfig: jest.fn(),
}));
describe('validateImageRequest middleware', () => {
let req, res, next, validateImageRequest;
let req, res, next;
const validObjectId = '65cfb246f7ecadb8b1e8036b';
const { getAppConfig } = require('~/server/services/Config/app');
beforeEach(() => {
jest.clearAllMocks();
@@ -22,278 +22,116 @@ describe('validateImageRequest middleware', () => {
};
next = jest.fn();
process.env.JWT_REFRESH_SECRET = 'test-secret';
process.env.OPENID_REUSE_TOKENS = 'false';
// Default: OpenID token reuse disabled
isEnabled.mockReturnValue(false);
// Mock getAppConfig to return secureImageLinks: true by default
getAppConfig.mockResolvedValue({
secureImageLinks: true,
});
});
afterEach(() => {
jest.clearAllMocks();
});
describe('Factory function', () => {
test('should return a pass-through middleware if secureImageLinks is false', async () => {
const middleware = createValidateImageRequest(false);
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('should return validation middleware if secureImageLinks is true', async () => {
validateImageRequest = createValidateImageRequest(true);
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
test('should call next() if secureImageLinks is false', async () => {
getAppConfig.mockResolvedValue({
secureImageLinks: false,
});
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
describe('Standard LibreChat token flow', () => {
beforeEach(() => {
validateImageRequest = createValidateImageRequest(true);
});
test('should return 401 if refresh token is not provided', async () => {
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
});
test('should return 403 if refresh token is invalid', async () => {
req.headers.cookie = 'refreshToken=invalid-token';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should return 403 if refresh token is expired', async () => {
const expiredToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${expiredToken}`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should call next() for valid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should allow agent avatar pattern for any valid ObjectId', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should prevent file traversal attempts', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
const traversalAttempts = [
`/images/${validObjectId}/../../../etc/passwd`,
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
];
for (const attempt of traversalAttempts) {
req.originalUrl = attempt;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
jest.clearAllMocks();
// Reset mocks for next iteration
res.status = jest.fn().mockReturnThis();
res.send = jest.fn();
}
});
test('should handle URL encoded characters in valid paths', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 401 if refresh token is not provided', async () => {
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
});
describe('OpenID token flow', () => {
beforeEach(() => {
validateImageRequest = createValidateImageRequest(true);
// Enable OpenID token reuse
isEnabled.mockReturnValue(true);
process.env.OPENID_REUSE_TOKENS = 'true';
});
test('should return 403 if no OpenID user ID cookie when token_provider is openid', async () => {
req.headers.cookie = 'refreshToken=dummy-token; token_provider=openid';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should validate JWT-signed user ID for OpenID flow', async () => {
const signedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
req.originalUrl = `/images/${validObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid JWT-signed user ID', async () => {
req.headers.cookie =
'refreshToken=dummy-token; token_provider=openid; openid_user_id=invalid-jwt';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should return 403 for expired JWT-signed user ID', async () => {
const expiredSignedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${expiredSignedUserId}`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should validate image path against JWT-signed user ID', async () => {
const signedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
const differentObjectId = '65cfb246f7ecadb8b1e8036c';
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
req.originalUrl = `/images/${differentObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should allow agent avatars in OpenID flow', async () => {
const signedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 if refresh token is invalid', async () => {
req.headers.cookie = 'refreshToken=invalid-token';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
describe('Security edge cases', () => {
let validToken;
test('should return 403 if refresh token is expired', async () => {
const expiredToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${expiredToken}`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
beforeEach(() => {
validateImageRequest = createValidateImageRequest(true);
validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
});
test('should call next() for valid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should handle very long image filenames', async () => {
const longFilename = 'a'.repeat(1000) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should handle URLs with maximum practical length', async () => {
// Most browsers support URLs up to ~2000 characters
const longFilename = 'x'.repeat(1900) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid ObjectId format', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/123/example.jpg'; // Invalid ObjectId
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should accept URLs just under the 2048 limit', async () => {
// Create a URL exactly 2047 characters long
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
const filenameLength = 2047 - baseLength;
const filename = 'a'.repeat(filenameLength) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${filename}`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
// File traversal tests
test('should prevent file traversal attempts', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
test('should handle malformed URL encoding gracefully', async () => {
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/test%ZZinvalid.jpg`;
const traversalAttempts = [
`/images/${validObjectId}/../../../etc/passwd`,
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
];
for (const attempt of traversalAttempts) {
req.originalUrl = attempt;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
jest.clearAllMocks();
}
});
test('should reject URLs with null bytes', async () => {
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/test\x00.jpg`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should handle URLs with repeated slashes', async () => {
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}//test.jpg`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should reject extremely long URLs as potential DoS', async () => {
// Create a URL longer than 2048 characters
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
const filenameLength = 2049 - baseLength; // Ensure total length exceeds 2048
const extremelyLongFilename = 'x'.repeat(filenameLength) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${extremelyLongFilename}`;
// Verify our test URL is actually too long
expect(req.originalUrl.length).toBeGreaterThan(2048);
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should handle URL encoded characters in valid paths', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
});

View File

@@ -1,7 +1,7 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { getAppConfig } = require('~/server/services/Config/app');
const OBJECT_ID_LENGTH = 24;
const OBJECT_ID_PATTERN = /^[0-9a-f]{24}$/i;
@@ -22,129 +22,50 @@ function isValidObjectId(id) {
}
/**
* Validates a LibreChat refresh token
* @param {string} refreshToken - The refresh token to validate
* @returns {{valid: boolean, userId?: string, error?: string}} - Validation result
* Middleware to validate image request.
* Must be set by `secureImageLinks` via custom config file.
*/
function validateToken(refreshToken) {
async function validateImageRequest(req, res, next) {
const appConfig = await getAppConfig({ role: req.user?.role });
if (!appConfig.secureImageLinks) {
return next();
}
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
if (!refreshToken) {
logger.warn('[validateImageRequest] Refresh token not provided');
return res.status(401).send('Unauthorized');
}
let payload;
try {
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
if (!isValidObjectId(payload.id)) {
return { valid: false, error: 'Invalid User ID' };
}
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
if (payload.exp < currentTimeInSeconds) {
return { valid: false, error: 'Refresh token expired' };
}
return { valid: true, userId: payload.id };
payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
} catch (err) {
logger.warn('[validateToken]', err);
return { valid: false, error: 'Invalid token' };
logger.warn('[validateImageRequest]', err);
return res.status(403).send('Access Denied');
}
if (!isValidObjectId(payload.id)) {
logger.warn('[validateImageRequest] Invalid User ID');
return res.status(403).send('Access Denied');
}
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
if (payload.exp < currentTimeInSeconds) {
logger.warn('[validateImageRequest] Refresh token expired');
return res.status(403).send('Access Denied');
}
const fullPath = decodeURIComponent(req.originalUrl);
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
if (pathPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
next();
} else {
logger.warn('[validateImageRequest] Invalid image path');
res.status(403).send('Access Denied');
}
}
/**
* Factory to create the `validateImageRequest` middleware with configured secureImageLinks
* @param {boolean} [secureImageLinks] - Whether secure image links are enabled
*/
function createValidateImageRequest(secureImageLinks) {
if (!secureImageLinks) {
return (_req, _res, next) => next();
}
/**
* Middleware to validate image request.
* Supports both LibreChat refresh tokens and OpenID JWT tokens.
* Must be set by `secureImageLinks` via custom config file.
*/
return async function validateImageRequest(req, res, next) {
try {
const cookieHeader = req.headers.cookie;
if (!cookieHeader) {
logger.warn('[validateImageRequest] No cookies provided');
return res.status(401).send('Unauthorized');
}
const parsedCookies = cookies.parse(cookieHeader);
const refreshToken = parsedCookies.refreshToken;
if (!refreshToken) {
logger.warn('[validateImageRequest] Token not provided');
return res.status(401).send('Unauthorized');
}
const tokenProvider = parsedCookies.token_provider;
let userIdForPath;
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
const openidUserId = parsedCookies.openid_user_id;
if (!openidUserId) {
logger.warn('[validateImageRequest] No OpenID user ID cookie found');
return res.status(403).send('Access Denied');
}
const validationResult = validateToken(openidUserId);
if (!validationResult.valid) {
logger.warn(`[validateImageRequest] ${validationResult.error}`);
return res.status(403).send('Access Denied');
}
userIdForPath = validationResult.userId;
} else {
const validationResult = validateToken(refreshToken);
if (!validationResult.valid) {
logger.warn(`[validateImageRequest] ${validationResult.error}`);
return res.status(403).send('Access Denied');
}
userIdForPath = validationResult.userId;
}
if (!userIdForPath) {
logger.warn('[validateImageRequest] No user ID available for path validation');
return res.status(403).send('Access Denied');
}
const MAX_URL_LENGTH = 2048;
if (req.originalUrl.length > MAX_URL_LENGTH) {
logger.warn('[validateImageRequest] URL too long');
return res.status(403).send('Access Denied');
}
if (req.originalUrl.includes('\x00')) {
logger.warn('[validateImageRequest] URL contains null byte');
return res.status(403).send('Access Denied');
}
let fullPath;
try {
fullPath = decodeURIComponent(req.originalUrl);
} catch {
logger.warn('[validateImageRequest] Invalid URL encoding');
return res.status(403).send('Access Denied');
}
const agentAvatarPattern = /^\/images\/[a-f0-9]{24}\/agent-[^/]*$/;
if (agentAvatarPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
return next();
}
const escapedUserId = userIdForPath.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const pathPattern = new RegExp(`^/images/${escapedUserId}/[^/]+$`);
if (pathPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
next();
} else {
logger.warn('[validateImageRequest] Invalid image path');
res.status(403).send('Access Denied');
}
} catch (error) {
logger.error('[validateImageRequest] Error:', error);
res.status(500).send('Internal Server Error');
}
};
}
module.exports = createValidateImageRequest;
module.exports = validateImageRequest;

View File

@@ -11,9 +11,6 @@ jest.mock('@librechat/api', () => ({
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
},
MCPTokenStorage: {
storeTokens: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
}));
@@ -50,8 +47,8 @@ jest.mock('~/server/services/Config', () => ({
loadCustomConfig: jest.fn(),
}));
jest.mock('~/server/services/Config/mcp', () => ({
updateMCPServerTools: jest.fn(),
jest.mock('~/server/services/Config/mcpToolsCache', () => ({
updateMCPUserTools: jest.fn(),
}));
jest.mock('~/server/services/MCP', () => ({
@@ -237,7 +234,7 @@ describe('MCP Routes', () => {
});
describe('GET /:serverName/oauth/callback', () => {
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const { MCPOAuthHandler } = require('@librechat/api');
const { getLogStores } = require('~/cache');
it('should redirect to error page when OAuth error is received', async () => {
@@ -283,7 +280,6 @@ describe('MCP Routes', () => {
it('should handle OAuth callback successfully', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@@ -299,7 +295,6 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -337,24 +332,11 @@ describe('MCP Routes', () => {
'test-auth-code',
mockFlowManager,
);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
userId: 'test-user-id',
serverName: 'test-server',
tokens: mockTokens,
clientInfo: mockFlowState.clientInfo,
metadata: mockFlowState.metadata,
}),
);
const storeInvocation = MCPTokenStorage.storeTokens.mock.invocationCallOrder[0];
const connectInvocation = mockMcpManager.getUserConnection.mock.invocationCallOrder[0];
expect(storeInvocation).toBeLessThan(connectInvocation);
expect(mockFlowManager.completeFlow).toHaveBeenCalledWith(
'tool-flow-123',
'mcp_oauth',
mockTokens,
);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should redirect to error page when callback processing fails', async () => {
@@ -372,7 +354,6 @@ describe('MCP Routes', () => {
it('should handle system-level OAuth completion', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@@ -388,7 +369,6 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -399,13 +379,11 @@ describe('MCP Routes', () => {
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should handle reconnection failure after OAuth', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@@ -421,7 +399,6 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -441,46 +418,6 @@ describe('MCP Routes', () => {
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should redirect to error page if token storage fails', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: { toolFlowId: 'tool-flow-123' },
clientInfo: {},
codeVerifier: 'test-verifier',
};
const mockTokens = {
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
};
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const mockMcpManager = {
getUserConnection: jest.fn(),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
});
});
@@ -841,10 +778,10 @@ describe('MCP Routes', () => {
require('~/cache').getLogStores.mockReturnValue({});
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
updateMCPServerTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
@@ -899,10 +836,10 @@ describe('MCP Routes', () => {
]);
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
updateMCPServerTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
@@ -1206,11 +1143,7 @@ describe('MCP Routes', () => {
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
};
const { MCPOAuthHandler } = require('@librechat/api');
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
id: 'test-flow-id',
userId: 'test-user-id',
@@ -1222,8 +1155,6 @@ describe('MCP Routes', () => {
clientInfo: {},
codeVerifier: 'test-verifier',
});
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
const mockFlowManager = {
completeFlow: jest.fn(),
@@ -1248,11 +1179,6 @@ describe('MCP Routes', () => {
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
const { getCachedTools } = require('~/server/services/Config');
getCachedTools.mockResolvedValue(null);
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
};
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
@@ -1265,15 +1191,6 @@ describe('MCP Routes', () => {
completeFlow: jest.fn(),
};
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.getFlowState.mockResolvedValue({
serverName: 'test-server',
userId: 'test-user-id',
metadata: { serverUrl: 'https://example.com', oauth: {} },
clientInfo: {},
codeVerifier: 'test-verifier',
});
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({

View File

@@ -1,19 +1,20 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { generateCheckAccess, isActionDomainAllowed } = require('@librechat/api');
const { generateCheckAccess } = require('@librechat/api');
const {
Permissions,
ResourceType,
PermissionBits,
PermissionTypes,
actionDelimiter,
PermissionBits,
removeNullishValues,
} = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { findAccessibleResources } = require('~/server/services/PermissionService');
const { getAgent, updateAgent, getListAgentsByAccess } = require('~/models/Agent');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { canAccessAgentResource } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role');

View File

@@ -1,12 +1,12 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { isActionDomainAllowed } = require('@librechat/api');
const { actionDelimiter, EModelEndpoint, removeNullishValues } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
const { isActionDomainAllowed } = require('~/server/services/domains');
const router = express.Router();

View File

@@ -122,11 +122,9 @@ router.get('/', async function (req, res) {
payload.minPasswordLength = minPasswordLength;
}
payload.mcpServers = {};
const getMCPServers = () => {
try {
if (appConfig?.mcpConfig == null) {
return;
}
const mcpManager = getMCPManager();
if (!mcpManager) {
return;
@@ -135,9 +133,6 @@ router.get('/', async function (req, res) {
if (!mcpServers) return;
const oauthServers = mcpManager.getOAuthServers();
for (const serverName in mcpServers) {
if (!payload.mcpServers) {
payload.mcpServers = {};
}
const serverConfig = mcpServers[serverName];
payload.mcpServers[serverName] = removeNullishValues({
startup: serverConfig?.startup,

View File

@@ -4,13 +4,9 @@ const { sleep } = require('@librechat/agents');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const {
createImportLimiters,
createForkLimiters,
configMiddleware,
} = require('~/server/middleware');
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { importConversations } = require('~/server/utils/import');
@@ -175,7 +171,6 @@ router.post(
'/import',
importIpLimiter,
importUserLimiter,
configMiddleware,
upload.single('file'),
async (req, res) => {
try {

View File

@@ -31,7 +31,6 @@ const { getAssistant } = require('~/models/Assistant');
const { getAgent } = require('~/models/Agent');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const { Readable } = require('stream');
const router = express.Router();
@@ -185,7 +184,6 @@ router.delete('/', async (req, res) => {
role: req.user.role,
fileIds: nonOwnedFileIds,
agentId: req.body.agent_id,
isDelete: true,
});
for (const file of nonOwnedFiles) {
@@ -327,6 +325,11 @@ router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
res.setHeader('X-File-Metadata', JSON.stringify(file));
};
/** @type {{ body: import('stream').PassThrough } | undefined} */
let passThrough;
/** @type {ReadableStream | undefined} */
let fileStream;
if (checkOpenAIStorage(file.source)) {
req.body = { model: file.model };
const endpointMap = {
@@ -339,19 +342,12 @@ router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
overrideEndpoint: endpointMap[file.source],
});
logger.debug(`Downloading file ${file_id} from OpenAI`);
const passThrough = await getDownloadStream(file_id, openai);
passThrough = await getDownloadStream(file_id, openai);
setHeaders();
logger.debug(`File ${file_id} downloaded from OpenAI`);
// Handle both Node.js and Web streams
const stream =
passThrough.body && typeof passThrough.body.getReader === 'function'
? Readable.fromWeb(passThrough.body)
: passThrough.body;
stream.pipe(res);
passThrough.body.pipe(res);
} else {
const fileStream = await getDownloadStream(req, file.filepath);
fileStream = await getDownloadStream(req, file.filepath);
fileStream.on('error', (streamError) => {
logger.error('[DOWNLOAD ROUTE] Stream error:', streamError);

View File

@@ -1,33 +1,19 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const {
createSafeUser,
MCPOAuthHandler,
MCPTokenStorage,
getUserMCPAuthMap,
} = require('@librechat/api');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { getMCPTools } = require('~/server/controllers/mcp');
const { requireJwtAuth } = require('~/server/middleware');
const { findPluginAuthsByKeys } = require('~/models');
const { getLogStores } = require('~/cache');
const router = Router();
/**
* Get all MCP tools available to the user
* Returns only MCP tools, completely decoupled from regular LibreChat tools
*/
router.get('/tools', requireJwtAuth, async (req, res) => {
return getMCPTools(req, res);
});
/**
* Initiate OAuth flow
* This endpoint is called when the user clicks the auth link in the UI
@@ -135,41 +121,6 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
/** Persist tokens immediately so reconnection uses fresh credentials */
if (flowState?.userId && tokens) {
try {
await MCPTokenStorage.storeTokens({
userId: flowState.userId,
serverName,
tokens,
createToken,
updateToken,
findToken,
clientInfo: flowState.clientInfo,
metadata: flowState.metadata,
});
logger.debug('[MCP OAuth] Stored OAuth tokens prior to reconnection', {
serverName,
userId: flowState.userId,
});
} catch (error) {
logger.error('[MCP OAuth] Failed to store OAuth tokens after callback', error);
throw error;
}
/**
* Clear any cached `mcp_get_tokens` flow result so subsequent lookups
* re-fetch the freshly stored credentials instead of returning stale nulls.
*/
if (typeof flowManager?.deleteFlow === 'function') {
try {
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
} catch (error) {
logger.warn('[MCP OAuth] Failed to clear cached token flow state', error);
}
}
}
try {
const mcpManager = getMCPManager(flowState.userId);
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
@@ -193,12 +144,9 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
);
// clear any reconnection attempts
const oauthReconnectionManager = getOAuthReconnectionManager();
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
const tools = await userConnection.fetchTools();
await updateMCPServerTools({
await updateMCPUserTools({
userId: flowState.userId,
serverName,
tools,
});
@@ -340,9 +288,9 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const user = createSafeUser(req.user);
const user = req.user;
if (!user.id) {
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
@@ -372,7 +320,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
}
const result = await reinitMCPServer({
user,
req,
serverName,
userMCPAuthMap,
});

View File

@@ -3,8 +3,8 @@ const { logger } = require('@librechat/data-schemas');
const { ContentTypes } = require('librechat-data-provider');
const {
saveConvo,
getMessage,
saveMessage,
getMessage,
getMessages,
updateMessage,
deleteMessages,
@@ -58,51 +58,34 @@ router.get('/', async (req, res) => {
const nextCursor = messages.length > pageSize ? messages.pop()[sortField] : null;
response = { messages, nextCursor };
} else if (search) {
const searchResults = await Message.meiliSearch(search, { filter: `user = "${user}"` }, true);
const searchResults = await Message.meiliSearch(search, undefined, true);
const messages = searchResults.hits || [];
const result = await getConvosQueried(req.user.id, messages, cursor);
const messageIds = [];
const cleanedMessages = [];
const activeMessages = [];
for (let i = 0; i < messages.length; i++) {
let message = messages[i];
if (message.conversationId.includes('--')) {
message.conversationId = cleanUpPrimaryKeyValue(message.conversationId);
}
if (result.convoMap[message.conversationId]) {
messageIds.push(message.messageId);
cleanedMessages.push(message);
const convo = result.convoMap[message.conversationId];
const dbMessage = await getMessage({ user, messageId: message.messageId });
activeMessages.push({
...message,
title: convo.title,
conversationId: message.conversationId,
model: convo.model,
isCreatedByUser: dbMessage?.isCreatedByUser,
endpoint: dbMessage?.endpoint,
iconURL: dbMessage?.iconURL,
});
}
}
const dbMessages = await getMessages({
user,
messageId: { $in: messageIds },
});
const dbMessageMap = {};
for (const dbMessage of dbMessages) {
dbMessageMap[dbMessage.messageId] = dbMessage;
}
const activeMessages = [];
for (const message of cleanedMessages) {
const convo = result.convoMap[message.conversationId];
const dbMessage = dbMessageMap[message.messageId];
activeMessages.push({
...message,
title: convo.title,
conversationId: message.conversationId,
model: convo.model,
isCreatedByUser: dbMessage?.isCreatedByUser,
endpoint: dbMessage?.endpoint,
iconURL: dbMessage?.iconURL,
});
}
response = { messages: activeMessages, nextCursor: null };
} else {
response = { messages: [], nextCursor: null };

View File

@@ -26,12 +26,9 @@ const domains = {
router.use(logHeaders);
router.use(loginLimiter);
const oauthHandler = async (req, res, next) => {
const oauthHandler = async (req, res) => {
try {
if (res.headersSent) {
return;
}
await checkDomainAllowed(req, res);
await checkBan(req, res);
if (req.banned) {
return;
@@ -42,14 +39,13 @@ const oauthHandler = async (req, res, next) => {
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
setOpenIDAuthTokens(req.user.tokenset, res);
} else {
await setAuthTokens(req.user._id, res);
}
res.redirect(domains.client);
} catch (err) {
logger.error('Error in setting authentication tokens:', err);
next(err);
}
};
@@ -83,7 +79,6 @@ router.get(
scope: ['openid', 'profile', 'email'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -109,7 +104,6 @@ router.get(
profileFields: ['id', 'email', 'name'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -131,7 +125,6 @@ router.get(
session: false,
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -155,7 +148,6 @@ router.get(
scope: ['user:email', 'read:user'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -179,7 +171,6 @@ router.get(
scope: ['identify', 'email'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -201,7 +192,6 @@ router.post(
session: false,
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);

View File

@@ -156,7 +156,7 @@ router.get('/all', async (req, res) => {
router.get('/groups', async (req, res) => {
try {
const userId = req.user.id;
const { pageSize, limit, cursor, name, category, ...otherFilters } = req.query;
const { pageSize, pageNumber, limit, cursor, name, category, ...otherFilters } = req.query;
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
name,
@@ -171,13 +171,6 @@ router.get('/groups', async (req, res) => {
actualLimit = parseInt(pageSize, 10);
}
if (
actualCursor &&
(actualCursor === 'undefined' || actualCursor === 'null' || actualCursor.length === 0)
) {
actualCursor = null;
}
let accessibleIds = await findAccessibleResources({
userId,
role: req.user.role,
@@ -197,7 +190,6 @@ router.get('/groups', async (req, res) => {
publicPromptGroupIds: publiclyAccessibleIds,
});
// Cursor-based pagination only
const result = await getListPromptGroupsByAccess({
accessibleIds: filteredAccessibleIds,
otherParams: filter,
@@ -206,21 +198,19 @@ router.get('/groups', async (req, res) => {
});
if (!result) {
const emptyResponse = createEmptyPromptGroupsResponse({
pageNumber: '1',
pageSize: actualLimit,
actualLimit,
});
const emptyResponse = createEmptyPromptGroupsResponse({ pageNumber, pageSize, actualLimit });
return res.status(200).send(emptyResponse);
}
const { data: promptGroups = [], has_more = false, after = null } = result;
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
const response = formatPromptGroupsResponse({
promptGroups: groupsWithPublicFlag,
pageNumber: '1', // Always 1 for cursor-based pagination
pageSize: actualLimit.toString(),
pageNumber,
pageSize,
actualLimit,
hasMore: has_more,
after,
});

View File

@@ -33,11 +33,22 @@ let promptRoutes;
let Prompt, PromptGroup, AclEntry, AccessRole, User;
let testUsers, testRoles;
let grantPermission;
let currentTestUser; // Track current user for middleware
// Helper function to set user in middleware
function setTestUser(app, user) {
currentTestUser = user;
app.use((req, res, next) => {
req.user = {
...(user.toObject ? user.toObject() : user),
id: user.id || user._id.toString(),
_id: user._id,
name: user.name,
role: user.role,
};
if (user.role === SystemRoles.ADMIN) {
console.log('Setting admin user with role:', req.user.role);
}
next();
});
}
beforeAll(async () => {
@@ -64,35 +75,14 @@ beforeAll(async () => {
app = express();
app.use(express.json());
// Add user middleware before routes
app.use((req, res, next) => {
if (currentTestUser) {
req.user = {
...(currentTestUser.toObject ? currentTestUser.toObject() : currentTestUser),
id: currentTestUser._id.toString(),
_id: currentTestUser._id,
name: currentTestUser.name,
role: currentTestUser.role,
};
}
next();
});
// Mock authentication middleware - default to owner
setTestUser(app, testUsers.owner);
// Set default user
currentTestUser = testUsers.owner;
// Import routes after middleware is set up
// Import routes after mocks are set up
promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
});
afterEach(() => {
// Always reset to owner user after each test for isolation
if (currentTestUser !== testUsers.owner) {
currentTestUser = testUsers.owner;
}
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
@@ -126,26 +116,36 @@ async function setupTestData() {
// Create test users
testUsers = {
owner: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Prompt Owner',
email: 'owner@example.com',
role: SystemRoles.USER,
}),
viewer: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Prompt Viewer',
email: 'viewer@example.com',
role: SystemRoles.USER,
}),
editor: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Prompt Editor',
email: 'editor@example.com',
role: SystemRoles.USER,
}),
noAccess: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'No Access',
email: 'noaccess@example.com',
role: SystemRoles.USER,
}),
admin: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Admin',
email: 'admin@example.com',
role: SystemRoles.ADMIN,
@@ -181,7 +181,8 @@ describe('Prompt Routes - ACL Permissions', () => {
it('should have routes loaded', async () => {
// This should at least not crash
const response = await request(app).get('/api/prompts/test-404');
console.log('Test 404 response status:', response.status);
console.log('Test 404 response body:', response.body);
// We expect a 401 or 404, not 500
expect(response.status).not.toBe(500);
});
@@ -206,6 +207,12 @@ describe('Prompt Routes - ACL Permissions', () => {
const response = await request(app).post('/api/prompts').send(promptData);
if (response.status !== 200) {
console.log('POST /api/prompts error status:', response.status);
console.log('POST /api/prompts error body:', response.body);
console.log('Console errors:', consoleErrorSpy.mock.calls);
}
expect(response.status).toBe(200);
expect(response.body.prompt).toBeDefined();
expect(response.body.prompt.prompt).toBe(promptData.prompt.prompt);
@@ -311,8 +318,29 @@ describe('Prompt Routes - ACL Permissions', () => {
});
it('should allow admin access without explicit permissions', async () => {
// Set admin user
setTestUser(app, testUsers.admin);
// First, reset the app to remove previous middleware
app = express();
app.use(express.json());
// Set admin user BEFORE adding routes
app.use((req, res, next) => {
req.user = {
...testUsers.admin.toObject(),
id: testUsers.admin._id.toString(),
_id: testUsers.admin._id,
name: testUsers.admin.name,
role: testUsers.admin.role,
};
next();
});
// Now add the routes
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
console.log('Admin user:', testUsers.admin);
console.log('Admin role:', testUsers.admin.role);
console.log('SystemRoles.ADMIN:', SystemRoles.ADMIN);
const response = await request(app).get(`/api/prompts/${testPrompt._id}`).expect(200);
@@ -404,8 +432,21 @@ describe('Prompt Routes - ACL Permissions', () => {
grantedBy: testUsers.editor._id,
});
// Set viewer user
setTestUser(app, testUsers.viewer);
// Recreate app with viewer user
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = {
...testUsers.viewer.toObject(),
id: testUsers.viewer._id.toString(),
_id: testUsers.viewer._id,
name: testUsers.viewer.name,
role: testUsers.viewer.role,
};
next();
});
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
await request(app)
.delete(`/api/prompts/${authorPrompt._id}`)
@@ -458,8 +499,21 @@ describe('Prompt Routes - ACL Permissions', () => {
grantedBy: testUsers.owner._id,
});
// Ensure owner user
setTestUser(app, testUsers.owner);
// Recreate app to ensure fresh middleware
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = {
...testUsers.owner.toObject(),
id: testUsers.owner._id.toString(),
_id: testUsers.owner._id,
name: testUsers.owner.name,
role: testUsers.owner.role,
};
next();
});
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
const response = await request(app)
.patch(`/api/prompts/${testPrompt._id}/tags/production`)
@@ -483,8 +537,21 @@ describe('Prompt Routes - ACL Permissions', () => {
grantedBy: testUsers.owner._id,
});
// Set viewer user
setTestUser(app, testUsers.viewer);
// Recreate app with viewer user
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = {
...testUsers.viewer.toObject(),
id: testUsers.viewer._id.toString(),
_id: testUsers.viewer._id,
name: testUsers.viewer.name,
role: testUsers.viewer.role,
};
next();
});
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
await request(app).patch(`/api/prompts/${testPrompt._id}/tags/production`).expect(403);
@@ -543,305 +610,4 @@ describe('Prompt Routes - ACL Permissions', () => {
expect(response.body._id).toBe(publicPrompt._id.toString());
});
});
describe('Pagination', () => {
beforeEach(async () => {
// Create multiple prompt groups for pagination testing
const groups = [];
for (let i = 0; i < 15; i++) {
const group = await PromptGroup.create({
name: `Test Group ${i + 1}`,
category: 'pagination-test',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000), // Stagger updatedAt for consistent ordering
});
groups.push(group);
// Grant owner permissions on each group
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
});
afterEach(async () => {
await PromptGroup.deleteMany({});
await AclEntry.deleteMany({});
});
it('should correctly indicate hasMore when there are more pages', async () => {
const response = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10' })
.expect(200);
expect(response.body.promptGroups).toHaveLength(10);
expect(response.body.has_more).toBe(true);
expect(response.body.after).toBeTruthy();
// Since has_more is true, pages should be a high number (9999 in our fix)
expect(parseInt(response.body.pages)).toBeGreaterThan(1);
});
it('should correctly indicate no more pages on the last page', async () => {
// First get the cursor for page 2
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10' })
.expect(200);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
// Now fetch the second page using the cursor
const response = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10', cursor: firstPage.body.after })
.expect(200);
expect(response.body.promptGroups).toHaveLength(5); // 15 total, 10 on page 1, 5 on page 2
expect(response.body.has_more).toBe(false);
});
it('should support cursor-based pagination', async () => {
// First page
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5' })
.expect(200);
expect(firstPage.body.promptGroups).toHaveLength(5);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
// Second page using cursor
const secondPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', cursor: firstPage.body.after })
.expect(200);
expect(secondPage.body.promptGroups).toHaveLength(5);
expect(secondPage.body.has_more).toBe(true);
expect(secondPage.body.after).toBeTruthy();
// Verify different groups
const firstPageIds = firstPage.body.promptGroups.map((g) => g._id);
const secondPageIds = secondPage.body.promptGroups.map((g) => g._id);
expect(firstPageIds).not.toEqual(secondPageIds);
});
it('should paginate correctly with category filtering', async () => {
// Create groups with different categories
await PromptGroup.deleteMany({}); // Clear existing groups
await AclEntry.deleteMany({});
// Create 8 groups with category 'test-cat-1'
for (let i = 0; i < 8; i++) {
const group = await PromptGroup.create({
name: `Category 1 Group ${i + 1}`,
category: 'test-cat-1',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Create 7 groups with category 'test-cat-2'
for (let i = 0; i < 7; i++) {
const group = await PromptGroup.create({
name: `Category 2 Group ${i + 1}`,
category: 'test-cat-2',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - (i + 8) * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Test pagination with category filter
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', category: 'test-cat-1' })
.expect(200);
expect(firstPage.body.promptGroups).toHaveLength(5);
expect(firstPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
const secondPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', cursor: firstPage.body.after, category: 'test-cat-1' })
.expect(200);
expect(secondPage.body.promptGroups).toHaveLength(3); // 8 total, 5 on page 1, 3 on page 2
expect(secondPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
expect(secondPage.body.has_more).toBe(false);
});
it('should paginate correctly with name/keyword filtering', async () => {
// Create groups with specific names
await PromptGroup.deleteMany({}); // Clear existing groups
await AclEntry.deleteMany({});
// Create 12 groups with 'Search' in the name
for (let i = 0; i < 12; i++) {
const group = await PromptGroup.create({
name: `Search Test Group ${i + 1}`,
category: 'search-test',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Create 5 groups without 'Search' in the name
for (let i = 0; i < 5; i++) {
const group = await PromptGroup.create({
name: `Other Group ${i + 1}`,
category: 'other-test',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - (i + 12) * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Test pagination with name filter
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10', name: 'Search' })
.expect(200);
expect(firstPage.body.promptGroups).toHaveLength(10);
expect(firstPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
const secondPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10', cursor: firstPage.body.after, name: 'Search' })
.expect(200);
expect(secondPage.body.promptGroups).toHaveLength(2); // 12 total, 10 on page 1, 2 on page 2
expect(secondPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
expect(secondPage.body.has_more).toBe(false);
});
it('should paginate correctly with combined filters', async () => {
// Create groups with various combinations
await PromptGroup.deleteMany({}); // Clear existing groups
await AclEntry.deleteMany({});
// Create 6 groups matching both category and name filters
for (let i = 0; i < 6; i++) {
const group = await PromptGroup.create({
name: `API Test Group ${i + 1}`,
category: 'api-category',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Create groups that only match one filter
for (let i = 0; i < 4; i++) {
const group = await PromptGroup.create({
name: `API Other Group ${i + 1}`,
category: 'other-category',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - (i + 6) * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Test pagination with both filters
const response = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', name: 'API', category: 'api-category' })
.expect(200);
expect(response.body.promptGroups).toHaveLength(5);
expect(
response.body.promptGroups.every(
(g) => g.name.includes('API') && g.category === 'api-category',
),
).toBe(true);
expect(response.body.has_more).toBe(true);
expect(response.body.after).toBeTruthy();
// Page 2
const page2 = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', cursor: response.body.after, name: 'API', category: 'api-category' })
.expect(200);
expect(page2.body.promptGroups).toHaveLength(1); // 6 total, 5 on page 1, 1 on page 2
expect(page2.body.has_more).toBe(false);
});
});
});

View File

@@ -1,12 +1,16 @@
const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider');
const {
isEnabled,
loadOCRConfig,
loadMemoryConfig,
agentsConfigSetup,
loadWebSearchConfig,
loadDefaultInterface,
} = require('@librechat/api');
const {
FileSources,
loadOCRConfig,
EModelEndpoint,
getConfigDefaults,
} = require('librechat-data-provider');
const {
checkWebSearchConfig,
checkVariables,
@@ -45,7 +49,6 @@ const AppService = async () => {
enabled: isEnabled(process.env.CHECK_BALANCE),
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
};
const transactions = config.transactions ?? configDefaults.transactions;
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
process.env.CDN_PROVIDER = fileStrategy;
@@ -81,7 +84,6 @@ const AppService = async () => {
memory,
speech,
balance,
transactions,
mcpConfig,
webSearch,
fileStrategy,

View File

@@ -142,6 +142,7 @@ describe('AppService', () => {
turnstileConfig: mockedTurnstileConfig,
modelSpecs: undefined,
paths: expect.anything(),
ocr: expect.anything(),
imageOutputType: expect.any(String),
fileConfig: undefined,
secureImageLinks: undefined,
@@ -151,7 +152,6 @@ describe('AppService', () => {
webSearch: expect.objectContaining({
safeSearch: 1,
jinaApiKey: '${JINA_API_KEY}',
jinaApiUrl: '${JINA_API_URL}',
cohereApiKey: '${COHERE_API_KEY}',
serperApiKey: '${SERPER_API_KEY}',
searxngApiKey: '${SEARXNG_API_KEY}',

View File

@@ -2,13 +2,13 @@ const bcrypt = require('bcryptjs');
const jwt = require('jsonwebtoken');
const { webcrypto } = require('node:crypto');
const { logger } = require('@librechat/data-schemas');
const { isEnabled, checkEmailConfig, isEmailDomainAllowed } = require('@librechat/api');
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
const { isEnabled, checkEmailConfig } = require('@librechat/api');
const { SystemRoles, errorsToString } = require('librechat-data-provider');
const {
findUser,
findToken,
createUser,
updateUser,
findToken,
countUsers,
getUserById,
findSession,
@@ -20,6 +20,7 @@ const {
deleteUserById,
generateRefreshToken,
} = require('~/models');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { registerSchema } = require('~/strategies/validators');
const { getAppConfig } = require('~/server/services/Config');
const { sendEmail } = require('~/server/utils');
@@ -129,7 +130,7 @@ const verifyEmail = async (req) => {
return { message: 'Email already verified', status: 'success' };
}
let emailVerificationData = await findToken({ email: decodedEmail }, { sort: { createdAt: -1 } });
let emailVerificationData = await findToken({ email: decodedEmail });
if (!emailVerificationData) {
logger.warn(`[verifyEmail] [No email verification data found] [Email: ${decodedEmail}]`);
@@ -180,14 +181,6 @@ const registerUser = async (user, additionalData = {}) => {
let newUserId;
try {
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
return { status: 403, message: errorMessage };
}
const existingUser = await findUser({ email }, 'email _id');
if (existingUser) {
@@ -202,6 +195,14 @@ const registerUser = async (user, additionalData = {}) => {
return { status: 200, message: genericVerificationMessage };
}
const appConfig = await getAppConfig({ role: user.role });
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
return { status: 403, message: errorMessage };
}
//determine if this is the first registered user (not counting anonymous_user)
const isFirstRegisteredUser = (await countUsers()) === 0;
@@ -251,13 +252,6 @@ const registerUser = async (user, additionalData = {}) => {
*/
const requestPasswordReset = async (req) => {
const { email } = req.body;
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const error = new Error(ErrorTypes.AUTH_FAILED);
error.code = ErrorTypes.AUTH_FAILED;
error.message = 'Email domain not allowed';
return error;
}
const user = await findUser({ email }, 'email _id');
const emailEnabled = checkEmailConfig();
@@ -319,12 +313,9 @@ const requestPasswordReset = async (req) => {
* @returns
*/
const resetPassword = async (userId, token, password) => {
let passwordResetToken = await findToken(
{
userId,
},
{ sort: { createdAt: -1 } },
);
let passwordResetToken = await findToken({
userId,
});
if (!passwordResetToken) {
return new Error('Invalid or expired password reset token');
@@ -359,18 +350,23 @@ const resetPassword = async (userId, token, password) => {
/**
* Set Auth Tokens
*
* @param {String | ObjectId} userId
* @param {ServerResponse} res
* @param {ISession | null} [session=null]
* @param {Object} res
* @param {String} sessionId
* @returns
*/
const setAuthTokens = async (userId, res, _session = null) => {
const setAuthTokens = async (userId, res, sessionId = null) => {
try {
let session = _session;
const user = await getUserById(userId);
const token = await generateToken(user);
let session;
let refreshToken;
let refreshTokenExpires;
if (session && session._id && session.expiration != null) {
if (sessionId) {
session = await findSession({ sessionId: sessionId }, { lean: false });
refreshTokenExpires = session.expiration.getTime();
refreshToken = await generateRefreshToken(session);
} else {
@@ -380,9 +376,6 @@ const setAuthTokens = async (userId, res, _session = null) => {
refreshTokenExpires = session.expiration.getTime();
}
const user = await getUserById(userId);
const token = await generateToken(user);
res.cookie('refreshToken', refreshToken, {
expires: new Date(refreshTokenExpires),
httpOnly: true,
@@ -409,10 +402,9 @@ const setAuthTokens = async (userId, res, _session = null) => {
* @param {import('openid-client').TokenEndpointResponse & import('openid-client').TokenEndpointResponseHelpers} tokenset
* - The tokenset object containing access and refresh tokens
* @param {Object} res - response object
* @param {string} [userId] - Optional MongoDB user ID for image path validation
* @returns {String} - access token
*/
const setOpenIDAuthTokens = (tokenset, res, userId) => {
const setOpenIDAuthTokens = (tokenset, res) => {
try {
if (!tokenset) {
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
@@ -443,18 +435,6 @@ const setOpenIDAuthTokens = (tokenset, res, userId) => {
secure: isProduction,
sameSite: 'strict',
});
if (userId && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
/** JWT-signed user ID cookie for image path validation when OPENID_REUSE_TOKENS is enabled */
const signedUserId = jwt.sign({ id: userId }, process.env.JWT_REFRESH_SECRET, {
expiresIn: expiryInMilliseconds / 1000,
});
res.cookie('openid_user_id', signedUserId, {
expires: expirationDate,
httpOnly: true,
secure: isProduction,
sameSite: 'strict',
});
}
return tokenset.access_token;
} catch (error) {
logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error);
@@ -472,7 +452,7 @@ const setOpenIDAuthTokens = (tokenset, res, userId) => {
const resendVerificationEmail = async (req) => {
try {
const { email } = req.body;
await deleteTokens({ email });
await deleteTokens(email);
const user = await findUser({ email }, 'email _id name');
if (!user) {

View File

@@ -4,8 +4,6 @@ const AppService = require('~/server/services/AppService');
const { setCachedTools } = require('./getCachedTools');
const getLogStores = require('~/cache/getLogStores');
const BASE_CONFIG_KEY = '_BASE_';
/**
* Get the app configuration based on user context
* @param {Object} [options]
@@ -16,8 +14,8 @@ const BASE_CONFIG_KEY = '_BASE_';
async function getAppConfig(options = {}) {
const { role, refresh } = options;
const cache = getLogStores(CacheKeys.APP_CONFIG);
const cacheKey = role ? role : BASE_CONFIG_KEY;
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cacheKey = role ? `${CacheKeys.APP_CONFIG}:${role}` : CacheKeys.APP_CONFIG;
if (!refresh) {
const cached = await cache.get(cacheKey);
@@ -26,7 +24,7 @@ async function getAppConfig(options = {}) {
}
}
let baseConfig = await cache.get(BASE_CONFIG_KEY);
let baseConfig = await cache.get(CacheKeys.APP_CONFIG);
if (!baseConfig) {
logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...');
baseConfig = await AppService();
@@ -36,10 +34,10 @@ async function getAppConfig(options = {}) {
}
if (baseConfig.availableTools) {
await setCachedTools(baseConfig.availableTools);
await setCachedTools(baseConfig.availableTools, { isGlobal: true });
}
await cache.set(BASE_CONFIG_KEY, baseConfig);
await cache.set(CacheKeys.APP_CONFIG, baseConfig);
}
// For now, return the base config

View File

@@ -3,32 +3,89 @@ const getLogStores = require('~/cache/getLogStores');
/**
* Cache key generators for different tool access patterns
* These will support future permission-based caching
*/
const ToolCacheKeys = {
/** Global tools available to all users */
GLOBAL: 'tools:global',
/** MCP tools cached by server name */
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
/** Tools available to a specific user */
USER: (userId) => `tools:user:${userId}`,
/** Tools available to a specific role */
ROLE: (roleId) => `tools:role:${roleId}`,
/** Tools available to a specific group */
GROUP: (groupId) => `tools:group:${groupId}`,
/** Combined effective tools for a user (computed from all sources) */
EFFECTIVE: (userId) => `tools:effective:${userId}`,
};
/**
* Retrieves available tools from cache
* @function getCachedTools
* @param {Object} options - Options for retrieving tools
* @param {string} [options.serverName] - MCP server name to get cached tools for
* @param {string} [options.userId] - User ID for user-specific tools
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName } = options;
const { userId, roleIds = [], groupIds = [], includeGlobal = true } = options;
// Return MCP server-specific tools if requested
if (serverName) {
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
// For now, return global tools (current behavior)
// This will be expanded to merge tools from different sources
if (!userId && includeGlobal) {
return await cache.get(ToolCacheKeys.GLOBAL);
}
// Default to global tools
return await cache.get(ToolCacheKeys.GLOBAL);
// Future implementation will merge tools from multiple sources
// based on user permissions, roles, and groups
if (userId) {
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
if (effectiveTools) {
return effectiveTools;
}
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
const toolSources = [];
if (includeGlobal) {
const globalTools = await cache.get(ToolCacheKeys.GLOBAL);
if (globalTools) {
toolSources.push(globalTools);
}
}
// User-specific tools
const userTools = await cache.get(ToolCacheKeys.USER(userId));
if (userTools) {
toolSources.push(userTools);
}
// Role-based tools
for (const roleId of roleIds) {
const roleTools = await cache.get(ToolCacheKeys.ROLE(roleId));
if (roleTools) {
toolSources.push(roleTools);
}
}
// Group-based tools
for (const groupId of groupIds) {
const groupTools = await cache.get(ToolCacheKeys.GROUP(groupId));
if (groupTools) {
toolSources.push(groupTools);
}
}
// Merge all tool sources (for now, simple merge - future will handle conflicts)
if (toolSources.length > 0) {
return mergeToolSources(toolSources);
}
}
return null;
}
/**
@@ -36,34 +93,49 @@ async function getCachedTools(options = {}) {
* @function setCachedTools
* @param {Object} tools - The tools object to cache
* @param {Object} options - Options for caching tools
* @param {string} [options.serverName] - MCP server name for server-specific tools
* @param {string} [options.userId] - User ID for user-specific tools
* @param {string} [options.roleId] - Role ID for role-based tools
* @param {string} [options.groupId] - Group ID for group-based tools
* @param {boolean} [options.isGlobal=false] - Whether these are global tools
* @param {number} [options.ttl] - Time to live in milliseconds
* @returns {Promise<boolean>} Whether the operation was successful
*/
async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, ttl } = options;
const { userId, roleId, groupId, isGlobal = false, ttl } = options;
// Cache by MCP server if specified
if (serverName) {
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
let cacheKey;
if (isGlobal || (!userId && !roleId && !groupId)) {
cacheKey = ToolCacheKeys.GLOBAL;
} else if (userId) {
cacheKey = ToolCacheKeys.USER(userId);
} else if (roleId) {
cacheKey = ToolCacheKeys.ROLE(roleId);
} else if (groupId) {
cacheKey = ToolCacheKeys.GROUP(groupId);
}
// Default to global cache
return await cache.set(ToolCacheKeys.GLOBAL, tools, ttl);
if (!cacheKey) {
throw new Error('Invalid cache key options provided');
}
return await cache.set(cacheKey, tools, ttl);
}
/**
* Invalidates cached tools
* @function invalidateCachedTools
* @param {Object} options - Options for invalidating tools
* @param {string} [options.serverName] - MCP server name to invalidate
* @param {string} [options.userId] - User ID to invalidate
* @param {string} [options.roleId] - Role ID to invalidate
* @param {string} [options.groupId] - Group ID to invalidate
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
* @param {boolean} [options.invalidateEffective=true] - Whether to invalidate effective tools
* @returns {Promise<void>}
*/
async function invalidateCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, invalidateGlobal = false } = options;
const { userId, roleId, groupId, invalidateGlobal = false, invalidateEffective = true } = options;
const keysToDelete = [];
@@ -71,34 +143,116 @@ async function invalidateCachedTools(options = {}) {
keysToDelete.push(ToolCacheKeys.GLOBAL);
}
if (serverName) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
if (userId) {
keysToDelete.push(ToolCacheKeys.USER(userId));
if (invalidateEffective) {
keysToDelete.push(ToolCacheKeys.EFFECTIVE(userId));
}
}
if (roleId) {
keysToDelete.push(ToolCacheKeys.ROLE(roleId));
// TODO: In future, invalidate all users with this role
}
if (groupId) {
keysToDelete.push(ToolCacheKeys.GROUP(groupId));
// TODO: In future, invalidate all users in this group
}
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
}
/**
* Gets MCP tools for a specific server from cache or merges with global tools
* @function getMCPServerTools
* @param {string} serverName - The MCP server name
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
* Computes and caches effective tools for a user
* @function computeEffectiveTools
* @param {string} userId - The user ID
* @param {Object} context - Context containing user's roles and groups
* @param {string[]} [context.roleIds=[]] - User's role IDs
* @param {string[]} [context.groupIds=[]] - User's group IDs
* @param {number} [ttl] - Time to live for the computed result
* @returns {Promise<Object>} The computed effective tools
*/
async function getMCPServerTools(serverName) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
async function computeEffectiveTools(userId, context = {}, ttl) {
const { roleIds = [], groupIds = [] } = context;
if (serverTools) {
return serverTools;
// Get all tool sources
const tools = await getCachedTools({
userId,
roleIds,
groupIds,
includeGlobal: true,
});
if (tools) {
// Cache the computed result
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.set(ToolCacheKeys.EFFECTIVE(userId), tools, ttl);
}
return null;
return tools;
}
/**
* Merges multiple tool sources into a single tools object
* @function mergeToolSources
* @param {Object[]} sources - Array of tool objects to merge
* @returns {Object} Merged tools object
*/
function mergeToolSources(sources) {
// For now, simple merge that combines all tools
// Future implementation will handle:
// - Permission precedence (deny > allow)
// - Tool property conflicts
// - Metadata merging
const merged = {};
for (const source of sources) {
if (!source || typeof source !== 'object') {
continue;
}
for (const [toolId, toolConfig] of Object.entries(source)) {
// Simple last-write-wins for now
// Future: merge based on permission levels
merged[toolId] = toolConfig;
}
}
return merged;
}
/**
* Middleware-friendly function to get tools for a request
* @function getToolsForRequest
* @param {Object} req - Express request object
* @returns {Promise<Object|null>} Available tools for the request
*/
async function getToolsForRequest(req) {
const userId = req.user?.id;
// For now, return global tools if no user
if (!userId) {
return getCachedTools({ includeGlobal: true });
}
// Future: Extract roles and groups from req.user
const roleIds = req.user?.roles || [];
const groupIds = req.user?.groups || [];
return getCachedTools({
userId,
roleIds,
groupIds,
includeGlobal: true,
});
}
module.exports = {
ToolCacheKeys,
getCachedTools,
setCachedTools,
getMCPServerTools,
getToolsForRequest,
invalidateCachedTools,
computeEffectiveTools,
};

View File

@@ -1,7 +1,7 @@
const appConfig = require('./app');
const mcpToolsCache = require('./mcp');
const { config } = require('./EndpointService');
const getCachedTools = require('./getCachedTools');
const mcpToolsCache = require('./mcpToolsCache');
const loadCustomConfig = require('./loadCustomConfig');
const loadConfigModels = require('./loadConfigModels');
const loadDefaultModels = require('./loadDefaultModels');

View File

@@ -119,6 +119,10 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
.filter((endpoint) => endpoint.customParams)
.forEach((endpoint) => parseCustomParams(endpoint.name, endpoint.customParams));
if (customConfig.cache) {
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
await cache.set(CacheKeys.LIBRECHAT_YAML_CONFIG, customConfig);
}
if (result.data.modelSpecs) {
customConfig.modelSpecs = result.data.modelSpecs;

View File

@@ -48,11 +48,16 @@ const axios = require('axios');
const { loadYaml } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
describe('loadCustomConfig', () => {
const mockSet = jest.fn();
const mockCache = { set: mockSet };
beforeEach(() => {
jest.resetAllMocks();
delete process.env.CONFIG_PATH;
getLogStores.mockReturnValue(mockCache);
});
it('should return null and log error if remote config fetch fails', async () => {
@@ -89,6 +94,7 @@ describe('loadCustomConfig', () => {
const result = await loadCustomConfig();
expect(result).toEqual(mockConfig);
expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig);
});
it('should return null and log if config schema validation fails', async () => {
@@ -128,6 +134,7 @@ describe('loadCustomConfig', () => {
axios.get.mockResolvedValue({ data: mockConfig });
const result = await loadCustomConfig();
expect(result).toEqual(mockConfig);
expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig);
});
it('should return null if the remote config file is not found', async () => {
@@ -161,6 +168,7 @@ describe('loadCustomConfig', () => {
process.env.CONFIG_PATH = 'validConfig.yaml';
loadYaml.mockReturnValueOnce(mockConfig);
await loadCustomConfig();
expect(mockSet).not.toHaveBeenCalled();
});
it('should log the loaded custom config', async () => {

View File

@@ -1,91 +0,0 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getCachedTools, setCachedTools } = require('./getCachedTools');
const { getLogStores } = require('~/cache');
/**
* Updates MCP tools in the cache for a specific server
* @param {Object} params - Parameters for updating MCP tools
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPServerTools({ serverName, tools }) {
try {
const serverTools = {};
const mcpDelimiter = Constants.mcp_delimiter;
for (const tool of tools) {
const name = `${tool.name}${mcpDelimiter}${serverName}`;
serverTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema,
},
};
}
await setCachedTools(serverTools, { serverName });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
return serverTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error;
}
}
/**
* Merges app-level tools with global tools
* @param {import('@librechat/api').LCAvailableTools} appTools
* @returns {Promise<void>}
*/
async function mergeAppTools(appTools) {
try {
const count = Object.keys(appTools).length;
if (!count) {
return;
}
const cachedTools = await getCachedTools();
const mergedTools = { ...cachedTools, ...appTools };
await setCachedTools(mergedTools);
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} app-level tools`);
} catch (error) {
logger.error('Failed to merge app-level tools:', error);
throw error;
}
}
/**
* Caches MCP server tools (no longer merges with global)
* @param {object} params
* @param {string} params.serverName
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
* @returns {Promise<void>}
*/
async function cacheMCPServerTools({ serverName, serverTools }) {
try {
const count = Object.keys(serverTools).length;
if (!count) {
return;
}
// Only cache server-specific tools, no merging with global
await setCachedTools(serverTools, { serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
} catch (error) {
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
throw error;
}
}
module.exports = {
mergeAppTools,
cacheMCPServerTools,
updateMCPServerTools,
};

View File

@@ -0,0 +1,143 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getCachedTools, setCachedTools } = require('./getCachedTools');
const { getLogStores } = require('~/cache');
/**
* Updates MCP tools in the cache for a specific server and user
* @param {Object} params - Parameters for updating MCP tools
* @param {string} params.userId - User ID
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPUserTools({ userId, serverName, tools }) {
try {
const userTools = await getCachedTools({ userId });
const mcpDelimiter = Constants.mcp_delimiter;
for (const key of Object.keys(userTools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete userTools[key];
}
}
for (const tool of tools) {
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
userTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema,
},
};
}
await setCachedTools(userTools, { userId });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
return userTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error;
}
}
/**
* Merges app-level tools with global tools
* @param {import('@librechat/api').LCAvailableTools} appTools
* @returns {Promise<void>}
*/
async function mergeAppTools(appTools) {
try {
const count = Object.keys(appTools).length;
if (!count) {
return;
}
const cachedTools = await getCachedTools({ includeGlobal: true });
const mergedTools = { ...cachedTools, ...appTools };
await setCachedTools(mergedTools, { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} app-level tools`);
} catch (error) {
logger.error('Failed to merge app-level tools:', error);
throw error;
}
}
/**
* Merges user-level tools with global tools
* @param {object} params
* @param {string} params.userId
* @param {Record<string, FunctionTool>} params.cachedUserTools
* @param {import('@librechat/api').LCAvailableTools} params.userTools
* @returns {Promise<void>}
*/
async function mergeUserTools({ userId, cachedUserTools, userTools }) {
try {
if (!userId) {
return;
}
const count = Object.keys(userTools).length;
if (!count) {
return;
}
const cachedTools = cachedUserTools ?? (await getCachedTools({ userId }));
const mergedTools = { ...cachedTools, ...userTools };
await setCachedTools(mergedTools, { userId });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} user-level tools`);
} catch (error) {
logger.error('Failed to merge user-level tools:', error);
throw error;
}
}
/**
* Clears all MCP tools for a specific server
* @param {Object} params - Parameters for clearing MCP tools
* @param {string} [params.userId] - User ID (if clearing user-specific tools)
* @param {string} params.serverName - MCP server name
* @returns {Promise<void>}
*/
async function clearMCPServerTools({ userId, serverName }) {
try {
const tools = await getCachedTools({ userId, includeGlobal: !userId });
// Remove all tools for this server
const mcpDelimiter = Constants.mcp_delimiter;
let removedCount = 0;
for (const key of Object.keys(tools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete tools[key];
removedCount++;
}
}
if (removedCount > 0) {
await setCachedTools(tools, userId ? { userId } : { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(
`[MCP Cache] Removed ${removedCount} tools for ${serverName}${userId ? ` user ${userId}` : ' (global)'}`,
);
}
} catch (error) {
logger.error(`[MCP Cache] Failed to clear tools for ${serverName}:`, error);
throw error;
}
}
module.exports = {
mergeAppTools,
mergeUserTools,
updateMCPUserTools,
clearMCPServerTools,
};

View File

@@ -1,7 +1,6 @@
const { Providers } = require('@librechat/agents');
const {
primeResources,
getModelMaxTokens,
extractLibreChatParams,
optionalChainWithEmptyCheck,
} = require('@librechat/api');
@@ -18,6 +17,7 @@ const { getProviderConfig } = require('~/server/services/Endpoints');
const { processFiles } = require('~/server/services/Files/process');
const { getFiles, getToolFilesByIds } = require('~/models/File');
const { getConvoFiles } = require('~/models/Conversation');
const { getModelMaxTokens } = require('~/utils');
/**
* @param {object} params

View File

@@ -54,11 +54,6 @@ const addTitle = async (req, { text, response, client }) => {
clearTimeout(timeoutId);
}
if (!title) {
logger.debug(`[${key}] No title generated`);
return;
}
await titleCache.set(key, title, 120000);
await saveConvo(
req,

View File

@@ -1,14 +1,13 @@
import { logger } from '@librechat/data-schemas';
import { AnthropicClientOptions } from '@librechat/agents';
import { EModelEndpoint, anthropicSettings } from 'librechat-data-provider';
import { matchModelName } from '~/utils/tokens';
const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider');
const { matchModelName } = require('~/utils');
const { logger } = require('~/config');
/**
* @param {string} modelName
* @returns {boolean}
*/
function checkPromptCacheSupport(modelName: string): boolean {
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic) ?? '';
function checkPromptCacheSupport(modelName) {
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic);
if (
modelMatch.includes('claude-3-5-sonnet-latest') ||
modelMatch.includes('claude-3.5-sonnet-latest')
@@ -32,10 +31,7 @@ function checkPromptCacheSupport(modelName: string): boolean {
* @param {boolean} supportsCacheControl Whether the model supports cache control
* @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable
*/
function getClaudeHeaders(
model: string,
supportsCacheControl: boolean,
): Record<string, string> | undefined {
function getClaudeHeaders(model, supportsCacheControl) {
if (!supportsCacheControl) {
return undefined;
}
@@ -76,13 +72,9 @@ function getClaudeHeaders(
* @param {number|null} extendedOptions.thinkingBudget The token budget for thinking
* @returns {Object} Updated request options
*/
function configureReasoning(
anthropicInput: AnthropicClientOptions & { max_tokens?: number },
extendedOptions: { thinking?: boolean; thinkingBudget?: number | null } = {},
): AnthropicClientOptions & { max_tokens?: number } {
function configureReasoning(anthropicInput, extendedOptions = {}) {
const updatedOptions = { ...anthropicInput };
const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens;
if (
extendedOptions.thinking &&
updatedOptions?.model &&
@@ -90,16 +82,11 @@ function configureReasoning(
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(updatedOptions.model))
) {
updatedOptions.thinking = {
...updatedOptions.thinking,
type: 'enabled',
} as { type: 'enabled'; budget_tokens: number };
};
}
if (
updatedOptions.thinking != null &&
extendedOptions.thinkingBudget != null &&
updatedOptions.thinking.type === 'enabled'
) {
if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) {
updatedOptions.thinking = {
...updatedOptions.thinking,
budget_tokens: extendedOptions.thinkingBudget,
@@ -108,10 +95,9 @@ function configureReasoning(
if (
updatedOptions.thinking != null &&
updatedOptions.thinking.type === 'enabled' &&
(currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens)
) {
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model ?? '');
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model);
updatedOptions.max_tokens = currentMaxTokens ?? maxTokens;
logger.warn(
@@ -129,4 +115,4 @@ function configureReasoning(
return updatedOptions;
}
export { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };
module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };

View File

@@ -1,6 +1,6 @@
const { getLLMConfig } = require('@librechat/api');
const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
const AnthropicClient = require('~/app/clients/AnthropicClient');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
@@ -40,6 +40,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
clientOptions = Object.assign(
{
proxy: PROXY ?? null,
userId: req.user.id,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
modelOptions: endpointOption?.model_parameters ?? {},
},
@@ -48,7 +49,6 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
if (overrideModel) {
clientOptions.modelOptions.model = overrideModel;
}
clientOptions.modelOptions.user = req.user.id;
return getLLMConfig(anthropicApiKey, clientOptions);
}

View File

@@ -0,0 +1,103 @@
const { ProxyAgent } = require('undici');
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
/**
* Generates configuration options for creating an Anthropic language model (LLM) instance.
*
* @param {string} apiKey - The API key for authentication with Anthropic.
* @param {Object} [options={}] - Additional options for configuring the LLM.
* @param {Object} [options.modelOptions] - Model-specific options.
* @param {string} [options.modelOptions.model] - The name of the model to use.
* @param {number} [options.modelOptions.maxOutputTokens] - The maximum number of tokens to generate.
* @param {number} [options.modelOptions.temperature] - Controls randomness in output generation.
* @param {number} [options.modelOptions.topP] - Controls diversity of output generation.
* @param {number} [options.modelOptions.topK] - Controls the number of top tokens to consider.
* @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens.
* @param {boolean} [options.modelOptions.stream] - Whether to stream the response.
* @param {string} options.userId - The user ID for tracking and personalization.
* @param {string} [options.proxy] - Proxy server URL.
* @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used.
*
* @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed.
*/
function getLLMConfig(apiKey, options = {}) {
const systemOptions = {
thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default,
promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default,
thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default,
};
for (let key in systemOptions) {
delete options.modelOptions[key];
}
const defaultOptions = {
model: anthropicSettings.model.default,
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
stream: true,
};
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
/** @type {AnthropicClientOptions} */
let requestOptions = {
apiKey,
model: mergedOptions.model,
stream: mergedOptions.stream,
temperature: mergedOptions.temperature,
stopSequences: mergedOptions.stop,
maxTokens:
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
clientOptions: {},
invocationKwargs: {
metadata: {
user_id: options.userId,
},
},
};
requestOptions = configureReasoning(requestOptions, systemOptions);
if (!/claude-3[-.]7/.test(mergedOptions.model)) {
requestOptions.topP = mergedOptions.topP;
requestOptions.topK = mergedOptions.topK;
} else if (requestOptions.thinking == null) {
requestOptions.topP = mergedOptions.topP;
requestOptions.topK = mergedOptions.topK;
}
const supportsCacheControl =
systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model);
const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl);
if (headers) {
requestOptions.clientOptions.defaultHeaders = headers;
}
if (options.proxy) {
const proxyAgent = new ProxyAgent(options.proxy);
requestOptions.clientOptions.fetchOptions = {
dispatcher: proxyAgent,
};
}
if (options.reverseProxyUrl) {
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
}
const tools = [];
if (mergedOptions.web_search) {
tools.push({
type: 'web_search_20250305',
name: 'web_search',
});
}
return {
tools,
/** @type {AnthropicClientOptions} */
llmConfig: removeNullishValues(requestOptions),
};
}
module.exports = { getLLMConfig };

View File

@@ -0,0 +1,341 @@
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
jest.mock('https-proxy-agent', () => ({
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
}));
describe('getLLMConfig', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should create a basic configuration with default values', () => {
const result = getLLMConfig('test-api-key', { modelOptions: {} });
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
expect(result.llmConfig).toHaveProperty('model', 'claude-3-5-sonnet-latest');
expect(result.llmConfig).toHaveProperty('stream', true);
expect(result.llmConfig).toHaveProperty('maxTokens');
});
it('should include proxy settings when provided', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
proxy: 'http://proxy:8080',
});
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
'ProxyAgent',
);
});
it('should include reverse proxy URL when provided', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
reverseProxyUrl: 'http://reverse-proxy',
});
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy');
});
it('should include topK and topP for non-Claude-3.7 models', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
topK: 10,
topP: 0.9,
},
});
expect(result.llmConfig).toHaveProperty('topK', 10);
expect(result.llmConfig).toHaveProperty('topP', 0.9);
});
it('should include topK and topP for Claude-3.5 models', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
topK: 10,
topP: 0.9,
},
});
expect(result.llmConfig).toHaveProperty('topK', 10);
expect(result.llmConfig).toHaveProperty('topP', 0.9);
});
it('should NOT include topK and topP for Claude-3-7 models with thinking enabled (hyphen notation)', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
topK: 10,
topP: 0.9,
thinking: true,
},
});
expect(result.llmConfig).not.toHaveProperty('topK');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('thinking');
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
// When thinking is enabled, it uses the default thinkingBudget of 2000
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
});
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model', () => {
const modelOptions = {
model: 'claude-sonnet-4-20250514',
promptCache: true,
};
const result = getLLMConfig('test-key', { modelOptions });
const clientOptions = result.llmConfig.clientOptions;
expect(clientOptions.defaultHeaders).toBeDefined();
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31,context-1m-2025-08-07',
);
});
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model formats', () => {
const modelVariations = [
'claude-sonnet-4-20250514',
'claude-sonnet-4-latest',
'anthropic/claude-sonnet-4-20250514',
];
modelVariations.forEach((model) => {
const modelOptions = { model, promptCache: true };
const result = getLLMConfig('test-key', { modelOptions });
const clientOptions = result.llmConfig.clientOptions;
expect(clientOptions.defaultHeaders).toBeDefined();
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31,context-1m-2025-08-07',
);
});
});
it('should NOT include topK and topP for Claude-3.7 models with thinking enabled (decimal notation)', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3.7-sonnet',
topK: 10,
topP: 0.9,
thinking: true,
},
});
expect(result.llmConfig).not.toHaveProperty('topK');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('thinking');
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
// When thinking is enabled, it uses the default thinkingBudget of 2000
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
});
it('should handle custom maxOutputTokens', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
maxOutputTokens: 2048,
},
});
expect(result.llmConfig).toHaveProperty('maxTokens', 2048);
});
it('should handle promptCache setting', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
promptCache: true,
},
});
// We're not checking specific header values since that depends on the actual helper function
// Just verifying that the promptCache setting is processed
expect(result.llmConfig).toBeDefined();
});
it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => {
// Test with thinking explicitly set to null/undefined
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
topK: 10,
topP: 0.9,
thinking: false,
},
});
expect(result.llmConfig).toHaveProperty('topK', 10);
expect(result.llmConfig).toHaveProperty('topP', 0.9);
// Test with thinking explicitly set to false
const result2 = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
topK: 10,
topP: 0.9,
thinking: false,
},
});
expect(result2.llmConfig).toHaveProperty('topK', 10);
expect(result2.llmConfig).toHaveProperty('topP', 0.9);
// Test with decimal notation as well
const result3 = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3.7-sonnet',
topK: 10,
topP: 0.9,
thinking: false,
},
});
expect(result3.llmConfig).toHaveProperty('topK', 10);
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
});
describe('Edge cases', () => {
it('should handle missing apiKey', () => {
const result = getLLMConfig(undefined, { modelOptions: {} });
expect(result.llmConfig).not.toHaveProperty('apiKey');
});
it('should handle empty modelOptions', () => {
expect(() => {
getLLMConfig('test-api-key', {});
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
});
it('should handle no options parameter', () => {
expect(() => {
getLLMConfig('test-api-key');
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
});
it('should handle temperature, stop sequences, and stream settings', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
temperature: 0.7,
stop: ['\n\n', 'END'],
stream: false,
},
});
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']);
expect(result.llmConfig).toHaveProperty('stream', false);
});
it('should handle maxOutputTokens when explicitly set to falsy value', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
maxOutputTokens: null,
},
});
// The actual anthropicSettings.maxOutputTokens.reset('claude-3-opus') returns 4096
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
});
it('should handle both proxy and reverseProxyUrl', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
proxy: 'http://proxy:8080',
reverseProxyUrl: 'https://reverse-proxy.com',
});
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
'ProxyAgent',
);
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com');
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com');
});
it('should handle prompt cache with supported model', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
promptCache: true,
},
});
// claude-3-5-sonnet supports prompt caching and should get the appropriate headers
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31',
});
});
it('should handle thinking and thinkingBudget options', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
thinking: true,
thinkingBudget: 10000, // This exceeds the default max_tokens of 8192
},
});
// The function should add thinking configuration for claude-3-7 models
expect(result.llmConfig).toHaveProperty('thinking');
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
// With claude-3-7-sonnet, the max_tokens default is 8192
// Budget tokens gets adjusted to 90% of max_tokens (8192 * 0.9 = 7372) when it exceeds max_tokens
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 7372);
// Test with budget_tokens within max_tokens limit
const result2 = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
thinking: true,
thinkingBudget: 2000,
},
});
expect(result2.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
});
it('should remove system options from modelOptions', () => {
const modelOptions = {
model: 'claude-3-opus',
thinking: true,
promptCache: true,
thinkingBudget: 1000,
temperature: 0.5,
};
getLLMConfig('test-api-key', { modelOptions });
expect(modelOptions).not.toHaveProperty('thinking');
expect(modelOptions).not.toHaveProperty('promptCache');
expect(modelOptions).not.toHaveProperty('thinkingBudget');
expect(modelOptions).toHaveProperty('temperature', 0.5);
});
it('should handle all nullish values removal', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
temperature: null,
topP: undefined,
topK: 0,
stop: [],
},
});
expect(result.llmConfig).not.toHaveProperty('temperature');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('topK', 0);
expect(result.llmConfig).toHaveProperty('stopSequences', []);
});
});
});

View File

@@ -1,4 +1,3 @@
const { getModelMaxTokens } = require('@librechat/api');
const { createContentAggregator } = require('@librechat/agents');
const {
EModelEndpoint,
@@ -8,6 +7,7 @@ const {
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
const getOptions = require('~/server/services/Endpoints/bedrock/options');
const AgentClient = require('~/server/controllers/agents/client');
const { getModelMaxTokens } = require('~/utils');
const initializeClient = async ({ req, res, endpointOption }) => {
if (!endpointOption) {

View File

@@ -36,12 +36,10 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
/** Intentionally excludes passing `body`, i.e. `req.body`, as
* values may not be accurate until `AgentClient` is initialized
*/
let resolvedHeaders = resolveHeaders({
headers: endpointConfig.headers,
user: req.user,
body: req.body,
});
if (CUSTOM_API_KEY.match(envVarRegex)) {

View File

@@ -76,10 +76,7 @@ describe('custom/initializeClient', () => {
expect(resolveHeaders).toHaveBeenCalledWith({
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
user: { id: 'user-123', email: 'test@example.com', role: 'user' },
/**
* Note: Request-based Header Resolution is deferred until right before LLM request is made
body: { endpoint: 'test-endpoint' }, // body - supports {{LIBRECHAT_BODY_*}} placeholders
*/
});
});

View File

@@ -159,11 +159,9 @@ class STTService {
* Prepares the request for the OpenAI STT provider.
* @param {Object} sttSchema - The STT schema for OpenAI.
* @param {Stream} audioReadStream - The audio data to be transcribed.
* @param {Object} audioFile - The audio file object (unused in OpenAI provider).
* @param {string} language - The language code for the transcription.
* @returns {Array} An array containing the URL, data, and headers for the request.
*/
openAIProvider(sttSchema, audioReadStream, audioFile, language) {
openAIProvider(sttSchema, audioReadStream) {
const url = sttSchema?.url || 'https://api.openai.com/v1/audio/transcriptions';
const apiKey = extractEnvVariable(sttSchema.apiKey) || '';
@@ -172,12 +170,6 @@ class STTService {
model: sttSchema.model,
};
if (language) {
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
const isoLanguage = language.split('-')[0];
data.language = isoLanguage;
}
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { Authorization: `Bearer ${apiKey}` }),
@@ -192,11 +184,10 @@ class STTService {
* @param {Object} sttSchema - The STT schema for Azure OpenAI.
* @param {Buffer} audioBuffer - The audio data to be transcribed.
* @param {Object} audioFile - The audio file object containing originalname, mimetype, and size.
* @param {string} language - The language code for the transcription.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the audio file size exceeds 25MB or the audio file format is not accepted.
*/
azureOpenAIProvider(sttSchema, audioBuffer, audioFile, language) {
azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
const url = `${genAzureEndpoint({
azureOpenAIApiInstanceName: extractEnvVariable(sttSchema?.instanceName),
azureOpenAIApiDeploymentName: extractEnvVariable(sttSchema?.deploymentName),
@@ -220,12 +211,6 @@ class STTService {
contentType: audioFile.mimetype,
});
if (language) {
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
const isoLanguage = language.split('-')[0];
formData.append('language', isoLanguage);
}
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { 'api-key': apiKey }),
@@ -244,11 +229,10 @@ class STTService {
* @param {Object} requestData - The data required for the STT request.
* @param {Buffer} requestData.audioBuffer - The audio data to be transcribed.
* @param {Object} requestData.audioFile - The audio file object containing originalname, mimetype, and size.
* @param {string} requestData.language - The language code for the transcription.
* @returns {Promise<string>} A promise that resolves to the transcribed text.
* @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing.
*/
async sttRequest(provider, sttSchema, { audioBuffer, audioFile, language }) {
async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) {
const strategy = this.providerStrategies[provider];
if (!strategy) {
throw new Error('Invalid provider');
@@ -259,13 +243,7 @@ class STTService {
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = `audio.${fileExtension}`;
const [url, data, headers] = strategy.call(
this,
sttSchema,
audioReadStream,
audioFile,
language,
);
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
try {
const response = await axios.post(url, data, { headers });
@@ -306,8 +284,7 @@ class STTService {
try {
const [provider, sttSchema] = await this.getProviderSchema(req);
const language = req.body?.language || '';
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile, language });
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile });
res.json({ text });
} catch (error) {
logger.error('An error occurred while processing the audio:', error);

View File

@@ -14,18 +14,16 @@ const { getProvider } = require('./TTSService');
*/
async function getVoices(req, res) {
try {
const appConfig =
req.config ??
(await getAppConfig({
role: req.user?.role,
}));
const appConfig = await getAppConfig({
role: req.user?.role,
});
const ttsSchema = appConfig?.speech?.tts;
if (!ttsSchema) {
if (!appConfig || !appConfig?.speech?.tts) {
throw new Error('Configuration or TTS schema is missing');
}
const provider = await getProvider(appConfig);
const ttsSchema = appConfig?.speech?.tts;
const provider = await getProvider(ttsSchema);
let voices;
switch (provider) {

View File

@@ -17,7 +17,7 @@ const { Files } = require('~/models');
* @param {IUser} options.user - The user object
* @param {AppConfig} options.appConfig - The app configuration object
* @param {GraphRunnableConfig['configurable']} options.metadata - The metadata
* @param {{ [Tools.file_search]: { sources: Object[]; fileCitations: boolean } }} options.toolArtifact - The tool artifact containing structured data
* @param {any} options.toolArtifact - The tool artifact containing structured data
* @param {string} options.toolCallId - The tool call ID
* @returns {Promise<Object|null>} The file search attachment or null
*/
@@ -29,14 +29,12 @@ async function processFileCitations({ user, appConfig, toolArtifact, toolCallId,
if (user) {
try {
const hasFileCitationsAccess =
toolArtifact?.[Tools.file_search]?.fileCitations ??
(await checkAccess({
user,
permissionType: PermissionTypes.FILE_CITATIONS,
permissions: [Permissions.USE],
getRoleByName,
}));
const hasFileCitationsAccess = await checkAccess({
user,
permissionType: PermissionTypes.FILE_CITATIONS,
permissions: [Permissions.USE],
getRoleByName,
});
if (!hasFileCitationsAccess) {
logger.debug(

View File

@@ -10,10 +10,9 @@ const { getAgent } = require('~/models/Agent');
* @param {string} [params.role] - Optional user role to avoid DB query
* @param {string[]} params.fileIds - Array of file IDs to check
* @param {string} params.agentId - The agent ID that might grant access
* @param {boolean} [params.isDelete] - Whether the operation is a delete operation
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId }) => {
const accessMap = new Map();
// Initialize all files as no access
@@ -45,23 +44,22 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
return accessMap;
}
if (isDelete) {
// Check if user has EDIT permission (which would indicate collaborative access)
const hasEditPermission = await checkPermission({
userId,
role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
// Check if user has EDIT permission (which would indicate collaborative access)
const hasEditPermission = await checkPermission({
userId,
role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
// If user only has VIEW permission, they can't access files
// Only users with EDIT permission or higher can access agent files
if (!hasEditPermission) {
return accessMap;
}
// If user only has VIEW permission, they can't access files
// Only users with EDIT permission or higher can access agent files
if (!hasEditPermission) {
return accessMap;
}
// User has edit permissions - check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {

View File

@@ -552,7 +552,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
throw new Error('File search is not enabled for Agents');
}
// Note: File search processing continues to dual storage logic below
} else if (tool_resource === EToolResources.context) {
} else if (tool_resource === EToolResources.ocr) {
const { file_id, temp_file_id = null } = metadata;
/**
@@ -594,9 +594,10 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
const fileConfig = mergeFileConfig(appConfig.fileConfig);
const shouldUseOCR =
appConfig?.ocr != null &&
fileConfig.checkType(file.mimetype, fileConfig.ocr?.supportedMimeTypes || []);
const shouldUseOCR = fileConfig.checkType(
file.mimetype,
fileConfig.ocr?.supportedMimeTypes || [],
);
if (shouldUseOCR && !(await checkCapability(req, AgentCapabilities.ocr))) {
throw new Error('OCR capability is not enabled for Agents');
@@ -615,7 +616,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
if (shouldUseSTT) {
const sttService = await STTService.getInstance();
const { text, bytes } = await processAudioFile({ req, file, sttService });
const { text, bytes } = await processAudioFile({ file, sttService });
return await createTextFile({ text, bytes });
}
@@ -625,7 +626,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
);
if (!shouldUseText) {
throw new Error(`File type ${file.mimetype} is not supported for text parsing.`);
throw new Error(`File type ${file.mimetype} is not supported for OCR or text parsing`);
}
const { text, bytes } = await parseText({ req, file, file_id });
@@ -645,8 +646,8 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
req,
file,
file_id,
basePath,
entity_id,
basePath,
});
// SECOND: Upload to Vector DB
@@ -669,18 +670,17 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
req,
file,
file_id,
basePath,
entity_id,
basePath,
});
}
let { bytes, filename, filepath: _filepath, height, width } = storageResult;
const { bytes, filename, filepath: _filepath, height, width } = storageResult;
// For RAG files, use embedding result; for others, use storage result
let embedded = storageResult.embedded;
if (tool_resource === EToolResources.file_search) {
embedded = embeddingResult?.embedded;
filename = embeddingResult?.filename || filename;
}
const embedded =
tool_resource === EToolResources.file_search
? embeddingResult?.embedded
: storageResult.embedded;
let filepath = _filepath;
@@ -929,7 +929,6 @@ async function saveBase64Image(
url,
{ req, file_id: _file_id, filename: _filename, endpoint, context, resolution },
) {
const appConfig = req.config;
const effectiveResolution = resolution ?? appConfig.fileConfig?.imageGeneration ?? 'high';
const file_id = _file_id ?? v4();
let filename = `${file_id}-${_filename}`;
@@ -944,6 +943,7 @@ async function saveBase64Image(
}
const image = await resizeImageBuffer(inputBuffer, effectiveResolution, endpoint);
const appConfig = req.config;
const source = getFileStrategy(appConfig, { isImage: true });
const { saveBuffer } = getStrategyFunctions(source);
const filepath = await saveBuffer({

View File

@@ -20,10 +20,10 @@ const {
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, getAppConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache');
/**
@@ -152,8 +152,8 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
/**
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {IUser} params.user - The user from the request object.
* @param {string} params.serverName
* @param {AbortSignal} params.signal
* @param {string} params.model
@@ -161,9 +161,9 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap }) {
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${user.id}:${serverName}:${Date.now()}`;
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
@@ -192,7 +192,7 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
flowManager,
});
return await reinitMCPServer({
user,
req,
signal,
serverName,
oauthStart,
@@ -211,8 +211,8 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {IUser} params.user - The user from the request object.
* @param {string} params.serverName
* @param {string} params.model
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
@@ -221,8 +221,8 @@ async function reconnectServer({ res, user, index, signal, serverName, userMCPAu
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({ res, user, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ res, user, index, signal, serverName, userMCPAuthMap });
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
@@ -231,8 +231,8 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
const serverTools = [];
for (const tool of result.tools) {
const toolInstance = await createMCPTool({
req,
res,
user,
provider,
userMCPAuthMap,
availableTools: result.availableTools,
@@ -249,8 +249,8 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
/**
* Creates a single tool from the specified MCP Server via `toolKey`.
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {IUser} params.user - The user from the request object.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {string} params.model - The model for the tool.
* @param {number} [params.index]
@@ -261,31 +261,25 @@ async function createMCPTools({ res, user, index, signal, serverName, provider,
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({
req,
res,
user,
index,
signal,
toolKey,
provider,
userMCPAuthMap,
availableTools,
availableTools: tools,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const availableTools =
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
);
const result = await reconnectServer({
res,
user,
index,
signal,
serverName,
userMCPAuthMap,
});
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
toolDefinition = result?.availableTools?.[toolKey]?.function;
}
@@ -442,10 +436,10 @@ async function getMCPSetupData(userId) {
}
const mcpManager = getMCPManager(userId);
/** @type {Map<string, import('@librechat/api').MCPConnection>} */
/** @type {ReturnType<MCPManager['getAllConnections']>} */
let appConnections = new Map();
try {
appConnections = (await mcpManager.appConnections?.getAll()) || new Map();
appConnections = (await mcpManager.getAllConnections()) || new Map();
} catch (error) {
logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error);
}
@@ -543,20 +537,13 @@ async function getServerConnectionStatus(
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
// connection state overrides specific to OAuth servers
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
// check if server is actively being reconnected
const oauthReconnectionManager = getOAuthReconnectionManager();
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
finalConnectionState = 'connecting';
} else {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}

View File

@@ -1,45 +1,13 @@
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const {
createMCPTool,
createMCPTools,
getMCPSetupData,
checkOAuthFlowStatus,
getServerConnectionStatus,
} = require('./MCP');
const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP');
// Mock all dependencies
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
error: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
},
}));
jest.mock('@langchain/core/tools', () => ({
tool: jest.fn((fn, config) => {
const toolInstance = { _call: fn, ...config };
return toolInstance;
}),
}));
jest.mock('@librechat/agents', () => ({
Providers: {
VERTEXAI: 'vertexai',
GOOGLE: 'google',
},
StepTypes: {
TOOL_CALLS: 'tool_calls',
},
GraphEvents: {
ON_RUN_STEP_DELTA: 'on_run_step_delta',
ON_RUN_STEP: 'on_run_step',
},
Constants: {
CONTENT_AND_ARTIFACT: 'content_and_artifact',
},
}));
@@ -47,27 +15,12 @@ jest.mock('@librechat/api', () => ({
MCPOAuthHandler: {
generateFlowId: jest.fn(),
},
sendEvent: jest.fn(),
normalizeServerName: jest.fn((name) => name),
convertWithResolvedRefs: jest.fn((params) => params),
}));
jest.mock('librechat-data-provider', () => ({
CacheKeys: {
FLOWS: 'flows',
},
Constants: {
USE_PRELIM_RESPONSE_MESSAGE_ID: 'prelim_response_id',
mcp_delimiter: '::',
mcp_prefix: 'mcp_',
},
ContentTypes: {
TEXT: 'text',
},
isAssistantsEndpoint: jest.fn(() => false),
Time: {
TWO_MINUTES: 120000,
},
}));
jest.mock('./Config', () => ({
@@ -78,7 +31,6 @@ jest.mock('./Config', () => ({
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getOAuthReconnectionManager: jest.fn(),
}));
jest.mock('~/cache', () => ({
@@ -91,23 +43,19 @@ jest.mock('~/models', () => ({
updateToken: jest.fn(),
}));
jest.mock('./Tools/mcp', () => ({
reinitMCPServer: jest.fn(),
}));
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
let mockLoadCustomConfig;
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
let mockGetOAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
mockLoadCustomConfig = require('./Config').loadCustomConfig;
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
});
describe('getMCPSetupData', () => {
@@ -123,7 +71,7 @@ describe('tests for the new helper functions used by the MCP connection status e
beforeEach(() => {
mockGetAppConfig = require('./Config').getAppConfig;
mockGetMCPManager.mockReturnValue({
appConnections: { getAll: jest.fn(() => new Map()) },
getAllConnections: jest.fn(() => new Map()),
getUserConnections: jest.fn(() => new Map()),
getOAuthServers: jest.fn(() => new Set()),
});
@@ -137,7 +85,7 @@ describe('tests for the new helper functions used by the MCP connection status e
const mockOAuthServers = new Set(['server2']);
const mockMCPManager = {
appConnections: { getAll: jest.fn(() => mockAppConnections) },
getAllConnections: jest.fn(() => mockAppConnections),
getUserConnections: jest.fn(() => mockUserConnections),
getOAuthServers: jest.fn(() => mockOAuthServers),
};
@@ -147,7 +95,7 @@ describe('tests for the new helper functions used by the MCP connection status e
expect(mockGetAppConfig).toHaveBeenCalled();
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled();
expect(mockMCPManager.getAllConnections).toHaveBeenCalled();
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
@@ -168,7 +116,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetAppConfig.mockResolvedValue({ mcpConfig: mockConfig.mcpServers });
const mockMCPManager = {
appConnections: { getAll: jest.fn(() => null) },
getAllConnections: jest.fn(() => null),
getUserConnections: jest.fn(() => null),
getOAuthServers: jest.fn(() => null),
};
@@ -406,12 +354,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
@@ -428,12 +370,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return failed flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -465,12 +401,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return active flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -502,12 +432,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return no flow
const mockFlowManager = {
getFlowState: jest.fn(() => null),
@@ -530,35 +454,6 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
it('should return connecting state when OAuth server is reconnecting', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager to return true for isReconnecting
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => true),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connecting',
});
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
mockUserId,
mockServerName,
);
});
it('should not check OAuth flow status when server is connected', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),
@@ -616,275 +511,3 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
});
describe('User parameter passing tests', () => {
let mockReinitMCPServer;
let mockGetFlowStateManager;
let mockGetLogStores;
beforeEach(() => {
jest.clearAllMocks();
mockReinitMCPServer = require('./Tools/mcp').reinitMCPServer;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
// Setup default mocks
mockGetLogStores.mockReturnValue({});
mockGetFlowStateManager.mockReturnValue({
createFlowWithHandler: jest.fn(),
failFlow: jest.fn(),
});
});
describe('createMCPTools', () => {
it('should pass user parameter to reinitMCPServer when calling reconnectServer internally', async () => {
const mockUser = { id: 'test-user-123', name: 'Test User' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
const mockSignal = new AbortController().signal;
mockReinitMCPServer.mockResolvedValue({
tools: [{ name: 'test-tool' }],
availableTools: {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
},
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'test-server',
provider: 'openai',
signal: mockSignal,
userMCPAuthMap: {},
});
// Verify reinitMCPServer was called with the user
expect(mockReinitMCPServer).toHaveBeenCalledWith(
expect.objectContaining({
user: mockUser,
serverName: 'test-server',
}),
);
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
});
it('should throw error if user is not provided', async () => {
const mockRes = { write: jest.fn(), flush: jest.fn() };
mockReinitMCPServer.mockResolvedValue({
tools: [],
availableTools: {},
});
// Call without user should throw error
await expect(
createMCPTools({
res: mockRes,
user: undefined,
serverName: 'test-server',
provider: 'openai',
userMCPAuthMap: {},
}),
).rejects.toThrow("Cannot read properties of undefined (reading 'id')");
// Verify reinitMCPServer was not called due to early error
expect(mockReinitMCPServer).not.toHaveBeenCalled();
});
});
describe('createMCPTool', () => {
it('should pass user parameter to reinitMCPServer when tool not in cache', async () => {
const mockUser = { id: 'test-user-456', email: 'test@example.com' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
const mockSignal = new AbortController().signal;
mockReinitMCPServer.mockResolvedValue({
availableTools: {
'test-tool::test-server': {
function: {
description: 'Test tool',
parameters: { type: 'object', properties: {} },
},
},
},
});
// Call without availableTools to trigger reinit
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
signal: mockSignal,
userMCPAuthMap: {},
availableTools: undefined, // Force reinit
});
// Verify reinitMCPServer was called with the user
expect(mockReinitMCPServer).toHaveBeenCalledWith(
expect.objectContaining({
user: mockUser,
serverName: 'test-server',
}),
);
expect(mockReinitMCPServer.mock.calls[0][0].user).toBe(mockUser);
});
it('should not call reinitMCPServer when tool is in cache', async () => {
const mockUser = { id: 'test-user-789' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
const availableTools = {
'test-tool::test-server': {
function: {
description: 'Cached tool',
parameters: { type: 'object', properties: {} },
},
},
};
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'test-tool::test-server',
provider: 'openai',
userMCPAuthMap: {},
availableTools: availableTools,
});
// Verify reinitMCPServer was NOT called since tool was in cache
expect(mockReinitMCPServer).not.toHaveBeenCalled();
});
});
describe('reinitMCPServer (via reconnectServer)', () => {
it('should always receive user parameter when called from createMCPTools', async () => {
const mockUser = { id: 'user-001', role: 'admin' };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Track all calls to reinitMCPServer
const reinitCalls = [];
mockReinitMCPServer.mockImplementation((params) => {
reinitCalls.push(params);
return Promise.resolve({
tools: [{ name: 'tool1' }, { name: 'tool2' }],
availableTools: {
'tool1::server1': { function: { description: 'Tool 1', parameters: {} } },
'tool2::server1': { function: { description: 'Tool 2', parameters: {} } },
},
});
});
await createMCPTools({
res: mockRes,
user: mockUser,
serverName: 'server1',
provider: 'anthropic',
userMCPAuthMap: {},
});
// Verify all calls to reinitMCPServer had the user
expect(reinitCalls.length).toBeGreaterThan(0);
reinitCalls.forEach((call) => {
expect(call.user).toBe(mockUser);
expect(call.user.id).toBe('user-001');
});
});
it('should always receive user parameter when called from createMCPTool', async () => {
const mockUser = { id: 'user-002', permissions: ['read', 'write'] };
const mockRes = { write: jest.fn(), flush: jest.fn() };
// Track all calls to reinitMCPServer
const reinitCalls = [];
mockReinitMCPServer.mockImplementation((params) => {
reinitCalls.push(params);
return Promise.resolve({
availableTools: {
'my-tool::my-server': {
function: { description: 'My Tool', parameters: {} },
},
},
});
});
await createMCPTool({
res: mockRes,
user: mockUser,
toolKey: 'my-tool::my-server',
provider: 'google',
userMCPAuthMap: {},
availableTools: undefined, // Force reinit
});
// Verify the call to reinitMCPServer had the user
expect(reinitCalls.length).toBe(1);
expect(reinitCalls[0].user).toBe(mockUser);
expect(reinitCalls[0].user.id).toBe('user-002');
});
});
describe('User parameter integrity', () => {
it('should preserve user object properties through the call chain', async () => {
const complexUser = {
id: 'complex-user',
name: 'John Doe',
email: 'john@example.com',
metadata: { subscription: 'premium', settings: { theme: 'dark' } },
};
const mockRes = { write: jest.fn(), flush: jest.fn() };
let capturedUser = null;
mockReinitMCPServer.mockImplementation((params) => {
capturedUser = params.user;
return Promise.resolve({
tools: [{ name: 'test' }],
availableTools: {
'test::server': { function: { description: 'Test', parameters: {} } },
},
});
});
await createMCPTools({
res: mockRes,
user: complexUser,
serverName: 'server',
provider: 'openai',
userMCPAuthMap: {},
});
// Verify the complete user object was passed
expect(capturedUser).toEqual(complexUser);
expect(capturedUser.id).toBe('complex-user');
expect(capturedUser.metadata.subscription).toBe('premium');
expect(capturedUser.metadata.settings.theme).toBe('dark');
});
it('should throw error when user is null', async () => {
const mockRes = { write: jest.fn(), flush: jest.fn() };
mockReinitMCPServer.mockResolvedValue({
tools: [],
availableTools: {},
});
await expect(
createMCPTools({
res: mockRes,
user: null,
serverName: 'test-server',
provider: 'openai',
userMCPAuthMap: {},
}),
).rejects.toThrow("Cannot read properties of null (reading 'id')");
// Verify reinitMCPServer was not called due to early error
expect(mockReinitMCPServer).not.toHaveBeenCalled();
});
});
});

View File

@@ -1,13 +1,13 @@
const axios = require('axios');
const { Providers } = require('@librechat/agents');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { logAxiosError, inputSchema, processModelData } = require('@librechat/api');
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
const { inputSchema, extractBaseURL, processModelData } = require('~/utils');
const { OllamaClient } = require('~/app/clients/OllamaClient');
const { isUserProvided } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { extractBaseURL } = require('~/utils');
/**
* Splits a string by commas and trims each resulting value.

View File

@@ -11,8 +11,8 @@ const {
getAnthropicModels,
} = require('./ModelService');
jest.mock('@librechat/api', () => {
const originalUtils = jest.requireActual('@librechat/api');
jest.mock('~/utils', () => {
const originalUtils = jest.requireActual('~/utils');
return {
...originalUtils,
processModelData: jest.fn((...args) => {
@@ -108,7 +108,7 @@ describe('fetchModels with createTokenConfig true', () => {
beforeEach(() => {
// Clears the mock's history before each test
const _utils = require('@librechat/api');
const _utils = require('~/utils');
axios.get.mockResolvedValue({ data });
});
@@ -120,7 +120,7 @@ describe('fetchModels with createTokenConfig true', () => {
createTokenConfig: true,
});
const { processModelData } = require('@librechat/api');
const { processModelData } = require('~/utils');
expect(processModelData).toHaveBeenCalled();
expect(processModelData).toHaveBeenCalledWith(data);
});

View File

@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
idOnTheSource: principal.idOnTheSource,
};
const userId = await createUser(userData, true, true);
const userId = await createUser(userData, true, false);
return userId.toString();
}

View File

@@ -1,12 +1,7 @@
const { sleep } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { tool: toolFn, DynamicStructuredTool } = require('@langchain/core/tools');
const {
getToolkitKey,
hasCustomUserVars,
getUserMCPAuthMap,
isActionDomainAllowed,
} = require('@librechat/api');
const { getToolkitKey, hasCustomUserVars, getUserMCPAuthMap } = require('@librechat/api');
const {
Tools,
Constants,
@@ -31,6 +26,7 @@ const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/p
const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config');
const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest');
const { createOnSearchResults } = require('~/server/services/Tools/search');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers');
@@ -78,7 +74,7 @@ async function processRequiredActions(client, requiredActions) {
requiredActions,
);
const appConfig = client.req.config;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ userId: client.req.user.id, includeGlobal: true });
const seenToolkits = new Set();
const tools = requiredActions
.map((action) => {
@@ -357,12 +353,7 @@ async function processRequiredActions(client, requiredActions) {
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) {
return {};
} else if (
agent.tools &&
agent.tools.length === 1 &&
/** Legacy handling for `ocr` as may still exist in existing Agents */
(agent.tools[0] === AgentCapabilities.context || agent.tools[0] === AgentCapabilities.ocr)
) {
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
return {};
}

View File

@@ -2,12 +2,12 @@ const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { updateMCPServerTools } = require('~/server/services/Config');
const { updateMCPUserTools } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
/**
* @param {Object} params
* @param {IUser} params.user - The user from the request object.
* @param {ServerRequest} params.req
* @param {string} params.serverName - The name of the MCP server
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
@@ -18,7 +18,7 @@ const { getLogStores } = require('~/cache');
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
user,
req,
signal,
forceNew,
serverName,
@@ -29,7 +29,7 @@ async function reinitMCPServer({
flowManager: _flowManager,
}) {
/** @type {MCPConnection | null} */
let connection = null;
let userConnection = null;
/** @type {LCAvailableTools | null} */
let availableTools = null;
/** @type {ReturnType<MCPConnection['fetchTools']> | null} */
@@ -44,14 +44,14 @@ async function reinitMCPServer({
const oauthStart =
_oauthStart ??
(async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received for ${serverName}`);
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
});
try {
connection = await mcpManager.getConnection({
user,
userConnection = await mcpManager.getUserConnection({
user: req.user,
signal,
forceNew,
oauthStart,
@@ -70,7 +70,7 @@ async function reinitMCPServer({
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
} catch (err) {
logger.info(`[MCP Reinitialize] getConnection threw error: ${err.message}`);
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
logger.info(
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
);
@@ -95,9 +95,10 @@ async function reinitMCPServer({
}
}
if (connection && !oauthRequired) {
tools = await connection.fetchTools();
availableTools = await updateMCPServerTools({
if (userConnection && !oauthRequired) {
tools = await userConnection.fetchTools();
availableTools = await updateMCPUserTools({
userId: req.user.id,
serverName,
tools,
});
@@ -111,7 +112,7 @@ async function reinitMCPServer({
if (oauthRequired) {
return `MCP server '${serverName}' ready for OAuth authentication`;
}
if (connection) {
if (userConnection) {
return `MCP server '${serverName}' reinitialized successfully`;
}
return `Failed to reinitialize MCP server '${serverName}'`;
@@ -119,7 +120,7 @@ async function reinitMCPServer({
const result = {
availableTools,
success: Boolean((connection && !oauthRequired) || (oauthRequired && oauthUrl)),
success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)),
message: getResponseMessage(),
oauthRequired,
serverName,

View File

@@ -1,13 +1,14 @@
/**
* @param email
* @param allowedDomains
* @param {string} email
* @param {string[]} [allowedDomains]
* @returns {boolean}
*/
export function isEmailDomainAllowed(email: string, allowedDomains?: string[] | null): boolean {
function isEmailDomainAllowed(email, allowedDomains) {
if (!email) {
return false;
}
const domain = email.split('@')[1]?.toLowerCase();
const domain = email.split('@')[1];
if (!domain) {
return false;
@@ -19,15 +20,21 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
return true;
}
return allowedDomains.some((allowedDomain) => allowedDomain?.toLowerCase() === domain);
return allowedDomains.includes(domain);
}
/**
* Normalizes a domain string
* @param {string} domain
* @returns {string|null}
*/
/**
* Normalizes a domain string. If the domain is invalid, returns null.
* Normalized === lowercase, trimmed, and protocol added if missing.
* @param domain
* @param {string} domain
* @returns {string|null}
*/
function normalizeDomain(domain: string): string | null {
function normalizeDomain(domain) {
try {
let normalizedDomain = domain.toLowerCase().trim();
@@ -55,13 +62,11 @@ function normalizeDomain(domain: string): string | null {
/**
* Checks if the given domain is allowed. If no restrictions are set, allows all domains.
* @param domain
* @param allowedDomains
* @param {string} [domain]
* @param {string[]} [allowedDomains]
* @returns {Promise<boolean>}
*/
export async function isActionDomainAllowed(
domain?: string | null,
allowedDomains?: string[] | null,
): Promise<boolean> {
async function isActionDomainAllowed(domain, allowedDomains) {
if (!domain || typeof domain !== 'string') {
return false;
}
@@ -96,3 +101,5 @@ export async function isActionDomainAllowed(
return false;
}
module.exports = { isEmailDomainAllowed, isActionDomainAllowed };

View File

@@ -1,5 +1,9 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { isEmailDomainAllowed, isActionDomainAllowed } from './domain';
const { isEmailDomainAllowed, isActionDomainAllowed } = require('~/server/services/domains');
const { getAppConfig } = require('~/server/services/Config');
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn(),
}));
describe('isEmailDomainAllowed', () => {
afterEach(() => {
@@ -20,72 +24,39 @@ describe('isEmailDomainAllowed', () => {
it('should return true if customConfig is not available', async () => {
const email = 'test@domain1.com';
getAppConfig.mockResolvedValue(null);
const result = isEmailDomainAllowed(email, null);
expect(result).toBe(true);
});
it('should return true if allowedDomains is not defined in customConfig', async () => {
const email = 'test@domain1.com';
getAppConfig.mockResolvedValue({});
const result = isEmailDomainAllowed(email, undefined);
expect(result).toBe(true);
});
it('should return true if domain is included in the allowedDomains', async () => {
const email = 'user@domain1.com';
getAppConfig.mockResolvedValue({
registration: {
allowedDomains: ['domain1.com', 'domain2.com'],
},
});
const result = isEmailDomainAllowed(email, ['domain1.com', 'domain2.com']);
expect(result).toBe(true);
});
it('should return false if domain is not included in the allowedDomains', async () => {
const email = 'user@domain3.com';
getAppConfig.mockResolvedValue({
registration: {
allowedDomains: ['domain1.com', 'domain2.com'],
},
});
const result = isEmailDomainAllowed(email, ['domain1.com', 'domain2.com']);
expect(result).toBe(false);
});
describe('case-insensitive domain matching', () => {
it('should match domains case-insensitively when email has uppercase domain', () => {
const email = 'user@DOMAIN1.COM';
const result = isEmailDomainAllowed(email, ['domain1.com', 'domain2.com']);
expect(result).toBe(true);
});
it('should match domains case-insensitively when allowedDomains has uppercase', () => {
const email = 'user@domain1.com';
const result = isEmailDomainAllowed(email, ['DOMAIN1.COM', 'DOMAIN2.COM']);
expect(result).toBe(true);
});
it('should match domains with mixed case in email', () => {
const email = 'user@Example.Com';
const result = isEmailDomainAllowed(email, ['example.com', 'domain2.com']);
expect(result).toBe(true);
});
it('should match domains with mixed case in allowedDomains', () => {
const email = 'user@example.com';
const result = isEmailDomainAllowed(email, ['Example.Com', 'Domain2.Com']);
expect(result).toBe(true);
});
it('should match when both email and allowedDomains have different mixed cases', () => {
const email = 'user@ExAmPlE.cOm';
const result = isEmailDomainAllowed(email, ['eXaMpLe.CoM', 'domain2.com']);
expect(result).toBe(true);
});
it('should still return false for non-matching domains regardless of case', () => {
const email = 'user@DOMAIN3.COM';
const result = isEmailDomainAllowed(email, ['domain1.com', 'DOMAIN2.COM']);
expect(result).toBe(false);
});
it('should handle null/undefined entries in allowedDomains gracefully', () => {
const email = 'user@domain1.com';
// @ts-expect-error Testing invalid input
const result = isEmailDomainAllowed(email, [null, 'DOMAIN1.COM', undefined]);
expect(result).toBe(true);
});
});
});
describe('isActionDomainAllowed', () => {
@@ -103,15 +74,15 @@ describe('isActionDomainAllowed', () => {
});
it('should return false for non-string inputs', async () => {
/** @ts-expect-error */
expect(await isActionDomainAllowed(123)).toBe(false);
/** @ts-expect-error */
expect(await isActionDomainAllowed({})).toBe(false);
/** @ts-expect-error */
expect(await isActionDomainAllowed([])).toBe(false);
});
it('should return false for invalid domain formats', async () => {
getAppConfig.mockResolvedValue({
actions: { allowedDomains: ['http://', 'https://'] },
});
expect(await isActionDomainAllowed('http://', ['http://', 'https://'])).toBe(false);
expect(await isActionDomainAllowed('https://', ['http://', 'https://'])).toBe(false);
});
@@ -120,14 +91,19 @@ describe('isActionDomainAllowed', () => {
// Configuration Tests
describe('configuration handling', () => {
it('should return true if customConfig is null', async () => {
getAppConfig.mockResolvedValue(null);
expect(await isActionDomainAllowed('example.com', null)).toBe(true);
});
it('should return true if actions.allowedDomains is not defined', async () => {
getAppConfig.mockResolvedValue({});
expect(await isActionDomainAllowed('example.com', undefined)).toBe(true);
});
it('should return true if allowedDomains is empty array', async () => {
getAppConfig.mockResolvedValue({
actions: { allowedDomains: [] },
});
expect(await isActionDomainAllowed('example.com', [])).toBe(true);
});
});
@@ -142,6 +118,14 @@ describe('isActionDomainAllowed', () => {
'swapi.dev',
];
beforeEach(() => {
getAppConfig.mockResolvedValue({
actions: {
allowedDomains,
},
});
});
it('should match exact domains', async () => {
expect(await isActionDomainAllowed('example.com', allowedDomains)).toBe(true);
expect(await isActionDomainAllowed('other.com', allowedDomains)).toBe(false);
@@ -175,6 +159,14 @@ describe('isActionDomainAllowed', () => {
describe('edge cases', () => {
const edgeAllowedDomains = ['example.com', '*.test.com'];
beforeEach(() => {
getAppConfig.mockResolvedValue({
actions: {
allowedDomains: edgeAllowedDomains,
},
});
});
it('should handle domains with query parameters', async () => {
expect(await isActionDomainAllowed('example.com?param=value', edgeAllowedDomains)).toBe(true);
});
@@ -194,9 +186,12 @@ describe('isActionDomainAllowed', () => {
it('should handle invalid entries in allowedDomains', async () => {
const invalidAllowedDomains = ['example.com', null, undefined, '', 'test.com'];
/** @ts-expect-error */
getAppConfig.mockResolvedValue({
actions: {
allowedDomains: invalidAllowedDomains,
},
});
expect(await isActionDomainAllowed('example.com', invalidAllowedDomains)).toBe(true);
/** @ts-expect-error */
expect(await isActionDomainAllowed('test.com', invalidAllowedDomains)).toBe(true);
});
});

View File

@@ -1,26 +0,0 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getLogStores } = require('~/cache');
/**
* Initialize OAuth reconnect manager
*/
async function initializeOAuthReconnectManager() {
try {
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const tokenMethods = {
findToken,
updateToken,
createToken,
deleteTokens,
};
await createOAuthReconnectionManager(flowManager, tokenMethods);
logger.info(`OAuth reconnect manager initialized successfully.`);
} catch (error) {
logger.error('Failed to initialize OAuth reconnect manager:', error);
}
}
module.exports = initializeOAuthReconnectManager;

View File

@@ -229,7 +229,7 @@
>
<!--[if mso]><style>.v-button {background: transparent !important;}</style><![endif]-->
<div align='left'>
<!--[if mso]><v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="{{verificationLink}}" style="height:37px; v-text-anchor:middle; width:114px;" arcsize="11%" stroke="f" fillcolor="#10a37f"><w:anchorlock/><center style="color:#FFFFFF;"><![endif]-->
<!--[if mso]><v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="href=&quot;{{verificationLink}}&quot;" style="height:37px; v-text-anchor:middle; width:114px;" arcsize="11%" stroke="f" fillcolor="#10a37f"><w:anchorlock/><center style="color:#FFFFFF;"><![endif]-->
<a
href='{{verificationLink}}'
target='_blank'

View File

@@ -10,10 +10,6 @@ jest.mock('~/models/Message', () => ({
bulkSaveMessages: jest.fn(),
}));
jest.mock('~/models/ConversationTag', () => ({
bulkIncrementTagCounts: jest.fn(),
}));
let mockIdCounter = 0;
jest.mock('uuid', () => {
return {
@@ -26,13 +22,11 @@ jest.mock('uuid', () => {
const {
forkConversation,
duplicateConversation,
splitAtTargetLevel,
getAllMessagesUpToParent,
getMessagesUpToTargetLevel,
cloneMessagesWithTimestamps,
} = require('./fork');
const { bulkIncrementTagCounts } = require('~/models/ConversationTag');
const { getConvo, bulkSaveConvos } = require('~/models/Conversation');
const { getMessages, bulkSaveMessages } = require('~/models/Message');
const { createImportBatchBuilder } = require('./importBatchBuilder');
@@ -187,120 +181,6 @@ describe('forkConversation', () => {
}),
).rejects.toThrow('Failed to fetch messages');
});
test('should increment tag counts when forking conversation with tags', async () => {
const mockConvoWithTags = {
...mockConversation,
tags: ['bookmark1', 'bookmark2'],
};
getConvo.mockResolvedValue(mockConvoWithTags);
await forkConversation({
originalConvoId: 'abc123',
targetMessageId: '3',
requestUserId: 'user1',
option: ForkOptions.DIRECT_PATH,
});
// Verify that bulkIncrementTagCounts was called with correct tags
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['bookmark1', 'bookmark2']);
});
test('should handle conversation without tags when forking', async () => {
const mockConvoWithoutTags = {
...mockConversation,
// No tags field
};
getConvo.mockResolvedValue(mockConvoWithoutTags);
await forkConversation({
originalConvoId: 'abc123',
targetMessageId: '3',
requestUserId: 'user1',
option: ForkOptions.DIRECT_PATH,
});
// bulkIncrementTagCounts will be called with array containing undefined
expect(bulkIncrementTagCounts).toHaveBeenCalled();
});
test('should handle empty tags array when forking', async () => {
const mockConvoWithEmptyTags = {
...mockConversation,
tags: [],
};
getConvo.mockResolvedValue(mockConvoWithEmptyTags);
await forkConversation({
originalConvoId: 'abc123',
targetMessageId: '3',
requestUserId: 'user1',
option: ForkOptions.DIRECT_PATH,
});
// bulkIncrementTagCounts will be called with empty array
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []);
});
});
describe('duplicateConversation', () => {
beforeEach(() => {
jest.clearAllMocks();
mockIdCounter = 0;
getConvo.mockResolvedValue(mockConversation);
getMessages.mockResolvedValue(mockMessages);
bulkSaveConvos.mockResolvedValue(null);
bulkSaveMessages.mockResolvedValue(null);
bulkIncrementTagCounts.mockResolvedValue(null);
});
test('should duplicate conversation and increment tag counts', async () => {
const mockConvoWithTags = {
...mockConversation,
tags: ['important', 'work', 'project'],
};
getConvo.mockResolvedValue(mockConvoWithTags);
await duplicateConversation({
userId: 'user1',
conversationId: 'abc123',
});
// Verify that bulkIncrementTagCounts was called with correct tags
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['important', 'work', 'project']);
});
test('should duplicate conversation without tags', async () => {
const mockConvoWithoutTags = {
...mockConversation,
// No tags field
};
getConvo.mockResolvedValue(mockConvoWithoutTags);
await duplicateConversation({
userId: 'user1',
conversationId: 'abc123',
});
// bulkIncrementTagCounts will be called with array containing undefined
expect(bulkIncrementTagCounts).toHaveBeenCalled();
});
test('should handle empty tags array when duplicating', async () => {
const mockConvoWithEmptyTags = {
...mockConversation,
tags: [],
};
getConvo.mockResolvedValue(mockConvoWithEmptyTags);
await duplicateConversation({
userId: 'user1',
conversationId: 'abc123',
});
// bulkIncrementTagCounts will be called with empty array
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []);
});
});
const mockMessagesComplex = [

View File

@@ -1,6 +1,5 @@
const { v4: uuidv4 } = require('uuid');
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
const { bulkIncrementTagCounts } = require('~/models/ConversationTag');
const { bulkSaveConvos } = require('~/models/Conversation');
const { bulkSaveMessages } = require('~/models/Message');
const { logger } = require('~/config');
@@ -94,22 +93,13 @@ class ImportBatchBuilder {
/**
* Saves the batch of conversations and messages to the DB.
* Also increments tag counts for any existing tags.
* @returns {Promise<void>} A promise that resolves when the batch is saved.
* @throws {Error} If there is an error saving the batch.
*/
async saveBatch() {
try {
const promises = [];
promises.push(bulkSaveConvos(this.conversations));
promises.push(bulkSaveMessages(this.messages, true));
promises.push(
bulkIncrementTagCounts(
this.requestUserId,
this.conversations.flatMap((convo) => convo.tags),
),
);
await Promise.all(promises);
await bulkSaveConvos(this.conversations);
await bulkSaveMessages(this.messages, true);
logger.debug(
`user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`,
);

View File

@@ -1,6 +1,6 @@
const fs = require('fs').promises;
const { logger } = require('@librechat/data-schemas');
const { getImporter } = require('./importers');
const { logger } = require('~/config');
/**
* Job definition for importing a conversation.

View File

@@ -1,9 +1,9 @@
const { v4: uuidv4 } = require('uuid');
const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
const { createImportBatchBuilder } = require('./importBatchBuilder');
const { cloneMessagesWithTimestamps } = require('./fork');
const getLogStores = require('~/cache/getLogStores');
const logger = require('~/config/winston');
/**
* Returns the appropriate importer function based on the provided JSON data.
@@ -212,29 +212,6 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
}
}
/**
* Helper function to find the nearest non-system parent
* @param {string} parentId - The ID of the parent message.
* @returns {string} The ID of the nearest non-system parent message.
*/
const findNonSystemParent = (parentId) => {
if (!parentId || !messageMap.has(parentId)) {
return Constants.NO_PARENT;
}
const parentMapping = conv.mapping[parentId];
if (!parentMapping?.message) {
return Constants.NO_PARENT;
}
/* If parent is a system message, traverse up to find the nearest non-system parent */
if (parentMapping.message.author?.role === 'system') {
return findNonSystemParent(parentMapping.parent);
}
return messageMap.get(parentId);
};
// Create and save messages using the mapped IDs
const messages = [];
for (const [id, mapping] of Object.entries(conv.mapping)) {
@@ -243,27 +220,23 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
messageMap.delete(id);
continue;
} else if (role === 'system') {
// Skip system messages but keep their ID in messageMap for parent references
messageMap.delete(id);
continue;
}
const newMessageId = messageMap.get(id);
const parentMessageId = findNonSystemParent(mapping.parent);
const parentMessageId =
mapping.parent && messageMap.has(mapping.parent)
? messageMap.get(mapping.parent)
: Constants.NO_PARENT;
const messageText = formatMessageText(mapping.message);
const isCreatedByUser = role === 'user';
let sender = isCreatedByUser ? 'user' : 'assistant';
let sender = isCreatedByUser ? 'user' : 'GPT-3.5';
const model = mapping.message.metadata.model_slug || openAISettings.model.default;
if (!isCreatedByUser) {
/** Extracted model name from model slug */
const gptMatch = model.match(/gpt-(.+)/i);
if (gptMatch) {
sender = `GPT-${gptMatch[1]}`;
} else {
sender = model || 'assistant';
}
if (model.includes('gpt-4')) {
sender = 'GPT-4';
}
messages.push({

Some files were not shown because too many files have changed in this diff Show More