Compare commits

..

20 Commits

Author SHA1 Message Date
Dustin Healy
b034624690 refactor: cost bar uses new cost table 2025-09-13 13:30:00 -07:00
Dustin Healy
adff605c50 feat: add target model to messages so we dont have to peek into child ai message responses every time to know what model the user sent their message to for tokenValue computations 2025-09-13 13:02:04 -07:00
Dustin Healy
465c81adee feat: add addToSet update operation to saveConvo to efficiently update modelHistory whenever a new model/endpoint combo is used 2025-09-13 12:59:33 -07:00
Dustin Healy
cb8e76e27e feat: add modelHistory to conversations
modelHistory keeps track of all the models and their respective endpoints that are used in a conversation, so we can quickly fetch the right values to populate our cost table whenever we load a conversation rather than having to recompute the list every time
2025-09-13 12:49:44 -07:00
Dustin Healy
4d9e17efe1 feat: add modelCosts to QueryKeys 2025-09-13 12:46:02 -07:00
Dustin Healy
95ebef13df refactor: swap conversationCosts query hooks out for new modelCosts query and move to react-query-service 2025-09-13 12:42:43 -07:00
Dustin Healy
4fb9d7bdff refactor: swap data-service and types to use new cost table 2025-09-13 12:26:39 -07:00
Dustin Healy
0edfecf44a refactor: swap old ConversationCosts endpoint out for much simpler 'costs' endpoint
/costs just recieves an array of { model, endpoint }  and returns the related prompt and completion token rates, leaving the actual message processing to be done later on rather than on the backend.
2025-09-13 12:19:26 -07:00
Dustin Healy
ba8c09b361 feat: stop costbox from displaying during responses and make scroll to bottom ux more consistent 2025-09-13 01:51:17 -07:00
Dustin Healy
794fe6fd11 test: add unit tests for costs endpoint with various scenarios 2025-09-13 01:51:17 -07:00
Dustin Healy
97ac52fc6c feat: add toggle for cost tracking 2025-09-13 01:51:17 -07:00
Dustin Healy
1a947607a5 feat: add localization strings for the token abbreviation 't' 2025-09-13 01:51:17 -07:00
Dustin Healy
1745708418 refactor: move arrow svgs out to their own component 2025-09-13 01:51:17 -07:00
Dustin Healy
14aedac1e1 refactor: move CostBar into its own component 2025-09-13 01:51:17 -07:00
Dustin Healy
a820d79bfc feat: pull up the cost hooks into chatView to reduce api calls 2025-09-13 01:51:17 -07:00
Dustin Healy
3b1c07ff46 feat: add frontend component for displaying convo total, per message cost and tokens, and hide total block on scroll 2025-09-13 01:51:17 -07:00
Dustin Healy
c1b0f13360 feat: add queryhook layer and invalidation so we keep tokens fresh during conversation turns 2025-09-13 01:51:17 -07:00
Dustin Healy
637bbd2e29 refactor: move endpoint up so it doesn't get eclipsed by get route for conversationId 2025-09-13 01:51:17 -07:00
Dustin Healy
30e1b421ba feat: add data-service layer 2025-09-13 01:51:17 -07:00
Dustin Healy
fb89f60470 feat: add transaction costs endpoint 2025-09-13 01:51:17 -07:00
241 changed files with 3890 additions and 9468 deletions

View File

@@ -233,6 +233,7 @@ class BaseClient {
sender: 'User',
text,
isCreatedByUser: true,
targetModel: this.modelOptions?.model ?? this.model,
};
}

View File

@@ -68,19 +68,19 @@ const primeFiles = async (options) => {
/**
*
* @param {Object} options
* @param {string} options.userId
* @param {ServerRequest} options.req
* @param {Array<{ file_id: string; filename: string }>} options.files
* @param {string} [options.entity_id]
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
* @returns
*/
const createFileSearchTool = async ({ userId, files, entity_id, fileCitations = false }) => {
const createFileSearchTool = async ({ req, files, entity_id, fileCitations = false }) => {
return tool(
async ({ query }) => {
if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.';
}
const jwtToken = generateShortLivedToken(userId);
const jwtToken = generateShortLivedToken(req.user.id);
if (!jwtToken) {
return 'There was an error authenticating the file search request.';
}

View File

@@ -33,7 +33,7 @@ const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSe
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getMCPServerTools } = require('~/server/services/Config');
const { getCachedTools } = require('~/server/services/Config');
const { getRoleByName } = require('~/models/Role');
/**
@@ -250,6 +250,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */
const toolContextMap = {};
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
const requestedMCPTools = {};
for (const tool of tools) {
@@ -306,7 +307,7 @@ const loadTools = async ({
}
return createFileSearchTool({
userId: user,
req: options.req,
files,
entity_id: agent?.id,
fileCitations,
@@ -339,7 +340,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
});
};
continue;
} else if (tool && mcpToolPattern.test(tool)) {
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
if (toolName === Constants.mcp_server) {
/** Placeholder used for UI purposes */
@@ -352,21 +353,33 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
continue;
}
if (toolName === Constants.mcp_all) {
requestedMCPTools[serverName] = [
{
type: 'all',
const currentMCPGenerator = async (index) =>
createMCPTools({
req: options.req,
res: options.res,
index,
serverName,
},
];
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
});
requestedMCPTools[serverName] = [currentMCPGenerator];
continue;
}
const currentMCPGenerator = async (index) =>
createMCPTool({
index,
req: options.req,
res: options.res,
toolKey: tool,
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
});
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
requestedMCPTools[serverName].push({
type: 'single',
toolKey: tool,
serverName,
});
requestedMCPTools[serverName].push(currentMCPGenerator);
continue;
}
@@ -409,64 +422,24 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
const mcpToolPromises = [];
/** MCP server tools are initialized sequentially by server */
let index = -1;
const failedMCPServers = new Set();
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
index++;
/** @type {LCAvailableTools} */
let availableTools;
for (const config of toolConfigs) {
for (const generator of generators) {
try {
if (failedMCPServers.has(serverName)) {
continue;
}
const mcpParams = {
res: options.res,
userId: user,
index,
serverName: config.serverName,
userMCPAuthMap,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,
signal,
};
if (config.type === 'all' && toolConfigs.length === 1) {
/** Handle async loading for single 'all' tool config */
if (generator && generators.length === 1) {
mcpToolPromises.push(
createMCPTools(mcpParams).catch((error) => {
generator(index).catch((error) => {
logger.error(`Error loading ${serverName} tools:`, error);
return null;
}),
);
continue;
}
if (!availableTools) {
try {
availableTools = await getMCPServerTools(serverName);
} catch (error) {
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
}
}
/** Handle synchronous loading */
const mcpTool =
config.type === 'all'
? await createMCPTools(mcpParams)
: await createMCPTool({
...mcpParams,
availableTools,
toolKey: config.toolKey,
});
const mcpTool = await generator(index);
if (Array.isArray(mcpTool)) {
loadedTools.push(...mcpTool);
} else if (mcpTool) {
loadedTools.push(mcpTool);
} else {
failedMCPServers.add(serverName);
logger.warn(
`MCP tool creation failed for "${config.toolKey}", server may be unavailable or unauthenticated.`,
);
}
} catch (error) {
logger.error(`Error loading MCP tool for server ${serverName}:`, error);

View File

@@ -1,5 +1,4 @@
const fs = require('fs');
const { logger } = require('@librechat/data-schemas');
const { math, isEnabled } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
@@ -35,35 +34,13 @@ if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {
}
}
/** Helper function to safely read Redis CA certificate from file
* @returns {string|null} The contents of the CA certificate file, or null if not set or on error
*/
const getRedisCA = () => {
const caPath = process.env.REDIS_CA;
if (!caPath) {
return null;
}
try {
if (fs.existsSync(caPath)) {
return fs.readFileSync(caPath, 'utf8');
} else {
logger.warn(`Redis CA certificate file not found: ${caPath}`);
return null;
}
} catch (error) {
logger.error(`Failed to read Redis CA certificate file '${caPath}':`, error);
return null;
}
};
const cacheConfig = {
FORCED_IN_MEMORY_CACHE_NAMESPACES,
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: getRedisCA(),
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),

View File

@@ -1,6 +1,6 @@
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
const logger = require('./winston');
global.EventSource = EventSource;
@@ -26,6 +26,4 @@ module.exports = {
createMCPManager: MCPManager.createInstance,
getMCPManager: MCPManager.getInstance,
getFlowStateManager,
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
};

View File

@@ -11,7 +11,7 @@ const {
getProjectByName,
} = require('./Project');
const { removeAllPermissions } = require('~/server/services/PermissionService');
const { getMCPServerTools } = require('~/server/services/Config');
const { getCachedTools } = require('~/server/services/Config');
const { getActions } = require('./Action');
const { Agent } = require('~/db/models');
@@ -49,14 +49,6 @@ const createAgent = async (agentData) => {
*/
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
/**
* Get multiple agent documents based on the provided search parameters.
*
* @param {Object} searchParameter - The search parameters to find agents.
* @returns {Promise<Agent[]>} Array of agent documents as plain objects.
*/
const getAgents = async (searchParameter) => await Agent.find(searchParameter).lean();
/**
* Load an agent based on the provided ID
*
@@ -69,6 +61,8 @@ const getAgents = async (searchParameter) => await Agent.find(searchParameter).l
*/
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
const { model, ...model_parameters } = _m;
/** @type {Record<string, FunctionTool>} */
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
@@ -86,18 +80,22 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
const addedServers = new Set();
if (mcpServers.size > 0) {
for (const toolName of Object.keys(availableTools)) {
if (!toolName.includes(mcp_delimiter)) {
continue;
}
const mcpServer = toolName.split(mcp_delimiter)?.[1];
if (mcpServer && mcpServers.has(mcpServer)) {
addedServers.add(mcpServer);
tools.push(toolName);
}
}
for (const mcpServer of mcpServers) {
if (addedServers.has(mcpServer)) {
continue;
}
const serverTools = await getMCPServerTools(mcpServer);
if (!serverTools) {
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
addedServers.add(mcpServer);
continue;
}
tools.push(...Object.keys(serverTools));
addedServers.add(mcpServer);
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
}
}
@@ -837,7 +835,6 @@ const countPromotedAgents = async () => {
module.exports = {
getAgent,
getAgents,
loadAgent,
createAgent,
updateAgent,

View File

@@ -8,7 +8,6 @@ process.env.CREDS_IV = '0123456789abcdef';
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn(),
getMCPServerTools: jest.fn(),
}));
const mongoose = require('mongoose');
@@ -31,7 +30,7 @@ const {
generateActionMetadataHash,
} = require('./Agent');
const permissionService = require('~/server/services/PermissionService');
const { getCachedTools, getMCPServerTools } = require('~/server/services/Config');
const { getCachedTools } = require('~/server/services/Config');
const { AclEntry } = require('~/db/models');
/**
@@ -1930,16 +1929,6 @@ describe('models/Agent', () => {
another_tool: {},
});
// Mock getMCPServerTools to return tools for each server
getMCPServerTools.mockImplementation(async (server) => {
if (server === 'server1') {
return { tool1_mcp_server1: {} };
} else if (server === 'server2') {
return { tool2_mcp_server2: {} };
}
return null;
});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -2124,14 +2113,6 @@ describe('models/Agent', () => {
getCachedTools.mockResolvedValue(availableTools);
// Mock getMCPServerTools to return all tools for server1
getMCPServerTools.mockImplementation(async (server) => {
if (server === 'server1') {
return availableTools; // All 100 tools belong to server1
}
return null;
});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -2673,17 +2654,6 @@ describe('models/Agent', () => {
tool_mcp_server2: {}, // Different server
});
// Mock getMCPServerTools to return only tools matching the server
getMCPServerTools.mockImplementation(async (server) => {
if (server === 'server1') {
// Only return tool that correctly matches server1 format
return { tool_mcp_server1: {} };
} else if (server === 'server2') {
return { tool_mcp_server2: {} };
}
return null;
});
const mockReq = {
user: { id: 'user123' },
body: {

View File

@@ -112,8 +112,17 @@ module.exports = {
update.expiredAt = null;
}
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
/** @type {{ $set: Partial<TConversation>; $addToSet?: Record<string, any>; $unset?: Record<keyof TConversation, number> }} */
const updateOperation = { $set: update };
if (convo.model && convo.endpoint) {
updateOperation.$addToSet = {
modelHistory: {
model: convo.model,
endpoint: convo.endpoint,
},
};
}
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
updateOperation.$unset = metadata.unsetFields;
}

View File

@@ -239,46 +239,10 @@ const updateTagsForConversation = async (user, conversationId, tags) => {
}
};
/**
* Increments tag counts for existing tags only.
* @param {string} user - The user ID.
* @param {string[]} tags - Array of tag names to increment
* @returns {Promise<void>}
*/
const bulkIncrementTagCounts = async (user, tags) => {
if (!tags || tags.length === 0) {
return;
}
try {
const uniqueTags = [...new Set(tags.filter(Boolean))];
if (uniqueTags.length === 0) {
return;
}
const bulkOps = uniqueTags.map((tag) => ({
updateOne: {
filter: { user, tag },
update: { $inc: { count: 1 } },
},
}));
const result = await ConversationTag.bulkWrite(bulkOps);
if (result && result.modifiedCount > 0) {
logger.debug(
`user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`,
);
}
} catch (error) {
logger.error('[bulkIncrementTagCounts] Error incrementing tag counts', error);
}
};
module.exports = {
getConversationTags,
createConversationTag,
updateConversationTag,
deleteConversationTag,
bulkIncrementTagCounts,
updateTagsForConversation,
};

View File

@@ -42,7 +42,7 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => {
$or: [],
};
if (toolResourceSet.has(EToolResources.context)) {
if (toolResourceSet.has(EToolResources.ocr)) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
}
if (toolResourceSet.has(EToolResources.file_search)) {

View File

@@ -49,7 +49,7 @@
"@langchain/google-vertexai": "^0.2.13",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.80",
"@librechat/agents": "^2.4.79",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@microsoft/microsoft-graph-client": "^3.0.7",

View File

@@ -1,8 +1,8 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const openIdClient = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { isEnabled, findOpenIDUser } = require('@librechat/api');
const {
requestPasswordReset,
setOpenIDAuthTokens,
@@ -11,9 +11,8 @@ const {
registerUser,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const registrationController = async (req, res) => {
try {
@@ -72,14 +71,8 @@ const refreshController = async (req, res) => {
const openIdConfig = getOpenIdConfig();
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
const claims = tokenset.claims();
const { user, error } = await findOpenIDUser({
findUser,
email: claims.email,
openidId: claims.sub,
idOnTheSource: claims.oid,
strategyName: 'refreshController',
});
if (error || !user) {
const user = await findUser({ email: claims.email });
if (!user) {
return res.status(401).redirect('/login');
}
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
@@ -103,25 +96,14 @@ const refreshController = async (req, res) => {
return res.status(200).send({ token, user });
}
/** Session with the hashed refresh token */
const session = await findSession(
{
userId: userId,
refreshToken: refreshToken,
},
{ lean: false },
);
// Find the session with the hashed refresh token
const session = await findSession({
userId: userId,
refreshToken: refreshToken,
});
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session);
// trigger OAuth MCP server reconnection asynchronously (best effort)
void getOAuthReconnectionManager()
.reconnectServers(userId)
.catch((err) => {
logger.error('Error reconnecting OAuth MCP servers:', err);
});
const token = await setAuthTokens(userId, res, session._id);
res.status(200).send({ token, user });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)
@@ -132,7 +114,7 @@ const refreshController = async (req, res) => {
res.status(401).send('Refresh token expired or not found for this user');
}
} catch (err) {
logger.error(`[refreshController] Invalid refresh token:`, err);
logger.error(`[refreshController] Refresh token: ${refreshToken}`, err);
res.status(403).send('Invalid refresh token');
}
};

View File

@@ -1,9 +1,16 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api');
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { CacheKeys, Constants } = require('librechat-data-provider');
const {
getToolkitKey,
checkPluginAuth,
filterUniquePlugins,
convertMCPToolToPlugin,
convertMCPToolsToPlugins,
} = require('@librechat/api');
const { getCachedTools, setCachedTools, mergeUserTools } = require('~/server/services/Config');
const { availableTools, toolkits } = require('~/app/clients/tools');
const { getAppConfig } = require('~/server/services/Config');
const { getMCPManager } = require('~/config');
const { getLogStores } = require('~/cache');
const getAvailablePluginsController = async (req, res) => {
@@ -65,27 +72,63 @@ const getAvailableTools = async (req, res) => {
}
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const cachedUserTools = await getCachedTools({ userId });
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
// Return early if we have cached tools
if (cachedToolsArray != null) {
res.status(200).json(cachedToolsArray);
/** @type {TPlugin[]} */
let mcpPlugins;
if (appConfig?.mcpConfig) {
const mcpManager = getMCPManager();
mcpPlugins =
cachedUserTools != null
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
: undefined;
}
if (
cachedToolsArray != null &&
(appConfig?.mcpConfig != null ? mcpPlugins != null && mcpPlugins.length > 0 : true)
) {
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...cachedToolsArray]);
res.status(200).json(dedupedTools);
return;
}
/** @type {Record<string, FunctionTool> | null} Get tool definitions to filter which tools are actually available */
let toolDefinitions = await getCachedTools();
if (toolDefinitions == null && appConfig?.availableTools != null) {
logger.warn('[getAvailableTools] Tool cache was empty, re-initializing from app config');
await setCachedTools(appConfig.availableTools);
toolDefinitions = appConfig.availableTools;
}
let toolDefinitions = await getCachedTools({ includeGlobal: true });
let prelimCachedTools;
/** @type {import('@librechat/api').LCManifestTool[]} */
let pluginManifest = availableTools;
if (appConfig?.mcpConfig != null) {
try {
const mcpManager = getMCPManager();
const mcpTools = await mcpManager.getAllToolFunctions(userId);
prelimCachedTools = prelimCachedTools ?? {};
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
const plugin = convertMCPToolToPlugin({
toolKey,
toolData,
mcpManager,
});
if (plugin) {
pluginManifest.push(plugin);
}
prelimCachedTools[toolKey] = toolData;
}
await mergeUserTools({ userId, cachedUserTools, userTools: prelimCachedTools });
} catch (error) {
logger.error(
'[getAvailableTools] Error loading MCP Tools, servers may still be initializing:',
error,
);
}
} else if (prelimCachedTools != null) {
await setCachedTools(prelimCachedTools, { isGlobal: true });
}
/** @type {TPlugin[]} Deduplicate and authenticate plugins */
const uniquePlugins = filterUniquePlugins(pluginManifest);
const authenticatedPlugins = uniquePlugins.map((plugin) => {
@@ -96,13 +139,13 @@ const getAvailableTools = async (req, res) => {
}
});
/** Filter plugins based on availability */
/** Filter plugins based on availability and add MCP-specific auth config */
const toolsOutput = [];
for (const plugin of authenticatedPlugins) {
const isToolDefined = toolDefinitions?.[plugin.pluginKey] !== undefined;
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
const isToolkit =
plugin.toolkit === true &&
Object.keys(toolDefinitions ?? {}).some(
Object.keys(toolDefinitions).some(
(key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey,
);
@@ -110,13 +153,39 @@ const getAvailableTools = async (req, res) => {
continue;
}
toolsOutput.push(plugin);
const toolToAdd = { ...plugin };
if (plugin.pluginKey.includes(Constants.mcp_delimiter)) {
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
const serverName = parts[parts.length - 1];
const serverConfig = appConfig?.mcpConfig?.[serverName];
if (serverConfig?.customUserVars) {
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length === 0) {
toolToAdd.authConfig = [];
toolToAdd.authenticated = true;
} else {
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(
([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}),
);
toolToAdd.authenticated = false;
}
}
}
toolsOutput.push(toolToAdd);
}
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
res.status(200).json(finalTools);
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...finalTools]);
res.status(200).json(dedupedTools);
} catch (error) {
logger.error('[getAvailableTools]', error);
res.status(500).json({ message: error.message });

View File

@@ -1,3 +1,4 @@
const { Constants } = require('librechat-data-provider');
const { getCachedTools, getAppConfig } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
@@ -16,10 +17,18 @@ jest.mock('~/server/services/Config', () => ({
includedTools: [],
}),
setCachedTools: jest.fn(),
mergeUserTools: jest.fn(),
}));
// loadAndFormatTools mock removed - no longer used in PluginController
// getMCPManager mock removed - no longer used in PluginController
jest.mock('~/config', () => ({
getMCPManager: jest.fn(() => ({
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
})),
getFlowStateManager: jest.fn(),
}));
jest.mock('~/app/clients/tools', () => ({
availableTools: [],
@@ -150,6 +159,52 @@ describe('PluginController', () => {
});
describe('getAvailableTools', () => {
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
const mockUserTools = {
[`tool1${Constants.mcp_delimiter}server1`]: {
type: 'function',
function: {
name: `tool1${Constants.mcp_delimiter}server1`,
description: 'Tool 1',
parameters: { type: 'object', properties: {} },
},
},
};
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(mockUserTools);
mockReq.config = {
mcpConfig: {
server1: {},
},
paths: { structuredTools: '/mock/path' },
};
// Mock MCP manager to return empty tools initially (since getAllToolFunctions is called)
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// Mock second call to return tool definitions (includeGlobal: true)
getCachedTools.mockResolvedValueOnce(mockUserTools);
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(responseData).toBeDefined();
expect(Array.isArray(responseData)).toBe(true);
expect(responseData.length).toBeGreaterThan(0);
const convertedTool = responseData.find(
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}server1`,
);
expect(convertedTool).toBeDefined();
// The real convertMCPToolsToPlugins extracts the name from the delimiter
expect(convertedTool.name).toBe('tool1');
});
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
const mockUserTools = {
'user-tool': {
@@ -174,6 +229,9 @@ describe('PluginController', () => {
paths: { structuredTools: '/mock/path' },
};
// Mock second call to return tool definitions
getCachedTools.mockResolvedValueOnce(mockUserTools);
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
@@ -196,7 +254,14 @@ describe('PluginController', () => {
require('~/app/clients/tools').availableTools.push(mockPlugin);
mockCache.get.mockResolvedValue(null);
// getCachedTools returns the tool definitions
// First call returns null for user tools
getCachedTools.mockResolvedValueOnce(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce({
tool1: {
type: 'function',
@@ -207,10 +272,6 @@ describe('PluginController', () => {
},
},
});
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
await getAvailableTools(mockReq, mockRes);
@@ -241,7 +302,14 @@ describe('PluginController', () => {
});
mockCache.get.mockResolvedValue(null);
// getCachedTools returns the tool definitions
// First call returns null for user tools
getCachedTools.mockResolvedValueOnce(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce({
toolkit1_function: {
type: 'function',
@@ -252,10 +320,6 @@ describe('PluginController', () => {
},
},
});
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
await getAvailableTools(mockReq, mockRes);
@@ -267,7 +331,126 @@ describe('PluginController', () => {
});
});
describe('plugin.icon behavior', () => {
const callGetAvailableToolsWithMCPServer = async (serverConfig) => {
mockCache.get.mockResolvedValue(null);
const functionTools = {
[`test-tool${Constants.mcp_delimiter}test-server`]: {
type: 'function',
function: {
name: `test-tool${Constants.mcp_delimiter}test-server`,
description: 'A test tool',
parameters: { type: 'object', properties: {} },
},
},
};
// Mock the MCP manager to return tools and server config
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue(functionTools),
getRawConfig: jest.fn().mockReturnValue(serverConfig),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// First call returns empty user tools
getCachedTools.mockResolvedValueOnce({});
// Mock getAppConfig to return the mcpConfig
mockReq.config = {
mcpConfig: {
'test-server': serverConfig,
},
};
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce(functionTools);
await getAvailableTools(mockReq, mockRes);
const responseData = mockRes.json.mock.calls[0][0];
return responseData.find(
(tool) => tool.pluginKey === `test-tool${Constants.mcp_delimiter}test-server`,
);
};
it('should set plugin.icon when iconPath is defined', async () => {
const serverConfig = {
iconPath: '/path/to/icon.png',
};
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
expect(testTool.icon).toBe('/path/to/icon.png');
});
it('should set plugin.icon to undefined when iconPath is not defined', async () => {
const serverConfig = {};
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
expect(testTool.icon).toBeUndefined();
});
});
describe('helper function integration', () => {
it('should properly handle MCP tools with custom user variables', async () => {
const appConfig = {
mcpConfig: {
'test-server': {
customUserVars: {
API_KEY: { title: 'API Key', description: 'Your API key' },
},
},
},
};
// Mock MCP tools returned by getAllToolFunctions
const mcpToolFunctions = {
[`tool1${Constants.mcp_delimiter}test-server`]: {
type: 'function',
function: {
name: `tool1${Constants.mcp_delimiter}test-server`,
description: 'Tool 1',
parameters: {},
},
},
};
// Mock the MCP manager to return tools
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue(mcpToolFunctions),
getRawConfig: jest.fn().mockReturnValue({
customUserVars: {
API_KEY: { title: 'API Key', description: 'Your API key' },
},
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
mockCache.get.mockResolvedValue(null);
mockReq.config = appConfig;
// First call returns user tools (empty in this case)
getCachedTools.mockResolvedValueOnce({});
// Second call (with includeGlobal: true) returns tool definitions including our MCP tool
getCachedTools.mockResolvedValueOnce(mcpToolFunctions);
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(Array.isArray(responseData)).toBe(true);
// Find the MCP tool in the response
const mcpTool = responseData.find(
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
);
// The actual implementation adds authConfig and sets authenticated to false when customUserVars exist
expect(mcpTool).toBeDefined();
expect(mcpTool.authConfig).toEqual([
{ authField: 'API_KEY', label: 'API Key', description: 'Your API key' },
]);
expect(mcpTool.authenticated).toBe(false);
});
it('should handle error cases gracefully', async () => {
mockCache.get.mockRejectedValue(new Error('Cache error'));
@@ -289,13 +472,23 @@ describe('PluginController', () => {
it('should handle null cachedTools and cachedUserTools', async () => {
mockCache.get.mockResolvedValue(null);
// getCachedTools returns empty object instead of null
getCachedTools.mockResolvedValueOnce({});
// First call returns null for user tools
getCachedTools.mockResolvedValueOnce(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Mock MCP manager to return no tools
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// Second call (with includeGlobal: true) returns empty object instead of null
getCachedTools.mockResolvedValueOnce({});
await getAvailableTools(mockReq, mockRes);
// Should handle null values gracefully
@@ -310,9 +503,9 @@ describe('PluginController', () => {
paths: { structuredTools: '/mock/path' },
};
// Mock getCachedTools to return undefined
// Mock getCachedTools to return undefined for both calls
getCachedTools.mockReset();
getCachedTools.mockResolvedValueOnce(undefined);
getCachedTools.mockResolvedValueOnce(undefined).mockResolvedValueOnce(undefined);
await getAvailableTools(mockReq, mockRes);
@@ -321,6 +514,51 @@ describe('PluginController', () => {
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should handle `cachedToolsArray` and `mcpPlugins` both being defined', async () => {
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
const userTools = {
[`user-tool${Constants.mcp_delimiter}server1`]: {
type: 'function',
function: {
name: `user-tool${Constants.mcp_delimiter}server1`,
description: 'User tool',
parameters: {},
},
},
};
mockCache.get.mockResolvedValue(cachedTools);
getCachedTools.mockResolvedValueOnce(userTools);
mockReq.config = {
mcpConfig: {
server1: {},
},
paths: { structuredTools: '/mock/path' },
};
// Mock MCP manager to return empty tools initially
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// The controller expects a second call to getCachedTools
getCachedTools.mockResolvedValueOnce({
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },
[`user-tool${Constants.mcp_delimiter}server1`]:
userTools[`user-tool${Constants.mcp_delimiter}server1`],
});
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
// Should have both cached and user tools
expect(responseData.length).toBeGreaterThanOrEqual(2);
});
it('should handle empty toolDefinitions object', async () => {
mockCache.get.mockResolvedValue(null);
// Reset getCachedTools to ensure clean state
@@ -331,12 +569,76 @@ describe('PluginController', () => {
// Ensure no plugins are available
require('~/app/clients/tools').availableTools.length = 0;
// Reset MCP manager to default state
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
await getAvailableTools(mockReq, mockRes);
// With empty tool definitions, no tools should be in the final output
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should handle MCP tools without customUserVars', async () => {
const appConfig = {
mcpConfig: {
'test-server': {
// No customUserVars defined
},
},
};
const mockUserTools = {
[`tool1${Constants.mcp_delimiter}test-server`]: {
type: 'function',
function: {
name: `tool1${Constants.mcp_delimiter}test-server`,
description: 'Tool 1',
parameters: { type: 'object', properties: {} },
},
},
};
// Mock the MCP manager to return the tools
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue(mockUserTools),
getRawConfig: jest.fn().mockReturnValue({
// No customUserVars defined
}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
mockCache.get.mockResolvedValue(null);
mockReq.config = appConfig;
// First call returns empty user tools
getCachedTools.mockResolvedValueOnce({});
// Second call (with includeGlobal: true) returns the tool definitions
getCachedTools.mockResolvedValueOnce(mockUserTools);
// Ensure no plugins in availableTools for clean test
require('~/app/clients/tools').availableTools.length = 0;
await getAvailableTools(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(Array.isArray(responseData)).toBe(true);
expect(responseData.length).toBeGreaterThan(0);
const mcpTool = responseData.find(
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
);
expect(mcpTool).toBeDefined();
expect(mcpTool.authenticated).toBe(true);
// The actual implementation sets authConfig to empty array when no customUserVars
expect(mcpTool.authConfig).toEqual([]);
});
it('should handle undefined filteredTools and includedTools', async () => {
mockReq.config = {};
mockCache.get.mockResolvedValue(null);
@@ -365,129 +667,20 @@ describe('PluginController', () => {
require('~/app/clients/tools').availableTools.push(mockToolkit);
mockCache.get.mockResolvedValue(null);
// getCachedTools returns empty object to avoid null reference error
// First call returns empty object
getCachedTools.mockResolvedValueOnce({});
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Second call (with includeGlobal: true) returns empty object to avoid null reference error
getCachedTools.mockResolvedValueOnce({});
await getAvailableTools(mockReq, mockRes);
// Should handle null toolDefinitions gracefully
expect(mockRes.status).toHaveBeenCalledWith(200);
});
it('should handle undefined toolDefinitions when checking isToolDefined (traversaal_search bug)', async () => {
// This test reproduces the bug where toolDefinitions is undefined
// and accessing toolDefinitions[plugin.pluginKey] causes a TypeError
const mockPlugin = {
name: 'Traversaal Search',
pluginKey: 'traversaal_search',
description: 'Search plugin',
};
// Add the plugin to availableTools
require('~/app/clients/tools').availableTools.push(mockPlugin);
mockCache.get.mockResolvedValue(null);
mockReq.config = {
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// CRITICAL: getCachedTools returns undefined
// This is what causes the bug when trying to access toolDefinitions[plugin.pluginKey]
getCachedTools.mockResolvedValueOnce(undefined);
// This should not throw an error with the optional chaining fix
await getAvailableTools(mockReq, mockRes);
// Should handle undefined toolDefinitions gracefully and return empty array
expect(mockRes.status).toHaveBeenCalledWith(200);
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should re-initialize tools from appConfig when cache returns null', async () => {
// Setup: Initial state with tools in appConfig
const mockAppTools = {
tool1: {
type: 'function',
function: {
name: 'tool1',
description: 'Tool 1',
parameters: {},
},
},
tool2: {
type: 'function',
function: {
name: 'tool2',
description: 'Tool 2',
parameters: {},
},
},
};
// Add matching plugins to availableTools
require('~/app/clients/tools').availableTools.push(
{ name: 'Tool 1', pluginKey: 'tool1', description: 'Tool 1' },
{ name: 'Tool 2', pluginKey: 'tool2', description: 'Tool 2' },
);
// Simulate cache cleared state (returns null)
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
mockReq.config = {
filteredTools: [],
includedTools: [],
availableTools: mockAppTools,
};
// Mock setCachedTools to verify it's called to re-initialize
const { setCachedTools } = require('~/server/services/Config');
await getAvailableTools(mockReq, mockRes);
// Should have re-initialized the cache with tools from appConfig
expect(setCachedTools).toHaveBeenCalledWith(mockAppTools);
// Should still return tools successfully
expect(mockRes.status).toHaveBeenCalledWith(200);
const responseData = mockRes.json.mock.calls[0][0];
expect(responseData).toHaveLength(2);
expect(responseData.find((t) => t.pluginKey === 'tool1')).toBeDefined();
expect(responseData.find((t) => t.pluginKey === 'tool2')).toBeDefined();
});
it('should handle cache clear without appConfig.availableTools gracefully', async () => {
// Setup: appConfig without availableTools
getAppConfig.mockResolvedValue({
filteredTools: [],
includedTools: [],
// No availableTools property
});
// Clear availableTools array
require('~/app/clients/tools').availableTools.length = 0;
// Cache returns null (cleared state)
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
mockReq.config = {
filteredTools: [],
includedTools: [],
// No availableTools
};
await getAvailableTools(mockReq, mockRes);
// Should handle gracefully without crashing
expect(mockRes.status).toHaveBeenCalledWith(200);
expect(mockRes.json).toHaveBeenCalledWith([]);
});
});
});

View File

@@ -1,34 +1,37 @@
const { logger } = require('@librechat/data-schemas');
const { Tools, CacheKeys, Constants, FileSources } = require('librechat-data-provider');
const {
webSearchKeys,
MCPOAuthHandler,
MCPTokenStorage,
normalizeHttpError,
extractWebSearchEnvVars,
normalizeHttpError,
MCPTokenStorage,
} = require('@librechat/api');
const {
getFiles,
findToken,
updateUser,
deleteFiles,
deleteConvos,
deletePresets,
deleteMessages,
deleteUserById,
deleteAllSharedLinks,
deleteAllUserSessions,
} = require('~/models');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { Tools, Constants, FileSources } = require('librechat-data-provider');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User, Token } = require('~/db/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getAppConfig } = require('~/server/services/Config');
const { deleteToolCalls } = require('~/models/ToolCall');
const { deleteAllSharedLinks } = require('~/models');
const { getMCPManager } = require('~/config');
const { MCPOAuthHandler } = require('@librechat/api');
const { getFlowStateManager } = require('~/config');
const { CacheKeys } = require('librechat-data-provider');
const { getLogStores } = require('~/cache');
const { clearMCPServerTools } = require('~/server/services/Config/mcpToolsCache');
const { findToken } = require('~/models');
const getUserController = async (req, res) => {
const appConfig = await getAppConfig({ role: req.user?.role });
@@ -372,6 +375,9 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
await flowManager.deleteFlow(flowId, 'mcp_oauth');
// 6. clear the tools cache for the server
await clearMCPServerTools({ userId, serverName });
};
module.exports = {

View File

@@ -158,7 +158,7 @@ describe('duplicateAgent', () => {
});
});
it('should convert `tool_resources.ocr` to `tool_resources.context`', async () => {
it('should handle tool_resources.ocr correctly', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
@@ -178,7 +178,7 @@ describe('duplicateAgent', () => {
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
tool_resources: {
context: { enabled: true, config: 'test' },
ocr: { enabled: true, config: 'test' },
},
}),
);

View File

@@ -2,12 +2,7 @@ const { z } = require('zod');
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const {
agentCreateSchema,
agentUpdateSchema,
mergeAgentOcrConversion,
convertOcrToContextInPlace,
} = require('@librechat/api');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const {
Tools,
Constants,
@@ -71,7 +66,7 @@ const createAgentHandler = async (req, res) => {
agentData.author = userId;
agentData.tools = [];
const availableTools = await getCachedTools();
const availableTools = await getCachedTools({ includeGlobal: true });
for (const tool of tools) {
if (availableTools[tool]) {
agentData.tools.push(tool);
@@ -203,32 +198,19 @@ const getAgentHandler = async (req, res, expandProperties = false) => {
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @param {AgentUpdateParams} req.body - The Agent update parameters.
* @returns {Promise<Agent>} 200 - success response - application/json
* @returns {Agent} 200 - success response - application/json
*/
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const validatedData = agentUpdateSchema.parse(req.body);
const { _id, ...updateData } = removeNullishValues(validatedData);
// Convert OCR to context in incoming updateData
convertOcrToContextInPlace(updateData);
const existingAgent = await getAgent({ id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
// Convert legacy OCR tool resource to context format in existing agent
const ocrConversion = mergeAgentOcrConversion(existingAgent, updateData);
if (ocrConversion.tool_resources) {
updateData.tool_resources = ocrConversion.tool_resources;
}
if (ocrConversion.tools) {
updateData.tools = ocrConversion.tools;
}
let updatedAgent =
Object.keys(updateData).length > 0
? await updateAgent({ id }, updateData, {
@@ -273,7 +255,7 @@ const updateAgentHandler = async (req, res) => {
* @param {object} req - Express Request
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @returns {Promise<Agent>} 201 - success response - application/json
* @returns {Agent} 201 - success response - application/json
*/
const duplicateAgentHandler = async (req, res) => {
const { id } = req.params;
@@ -306,19 +288,9 @@ const duplicateAgentHandler = async (req, res) => {
hour12: false,
})})`;
if (_tool_resources?.[EToolResources.context]) {
cloneData.tool_resources = {
[EToolResources.context]: _tool_resources[EToolResources.context],
};
}
if (_tool_resources?.[EToolResources.ocr]) {
cloneData.tool_resources = {
/** Legacy conversion from `ocr` to `context` */
[EToolResources.context]: {
...(_tool_resources[EToolResources.context] ?? {}),
..._tool_resources[EToolResources.ocr],
},
[EToolResources.ocr]: _tool_resources[EToolResources.ocr],
};
}
@@ -410,7 +382,7 @@ const duplicateAgentHandler = async (req, res) => {
* @param {object} req - Express Request
* @param {object} req.params - Request params
* @param {string} req.params.id - Agent identifier.
* @returns {Promise<Agent>} 200 - success response - application/json
* @returns {Agent} 200 - success response - application/json
*/
const deleteAgentHandler = async (req, res) => {
try {
@@ -512,7 +484,7 @@ const getListAgentsHandler = async (req, res) => {
* @param {Express.Multer.File} req.file - The avatar image file.
* @param {object} req.body - Request body
* @param {string} [req.body.avatar] - Optional avatar for the agent's avatar.
* @returns {Promise<void>} 200 - success response - application/json
* @returns {Object} 200 - success response - application/json
*/
const uploadAgentAvatarHandler = async (req, res) => {
try {

View File

@@ -512,7 +512,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
mockReq.params.id = existingAgentId;
mockReq.body = {
tool_resources: {
/** Legacy conversion from `ocr` to `context` */
ocr: {
file_ids: ['ocr1', 'ocr2'],
},
@@ -532,8 +531,7 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tool_resources).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeUndefined();
expect(updatedAgent.tool_resources.context).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeDefined();
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
});

View File

@@ -31,7 +31,7 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
@@ -136,7 +136,7 @@ const patchAssistant = async (req, res) => {
...updateData
} = req.body;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
updateData.tools = (updateData.tools ?? [])
.map((tool) => {

View File

@@ -28,7 +28,7 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
@@ -125,7 +125,7 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
let hasFileSearch = false;
for (const tool of updateData.tools ?? []) {
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ includeGlobal: true });
let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool;
if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {

View File

@@ -1,126 +0,0 @@
/**
* MCP Tools Controller
* Handles MCP-specific tool endpoints, decoupled from regular LibreChat tools
*/
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
cacheMCPServerTools,
getMCPServerTools,
getAppConfig,
} = require('~/server/services/Config');
const { getMCPManager } = require('~/config');
/**
* Get all MCP tools available to the user
*/
const getMCPTools = async (req, res) => {
try {
const userId = req.user?.id;
if (!userId) {
logger.warn('[getMCPTools] User ID not found in request');
return res.status(401).json({ message: 'Unauthorized' });
}
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
if (!appConfig?.mcpConfig) {
return res.status(200).json({ servers: {} });
}
const mcpManager = getMCPManager();
const configuredServers = Object.keys(appConfig.mcpConfig);
const mcpServers = {};
const cachePromises = configuredServers.map((serverName) =>
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })),
);
const cacheResults = await Promise.all(cachePromises);
const serverToolsMap = new Map();
for (const { serverName, tools } of cacheResults) {
if (tools) {
serverToolsMap.set(serverName, tools);
continue;
}
const serverTools = await mcpManager.getServerToolFunctions(userId, serverName);
if (!serverTools) {
logger.debug(`[getMCPTools] No tools found for server ${serverName}`);
continue;
}
serverToolsMap.set(serverName, serverTools);
if (Object.keys(serverTools).length > 0) {
// Cache asynchronously without blocking
cacheMCPServerTools({ serverName, serverTools }).catch((err) =>
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
);
}
}
// Process each configured server
for (const serverName of configuredServers) {
try {
const serverTools = serverToolsMap.get(serverName);
// Get server config once
const serverConfig = appConfig.mcpConfig[serverName];
const rawServerConfig = mcpManager.getRawConfig(serverName);
// Initialize server object with all server-level data
const server = {
name: serverName,
icon: rawServerConfig?.iconPath || '',
authenticated: true,
authConfig: [],
tools: [],
};
// Set authentication config once for the server
if (serverConfig?.customUserVars) {
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length > 0) {
server.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
server.authenticated = false;
}
}
// Process tools efficiently - no need for convertMCPToolToPlugin
if (serverTools) {
for (const [toolKey, toolData] of Object.entries(serverTools)) {
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
continue;
}
const toolName = toolKey.split(Constants.mcp_delimiter)[0];
server.tools.push({
name: toolName,
pluginKey: toolKey,
description: toolData.function.description || '',
});
}
}
// Only add server if it has tools or is configured
if (server.tools.length > 0 || serverConfig) {
mcpServers[serverName] = server;
}
} catch (error) {
logger.error(`[getMCPTools] Error loading tools for server ${serverName}:`, error);
}
}
res.status(200).json({ servers: mcpServers });
} catch (error) {
logger.error('[getMCPTools]', error);
res.status(500).json({ message: error.message });
}
};
module.exports = {
getMCPTools,
};

View File

@@ -12,7 +12,6 @@ const { logger } = require('@librechat/data-schemas');
const mongoSanitize = require('express-mongo-sanitize');
const { isEnabled, ErrorController } = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
const createValidateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const { updateInterfacePermissions } = require('~/models/interface');
@@ -155,7 +154,7 @@ const startServer = async () => {
res.send(updatedIndexHtml);
});
app.listen(port, host, async () => {
app.listen(port, host, () => {
if (host === '0.0.0.0') {
logger.info(
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
@@ -164,9 +163,7 @@ const startServer = async () => {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
await initializeMCPs();
await initializeOAuthReconnectManager();
await checkMigrations();
initializeMCPs().then(() => checkMigrations());
});
};

View File

@@ -1,7 +1,7 @@
const { logger } = require('@librechat/data-schemas');
const { PermissionBits, hasPermissions, ResourceType } = require('librechat-data-provider');
const { getEffectivePermissions } = require('~/server/services/PermissionService');
const { getAgents } = require('~/models/Agent');
const { getAgent } = require('~/models/Agent');
const { getFiles } = require('~/models/File');
/**
@@ -10,12 +10,11 @@ const { getFiles } = require('~/models/File');
*/
const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
try {
/** Agents that have this file in their tool_resources */
const agentsWithFile = await getAgents({
// Find agents that have this file in their tool_resources
const agentsWithFile = await getAgent({
$or: [
{ 'tool_resources.execute_code.file_ids': fileId },
{ 'tool_resources.file_search.file_ids': fileId },
{ 'tool_resources.context.file_ids': fileId },
{ 'tool_resources.execute_code.file_ids': fileId },
{ 'tool_resources.ocr.file_ids': fileId },
],
});
@@ -25,7 +24,7 @@ const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
}
// Check if user has access to any of these agents
for (const agent of agentsWithFile) {
for (const agent of Array.isArray(agentsWithFile) ? agentsWithFile : [agentsWithFile]) {
// Check if user is the agent author
if (agent.author && agent.author.toString() === userId) {
logger.debug(`[fileAccess] User is author of agent ${agent.id}`);
@@ -84,6 +83,7 @@ const fileAccess = async (req, res, next) => {
});
}
// Get the file
const [file] = await getFiles({ file_id: fileId });
if (!file) {
return res.status(404).json({
@@ -92,18 +92,20 @@ const fileAccess = async (req, res, next) => {
});
}
// Check if user owns the file
if (file.user && file.user.toString() === userId) {
req.fileAccess = { file };
return next();
}
/** Agent-based access (file inherits agent permissions) */
// Check agent-based access (file inherits agent permissions)
const hasAgentAccess = await checkAgentBasedFileAccess({ userId, role: userRole, fileId });
if (hasAgentAccess) {
req.fileAccess = { file };
return next();
}
// No access
logger.warn(`[fileAccess] User ${userId} denied access to file ${fileId}`);
return res.status(403).json({
error: 'Forbidden',

View File

@@ -1,483 +0,0 @@
const mongoose = require('mongoose');
const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-provider');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { fileAccess } = require('./fileAccess');
const { User, Role, AclEntry } = require('~/db/models');
const { createAgent } = require('~/models/Agent');
const { createFile } = require('~/models/File');
describe('fileAccess middleware', () => {
let mongoServer;
let req, res, next;
let testUser, otherUser, thirdUser;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
// Create test role
await Role.create({
name: 'test-role',
permissions: {
AGENTS: {
USE: true,
CREATE: true,
SHARED_GLOBAL: false,
},
},
});
// Create test users
testUser = await User.create({
email: 'test@example.com',
name: 'Test User',
username: 'testuser',
role: 'test-role',
});
otherUser = await User.create({
email: 'other@example.com',
name: 'Other User',
username: 'otheruser',
role: 'test-role',
});
thirdUser = await User.create({
email: 'third@example.com',
name: 'Third User',
username: 'thirduser',
role: 'test-role',
});
// Setup request/response objects
req = {
user: { id: testUser._id.toString(), role: testUser.role },
params: {},
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
next = jest.fn();
jest.clearAllMocks();
});
describe('basic file access', () => {
test('should allow access when user owns the file', async () => {
// Create a file owned by testUser
await createFile({
user: testUser._id.toString(),
file_id: 'file_owned_by_user',
filepath: '/test/file.txt',
filename: 'file.txt',
type: 'text/plain',
size: 100,
});
req.params.file_id = 'file_owned_by_user';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
expect(req.fileAccess.file).toBeDefined();
expect(res.status).not.toHaveBeenCalled();
});
test('should deny access when user does not own the file and no agent access', async () => {
// Create a file owned by otherUser
await createFile({
user: otherUser._id.toString(),
file_id: 'file_owned_by_other',
filepath: '/test/file.txt',
filename: 'file.txt',
type: 'text/plain',
size: 100,
});
req.params.file_id = 'file_owned_by_other';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({
error: 'Forbidden',
message: 'Insufficient permissions to access this file',
});
});
test('should return 404 when file does not exist', async () => {
req.params.file_id = 'non_existent_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Not Found',
message: 'File not found',
});
});
test('should return 400 when file_id is missing', async () => {
// Don't set file_id in params
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
expect(res.json).toHaveBeenCalledWith({
error: 'Bad Request',
message: 'file_id is required',
});
});
test('should return 401 when user is not authenticated', async () => {
req.user = null;
req.params.file_id = 'some_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(401);
expect(res.json).toHaveBeenCalledWith({
error: 'Unauthorized',
message: 'Authentication required',
});
});
});
describe('agent-based file access', () => {
beforeEach(async () => {
// Create a file owned by otherUser (not testUser)
await createFile({
user: otherUser._id.toString(),
file_id: 'shared_file_via_agent',
filepath: '/test/shared.txt',
filename: 'shared.txt',
type: 'text/plain',
size: 100,
});
});
test('should allow access when user is author of agent with file', async () => {
// Create agent owned by testUser with the file
await createAgent({
id: `agent_${Date.now()}`,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: {
file_search: {
file_ids: ['shared_file_via_agent'],
},
},
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
expect(req.fileAccess.file).toBeDefined();
});
test('should allow access when user has VIEW permission on agent with file', async () => {
// Create agent owned by otherUser
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Shared Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
execute_code: {
file_ids: ['shared_file_via_agent'],
},
},
});
// Grant VIEW permission to testUser
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
permBits: 1, // VIEW permission
grantedBy: otherUser._id,
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
});
test('should check file in ocr tool_resources', async () => {
await createAgent({
id: `agent_ocr_${Date.now()}`,
name: 'OCR Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: {
ocr: {
file_ids: ['shared_file_via_agent'],
},
},
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
});
test('should deny access when user has no permission on agent with file', async () => {
// Create agent owned by otherUser without granting permission to testUser
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
file_search: {
file_ids: ['shared_file_via_agent'],
},
},
});
// Create ACL entry for otherUser only (owner)
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
permBits: 15, // All permissions
grantedBy: otherUser._id,
});
req.params.file_id = 'shared_file_via_agent';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
});
describe('multiple agents with same file', () => {
/**
* This test suite verifies that when multiple agents have the same file,
* all agents are checked for permissions, not just the first one found.
* This ensures users can access files through any agent they have permission for.
*/
test('should check ALL agents with file, not just first one', async () => {
// Create a file owned by someone else
await createFile({
user: otherUser._id.toString(),
file_id: 'multi_agent_file',
filepath: '/test/multi.txt',
filename: 'multi.txt',
type: 'text/plain',
size: 100,
});
// Create first agent (owned by otherUser, no access for testUser)
const agent1 = await createAgent({
id: 'agent_no_access',
name: 'No Access Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
file_search: {
file_ids: ['multi_agent_file'],
},
},
});
// Create ACL for agent1 - only otherUser has access
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: otherUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent1._id,
permBits: 15,
grantedBy: otherUser._id,
});
// Create second agent (owned by thirdUser, but testUser has VIEW access)
const agent2 = await createAgent({
id: 'agent_with_access',
name: 'Accessible Agent',
provider: 'openai',
model: 'gpt-4',
author: thirdUser._id,
tool_resources: {
file_search: {
file_ids: ['multi_agent_file'],
},
},
});
// Grant testUser VIEW access to agent2
await AclEntry.create({
principalType: PrincipalType.USER,
principalId: testUser._id,
principalModel: PrincipalModel.USER,
resourceType: ResourceType.AGENT,
resourceId: agent2._id,
permBits: 1, // VIEW permission
grantedBy: thirdUser._id,
});
req.params.file_id = 'multi_agent_file';
await fileAccess(req, res, next);
/**
* Should succeed because testUser has access to agent2,
* even though they don't have access to agent1.
* The fix ensures all agents are checked, not just the first one.
*/
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
expect(res.status).not.toHaveBeenCalled();
});
test('should find file in any agent tool_resources type', async () => {
// Create a file
await createFile({
user: otherUser._id.toString(),
file_id: 'multi_tool_file',
filepath: '/test/tool.txt',
filename: 'tool.txt',
type: 'text/plain',
size: 100,
});
// Agent 1: file in file_search (no access for testUser)
await createAgent({
id: 'agent_file_search',
name: 'File Search Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
tool_resources: {
file_search: {
file_ids: ['multi_tool_file'],
},
},
});
// Agent 2: same file in execute_code (testUser has access)
await createAgent({
id: 'agent_execute_code',
name: 'Execute Code Agent',
provider: 'openai',
model: 'gpt-4',
author: thirdUser._id,
tool_resources: {
execute_code: {
file_ids: ['multi_tool_file'],
},
},
});
// Agent 3: same file in ocr (testUser also has access)
await createAgent({
id: 'agent_ocr',
name: 'OCR Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id, // testUser owns this one
tool_resources: {
ocr: {
file_ids: ['multi_tool_file'],
},
},
});
req.params.file_id = 'multi_tool_file';
await fileAccess(req, res, next);
/**
* Should succeed because testUser owns agent3,
* even if other agents with the file are found first.
*/
expect(next).toHaveBeenCalled();
expect(req.fileAccess).toBeDefined();
});
});
describe('edge cases', () => {
test('should handle agent with empty tool_resources', async () => {
await createFile({
user: otherUser._id.toString(),
file_id: 'orphan_file',
filepath: '/test/orphan.txt',
filename: 'orphan.txt',
type: 'text/plain',
size: 100,
});
// Create agent with no files in tool_resources
await createAgent({
id: `agent_empty_${Date.now()}`,
name: 'Empty Resources Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: {},
});
req.params.file_id = 'orphan_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
test('should handle agent with null tool_resources', async () => {
await createFile({
user: otherUser._id.toString(),
file_id: 'another_orphan_file',
filepath: '/test/orphan2.txt',
filename: 'orphan2.txt',
type: 'text/plain',
size: 100,
});
// Create agent with null tool_resources
await createAgent({
id: `agent_null_${Date.now()}`,
name: 'Null Resources Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
tool_resources: null,
});
req.params.file_id = 'another_orphan_file';
await fileAccess(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
});
});
});

View File

@@ -0,0 +1,680 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
createMethods: jest.fn(() => ({})),
createModels: jest.fn(() => ({})),
}));
jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
validateMessageReq: (req, res, next) => next(),
}));
jest.mock('~/models', () => ({
getConvo: jest.fn(),
saveConvo: jest.fn(),
saveMessage: jest.fn(),
getMessage: jest.fn(),
getMessages: jest.fn(),
updateMessage: jest.fn(),
deleteMessages: jest.fn(),
}));
jest.mock('~/db/models', () => {
let User, Message, Transaction, Conversation;
return {
get User() {
return User;
},
get Message() {
return Message;
},
get Transaction() {
return Transaction;
},
get Conversation() {
return Conversation;
},
setUser: (model) => {
User = model;
},
setMessage: (model) => {
Message = model;
},
setTransaction: (model) => {
Transaction = model;
},
setConversation: (model) => {
Conversation = model;
},
};
});
describe('Costs Endpoint', () => {
let app;
let mongoServer;
let messagesRouter;
let User, Message, Transaction, Conversation;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
const userSchema = new mongoose.Schema({
_id: String,
name: String,
email: String,
});
const conversationSchema = new mongoose.Schema({
conversationId: String,
user: String,
title: String,
createdAt: Date,
});
const messageSchema = new mongoose.Schema({
messageId: String,
conversationId: String,
user: String,
isCreatedByUser: Boolean,
tokenCount: Number,
createdAt: Date,
});
const transactionSchema = new mongoose.Schema({
conversationId: String,
user: String,
tokenType: String,
tokenValue: Number,
createdAt: Date,
});
User = mongoose.model('User', userSchema);
Conversation = mongoose.model('Conversation', conversationSchema);
Message = mongoose.model('Message', messageSchema);
Transaction = mongoose.model('Transaction', transactionSchema);
const dbModels = require('~/db/models');
dbModels.setUser(User);
dbModels.setMessage(Message);
dbModels.setTransaction(Transaction);
dbModels.setConversation(Conversation);
require('~/db/models');
try {
messagesRouter = require('../messages');
} catch (error) {
console.error('Error loading messages router:', error);
throw error;
}
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: 'test-user-id' };
next();
});
app.use('/api/messages', messagesRouter);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await User.deleteMany({});
await Conversation.deleteMany({});
await Message.deleteMany({});
await Transaction.deleteMany({});
});
describe('GET /:conversationId/costs', () => {
const conversationId = 'test-conversation-123';
const userId = 'test-user-id';
it('should return cost data for valid conversation', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const aiMessage = new Message({
messageId: 'ai-msg-1',
conversationId,
user: userId,
isCreatedByUser: false,
tokenCount: 150,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage.save(), aiMessage.save()]);
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const completionTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'completion',
tokenValue: 750000,
createdAt: new Date('2024-01-01T10:01:30Z'),
});
await Promise.all([promptTransaction.save(), completionTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body).toMatchObject({
conversationId,
totals: {
prompt: { usd: 0.5, tokenCount: 100 },
completion: { usd: 0.75, tokenCount: 150 },
total: { usd: 1.25, tokenCount: 250 },
},
perMessage: [
{ messageId: 'user-msg-1', tokenType: 'prompt', tokenCount: 100, usd: 0.5 },
{ messageId: 'ai-msg-1', tokenType: 'completion', tokenCount: 150, usd: 0.75 },
],
});
});
it('should return empty data for conversation with no messages', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body).toMatchObject({
conversationId,
totals: {
prompt: { usd: 0, tokenCount: 0 },
completion: { usd: 0, tokenCount: 0 },
total: { usd: 0, tokenCount: 0 },
},
perMessage: [],
});
});
it('should handle messages without transactions', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const aiMessage = new Message({
messageId: 'ai-msg-1',
conversationId,
user: userId,
isCreatedByUser: false,
tokenCount: 150,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage.save(), aiMessage.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0);
expect(response.body.totals.completion.usd).toBe(0);
expect(response.body.totals.total.usd).toBe(0);
});
it('should aggregate multiple transactions correctly', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction1 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 300000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const promptTransaction2 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 200000,
createdAt: new Date('2024-01-01T10:00:45Z'),
});
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.perMessage[0].usd).toBe(0.5);
});
it('should handle null tokenCount values', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: null,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.tokenCount).toBe(0);
});
it('should handle null tokenValue in transactions', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: null,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await promptTransaction.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0);
});
it('should handle negative tokenValue using Math.abs', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: -500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await promptTransaction.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
});
it('should filter by user correctly', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const otherUserId = 'other-user-id';
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const otherUserMessage = new Message({
messageId: 'other-user-msg-1',
conversationId,
user: otherUserId,
isCreatedByUser: true,
tokenCount: 200,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await Promise.all([userMessage.save(), otherUserMessage.save()]);
const userTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const otherUserTransaction = new Transaction({
conversationId,
user: otherUserId,
tokenType: 'prompt',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await Promise.all([userTransaction.save(), otherUserTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.perMessage).toHaveLength(1);
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
});
it('should filter transactions by tokenType', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const otherTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'other',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await Promise.all([promptTransaction.save(), otherTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.totals.completion.usd).toBe(0);
expect(response.body.totals.total.usd).toBe(0.5);
});
it('should map transactions to messages chronologically', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage1 = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const userMessage2 = new Message({
messageId: 'user-msg-2',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 200,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage1.save(), userMessage2.save()]);
const promptTransaction1 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const promptTransaction2 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:01:30Z'),
});
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.perMessage).toHaveLength(2);
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
expect(response.body.perMessage[0].usd).toBe(0.5);
expect(response.body.perMessage[1].messageId).toBe('user-msg-2');
expect(response.body.perMessage[1].usd).toBe(1.0);
});
it('should handle database errors', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
await mongoose.connection.close();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(500);
expect(response.body).toHaveProperty('error');
});
});
});

View File

@@ -11,9 +11,6 @@ jest.mock('@librechat/api', () => ({
completeOAuthFlow: jest.fn(),
generateFlowId: jest.fn(),
},
MCPTokenStorage: {
storeTokens: jest.fn(),
},
getUserMCPAuthMap: jest.fn(),
}));
@@ -50,8 +47,8 @@ jest.mock('~/server/services/Config', () => ({
loadCustomConfig: jest.fn(),
}));
jest.mock('~/server/services/Config/mcp', () => ({
updateMCPServerTools: jest.fn(),
jest.mock('~/server/services/Config/mcpToolsCache', () => ({
updateMCPUserTools: jest.fn(),
}));
jest.mock('~/server/services/MCP', () => ({
@@ -237,7 +234,7 @@ describe('MCP Routes', () => {
});
describe('GET /:serverName/oauth/callback', () => {
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const { MCPOAuthHandler } = require('@librechat/api');
const { getLogStores } = require('~/cache');
it('should redirect to error page when OAuth error is received', async () => {
@@ -283,7 +280,6 @@ describe('MCP Routes', () => {
it('should handle OAuth callback successfully', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@@ -299,7 +295,6 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -337,24 +332,11 @@ describe('MCP Routes', () => {
'test-auth-code',
mockFlowManager,
);
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
expect.objectContaining({
userId: 'test-user-id',
serverName: 'test-server',
tokens: mockTokens,
clientInfo: mockFlowState.clientInfo,
metadata: mockFlowState.metadata,
}),
);
const storeInvocation = MCPTokenStorage.storeTokens.mock.invocationCallOrder[0];
const connectInvocation = mockMcpManager.getUserConnection.mock.invocationCallOrder[0];
expect(storeInvocation).toBeLessThan(connectInvocation);
expect(mockFlowManager.completeFlow).toHaveBeenCalledWith(
'tool-flow-123',
'mcp_oauth',
mockTokens,
);
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should redirect to error page when callback processing fails', async () => {
@@ -372,7 +354,6 @@ describe('MCP Routes', () => {
it('should handle system-level OAuth completion', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@@ -388,7 +369,6 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -399,13 +379,11 @@ describe('MCP Routes', () => {
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should handle reconnection failure after OAuth', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
@@ -421,7 +399,6 @@ describe('MCP Routes', () => {
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
@@ -441,46 +418,6 @@ describe('MCP Routes', () => {
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
});
it('should redirect to error page if token storage fails', async () => {
const mockFlowManager = {
completeFlow: jest.fn().mockResolvedValue(),
deleteFlow: jest.fn().mockResolvedValue(true),
};
const mockFlowState = {
serverName: 'test-server',
userId: 'test-user-id',
metadata: { toolFlowId: 'tool-flow-123' },
clientInfo: {},
codeVerifier: 'test-verifier',
};
const mockTokens = {
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
};
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
getLogStores.mockReturnValue({});
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
const mockMcpManager = {
getUserConnection: jest.fn(),
};
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
code: 'test-auth-code',
state: 'test-flow-id',
});
expect(response.status).toBe(302);
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
});
});
@@ -841,10 +778,10 @@ describe('MCP Routes', () => {
require('~/cache').getLogStores.mockReturnValue({});
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
updateMCPServerTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
@@ -899,10 +836,10 @@ describe('MCP Routes', () => {
]);
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
getCachedTools.mockResolvedValue({});
setCachedTools.mockResolvedValue();
updateMCPServerTools.mockResolvedValue();
updateMCPUserTools.mockResolvedValue();
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
success: true,
@@ -1206,11 +1143,7 @@ describe('MCP Routes', () => {
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
};
const { MCPOAuthHandler } = require('@librechat/api');
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
id: 'test-flow-id',
userId: 'test-user-id',
@@ -1222,8 +1155,6 @@ describe('MCP Routes', () => {
clientInfo: {},
codeVerifier: 'test-verifier',
});
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
const mockFlowManager = {
completeFlow: jest.fn(),
@@ -1248,11 +1179,6 @@ describe('MCP Routes', () => {
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
const { getCachedTools } = require('~/server/services/Config');
getCachedTools.mockResolvedValue(null);
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
const mockTokens = {
access_token: 'edge-access-token',
refresh_token: 'edge-refresh-token',
};
const mockFlowManager = {
getFlowState: jest.fn().mockResolvedValue({
@@ -1265,15 +1191,6 @@ describe('MCP Routes', () => {
completeFlow: jest.fn(),
};
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
MCPOAuthHandler.getFlowState.mockResolvedValue({
serverName: 'test-server',
userId: 'test-user-id',
metadata: { serverUrl: 'https://example.com', oauth: {} },
clientInfo: {},
codeVerifier: 'test-verifier',
});
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
MCPTokenStorage.storeTokens.mockResolvedValue();
const mockMcpManager = {
getUserConnection: jest.fn().mockResolvedValue({

View File

@@ -1,28 +1,19 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { getMCPTools } = require('~/server/controllers/mcp');
const { requireJwtAuth } = require('~/server/middleware');
const { findPluginAuthsByKeys } = require('~/models');
const { getLogStores } = require('~/cache');
const router = Router();
/**
* Get all MCP tools available to the user
* Returns only MCP tools, completely decoupled from regular LibreChat tools
*/
router.get('/tools', requireJwtAuth, async (req, res) => {
return getMCPTools(req, res);
});
/**
* Initiate OAuth flow
* This endpoint is called when the user clicks the auth link in the UI
@@ -130,41 +121,6 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
/** Persist tokens immediately so reconnection uses fresh credentials */
if (flowState?.userId && tokens) {
try {
await MCPTokenStorage.storeTokens({
userId: flowState.userId,
serverName,
tokens,
createToken,
updateToken,
findToken,
clientInfo: flowState.clientInfo,
metadata: flowState.metadata,
});
logger.debug('[MCP OAuth] Stored OAuth tokens prior to reconnection', {
serverName,
userId: flowState.userId,
});
} catch (error) {
logger.error('[MCP OAuth] Failed to store OAuth tokens after callback', error);
throw error;
}
/**
* Clear any cached `mcp_get_tokens` flow result so subsequent lookups
* re-fetch the freshly stored credentials instead of returning stale nulls.
*/
if (typeof flowManager?.deleteFlow === 'function') {
try {
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
} catch (error) {
logger.warn('[MCP OAuth] Failed to clear cached token flow state', error);
}
}
}
try {
const mcpManager = getMCPManager(flowState.userId);
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
@@ -188,12 +144,9 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
);
// clear any reconnection attempts
const oauthReconnectionManager = getOAuthReconnectionManager();
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
const tools = await userConnection.fetchTools();
await updateMCPServerTools({
await updateMCPUserTools({
userId: flowState.userId,
serverName,
tools,
});
@@ -335,9 +288,9 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const userId = req.user?.id;
const user = req.user;
if (!userId) {
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
@@ -351,7 +304,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
});
}
await mcpManager.disconnectUserConnection(userId, serverName);
await mcpManager.disconnectUserConnection(user.id, serverName);
logger.info(
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
);
@@ -360,14 +313,14 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
let userMCPAuthMap;
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
userMCPAuthMap = await getUserMCPAuthMap({
userId,
userId: user.id,
servers: [serverName],
findPluginAuthsByKeys,
});
}
const result = await reinitMCPServer({
userId,
req,
serverName,
userMCPAuthMap,
});

View File

@@ -11,6 +11,7 @@ const {
} = require('~/models');
const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { tokenValues, getValueKey, defaultRate } = require('~/models/tx');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const { getConvosQueried } = require('~/models/Conversation');
const { countTokens } = require('~/server/utils');
@@ -160,6 +161,41 @@ router.post('/artifact/:messageId', async (req, res) => {
}
});
/**
* POST /costs
* Get cost information for models in modelHistory array
*/
router.post('/costs', async (req, res) => {
try {
const { modelHistory } = req.body;
if (!Array.isArray(modelHistory)) {
return res.status(400).json({ error: 'modelHistory must be an array' });
}
const modelCostTable = {};
modelHistory.forEach((modelEntry) => {
if (modelEntry && typeof modelEntry === 'object' && modelEntry.model && modelEntry.endpoint) {
const { model, endpoint } = modelEntry;
const valueKey = getValueKey(model, endpoint);
const pricing = tokenValues[valueKey];
modelCostTable[model] = {
prompt: pricing?.prompt ?? defaultRate,
completion: pricing?.completion ?? defaultRate,
};
}
});
res.status(200).json({ modelCostTable });
} catch (error) {
logger.error('Error fetching model costs:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
router.get('/:conversationId', validateMessageReq, async (req, res) => {
try {

View File

@@ -357,18 +357,23 @@ const resetPassword = async (userId, token, password) => {
/**
* Set Auth Tokens
*
* @param {String | ObjectId} userId
* @param {ServerResponse} res
* @param {ISession | null} [session=null]
* @param {Object} res
* @param {String} sessionId
* @returns
*/
const setAuthTokens = async (userId, res, _session = null) => {
const setAuthTokens = async (userId, res, sessionId = null) => {
try {
let session = _session;
const user = await getUserById(userId);
const token = await generateToken(user);
let session;
let refreshToken;
let refreshTokenExpires;
if (session && session._id && session.expiration != null) {
if (sessionId) {
session = await findSession({ sessionId: sessionId }, { lean: false });
refreshTokenExpires = session.expiration.getTime();
refreshToken = await generateRefreshToken(session);
} else {
@@ -378,9 +383,6 @@ const setAuthTokens = async (userId, res, _session = null) => {
refreshTokenExpires = session.expiration.getTime();
}
const user = await getUserById(userId);
const token = await generateToken(user);
res.cookie('refreshToken', refreshToken, {
expires: new Date(refreshTokenExpires),
httpOnly: true,

View File

@@ -36,7 +36,7 @@ async function getAppConfig(options = {}) {
}
if (baseConfig.availableTools) {
await setCachedTools(baseConfig.availableTools);
await setCachedTools(baseConfig.availableTools, { isGlobal: true });
}
await cache.set(BASE_CONFIG_KEY, baseConfig);

View File

@@ -3,32 +3,89 @@ const getLogStores = require('~/cache/getLogStores');
/**
* Cache key generators for different tool access patterns
* These will support future permission-based caching
*/
const ToolCacheKeys = {
/** Global tools available to all users */
GLOBAL: 'tools:global',
/** MCP tools cached by server name */
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
/** Tools available to a specific user */
USER: (userId) => `tools:user:${userId}`,
/** Tools available to a specific role */
ROLE: (roleId) => `tools:role:${roleId}`,
/** Tools available to a specific group */
GROUP: (groupId) => `tools:group:${groupId}`,
/** Combined effective tools for a user (computed from all sources) */
EFFECTIVE: (userId) => `tools:effective:${userId}`,
};
/**
* Retrieves available tools from cache
* @function getCachedTools
* @param {Object} options - Options for retrieving tools
* @param {string} [options.serverName] - MCP server name to get cached tools for
* @param {string} [options.userId] - User ID for user-specific tools
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
*/
async function getCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName } = options;
const { userId, roleIds = [], groupIds = [], includeGlobal = true } = options;
// Return MCP server-specific tools if requested
if (serverName) {
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
// For now, return global tools (current behavior)
// This will be expanded to merge tools from different sources
if (!userId && includeGlobal) {
return await cache.get(ToolCacheKeys.GLOBAL);
}
// Default to global tools
return await cache.get(ToolCacheKeys.GLOBAL);
// Future implementation will merge tools from multiple sources
// based on user permissions, roles, and groups
if (userId) {
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
if (effectiveTools) {
return effectiveTools;
}
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
const toolSources = [];
if (includeGlobal) {
const globalTools = await cache.get(ToolCacheKeys.GLOBAL);
if (globalTools) {
toolSources.push(globalTools);
}
}
// User-specific tools
const userTools = await cache.get(ToolCacheKeys.USER(userId));
if (userTools) {
toolSources.push(userTools);
}
// Role-based tools
for (const roleId of roleIds) {
const roleTools = await cache.get(ToolCacheKeys.ROLE(roleId));
if (roleTools) {
toolSources.push(roleTools);
}
}
// Group-based tools
for (const groupId of groupIds) {
const groupTools = await cache.get(ToolCacheKeys.GROUP(groupId));
if (groupTools) {
toolSources.push(groupTools);
}
}
// Merge all tool sources (for now, simple merge - future will handle conflicts)
if (toolSources.length > 0) {
return mergeToolSources(toolSources);
}
}
return null;
}
/**
@@ -36,34 +93,49 @@ async function getCachedTools(options = {}) {
* @function setCachedTools
* @param {Object} tools - The tools object to cache
* @param {Object} options - Options for caching tools
* @param {string} [options.serverName] - MCP server name for server-specific tools
* @param {string} [options.userId] - User ID for user-specific tools
* @param {string} [options.roleId] - Role ID for role-based tools
* @param {string} [options.groupId] - Group ID for group-based tools
* @param {boolean} [options.isGlobal=false] - Whether these are global tools
* @param {number} [options.ttl] - Time to live in milliseconds
* @returns {Promise<boolean>} Whether the operation was successful
*/
async function setCachedTools(tools, options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, ttl } = options;
const { userId, roleId, groupId, isGlobal = false, ttl } = options;
// Cache by MCP server if specified
if (serverName) {
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
let cacheKey;
if (isGlobal || (!userId && !roleId && !groupId)) {
cacheKey = ToolCacheKeys.GLOBAL;
} else if (userId) {
cacheKey = ToolCacheKeys.USER(userId);
} else if (roleId) {
cacheKey = ToolCacheKeys.ROLE(roleId);
} else if (groupId) {
cacheKey = ToolCacheKeys.GROUP(groupId);
}
// Default to global cache
return await cache.set(ToolCacheKeys.GLOBAL, tools, ttl);
if (!cacheKey) {
throw new Error('Invalid cache key options provided');
}
return await cache.set(cacheKey, tools, ttl);
}
/**
* Invalidates cached tools
* @function invalidateCachedTools
* @param {Object} options - Options for invalidating tools
* @param {string} [options.serverName] - MCP server name to invalidate
* @param {string} [options.userId] - User ID to invalidate
* @param {string} [options.roleId] - Role ID to invalidate
* @param {string} [options.groupId] - Group ID to invalidate
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
* @param {boolean} [options.invalidateEffective=true] - Whether to invalidate effective tools
* @returns {Promise<void>}
*/
async function invalidateCachedTools(options = {}) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const { serverName, invalidateGlobal = false } = options;
const { userId, roleId, groupId, invalidateGlobal = false, invalidateEffective = true } = options;
const keysToDelete = [];
@@ -71,34 +143,116 @@ async function invalidateCachedTools(options = {}) {
keysToDelete.push(ToolCacheKeys.GLOBAL);
}
if (serverName) {
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
if (userId) {
keysToDelete.push(ToolCacheKeys.USER(userId));
if (invalidateEffective) {
keysToDelete.push(ToolCacheKeys.EFFECTIVE(userId));
}
}
if (roleId) {
keysToDelete.push(ToolCacheKeys.ROLE(roleId));
// TODO: In future, invalidate all users with this role
}
if (groupId) {
keysToDelete.push(ToolCacheKeys.GROUP(groupId));
// TODO: In future, invalidate all users in this group
}
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
}
/**
* Gets MCP tools for a specific server from cache or merges with global tools
* @function getMCPServerTools
* @param {string} serverName - The MCP server name
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
* Computes and caches effective tools for a user
* @function computeEffectiveTools
* @param {string} userId - The user ID
* @param {Object} context - Context containing user's roles and groups
* @param {string[]} [context.roleIds=[]] - User's role IDs
* @param {string[]} [context.groupIds=[]] - User's group IDs
* @param {number} [ttl] - Time to live for the computed result
* @returns {Promise<Object>} The computed effective tools
*/
async function getMCPServerTools(serverName) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
async function computeEffectiveTools(userId, context = {}, ttl) {
const { roleIds = [], groupIds = [] } = context;
if (serverTools) {
return serverTools;
// Get all tool sources
const tools = await getCachedTools({
userId,
roleIds,
groupIds,
includeGlobal: true,
});
if (tools) {
// Cache the computed result
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.set(ToolCacheKeys.EFFECTIVE(userId), tools, ttl);
}
return null;
return tools;
}
/**
* Merges multiple tool sources into a single tools object
* @function mergeToolSources
* @param {Object[]} sources - Array of tool objects to merge
* @returns {Object} Merged tools object
*/
function mergeToolSources(sources) {
// For now, simple merge that combines all tools
// Future implementation will handle:
// - Permission precedence (deny > allow)
// - Tool property conflicts
// - Metadata merging
const merged = {};
for (const source of sources) {
if (!source || typeof source !== 'object') {
continue;
}
for (const [toolId, toolConfig] of Object.entries(source)) {
// Simple last-write-wins for now
// Future: merge based on permission levels
merged[toolId] = toolConfig;
}
}
return merged;
}
/**
* Middleware-friendly function to get tools for a request
* @function getToolsForRequest
* @param {Object} req - Express request object
* @returns {Promise<Object|null>} Available tools for the request
*/
async function getToolsForRequest(req) {
const userId = req.user?.id;
// For now, return global tools if no user
if (!userId) {
return getCachedTools({ includeGlobal: true });
}
// Future: Extract roles and groups from req.user
const roleIds = req.user?.roles || [];
const groupIds = req.user?.groups || [];
return getCachedTools({
userId,
roleIds,
groupIds,
includeGlobal: true,
});
}
module.exports = {
ToolCacheKeys,
getCachedTools,
setCachedTools,
getMCPServerTools,
getToolsForRequest,
invalidateCachedTools,
computeEffectiveTools,
};

View File

@@ -1,7 +1,7 @@
const appConfig = require('./app');
const mcpToolsCache = require('./mcp');
const { config } = require('./EndpointService');
const getCachedTools = require('./getCachedTools');
const mcpToolsCache = require('./mcpToolsCache');
const loadCustomConfig = require('./loadCustomConfig');
const loadConfigModels = require('./loadConfigModels');
const loadDefaultModels = require('./loadDefaultModels');

View File

@@ -1,91 +0,0 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getCachedTools, setCachedTools } = require('./getCachedTools');
const { getLogStores } = require('~/cache');
/**
* Updates MCP tools in the cache for a specific server
* @param {Object} params - Parameters for updating MCP tools
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPServerTools({ serverName, tools }) {
try {
const serverTools = {};
const mcpDelimiter = Constants.mcp_delimiter;
for (const tool of tools) {
const name = `${tool.name}${mcpDelimiter}${serverName}`;
serverTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema,
},
};
}
await setCachedTools(serverTools, { serverName });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
return serverTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error;
}
}
/**
* Merges app-level tools with global tools
* @param {import('@librechat/api').LCAvailableTools} appTools
* @returns {Promise<void>}
*/
async function mergeAppTools(appTools) {
try {
const count = Object.keys(appTools).length;
if (!count) {
return;
}
const cachedTools = await getCachedTools();
const mergedTools = { ...cachedTools, ...appTools };
await setCachedTools(mergedTools);
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} app-level tools`);
} catch (error) {
logger.error('Failed to merge app-level tools:', error);
throw error;
}
}
/**
* Caches MCP server tools (no longer merges with global)
* @param {object} params
* @param {string} params.serverName
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
* @returns {Promise<void>}
*/
async function cacheMCPServerTools({ serverName, serverTools }) {
try {
const count = Object.keys(serverTools).length;
if (!count) {
return;
}
// Only cache server-specific tools, no merging with global
await setCachedTools(serverTools, { serverName });
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
} catch (error) {
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
throw error;
}
}
module.exports = {
mergeAppTools,
cacheMCPServerTools,
updateMCPServerTools,
};

View File

@@ -0,0 +1,143 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getCachedTools, setCachedTools } = require('./getCachedTools');
const { getLogStores } = require('~/cache');
/**
* Updates MCP tools in the cache for a specific server and user
* @param {Object} params - Parameters for updating MCP tools
* @param {string} params.userId - User ID
* @param {string} params.serverName - MCP server name
* @param {Array} params.tools - Array of tool objects from MCP server
* @returns {Promise<LCAvailableTools>}
*/
async function updateMCPUserTools({ userId, serverName, tools }) {
try {
const userTools = await getCachedTools({ userId });
const mcpDelimiter = Constants.mcp_delimiter;
for (const key of Object.keys(userTools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete userTools[key];
}
}
for (const tool of tools) {
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
userTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema,
},
};
}
await setCachedTools(userTools, { userId });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
return userTools;
} catch (error) {
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
throw error;
}
}
/**
* Merges app-level tools with global tools
* @param {import('@librechat/api').LCAvailableTools} appTools
* @returns {Promise<void>}
*/
async function mergeAppTools(appTools) {
try {
const count = Object.keys(appTools).length;
if (!count) {
return;
}
const cachedTools = await getCachedTools({ includeGlobal: true });
const mergedTools = { ...cachedTools, ...appTools };
await setCachedTools(mergedTools, { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} app-level tools`);
} catch (error) {
logger.error('Failed to merge app-level tools:', error);
throw error;
}
}
/**
* Merges user-level tools with global tools
* @param {object} params
* @param {string} params.userId
* @param {Record<string, FunctionTool>} params.cachedUserTools
* @param {import('@librechat/api').LCAvailableTools} params.userTools
* @returns {Promise<void>}
*/
async function mergeUserTools({ userId, cachedUserTools, userTools }) {
try {
if (!userId) {
return;
}
const count = Object.keys(userTools).length;
if (!count) {
return;
}
const cachedTools = cachedUserTools ?? (await getCachedTools({ userId }));
const mergedTools = { ...cachedTools, ...userTools };
await setCachedTools(mergedTools, { userId });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(`Merged ${count} user-level tools`);
} catch (error) {
logger.error('Failed to merge user-level tools:', error);
throw error;
}
}
/**
* Clears all MCP tools for a specific server
* @param {Object} params - Parameters for clearing MCP tools
* @param {string} [params.userId] - User ID (if clearing user-specific tools)
* @param {string} params.serverName - MCP server name
* @returns {Promise<void>}
*/
async function clearMCPServerTools({ userId, serverName }) {
try {
const tools = await getCachedTools({ userId, includeGlobal: !userId });
// Remove all tools for this server
const mcpDelimiter = Constants.mcp_delimiter;
let removedCount = 0;
for (const key of Object.keys(tools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete tools[key];
removedCount++;
}
}
if (removedCount > 0) {
await setCachedTools(tools, userId ? { userId } : { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug(
`[MCP Cache] Removed ${removedCount} tools for ${serverName}${userId ? ` user ${userId}` : ' (global)'}`,
);
}
} catch (error) {
logger.error(`[MCP Cache] Failed to clear tools for ${serverName}:`, error);
throw error;
}
}
module.exports = {
mergeAppTools,
mergeUserTools,
updateMCPUserTools,
clearMCPServerTools,
};

View File

@@ -552,7 +552,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
throw new Error('File search is not enabled for Agents');
}
// Note: File search processing continues to dual storage logic below
} else if (tool_resource === EToolResources.context) {
} else if (tool_resource === EToolResources.ocr) {
const { file_id, temp_file_id = null } = metadata;
/**

View File

@@ -20,10 +20,10 @@ const {
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, getAppConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getAppConfig } = require('./Config');
const { getLogStores } = require('~/cache');
/**
@@ -152,8 +152,8 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
/**
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.userId - The user ID from the request object.
* @param {string} params.serverName
* @param {AbortSignal} params.signal
* @param {string} params.model
@@ -161,9 +161,9 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap }) {
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
const flowId = `${userId}:${serverName}:${Date.now()}`;
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const stepId = 'step_oauth_login_' + serverName;
const toolCall = {
@@ -192,7 +192,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
flowManager,
});
return await reinitMCPServer({
userId,
req,
signal,
serverName,
oauthStart,
@@ -211,8 +211,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.userId - The user ID from the request object.
* @param {string} params.serverName
* @param {string} params.model
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
@@ -221,16 +221,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
*/
async function createMCPTools({
res,
userId,
index,
signal,
serverName,
provider,
userMCPAuthMap,
}) {
const result = await reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap });
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
if (!result || !result.tools) {
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
return;
@@ -239,8 +231,8 @@ async function createMCPTools({
const serverTools = [];
for (const tool of result.tools) {
const toolInstance = await createMCPTool({
req,
res,
userId,
provider,
userMCPAuthMap,
availableTools: result.availableTools,
@@ -257,8 +249,8 @@ async function createMCPTools({
/**
* Creates a single tool from the specified MCP Server via `toolKey`.
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
* @param {ServerResponse} params.res - The Express response object for sending events.
* @param {string} params.userId - The user ID from the request object.
* @param {string} params.toolKey - The toolKey for the tool.
* @param {string} params.model - The model for the tool.
* @param {number} [params.index]
@@ -269,31 +261,26 @@ async function createMCPTools({
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({
req,
res,
userId,
index,
signal,
toolKey,
provider,
userMCPAuthMap,
availableTools,
availableTools: tools,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const availableTools =
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
/** @type {LCTool | undefined} */
let toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
logger.warn(
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
);
const result = await reconnectServer({
res,
userId,
index,
signal,
serverName,
userMCPAuthMap,
});
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
toolDefinition = result?.availableTools?.[toolKey]?.function;
}
@@ -551,20 +538,13 @@ async function getServerConnectionStatus(
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
// connection state overrides specific to OAuth servers
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
// check if server is actively being reconnected
const oauthReconnectionManager = getOAuthReconnectionManager();
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
finalConnectionState = 'connecting';
} else {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}

View File

@@ -31,7 +31,6 @@ jest.mock('./Config', () => ({
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getOAuthReconnectionManager: jest.fn(),
}));
jest.mock('~/cache', () => ({
@@ -49,7 +48,6 @@ describe('tests for the new helper functions used by the MCP connection status e
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
let mockGetOAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
@@ -58,7 +56,6 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
});
describe('getMCPSetupData', () => {
@@ -357,12 +354,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
@@ -379,12 +370,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return failed flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -416,12 +401,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return active flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -453,12 +432,6 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return no flow
const mockFlowManager = {
getFlowState: jest.fn(() => null),
@@ -481,35 +454,6 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
it('should return connecting state when OAuth server is reconnecting', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager to return true for isReconnecting
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => true),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connecting',
});
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
mockUserId,
mockServerName,
);
});
it('should not check OAuth flow status when server is connected', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),

View File

@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
idOnTheSource: principal.idOnTheSource,
};
const userId = await createUser(userData, true, true);
const userId = await createUser(userData, true, false);
return userId.toString();
}

View File

@@ -88,6 +88,7 @@ async function saveUserMessage(req, params) {
parentMessageId: params.parentMessageId ?? Constants.NO_PARENT,
/* For messages, use the assistant_id instead of model */
model: params.assistant_id,
targetModel: params.model,
thread_id: params.thread_id,
sender: 'User',
text: params.text,

View File

@@ -74,7 +74,7 @@ async function processRequiredActions(client, requiredActions) {
requiredActions,
);
const appConfig = client.req.config;
const toolDefinitions = await getCachedTools();
const toolDefinitions = await getCachedTools({ userId: client.req.user.id, includeGlobal: true });
const seenToolkits = new Set();
const tools = requiredActions
.map((action) => {
@@ -353,12 +353,7 @@ async function processRequiredActions(client, requiredActions) {
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) {
return {};
} else if (
agent.tools &&
agent.tools.length === 1 &&
/** Legacy handling for `ocr` as may still exist in existing Agents */
(agent.tools[0] === AgentCapabilities.context || agent.tools[0] === AgentCapabilities.ocr)
) {
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
return {};
}

View File

@@ -2,12 +2,12 @@ const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { updateMCPServerTools } = require('~/server/services/Config');
const { updateMCPUserTools } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
/**
* @param {Object} params
* @param {string} params.userId
* @param {ServerRequest} params.req
* @param {string} params.serverName - The name of the MCP server
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
@@ -18,7 +18,7 @@ const { getLogStores } = require('~/cache');
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
*/
async function reinitMCPServer({
userId,
req,
signal,
forceNew,
serverName,
@@ -44,14 +44,14 @@ async function reinitMCPServer({
const oauthStart =
_oauthStart ??
(async (authURL) => {
logger.info(`[MCP Reinitialize] OAuth URL received for ${serverName}`);
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
oauthUrl = authURL;
oauthRequired = true;
});
try {
userConnection = await mcpManager.getUserConnection({
user: { id: userId },
user: req.user,
signal,
forceNew,
oauthStart,
@@ -97,7 +97,8 @@ async function reinitMCPServer({
if (userConnection && !oauthRequired) {
tools = await userConnection.fetchTools();
availableTools = await updateMCPServerTools({
availableTools = await updateMCPUserTools({
userId: req.user.id,
serverName,
tools,
});

View File

@@ -1,26 +0,0 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getLogStores } = require('~/cache');
/**
* Initialize OAuth reconnect manager
*/
async function initializeOAuthReconnectManager() {
try {
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const tokenMethods = {
findToken,
updateToken,
createToken,
deleteTokens,
};
await createOAuthReconnectionManager(flowManager, tokenMethods);
logger.info(`OAuth reconnect manager initialized successfully.`);
} catch (error) {
logger.error('Failed to initialize OAuth reconnect manager:', error);
}
}
module.exports = initializeOAuthReconnectManager;

View File

@@ -10,10 +10,6 @@ jest.mock('~/models/Message', () => ({
bulkSaveMessages: jest.fn(),
}));
jest.mock('~/models/ConversationTag', () => ({
bulkIncrementTagCounts: jest.fn(),
}));
let mockIdCounter = 0;
jest.mock('uuid', () => {
return {
@@ -26,13 +22,11 @@ jest.mock('uuid', () => {
const {
forkConversation,
duplicateConversation,
splitAtTargetLevel,
getAllMessagesUpToParent,
getMessagesUpToTargetLevel,
cloneMessagesWithTimestamps,
} = require('./fork');
const { bulkIncrementTagCounts } = require('~/models/ConversationTag');
const { getConvo, bulkSaveConvos } = require('~/models/Conversation');
const { getMessages, bulkSaveMessages } = require('~/models/Message');
const { createImportBatchBuilder } = require('./importBatchBuilder');
@@ -187,120 +181,6 @@ describe('forkConversation', () => {
}),
).rejects.toThrow('Failed to fetch messages');
});
test('should increment tag counts when forking conversation with tags', async () => {
const mockConvoWithTags = {
...mockConversation,
tags: ['bookmark1', 'bookmark2'],
};
getConvo.mockResolvedValue(mockConvoWithTags);
await forkConversation({
originalConvoId: 'abc123',
targetMessageId: '3',
requestUserId: 'user1',
option: ForkOptions.DIRECT_PATH,
});
// Verify that bulkIncrementTagCounts was called with correct tags
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['bookmark1', 'bookmark2']);
});
test('should handle conversation without tags when forking', async () => {
const mockConvoWithoutTags = {
...mockConversation,
// No tags field
};
getConvo.mockResolvedValue(mockConvoWithoutTags);
await forkConversation({
originalConvoId: 'abc123',
targetMessageId: '3',
requestUserId: 'user1',
option: ForkOptions.DIRECT_PATH,
});
// bulkIncrementTagCounts will be called with array containing undefined
expect(bulkIncrementTagCounts).toHaveBeenCalled();
});
test('should handle empty tags array when forking', async () => {
const mockConvoWithEmptyTags = {
...mockConversation,
tags: [],
};
getConvo.mockResolvedValue(mockConvoWithEmptyTags);
await forkConversation({
originalConvoId: 'abc123',
targetMessageId: '3',
requestUserId: 'user1',
option: ForkOptions.DIRECT_PATH,
});
// bulkIncrementTagCounts will be called with empty array
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []);
});
});
describe('duplicateConversation', () => {
beforeEach(() => {
jest.clearAllMocks();
mockIdCounter = 0;
getConvo.mockResolvedValue(mockConversation);
getMessages.mockResolvedValue(mockMessages);
bulkSaveConvos.mockResolvedValue(null);
bulkSaveMessages.mockResolvedValue(null);
bulkIncrementTagCounts.mockResolvedValue(null);
});
test('should duplicate conversation and increment tag counts', async () => {
const mockConvoWithTags = {
...mockConversation,
tags: ['important', 'work', 'project'],
};
getConvo.mockResolvedValue(mockConvoWithTags);
await duplicateConversation({
userId: 'user1',
conversationId: 'abc123',
});
// Verify that bulkIncrementTagCounts was called with correct tags
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['important', 'work', 'project']);
});
test('should duplicate conversation without tags', async () => {
const mockConvoWithoutTags = {
...mockConversation,
// No tags field
};
getConvo.mockResolvedValue(mockConvoWithoutTags);
await duplicateConversation({
userId: 'user1',
conversationId: 'abc123',
});
// bulkIncrementTagCounts will be called with array containing undefined
expect(bulkIncrementTagCounts).toHaveBeenCalled();
});
test('should handle empty tags array when duplicating', async () => {
const mockConvoWithEmptyTags = {
...mockConversation,
tags: [],
};
getConvo.mockResolvedValue(mockConvoWithEmptyTags);
await duplicateConversation({
userId: 'user1',
conversationId: 'abc123',
});
// bulkIncrementTagCounts will be called with empty array
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []);
});
});
const mockMessagesComplex = [

View File

@@ -1,6 +1,5 @@
const { v4: uuidv4 } = require('uuid');
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
const { bulkIncrementTagCounts } = require('~/models/ConversationTag');
const { bulkSaveConvos } = require('~/models/Conversation');
const { bulkSaveMessages } = require('~/models/Message');
const { logger } = require('~/config');
@@ -94,22 +93,13 @@ class ImportBatchBuilder {
/**
* Saves the batch of conversations and messages to the DB.
* Also increments tag counts for any existing tags.
* @returns {Promise<void>} A promise that resolves when the batch is saved.
* @throws {Error} If there is an error saving the batch.
*/
async saveBatch() {
try {
const promises = [];
promises.push(bulkSaveConvos(this.conversations));
promises.push(bulkSaveMessages(this.messages, true));
promises.push(
bulkIncrementTagCounts(
this.requestUserId,
this.conversations.flatMap((convo) => convo.tags),
),
);
await Promise.all(promises);
await bulkSaveConvos(this.conversations);
await bulkSaveMessages(this.messages, true);
logger.debug(
`user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`,
);

View File

@@ -1,10 +1,11 @@
const fs = require('fs');
const path = require('path');
const axios = require('axios');
const FormData = require('form-data');
const nodemailer = require('nodemailer');
const handlebars = require('handlebars');
const { logger } = require('@librechat/data-schemas');
const { logAxiosError, isEnabled, readFileAsString } = require('@librechat/api');
const { logAxiosError, isEnabled } = require('@librechat/api');
/**
* Sends an email using Mailgun API.
@@ -92,7 +93,8 @@ const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => {
*/
const sendEmail = async ({ email, subject, payload, template, throwError = true }) => {
try {
const { content: source } = await readFileAsString(path.join(__dirname, 'emails', template));
// Read and compile the email template
const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8');
const compiledTemplate = handlebars.compile(source);
const html = compiledTemplate(payload);

View File

@@ -109,8 +109,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
const username =
(LDAP_USERNAME && userinfo[LDAP_USERNAME]) || userinfo.givenName || userinfo.mail;
let mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local';
mail = Array.isArray(mail) ? mail[0] : mail;
const mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local';
if (!userinfo.mail && !(LDAP_EMAIL && userinfo[LDAP_EMAIL])) {
logger.warn(

View File

@@ -1,186 +0,0 @@
// --- Mocks ---
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
}));
jest.mock('@librechat/api', () => ({
// isEnabled used for TLS flags
isEnabled: jest.fn(() => false),
getBalanceConfig: jest.fn(() => ({ enabled: false })),
}));
jest.mock('~/models', () => ({
findUser: jest.fn(),
createUser: jest.fn(),
updateUser: jest.fn(),
countUsers: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn().mockResolvedValue({}),
}));
jest.mock('~/server/services/domains', () => ({
isEmailDomainAllowed: jest.fn(() => true),
}));
// Mock passport-ldapauth to capture verify callback
let verifyCallback;
jest.mock('passport-ldapauth', () => {
return jest.fn().mockImplementation((options, verify) => {
verifyCallback = verify; // capture the strategy verify function
return { name: 'ldap', options, verify };
});
});
const { ErrorTypes } = require('librechat-data-provider');
const { findUser, createUser, updateUser, countUsers } = require('~/models');
const { isEmailDomainAllowed } = require('~/server/services/domains');
// Helper to call the verify callback and wrap in a Promise for convenience
const callVerify = (userinfo) =>
new Promise((resolve, reject) => {
verifyCallback(userinfo, (err, user, info) => {
if (err) return reject(err);
resolve({ user, info });
});
});
describe('ldapStrategy', () => {
beforeEach(() => {
jest.clearAllMocks();
// minimal required env for ldapStrategy module to export
process.env.LDAP_URL = 'ldap://example.com';
process.env.LDAP_USER_SEARCH_BASE = 'ou=users,dc=example,dc=com';
// Unset optional envs to exercise defaults
delete process.env.LDAP_CA_CERT_PATH;
delete process.env.LDAP_FULL_NAME;
delete process.env.LDAP_ID;
delete process.env.LDAP_USERNAME;
delete process.env.LDAP_EMAIL;
delete process.env.LDAP_TLS_REJECT_UNAUTHORIZED;
delete process.env.LDAP_STARTTLS;
// Default model/domain mocks
findUser.mockReset().mockResolvedValue(null);
createUser.mockReset().mockResolvedValue('newUserId');
updateUser.mockReset().mockImplementation(async (id, user) => ({ _id: id, ...user }));
countUsers.mockReset().mockResolvedValue(0);
isEmailDomainAllowed.mockReset().mockReturnValue(true);
// Ensure requiring the strategy sets up the verify callback
jest.isolateModules(() => {
require('./ldapStrategy');
});
});
it('uses the first email when LDAP returns multiple emails (array)', async () => {
const userinfo = {
uid: 'uid123',
givenName: 'Alice',
cn: 'Alice Doe',
mail: ['first@example.com', 'second@example.com'],
};
const { user } = await callVerify(userinfo);
expect(user.email).toBe('first@example.com');
expect(createUser).toHaveBeenCalledWith(
expect.objectContaining({
provider: 'ldap',
ldapId: 'uid123',
username: 'Alice',
email: 'first@example.com',
emailVerified: true,
name: 'Alice Doe',
}),
expect.any(Object),
);
});
it('blocks login if an existing user has a different provider', async () => {
findUser.mockResolvedValue({ _id: 'u1', email: 'first@example.com', provider: 'google' });
const userinfo = {
uid: 'uid123',
mail: 'first@example.com',
givenName: 'Alice',
cn: 'Alice Doe',
};
const { user, info } = await callVerify(userinfo);
expect(user).toBe(false);
expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED });
expect(createUser).not.toHaveBeenCalled();
});
it('updates an existing ldap user with current LDAP info', async () => {
const existing = {
_id: 'u2',
provider: 'ldap',
email: 'old@example.com',
ldapId: 'uid123',
username: 'olduser',
name: 'Old Name',
};
findUser.mockResolvedValue(existing);
const userinfo = {
uid: 'uid123',
mail: 'new@example.com',
givenName: 'NewFirst',
cn: 'NewFirst NewLast',
};
const { user } = await callVerify(userinfo);
expect(createUser).not.toHaveBeenCalled();
expect(updateUser).toHaveBeenCalledWith(
'u2',
expect.objectContaining({
provider: 'ldap',
ldapId: 'uid123',
email: 'new@example.com',
username: 'NewFirst',
name: 'NewFirst NewLast',
}),
);
expect(user.email).toBe('new@example.com');
});
it('falls back to username@ldap.local when no email attributes are present', async () => {
const userinfo = {
uid: 'uid999',
givenName: 'John',
cn: 'John Doe',
// no mail and no custom LDAP_EMAIL
};
const { user } = await callVerify(userinfo);
expect(user.email).toBe('John@ldap.local');
});
it('denies login if email domain is not allowed', async () => {
isEmailDomainAllowed.mockReturnValue(false);
const userinfo = {
uid: 'uid123',
mail: 'notallowed@blocked.com',
givenName: 'Alice',
cn: 'Alice Doe',
};
const { user, info } = await callVerify(userinfo);
expect(user).toBe(false);
expect(info).toEqual({ message: 'Email domain not allowed' });
});
});

View File

@@ -41,18 +41,13 @@ const openIdJwtLogin = (openIdConfig) => {
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
secretOrKeyProvider: jwksRsa.passportJwtSecret(jwksRsaOptions),
},
/**
* @param {import('openid-client').IDToken} payload
* @param {import('passport-jwt').VerifyCallback} done
*/
async (payload, done) => {
try {
const { user, error, migration } = await findOpenIDUser({
findUser,
email: payload?.email,
openidId: payload?.sub,
idOnTheSource: payload?.oid,
email: payload?.email,
strategyName: 'openIdJwtLogin',
findUser,
});
if (error) {

View File

@@ -337,10 +337,6 @@ async function setupOpenId() {
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
usePKCE,
},
/**
* @param {import('openid-client').TokenEndpointResponseHelpers} tokenset
* @param {import('passport-jwt').VerifyCallback} done
*/
async (tokenset, done) => {
try {
const claims = tokenset.claims();
@@ -358,11 +354,10 @@ async function setupOpenId() {
}
const result = await findOpenIDUser({
findUser,
email: claims.email,
openidId: claims.sub,
idOnTheSource: claims.oid,
email: claims.email,
strategyName: 'openidStrategy',
findUser,
});
let user = result.user;
const error = result.error;
@@ -376,10 +371,6 @@ async function setupOpenId() {
const fullName = getFullName(userinfo);
if (requiredRole) {
const requiredRoles = requiredRole
.split(',')
.map((role) => role.trim())
.filter(Boolean);
let decodedToken = '';
if (requiredRoleTokenKind === 'access') {
decodedToken = jwtDecode(tokenset.access_token);
@@ -402,13 +393,9 @@ async function setupOpenId() {
);
}
if (!requiredRoles.some((role) => roles.includes(role))) {
const rolesList =
requiredRoles.length === 1
? `"${requiredRoles[0]}"`
: `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`;
if (!roles.includes(requiredRole)) {
return done(null, false, {
message: `You must have ${rolesList} role to log in.`,
message: `You must have the "${requiredRole}" role to log in.`,
});
}
}
@@ -441,10 +428,6 @@ async function setupOpenId() {
user.username = username;
user.name = fullName;
user.idOnTheSource = userinfo.oid;
if (userinfo.email && userinfo.email !== user.email) {
user.email = userinfo.email;
user.emailVerified = userinfo.email_verified || false;
}
}
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {

View File

@@ -274,7 +274,10 @@ describe('setupOpenId', () => {
name: '',
};
findUser.mockImplementation(async (query) => {
if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) {
if (
query.openidId === tokenset.claims().sub ||
(query.email === tokenset.claims().email && query.provider === 'openid')
) {
return existingUser;
}
return null;
@@ -335,25 +338,7 @@ describe('setupOpenId', () => {
// Assert verify that the strategy rejects login
expect(user).toBe(false);
expect(details.message).toBe('You must have "requiredRole" role to log in.');
});
it('should allow login when single required role is present (backward compatibility)', async () => {
// Arrange ensure single role configuration (as set in beforeEach)
// OPENID_REQUIRED_ROLE = 'requiredRole'
// Default jwtDecode mock in beforeEach already returns this role
jwtDecode.mockReturnValue({
roles: ['requiredRole', 'anotherRole'],
});
// Act
const { user } = await validate(tokenset);
// Assert verify that login succeeds with single role configuration
expect(user).toBeTruthy();
expect(user.email).toBe(tokenset.claims().email);
expect(user.username).toBe(tokenset.claims().preferred_username);
expect(createUser).toHaveBeenCalled();
expect(details.message).toBe('You must have the "requiredRole" role to log in.');
});
it('should attempt to download and save the avatar if picture is provided', async () => {
@@ -379,58 +364,6 @@ describe('setupOpenId', () => {
// Depending on your implementation, user.avatar may be undefined or an empty string.
});
it('should support comma-separated multiple roles', async () => {
// Arrange
process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin';
await setupOpenId(); // Re-initialize the strategy
verifyCallback = require('openid-client/passport').__getVerifyCallback();
jwtDecode.mockReturnValue({
roles: ['anotherRole', 'aThirdRole'],
});
// Act
const { user } = await validate(tokenset);
// Assert
expect(user).toBeTruthy();
expect(user.email).toBe(tokenset.claims().email);
});
it('should reject login when user has none of the required multiple roles', async () => {
// Arrange
process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin';
await setupOpenId(); // Re-initialize the strategy
verifyCallback = require('openid-client/passport').__getVerifyCallback();
jwtDecode.mockReturnValue({
roles: ['aThirdRole', 'aFourthRole'],
});
// Act
const { user, details } = await validate(tokenset);
// Assert
expect(user).toBe(false);
expect(details.message).toBe(
'You must have one of: "someRole", "anotherRole", "admin" role to log in.',
);
});
it('should handle spaces in comma-separated roles', async () => {
// Arrange
process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin ';
await setupOpenId(); // Re-initialize the strategy
verifyCallback = require('openid-client/passport').__getVerifyCallback();
jwtDecode.mockReturnValue({
roles: ['someRole'],
});
// Act
const { user } = await validate(tokenset);
// Assert
expect(user).toBeTruthy();
});
it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => {
const OpenIDStrategy = require('openid-client/passport').Strategy;

View File

@@ -46,7 +46,7 @@ describe('fileSearch.js - test only new file_id and page additions', () => {
queryVectors.mockResolvedValue(mockResults);
const fileSearchTool = await createFileSearchTool({
userId: 'user1',
req: { user: { id: 'user1' } },
files: mockFiles,
entity_id: 'agent-123',
});

View File

@@ -873,13 +873,6 @@
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
* @memberof typedefs
*/
/**
* @exports ISession
* @typedef {import('@librechat/data-schemas').ISession} ISession
* @memberof typedefs
*/
/**
* @exports IBalance
* @typedef {import('@librechat/data-schemas').IBalance} IBalance

View File

@@ -1,15 +1,13 @@
import React, { createContext, useContext, useState, useMemo } from 'react';
import { EModelEndpoint } from 'librechat-data-provider';
import type { MCP, Action, TPlugin } from 'librechat-data-provider';
import { Constants, EModelEndpoint } from 'librechat-data-provider';
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
import type { AgentPanelContextType, MCPServerInfo } from '~/common';
import {
useAvailableToolsQuery,
useGetActionsQuery,
useGetStartupConfig,
useMCPToolsQuery,
} from '~/data-provider';
import { useAvailableToolsQuery, useGetActionsQuery, useGetStartupConfig } from '~/data-provider';
import { useLocalize, useGetAgentsConfig, useMCPConnectionStatus } from '~/hooks';
import { Panel, isEphemeralAgent } from '~/common';
import { Panel } from '~/common';
type GroupedToolType = AgentToolType & { tools?: AgentToolType[] };
type GroupedToolsRecord = Record<string, GroupedToolType>;
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
@@ -30,67 +28,79 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
const [activePanel, setActivePanel] = useState<Panel>(Panel.builder);
const [agent_id, setCurrentAgentId] = useState<string | undefined>(undefined);
const { data: startupConfig } = useGetStartupConfig();
const { data: actions } = useGetActionsQuery(EModelEndpoint.agents, {
enabled: !isEphemeralAgent(agent_id),
enabled: !!agent_id,
});
const { data: regularTools } = useAvailableToolsQuery(EModelEndpoint.agents, {
enabled: !isEphemeralAgent(agent_id),
const { data: pluginTools } = useAvailableToolsQuery(EModelEndpoint.agents, {
enabled: !!agent_id,
});
const { data: mcpData } = useMCPToolsQuery({
enabled: !isEphemeralAgent(agent_id) && startupConfig?.mcpServers != null,
});
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
const { data: startupConfig } = useGetStartupConfig();
const mcpServerNames = useMemo(
() => Object.keys(startupConfig?.mcpServers ?? {}),
[startupConfig],
);
const { connectionStatus } = useMCPConnectionStatus({
enabled: !isEphemeralAgent(agent_id) && mcpServerNames.length > 0,
enabled: !!agent_id && mcpServerNames.length > 0,
});
const mcpServersMap = useMemo(() => {
const processedData = useMemo(() => {
if (!pluginTools) {
return {
tools: [],
groupedTools: {},
mcpServersMap: new Map<string, MCPServerInfo>(),
};
}
const tools: AgentToolType[] = [];
const groupedTools: GroupedToolsRecord = {};
const configuredServers = new Set(mcpServerNames);
const serversMap = new Map<string, MCPServerInfo>();
const mcpServersMap = new Map<string, MCPServerInfo>();
if (mcpData?.servers) {
for (const [serverName, serverData] of Object.entries(mcpData.servers)) {
const metadata = {
name: serverName,
pluginKey: serverName,
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
icon: serverData.icon || '',
authConfig: serverData.authConfig,
authenticated: serverData.authenticated,
} as TPlugin;
for (const pluginTool of pluginTools) {
const tool: AgentToolType = {
tool_id: pluginTool.pluginKey,
metadata: pluginTool as TPlugin,
};
const tools = serverData.tools.map((tool) => ({
tool_id: tool.pluginKey,
metadata: {
...tool,
icon: serverData.icon,
authConfig: serverData.authConfig,
authenticated: serverData.authenticated,
} as TPlugin,
}));
tools.push(tool);
serversMap.set(serverName, {
serverName,
tools,
isConfigured: configuredServers.has(serverName),
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
metadata,
});
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
if (!mcpServersMap.has(serverName)) {
const metadata = {
name: serverName,
pluginKey: serverName,
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
icon: pluginTool.icon || '',
} as TPlugin;
mcpServersMap.set(serverName, {
serverName,
tools: [],
isConfigured: configuredServers.has(serverName),
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
metadata,
});
}
mcpServersMap.get(serverName)!.tools.push(tool);
} else {
// Non-MCP tool
groupedTools[tool.tool_id] = {
tool_id: tool.tool_id,
metadata: tool.metadata,
};
}
}
// Add configured servers that don't have tools yet
for (const mcpServerName of mcpServerNames) {
if (serversMap.has(mcpServerName)) {
if (mcpServersMap.has(mcpServerName)) {
continue;
}
const metadata = {
@@ -100,7 +110,7 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
description: `${localize('com_ui_tool_collection_prefix')} ${mcpServerName}`,
} as TPlugin;
serversMap.set(mcpServerName, {
mcpServersMap.set(mcpServerName, {
tools: [],
metadata,
isConfigured: true,
@@ -109,8 +119,14 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
});
}
return serversMap;
}, [mcpData, localize, mcpServerNames, connectionStatus]);
return {
tools,
groupedTools,
mcpServersMap,
};
}, [pluginTools, localize, mcpServerNames, connectionStatus]);
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
const value: AgentPanelContextType = {
mcp,
@@ -121,14 +137,16 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
setMcps,
agent_id,
setAction,
pluginTools,
activePanel,
regularTools,
agentsConfig,
startupConfig,
mcpServersMap,
setActivePanel,
endpointsConfig,
setCurrentAgentId,
tools: processedData.tools,
groupedTools: processedData.groupedTools,
mcpServersMap: processedData.mcpServersMap,
};
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;

View File

@@ -1,32 +0,0 @@
import React, { createContext, useContext, useMemo } from 'react';
import { useChatContext } from './ChatContext';
interface DragDropContextValue {
conversationId: string | null | undefined;
agentId: string | null | undefined;
}
const DragDropContext = createContext<DragDropContextValue | undefined>(undefined);
export function DragDropProvider({ children }: { children: React.ReactNode }) {
const { conversation } = useChatContext();
/** Context value only created when conversation fields change */
const contextValue = useMemo<DragDropContextValue>(
() => ({
conversationId: conversation?.conversationId,
agentId: conversation?.agent_id,
}),
[conversation?.conversationId, conversation?.agent_id],
);
return <DragDropContext.Provider value={contextValue}>{children}</DragDropContext.Provider>;
}
export function useDragDropContext() {
const context = useContext(DragDropContext);
if (!context) {
throw new Error('useDragDropContext must be used within DragDropProvider');
}
return context;
}

View File

@@ -1,15 +1,10 @@
import { createContext, useContext } from 'react';
type MessageContext = {
messageId: string;
nextType?: string;
partIndex?: number;
isExpanded: boolean;
conversationId?: string | null;
/** Submission state for cursor display - only true for latest message when submitting */
isSubmitting?: boolean;
/** Whether this is the latest message in the conversation */
isLatestMessage?: boolean;
};
export const MessageContext = createContext<MessageContext>({} as MessageContext);

View File

@@ -1,150 +0,0 @@
import React, { createContext, useContext, useMemo } from 'react';
import { useAddedChatContext } from './AddedChatContext';
import { useChatContext } from './ChatContext';
interface MessagesViewContextValue {
/** Core conversation data */
conversation: ReturnType<typeof useChatContext>['conversation'];
conversationId: string | null | undefined;
/** Submission and control states */
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
isSubmittingFamily: boolean;
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
/** Message operations */
ask: ReturnType<typeof useChatContext>['ask'];
regenerate: ReturnType<typeof useChatContext>['regenerate'];
handleContinue: ReturnType<typeof useChatContext>['handleContinue'];
/** Message state management */
index: ReturnType<typeof useChatContext>['index'];
latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
getMessages: ReturnType<typeof useChatContext>['getMessages'];
setMessages: ReturnType<typeof useChatContext>['setMessages'];
}
const MessagesViewContext = createContext<MessagesViewContextValue | undefined>(undefined);
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
const chatContext = useChatContext();
const addedChatContext = useAddedChatContext();
const {
ask,
index,
regenerate,
isSubmitting: isSubmittingRoot,
conversation,
latestMessage,
setAbortScroll,
handleContinue,
setLatestMessage,
abortScroll,
getMessages,
setMessages,
} = chatContext;
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
/** Memoize conversation-related values */
const conversationValues = useMemo(
() => ({
conversation,
conversationId: conversation?.conversationId,
}),
[conversation],
);
/** Memoize submission states */
const submissionStates = useMemo(
() => ({
isSubmitting: isSubmittingRoot,
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
abortScroll,
setAbortScroll,
}),
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
);
/** Memoize message operations (these are typically stable references) */
const messageOperations = useMemo(
() => ({
ask,
regenerate,
getMessages,
setMessages,
handleContinue,
}),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
/** Memoize message state values */
const messageState = useMemo(
() => ({
index,
latestMessage,
setLatestMessage,
}),
[index, latestMessage, setLatestMessage],
);
/** Combine all values into final context value */
const contextValue = useMemo<MessagesViewContextValue>(
() => ({
...conversationValues,
...submissionStates,
...messageOperations,
...messageState,
}),
[conversationValues, submissionStates, messageOperations, messageState],
);
return (
<MessagesViewContext.Provider value={contextValue}>{children}</MessagesViewContext.Provider>
);
}
export function useMessagesViewContext() {
const context = useContext(MessagesViewContext);
if (!context) {
throw new Error('useMessagesViewContext must be used within MessagesViewProvider');
}
return context;
}
/** Hook for components that only need conversation data */
export function useMessagesConversation() {
const { conversation, conversationId } = useMessagesViewContext();
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
}
/** Hook for components that only need submission states */
export function useMessagesSubmission() {
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
useMessagesViewContext();
return useMemo(
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
);
}
/** Hook for components that only need message operations */
export function useMessagesOperations() {
const { ask, regenerate, handleContinue, getMessages, setMessages } = useMessagesViewContext();
return useMemo(
() => ({ ask, regenerate, handleContinue, getMessages, setMessages }),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
}
/** Hook for components that only need message state */
export function useMessagesState() {
const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
return useMemo(
() => ({ index, latestMessage, setLatestMessage }),
[index, latestMessage, setLatestMessage],
);
}

View File

@@ -1,10 +1,9 @@
import React, { createContext, useContext, ReactNode, useMemo } from 'react';
import { PermissionTypes, Permissions } from 'librechat-data-provider';
import type { TPromptGroup } from 'librechat-data-provider';
import type { PromptOption } from '~/common';
import CategoryIcon from '~/components/Prompts/Groups/CategoryIcon';
import { usePromptGroupsNav, useHasAccess } from '~/hooks';
import { useGetAllPromptGroups } from '~/data-provider';
import { usePromptGroupsNav } from '~/hooks';
import { mapPromptGroups } from '~/utils';
type AllPromptGroupsData =
@@ -20,21 +19,14 @@ type PromptGroupsContextType =
data: AllPromptGroupsData;
isLoading: boolean;
};
hasAccess: boolean;
})
| null;
const PromptGroupsContext = createContext<PromptGroupsContextType>(null);
export const PromptGroupsProvider = ({ children }: { children: ReactNode }) => {
const hasAccess = useHasAccess({
permissionType: PermissionTypes.PROMPTS,
permission: Permissions.USE,
});
const promptGroupsNav = usePromptGroupsNav(hasAccess);
const promptGroupsNav = usePromptGroupsNav();
const { data: allGroupsData, isLoading: isLoadingAll } = useGetAllPromptGroups(undefined, {
enabled: hasAccess,
select: (data) => {
const mappedArray: PromptOption[] = data.map((group) => ({
id: group._id ?? '',
@@ -63,12 +55,11 @@ export const PromptGroupsProvider = ({ children }: { children: ReactNode }) => {
() => ({
...promptGroupsNav,
allPromptGroups: {
data: hasAccess ? allGroupsData : undefined,
isLoading: hasAccess ? isLoadingAll : false,
data: allGroupsData,
isLoading: isLoadingAll,
},
hasAccess,
}),
[promptGroupsNav, allGroupsData, isLoadingAll, hasAccess],
[promptGroupsNav, allGroupsData, isLoadingAll],
);
return (

View File

@@ -23,9 +23,7 @@ export * from './SetConvoContext';
export * from './SearchContext';
export * from './BadgeRowContext';
export * from './SidePanelContext';
export * from './DragDropContext';
export * from './MCPPanelContext';
export * from './ArtifactsContext';
export * from './PromptGroupsContext';
export * from './MessagesViewContext';
export { default as BadgeRowProvider } from './BadgeRowContext';

View File

@@ -1,5 +1,5 @@
import React from 'react';
import { TStartupConfig } from 'librechat-data-provider';
import { TModelSpec, TStartupConfig } from 'librechat-data-provider';
export interface Endpoint {
value: string;

View File

@@ -1,5 +1,5 @@
import { RefObject } from 'react';
import { Constants, FileSources, EModelEndpoint } from 'librechat-data-provider';
import { FileSources, EModelEndpoint } from 'librechat-data-provider';
import type { UseMutationResult } from '@tanstack/react-query';
import type * as InputNumberPrimitive from 'rc-input-number';
import type { SetterOrUpdater, RecoilState } from 'recoil';
@@ -8,10 +8,6 @@ import type * as t from 'librechat-data-provider';
import type { LucideIcon } from 'lucide-react';
import type { TranslationKeys } from '~/hooks';
export function isEphemeralAgent(agentId: string | null | undefined): boolean {
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
}
export interface ConfigFieldDetail {
title: string;
description: string;
@@ -236,8 +232,10 @@ export type AgentPanelContextType = {
mcps?: t.MCP[];
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
setMcps: React.Dispatch<React.SetStateAction<t.MCP[] | undefined>>;
groupedTools: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
activePanel?: string;
regularTools?: t.TPlugin[];
tools: t.AgentToolType[];
pluginTools?: t.TPlugin[];
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
agent_id?: string;

View File

@@ -11,9 +11,9 @@ import {
AgentListResponse,
} from 'librechat-data-provider';
import type t from 'librechat-data-provider';
import { useLocalize, useDefaultConvo } from '~/hooks';
import { useChatContext } from '~/Providers';
import { renderAgentAvatar } from '~/utils';
import { useLocalize } from '~/hooks';
interface SupportContact {
name?: string;
@@ -34,11 +34,11 @@ interface AgentDetailProps {
*/
const AgentDetail: React.FC<AgentDetailProps> = ({ agent, isOpen, onClose }) => {
const localize = useLocalize();
const queryClient = useQueryClient();
// const navigate = useNavigate();
const { conversation, newConversation } = useChatContext();
const { showToast } = useToastContext();
const dialogRef = useRef<HTMLDivElement>(null);
const getDefaultConversation = useDefaultConvo();
const { conversation, newConversation } = useChatContext();
const queryClient = useQueryClient();
/**
* Navigate to chat with the selected agent
@@ -62,22 +62,13 @@ const AgentDetail: React.FC<AgentDetailProps> = ({ agent, isOpen, onClose }) =>
);
queryClient.invalidateQueries([QueryKeys.messages]);
/** Template with agent configuration */
const template = {
conversationId: Constants.NEW_CONVO as string,
endpoint: EModelEndpoint.agents,
agent_id: agent.id,
title: localize('com_agents_chat_with', { name: agent.name || localize('com_ui_agent') }),
};
const currentConvo = getDefaultConversation({
conversation: { ...(conversation ?? {}), ...template },
preset: template,
});
newConversation({
template: currentConvo,
preset: template,
template: {
conversationId: Constants.NEW_CONVO as string,
endpoint: EModelEndpoint.agents,
agent_id: agent.id,
title: `Chat with ${agent.name || 'Agent'}`,
},
});
}
};

View File

@@ -20,7 +20,6 @@ jest.mock('react-router-dom', () => ({
jest.mock('~/hooks', () => ({
useMediaQuery: jest.fn(() => false), // Mock as desktop by default
useLocalize: jest.fn(),
useDefaultConvo: jest.fn(),
}));
jest.mock('@librechat/client', () => ({
@@ -48,12 +47,7 @@ const mockWriteText = jest.fn();
const mockNavigate = jest.fn();
const mockShowToast = jest.fn();
const mockLocalize = jest.fn((key: string, values?: Record<string, any>) => {
if (key === 'com_agents_chat_with' && values?.name) {
return `Chat with ${values.name}`;
}
return key;
});
const mockLocalize = jest.fn((key: string) => key);
const mockAgent: t.Agent = {
id: 'test-agent-id',
@@ -112,12 +106,8 @@ describe('AgentDetail', () => {
(useNavigate as jest.Mock).mockReturnValue(mockNavigate);
const { useToastContext } = require('@librechat/client');
(useToastContext as jest.Mock).mockReturnValue({ showToast: mockShowToast });
const { useLocalize, useDefaultConvo } = require('~/hooks');
const { useLocalize } = require('~/hooks');
(useLocalize as jest.Mock).mockReturnValue(mockLocalize);
(useDefaultConvo as jest.Mock).mockReturnValue(() => ({
conversationId: Constants.NEW_CONVO,
endpoint: EModelEndpoint.agents,
}));
// Mock useChatContext
const { useChatContext } = require('~/Providers');
@@ -237,10 +227,6 @@ describe('AgentDetail', () => {
template: {
conversationId: Constants.NEW_CONVO,
endpoint: EModelEndpoint.agents,
},
preset: {
conversationId: Constants.NEW_CONVO,
endpoint: EModelEndpoint.agents,
agent_id: 'test-agent-id',
title: 'Chat with Test Agent',
},

View File

@@ -4,7 +4,7 @@ import {
SandpackProvider,
SandpackProviderProps,
} from '@codesandbox/sandpack-react/unstyled';
import type { SandpackPreviewRef, PreviewProps } from '@codesandbox/sandpack-react/unstyled';
import type { SandpackPreviewRef } from '@codesandbox/sandpack-react/unstyled';
import type { TStartupConfig } from 'librechat-data-provider';
import type { ArtifactFiles } from '~/common';
import { sharedFiles, sharedOptions } from '~/utils/artifacts';
@@ -13,7 +13,6 @@ export const ArtifactPreview = memo(function ({
files,
fileKey,
template,
isMermaid,
sharedProps,
previewRef,
currentCode,
@@ -21,7 +20,6 @@ export const ArtifactPreview = memo(function ({
}: {
files: ArtifactFiles;
fileKey: string;
isMermaid: boolean;
template: SandpackProviderProps['template'];
sharedProps: Partial<SandpackProviderProps>;
previewRef: React.MutableRefObject<SandpackPreviewRef>;
@@ -56,15 +54,6 @@ export const ArtifactPreview = memo(function ({
return _options;
}, [startupConfig, template]);
const style: PreviewProps['style'] | undefined = useMemo(() => {
if (isMermaid) {
return {
backgroundColor: '#282C34',
};
}
return;
}, [isMermaid]);
if (Object.keys(artifactFiles).length === 0) {
return null;
}
@@ -84,7 +73,6 @@ export const ArtifactPreview = memo(function ({
showRefreshButton={false}
tabIndex={0}
ref={previewRef}
style={style}
/>
</SandpackProvider>
);

View File

@@ -8,7 +8,6 @@ import { useAutoScroll } from '~/hooks/Artifacts/useAutoScroll';
import { ArtifactCodeEditor } from './ArtifactCodeEditor';
import { useGetStartupConfig } from '~/data-provider';
import { ArtifactPreview } from './ArtifactPreview';
import { MermaidMarkdown } from './MermaidMarkdown';
import { cn } from '~/utils';
export default function ArtifactTabs({
@@ -45,25 +44,23 @@ export default function ArtifactTabs({
id="artifacts-code"
className={cn('flex-grow overflow-auto')}
>
{isMermaid ? (
<MermaidMarkdown content={content} isSubmitting={isSubmitting} />
) : (
<ArtifactCodeEditor
files={files}
fileKey={fileKey}
template={template}
artifact={artifact}
editorRef={editorRef}
sharedProps={sharedProps}
/>
)}
<ArtifactCodeEditor
files={files}
fileKey={fileKey}
template={template}
artifact={artifact}
editorRef={editorRef}
sharedProps={sharedProps}
/>
</Tabs.Content>
<Tabs.Content value="preview" className="flex-grow overflow-auto">
<Tabs.Content
value="preview"
className={cn('flex-grow overflow-auto', isMermaid ? 'bg-[#282C34]' : 'bg-white')}
>
<ArtifactPreview
files={files}
fileKey={fileKey}
template={template}
isMermaid={isMermaid}
previewRef={previewRef}
sharedProps={sharedProps}
currentCode={currentCode}

View File

@@ -9,7 +9,6 @@ import { useEditorContext } from '~/Providers';
import ArtifactTabs from './ArtifactTabs';
import { CopyCodeButton } from './Code';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import store from '~/store';
export default function Artifacts() {
@@ -59,10 +58,9 @@ export default function Artifacts() {
<div className="flex h-full w-full items-center justify-center">
{/* Main Container */}
<div
className={cn(
`flex h-full w-full flex-col overflow-hidden border border-border-medium bg-surface-primary text-xl text-text-primary shadow-xl transition-all duration-500 ease-in-out`,
isVisible ? 'scale-100 opacity-100 blur-0' : 'scale-105 opacity-0 blur-sm',
)}
className={`flex h-full w-full flex-col overflow-hidden border border-border-medium bg-surface-primary text-xl text-text-primary shadow-xl transition-all duration-500 ease-in-out ${
isVisible ? 'scale-100 opacity-100 blur-0' : 'scale-105 opacity-0 blur-sm'
}`}
>
{/* Header */}
<div className="flex items-center justify-between border-b border-border-medium bg-surface-primary-alt p-2">
@@ -76,17 +74,16 @@ export default function Artifacts() {
{/* Refresh button */}
{activeTab === 'preview' && (
<button
className={cn(
'mr-2 text-text-secondary transition-transform duration-500 ease-in-out',
isRefreshing ? 'rotate-180' : '',
)}
className={`mr-2 text-text-secondary transition-transform duration-500 ease-in-out ${
isRefreshing ? 'rotate-180' : ''
}`}
onClick={handleRefresh}
disabled={isRefreshing}
aria-label="Refresh"
>
<RefreshCw
size={16}
className={cn('transform', isRefreshing ? 'animate-spin' : '')}
className={`transform ${isRefreshing ? 'animate-spin' : ''}`}
/>
</button>
)}

View File

@@ -1,11 +0,0 @@
import { CodeMarkdown } from './Code';
export function MermaidMarkdown({
content,
isSubmitting,
}: {
content: string;
isSubmitting: boolean;
}) {
return <CodeMarkdown content={`\`\`\`mermaid\n${content}\`\`\``} isSubmitting={isSubmitting} />;
}

View File

@@ -1,4 +1,4 @@
import { memo, useCallback } from 'react';
import { memo, useCallback, useState, useEffect, useRef } from 'react';
import { useRecoilValue } from 'recoil';
import { useForm } from 'react-hook-form';
import { Spinner } from '@librechat/client';
@@ -13,6 +13,7 @@ import { useGetMessagesByConvoId } from '~/data-provider';
import MessagesView from './Messages/MessagesView';
import Presentation from './Presentation';
import ChatForm from './Input/ChatForm';
import CostBar from './CostBar';
import Landing from './Landing';
import Header from './Header';
import Footer from './Footer';
@@ -29,7 +30,13 @@ function LoadingSpinner() {
);
}
function ChatView({ index = 0 }: { index?: number }) {
function ChatView({
index = 0,
modelCosts,
}: {
index?: number;
modelCosts?: { modelCostTable: Record<string, { prompt: number; completion: number }> };
}) {
const { conversationId } = useParams();
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
@@ -37,6 +44,9 @@ function ChatView({ index = 0 }: { index?: number }) {
const fileMap = useFileMapContext();
const [showCostBar, setShowCostBar] = useState(false);
const lastScrollY = useRef(0);
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
select: useCallback(
(data: TMessage[]) => {
@@ -54,6 +64,58 @@ function ChatView({ index = 0 }: { index?: number }) {
useSSE(rootSubmission, chatHelpers, false);
useSSE(addedSubmission, addedChatHelpers, true);
const checkIfAtBottom = useCallback(
(container: HTMLElement) => {
const currentScrollY = container.scrollTop;
const scrollHeight = container.scrollHeight;
const clientHeight = container.clientHeight;
const distanceFromBottom = scrollHeight - currentScrollY - clientHeight;
const isAtBottom = distanceFromBottom < 10;
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
setShowCostBar(isAtBottom && !isStreaming);
lastScrollY.current = currentScrollY;
},
[chatHelpers.isSubmitting, addedChatHelpers.isSubmitting],
);
useEffect(() => {
const handleScroll = (event: Event) => {
const target = event.target as HTMLElement;
checkIfAtBottom(target);
};
const findAndAttachScrollListener = () => {
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
if (messagesContainer) {
checkIfAtBottom(messagesContainer as HTMLElement);
messagesContainer.addEventListener('scroll', handleScroll, { passive: true });
return () => {
messagesContainer.removeEventListener('scroll', handleScroll);
};
}
setTimeout(findAndAttachScrollListener, 100);
};
const cleanup = findAndAttachScrollListener();
return cleanup;
}, [messagesTree, checkIfAtBottom]);
useEffect(() => {
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
if (isStreaming) {
setShowCostBar(false);
} else {
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
if (messagesContainer) {
checkIfAtBottom(messagesContainer as HTMLElement);
}
}
}, [chatHelpers.isSubmitting, addedChatHelpers.isSubmitting, checkIfAtBottom]);
const methods = useForm<ChatFormValues>({
defaultValues: { text: '' },
});
@@ -69,7 +131,22 @@ function ChatView({ index = 0 }: { index?: number }) {
} else if ((isLoading || isNavigating) && !isLandingPage) {
content = <LoadingSpinner />;
} else if (!isLandingPage) {
content = <MessagesView messagesTree={messagesTree} />;
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
content = (
<MessagesView
messagesTree={messagesTree}
costBar={
!isLandingPage &&
modelCosts && (
<CostBar
messagesTree={messagesTree}
modelCosts={modelCosts}
showCostBar={showCostBar && !isStreaming}
/>
)
}
/>
);
} else {
content = <Landing centerFormOnLanding={centerFormOnLanding} />;
}

View File

@@ -0,0 +1,112 @@
import { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import { TModelCosts, TMessage } from 'librechat-data-provider';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import store from '~/store';
interface CostBarProps {
messagesTree: TMessage[];
modelCosts: TModelCosts;
showCostBar: boolean;
}
export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) {
const localize = useLocalize();
const showCostTracking = useRecoilValue(store.showCostTracking);
const conversationCosts = useMemo(() => {
if (!modelCosts?.modelCostTable || !messagesTree) {
return null;
}
let totalPromptTokens = 0;
let totalCompletionTokens = 0;
let totalPromptUSD = 0;
let totalCompletionUSD = 0;
const flattenMessages = (messages: TMessage[]) => {
const flattened: TMessage[] = [];
messages.forEach((message: TMessage) => {
flattened.push(message);
if (message.children && message.children.length > 0) {
flattened.push(...flattenMessages(message.children));
}
});
return flattened;
};
const allMessages = flattenMessages(messagesTree);
allMessages.forEach((message) => {
if (!message.tokenCount) {
return null;
}
const modelToUse = message.isCreatedByUser ? message.targetModel : message.model;
const modelPricing = modelCosts.modelCostTable[modelToUse];
if (message.isCreatedByUser) {
totalPromptTokens += message.tokenCount;
totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt;
} else {
totalCompletionTokens += message.tokenCount;
totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion;
}
});
const totalTokens = totalPromptTokens + totalCompletionTokens;
const totalUSD = totalPromptUSD + totalCompletionUSD;
return {
totals: {
prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD },
completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD },
total: { tokenCount: totalTokens, usd: totalUSD },
},
};
}, [modelCosts, messagesTree]);
if (!showCostTracking || !conversationCosts || !conversationCosts.totals) {
return null;
}
return (
<div
className={cn(
'mx-auto w-full max-w-md px-4 text-xs text-muted-foreground transition-all duration-300 ease-in-out',
showCostBar ? 'opacity-100' : 'opacity-0',
)}
>
<div className="grid grid-cols-3 gap-2 text-center">
<div>
<div>
<ArrowIcon direction="up" />
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.prompt.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.prompt.usd).toFixed(6)}</div>
</div>
<div>
<div>
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.total.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.total.usd).toFixed(6)}</div>
</div>
<div>
<div>
<ArrowIcon direction="down" />
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.completion.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.completion.usd).toFixed(6)}</div>
</div>
</div>
</div>
);
}

View File

@@ -1,6 +1,6 @@
import React, { useRef, useState, useMemo } from 'react';
import * as Ariakit from '@ariakit/react';
import { useRecoilState } from 'recoil';
import { useSetRecoilState } from 'recoil';
import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
import { EToolResources, EModelEndpoint, defaultAgentCapabilities } from 'librechat-data-provider';
import {
@@ -42,9 +42,7 @@ const AttachFileMenu = ({
const isUploadDisabled = disabled ?? false;
const inputRef = useRef<HTMLInputElement>(null);
const [isPopoverActive, setIsPopoverActive] = useState(false);
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(
ephemeralAgentByConvoId(conversationId),
);
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
const { handleFileChange } = useFileHandling({
overrideEndpoint: EModelEndpoint.agents,
@@ -66,10 +64,7 @@ const AttachFileMenu = ({
* */
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(
agentId,
ephemeralAgent,
);
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(agentId);
const handleUploadClick = (isImage?: boolean) => {
if (!inputRef.current) {
@@ -94,11 +89,11 @@ const AttachFileMenu = ({
},
];
if (capabilities.contextEnabled) {
if (capabilities.ocrEnabled) {
items.push({
label: localize('com_ui_upload_ocr_text'),
onClick: () => {
setToolResource(EToolResources.context);
setToolResource(EToolResources.ocr);
onAction();
},
icon: <FileType2Icon className="icon-md" />,

View File

@@ -1,16 +1,14 @@
import React, { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { OGDialog, OGDialogTemplate } from '@librechat/client';
import { EToolResources, defaultAgentCapabilities } from 'librechat-data-provider';
import { ImageUpIcon, FileSearch, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
import { EToolResources, defaultAgentCapabilities } from 'librechat-data-provider';
import {
useAgentToolPermissions,
useAgentCapabilities,
useGetAgentsConfig,
useLocalize,
} from '~/hooks';
import { ephemeralAgentByConvoId } from '~/store';
import { useDragDropContext } from '~/Providers';
import { useChatContext } from '~/Providers';
interface DragDropModalProps {
onOptionSelect: (option: EToolResources | undefined) => void;
@@ -34,11 +32,9 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
* Use definition for agents endpoint for ephemeral agents
* */
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
const { conversationId, agentId } = useDragDropContext();
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(conversationId ?? ''));
const { conversation } = useChatContext();
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(
agentId,
ephemeralAgent,
conversation?.agent_id,
);
const options = useMemo(() => {
@@ -64,10 +60,10 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
icon: <TerminalSquareIcon className="icon-md" />,
});
}
if (capabilities.contextEnabled) {
if (capabilities.ocrEnabled) {
_options.push({
label: localize('com_ui_upload_ocr_text'),
value: EToolResources.context,
value: EToolResources.ocr,
icon: <FileType2Icon className="icon-md" />,
});
}

View File

@@ -1,7 +1,6 @@
import { useDragHelpers } from '~/hooks';
import DragDropOverlay from '~/components/Chat/Input/Files/DragDropOverlay';
import DragDropModal from '~/components/Chat/Input/Files/DragDropModal';
import { DragDropProvider } from '~/Providers';
import { cn } from '~/utils';
interface DragDropWrapperProps {
@@ -20,14 +19,12 @@ export default function DragDropWrapper({ children, className }: DragDropWrapper
{children}
{/** Always render overlay to avoid mount/unmount overhead */}
<DragDropOverlay isActive={isActive} />
<DragDropProvider>
<DragDropModal
files={draggedFiles}
isVisible={showModal}
setShowModal={setShowModal}
onOptionSelect={handleOptionSelect}
/>
</DragDropProvider>
<DragDropModal
files={draggedFiles}
isVisible={showModal}
setShowModal={setShowModal}
onOptionSelect={handleOptionSelect}
/>
</div>
);
}

View File

@@ -2,13 +2,14 @@ import { useState, useRef, useEffect, useMemo, memo, useCallback } from 'react';
import { AutoSizer, List } from 'react-virtualized';
import { Spinner, useCombobox } from '@librechat/client';
import { useSetRecoilState, useRecoilValue } from 'recoil';
import { PermissionTypes, Permissions } from 'librechat-data-provider';
import type { TPromptGroup } from 'librechat-data-provider';
import type { PromptOption } from '~/common';
import { removeCharIfLast, detectVariables } from '~/utils';
import VariableDialog from '~/components/Prompts/Groups/VariableDialog';
import { usePromptGroupsContext } from '~/Providers';
import { useLocalize, useHasAccess } from '~/hooks';
import MentionItem from './MentionItem';
import { useLocalize } from '~/hooks';
import store from '~/store';
const commandChar = '/';
@@ -53,7 +54,12 @@ function PromptsCommand({
submitPrompt: (textPrompt: string) => void;
}) {
const localize = useLocalize();
const { allPromptGroups, hasAccess } = usePromptGroupsContext();
const hasAccess = useHasAccess({
permissionType: PermissionTypes.PROMPTS,
permission: Permissions.USE,
});
const { allPromptGroups } = usePromptGroupsContext();
const { data, isLoading } = allPromptGroups;
const [activeIndex, setActiveIndex] = useState(0);

View File

@@ -26,7 +26,6 @@ type ContentPartsProps = {
isCreatedByUser: boolean;
isLast: boolean;
isSubmitting: boolean;
isLatestMessage?: boolean;
edit?: boolean;
enterEdit?: (cancel?: boolean) => void | null | undefined;
siblingIdx?: number;
@@ -46,7 +45,6 @@ const ContentParts = memo(
isCreatedByUser,
isLast,
isSubmitting,
isLatestMessage,
edit,
enterEdit,
siblingIdx,
@@ -57,8 +55,6 @@ const ContentParts = memo(
const [isExpanded, setIsExpanded] = useState(showThinking);
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const hasReasoningParts = useMemo(() => {
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
const allThinkPartsHaveContent =
@@ -138,9 +134,7 @@ const ContentParts = memo(
})
}
label={
effectiveIsSubmitting && isLast
? localize('com_ui_thinking')
: localize('com_ui_thoughts')
isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts')
}
/>
</div>
@@ -161,14 +155,12 @@ const ContentParts = memo(
conversationId,
partIndex: idx,
nextType: content[idx + 1]?.type,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
<Part
part={part}
attachments={attachments}
isSubmitting={effectiveIsSubmitting}
isSubmitting={isSubmitting}
key={`part-${messageId}-${idx}`}
isCreatedByUser={isCreatedByUser}
isLast={idx === content.length - 1}

View File

@@ -4,7 +4,7 @@ import { useRecoilState, useRecoilValue } from 'recoil';
import { TextareaAutosize, TooltipAnchor } from '@librechat/client';
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
import type { TEditProps } from '~/common';
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import Container from './Container';
@@ -22,8 +22,7 @@ const EditMessage = ({
const { addedIndex } = useAddedChatContext();
const saveButtonRef = useRef<HTMLButtonElement | null>(null);
const submitButtonRef = useRef<HTMLButtonElement | null>(null);
const { conversation } = useMessagesConversation();
const { getMessages, setMessages } = useMessagesOperations();
const { getMessages, setMessages, conversation } = useChatContext();
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
store.latestMessageFamily(addedIndex),
);

View File

@@ -5,7 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
import type { TMessageContentProps, TDisplayProps } from '~/common';
import Error from '~/components/Messages/Content/Error';
import Thinking from '~/components/Artifacts/Thinking';
import { useMessageContext } from '~/Providers';
import { useChatContext } from '~/Providers';
import MarkdownLite from './MarkdownLite';
import EditMessage from './EditMessage';
import { useLocalize } from '~/hooks';
@@ -70,12 +70,16 @@ export const ErrorMessage = ({
};
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const { isSubmitting, latestMessage } = useChatContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(
() => showCursor === true && isSubmitting,
[showCursor, isSubmitting],
);
const isLatestMessage = useMemo(
() => message.messageId === latestMessage?.messageId,
[message.messageId, latestMessage?.messageId],
);
let content: React.ReactElement;
if (!isCreatedByUser) {

View File

@@ -85,14 +85,13 @@ const Part = memo(
const isToolCall =
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
if (isToolCall && toolCall.name === Tools.execute_code) {
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
return (
<ExecuteCode
attachments={attachments}
isSubmitting={isSubmitting}
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
output={toolCall.output ?? ''}
initialProgress={toolCall.progress ?? 0.1}
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
attachments={attachments}
/>
);
} else if (

View File

@@ -70,7 +70,7 @@ const ImageAttachment = memo(({ attachment }: { attachment: TAttachment }) => {
}}
>
<Image
altText={attachment.filename || 'attachment image'}
altText={attachment.filename}
imagePath={filepath ?? ''}
height={height ?? 0}
width={width ?? 0}
@@ -89,9 +89,8 @@ export default function Attachment({ attachment }: { attachment?: TAttachment })
}
const { width, height, filepath = null } = attachment as TFile & TAttachmentMetadata;
const isImage = attachment.filename
? imageExtRegex.test(attachment.filename) && width != null && height != null && filepath != null
: false;
const isImage =
imageExtRegex.test(attachment.filename) && width != null && height != null && filepath != null;
if (isImage) {
return <ImageAttachment attachment={attachment} />;
@@ -111,12 +110,11 @@ export function AttachmentGroup({ attachments }: { attachments?: TAttachment[] }
attachments.forEach((attachment) => {
const { width, height, filepath = null } = attachment as TFile & TAttachmentMetadata;
const isImage = attachment.filename
? imageExtRegex.test(attachment.filename) &&
width != null &&
height != null &&
filepath != null
: false;
const isImage =
imageExtRegex.test(attachment.filename) &&
width != null &&
height != null &&
filepath != null;
if (isImage) {
imageAttachments.push(attachment);

View File

@@ -6,8 +6,8 @@ import { useRecoilState, useRecoilValue } from 'recoil';
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
import type { Agents } from 'librechat-data-provider';
import type { TEditProps } from '~/common';
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
import Container from '~/components/Chat/Messages/Content/Container';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import store from '~/store';
@@ -25,8 +25,7 @@ const EditTextPart = ({
}) => {
const localize = useLocalize();
const { addedIndex } = useAddedChatContext();
const { conversation } = useMessagesConversation();
const { ask, getMessages, setMessages } = useMessagesOperations();
const { ask, getMessages, setMessages, conversation } = useChatContext();
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
store.latestMessageFamily(addedIndex),
);

View File

@@ -45,28 +45,26 @@ export function useParseArgs(args?: string): ParsedArgs | null {
}
export default function ExecuteCode({
isSubmitting,
initialProgress = 0.1,
args,
output = '',
attachments,
}: {
initialProgress: number;
isSubmitting: boolean;
args?: string;
output?: string;
attachments?: TAttachment[];
}) {
const localize = useLocalize();
const hasOutput = output.length > 0;
const outputRef = useRef<string>(output);
const codeContentRef = useRef<HTMLDivElement>(null);
const [isAnimating, setIsAnimating] = useState(false);
const showAnalysisCode = useRecoilValue(store.showCode);
const [showCode, setShowCode] = useState(showAnalysisCode);
const codeContentRef = useRef<HTMLDivElement>(null);
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
const [isAnimating, setIsAnimating] = useState(false);
const hasOutput = output.length > 0;
const outputRef = useRef<string>(output);
const prevShowCodeRef = useRef<boolean>(showCode);
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
const progress = useProgress(initialProgress);
@@ -138,8 +136,6 @@ export default function ExecuteCode({
};
}, [showCode, isAnimating]);
const cancelled = !isSubmitting && progress < 1;
return (
<>
<div className="relative my-2.5 flex size-5 shrink-0 items-center gap-2.5">
@@ -147,12 +143,9 @@ export default function ExecuteCode({
progress={progress}
onClick={() => setShowCode((prev) => !prev)}
inProgressText={localize('com_ui_analyzing')}
finishedText={
cancelled ? localize('com_ui_cancelled') : localize('com_ui_analyzing_finished')
}
finishedText={localize('com_ui_analyzing_finished')}
hasInput={!!code?.length}
isExpanded={showCode}
error={cancelled}
/>
</div>
<div

View File

@@ -2,7 +2,7 @@ import { memo, useMemo, ReactElement } from 'react';
import { useRecoilValue } from 'recoil';
import MarkdownLite from '~/components/Chat/Messages/Content/MarkdownLite';
import Markdown from '~/components/Chat/Messages/Content/Markdown';
import { useMessageContext } from '~/Providers';
import { useChatContext, useMessageContext } from '~/Providers';
import { cn } from '~/utils';
import store from '~/store';
@@ -18,9 +18,14 @@ type ContentType =
| ReactElement;
const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => {
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const { messageId } = useMessageContext();
const { isSubmitting, latestMessage } = useChatContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]);
const isLatestMessage = useMemo(
() => messageId === latestMessage?.messageId,
[messageId, latestMessage?.messageId],
);
const content: ContentType = useMemo(() => {
if (!isCreatedByUser) {

View File

@@ -21,7 +21,7 @@ type THoverButtons = {
latestMessage: TMessage | null;
isLast: boolean;
index: number;
handleFeedback?: ({ feedback }: { feedback: TFeedback | undefined }) => void;
handleFeedback: ({ feedback }: { feedback: TFeedback | undefined }) => void;
};
type HoverButtonProps = {
@@ -238,7 +238,7 @@ const HoverButtons = ({
/>
{/* Feedback Buttons */}
{!isCreatedByUser && handleFeedback != null && (
{!isCreatedByUser && (
<Feedback handleFeedback={handleFeedback} feedback={message.feedback} isLast={isLast} />
)}

View File

@@ -1,6 +1,7 @@
import React from 'react';
import { useRecoilValue } from 'recoil';
import { useMessageProcess } from '~/hooks';
import type { TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import MessageRender from './ui/MessageRender';
// eslint-disable-next-line import/no-cycle
@@ -28,7 +29,7 @@ const MessageContainer = React.memo(
},
);
export default function Message(props: TMessageProps) {
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
const {
showSibling,
conversation,
@@ -37,7 +38,7 @@ export default function Message(props: TMessageProps) {
latestMultiMessage,
isSubmittingFamily,
} = useMessageProcess({ message: props.message });
const { message, currentEditId, setCurrentEditId } = props;
const { message, currentEditId, setCurrentEditId, costs } = props;
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
if (!message || typeof message !== 'object') {
@@ -62,6 +63,7 @@ export default function Message(props: TMessageProps) {
message={message}
isSubmittingFamily={isSubmittingFamily}
isCard
costs={costs}
/>
<MessageRender
{...props}
@@ -69,12 +71,13 @@ export default function Message(props: TMessageProps) {
isCard
message={siblingMessage ?? latestMultiMessage ?? undefined}
isSubmittingFamily={isSubmittingFamily}
costs={costs}
/>
</div>
</div>
) : (
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<MessageRender {...props} />
<MessageRender {...props} costs={costs} />
</div>
)}
</MessageContainer>
@@ -85,6 +88,7 @@ export default function Message(props: TMessageProps) {
messagesTree={children ?? []}
currentEditId={currentEditId}
setCurrentEditId={setCurrentEditId}
costs={costs}
/>
</>
);

View File

@@ -1,6 +1,6 @@
import React, { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import type { TMessageContentParts } from 'librechat-data-provider';
import type { TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import { useMessageHelpers, useLocalize, useAttachments } from '~/hooks';
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
@@ -12,10 +12,17 @@ import SubRow from './SubRow';
import { cn } from '~/utils';
import store from '~/store';
export default function Message(props: TMessageProps) {
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
const localize = useLocalize();
const { message, siblingIdx, siblingCount, setSiblingIdx, currentEditId, setCurrentEditId } =
props;
const {
message,
siblingIdx,
siblingCount,
setSiblingIdx,
currentEditId,
setCurrentEditId,
costs,
} = props;
const { attachments, searchResults } = useAttachments({
messageId: message?.messageId,
attachments: message?.attachments,
@@ -125,7 +132,6 @@ export default function Message(props: TMessageProps) {
setSiblingIdx={setSiblingIdx}
isCreatedByUser={message.isCreatedByUser}
conversationId={conversation?.conversationId}
isLatestMessage={messageId === latestMessage?.messageId}
content={message.content as Array<TMessageContentParts | undefined>}
/>
</div>
@@ -165,6 +171,7 @@ export default function Message(props: TMessageProps) {
messagesTree={children ?? []}
currentEditId={currentEditId}
setCurrentEditId={setCurrentEditId}
costs={costs}
/>
</>
);

View File

@@ -1,18 +1,21 @@
import { useState } from 'react';
import { useRecoilValue } from 'recoil';
import { CSSTransition } from 'react-transition-group';
import type { TMessage } from 'librechat-data-provider';
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
import { useScreenshot, useMessageScrolling, useLocalize } from '~/hooks';
import ScrollToBottom from '~/components/Messages/ScrollToBottom';
import { MessagesViewProvider } from '~/Providers';
import MultiMessage from './MultiMessage';
import { cn } from '~/utils';
import store from '~/store';
function MessagesViewContent({
export default function MessagesView({
messagesTree: _messagesTree,
costBar,
costs,
}: {
messagesTree?: TMessage[] | null;
costBar?: React.ReactNode;
costs?: TConversationCosts;
}) {
const localize = useLocalize();
const fontSize = useRecoilValue(store.fontSize);
@@ -45,7 +48,7 @@ function MessagesViewContent({
width: '100%',
}}
>
<div className="flex flex-col pb-9 dark:bg-transparent">
<div className="flex flex-col dark:bg-transparent">
{(_messagesTree && _messagesTree.length == 0) || _messagesTree === null ? (
<div
className={cn(
@@ -64,18 +67,25 @@ function MessagesViewContent({
messageId={conversationId ?? null}
setCurrentEditId={setCurrentEditId}
currentEditId={currentEditId ?? null}
costs={costs}
/>
</div>
</>
)}
<div
id="messages-end"
className="group h-0 w-full flex-shrink-0"
className="group h-1 w-full flex-shrink-0 pb-7"
ref={messagesEndRef}
/>
</div>
</div>
{costBar && (
<div className="pointer-events-none absolute bottom-2 left-1/2 z-10 -translate-x-1/2">
{costBar}
</div>
)}
<CSSTransition
in={showScrollButton && scrollButtonPreference}
timeout={{
@@ -93,11 +103,3 @@ function MessagesViewContent({
</>
);
}
export default function MessagesView({ messagesTree }: { messagesTree?: TMessage[] | null }) {
return (
<MessagesViewProvider>
<MessagesViewContent messagesTree={messagesTree} />
</MessagesViewProvider>
);
}

View File

@@ -1,7 +1,7 @@
import { useRecoilState } from 'recoil';
import { useEffect, useCallback } from 'react';
import { isAssistantsEndpoint } from 'librechat-data-provider';
import type { TMessage } from 'librechat-data-provider';
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import MessageContent from '~/components/Messages/MessageContent';
import MessageParts from './MessageParts';
@@ -14,7 +14,8 @@ export default function MultiMessage({
messagesTree,
currentEditId,
setCurrentEditId,
}: TMessageProps) {
costs,
}: TMessageProps & { costs?: TConversationCosts }) {
const [siblingIdx, setSiblingIdx] = useRecoilState(store.messagesSiblingIdxFamily(messageId));
const setSiblingIdxRev = useCallback(
@@ -27,7 +28,7 @@ export default function MultiMessage({
useEffect(() => {
// reset siblingIdx when the tree changes, mostly when a new message is submitting.
setSiblingIdx(0);
}, [messagesTree?.length, setSiblingIdx]);
}, [messagesTree?.length]);
useEffect(() => {
if (messagesTree?.length && siblingIdx >= messagesTree.length) {
@@ -55,6 +56,7 @@ export default function MultiMessage({
siblingIdx={messagesTree.length - siblingIdx - 1}
siblingCount={messagesTree.length}
setSiblingIdx={setSiblingIdxRev}
costs={costs}
/>
);
} else if (message.content) {
@@ -67,6 +69,7 @@ export default function MultiMessage({
siblingIdx={messagesTree.length - siblingIdx - 1}
siblingCount={messagesTree.length}
setSiblingIdx={setSiblingIdxRev}
costs={costs}
/>
);
}
@@ -80,6 +83,7 @@ export default function MultiMessage({
siblingIdx={messagesTree.length - siblingIdx - 1}
siblingCount={messagesTree.length}
setSiblingIdx={setSiblingIdxRev}
costs={costs}
/>
);
}

View File

@@ -1,16 +1,17 @@
import React, { useCallback, useMemo, memo } from 'react';
import { useRecoilValue } from 'recoil';
import { type TMessage } from 'librechat-data-provider';
import { ArrowIcon } from '@librechat/client';
import { type TMessage, TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import MessageContent from '~/components/Chat/Messages/Content/MessageContent';
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
import { useMessageActions, useLocalize } from '~/hooks';
import { Plugin } from '~/components/Messages/Content';
import SubRow from '~/components/Chat/Messages/SubRow';
import { MessageContext } from '~/Providers';
import { useMessageActions } from '~/hooks';
import { cn, logger } from '~/utils';
import store from '~/store';
@@ -19,6 +20,7 @@ type MessageRenderProps = {
isCard?: boolean;
isMultiMessage?: boolean;
isSubmittingFamily?: boolean;
costs?: TConversationCosts;
} & Pick<
TMessageProps,
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
@@ -35,7 +37,9 @@ const MessageRender = memo(
isMultiMessage = false,
setCurrentEditId,
isSubmittingFamily = false,
costs,
}: MessageRenderProps) => {
const localize = useLocalize();
const {
ask,
edit,
@@ -60,6 +64,18 @@ const MessageRender = memo(
});
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
const fontSize = useRecoilValue(store.fontSize);
const showCostTracking = useRecoilValue(store.showCostTracking);
const perMessageCost = useMemo(() => {
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
return null;
}
const entry = costs.perMessage.find((p) => p.messageId === msg.messageId);
if (!entry) {
return null;
}
return entry;
}, [showCostTracking, costs, msg?.messageId]);
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
const hasNoChildren = !(msg?.children?.length ?? 0);
@@ -71,9 +87,6 @@ const MessageRender = memo(
const showCardRender = isLast && !isSubmittingFamily && isCard;
const isLatestCard = isCard && !isSubmittingFamily && isLatestMessage;
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const iconData: TMessageIcon = useMemo(
() => ({
endpoint: msg?.endpoint ?? conversation?.endpoint,
@@ -160,7 +173,26 @@ const MessageRender = memo(
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
)}
>
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
<h2 className={cn('select-none font-semibold', fontSize)}>
{messageLabel}
{perMessageCost && (
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
{perMessageCost.tokenCount > 0 && (
<span>
{perMessageCost.tokenType === 'prompt' ? (
<ArrowIcon direction="up" className="inline" />
) : (
<ArrowIcon direction="down" className="inline" />
)}
{localize('com_ui_token_abbreviation', {
0: perMessageCost.tokenCount,
})}
</span>
)}
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
</span>
)}
</h2>
<div className="flex flex-col gap-1">
<div className="flex max-w-full flex-grow flex-col gap-0">
@@ -169,8 +201,6 @@ const MessageRender = memo(
messageId: msg.messageId,
conversationId: conversation?.conversationId,
isExpanded: false,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
{msg.plugin && <Plugin plugin={msg.plugin} />}
@@ -182,7 +212,7 @@ const MessageRender = memo(
message={msg}
enterEdit={enterEdit}
error={!!(msg.error ?? false)}
isSubmitting={effectiveIsSubmitting}
isSubmitting={isSubmitting}
unfinished={msg.unfinished ?? false}
isCreatedByUser={msg.isCreatedByUser ?? true}
siblingIdx={siblingIdx ?? 0}
@@ -191,7 +221,7 @@ const MessageRender = memo(
</MessageContext.Provider>
</div>
{hasNoChildren && (isSubmittingFamily === true || effectiveIsSubmitting) ? (
{hasNoChildren && (isSubmittingFamily === true || isSubmitting) ? (
<PlaceholderRow isCard={isCard} />
) : (
<SubRow classes="text-xs">

View File

@@ -1,5 +1,6 @@
import { useMemo, memo, type FC, useCallback } from 'react';
import throttle from 'lodash/throttle';
import { parseISO, isToday } from 'date-fns';
import { Spinner, useMediaQuery } from '@librechat/client';
import { List, AutoSizer, CellMeasurer, CellMeasurerCache } from 'react-virtualized';
import { TConversation } from 'librechat-data-provider';
@@ -49,17 +50,27 @@ const MemoizedConvo = memo(
conversation,
retainView,
toggleNav,
isLatestConvo,
}: {
conversation: TConversation;
retainView: () => void;
toggleNav: () => void;
isLatestConvo: boolean;
}) => {
return <Convo conversation={conversation} retainView={retainView} toggleNav={toggleNav} />;
return (
<Convo
conversation={conversation}
retainView={retainView}
toggleNav={toggleNav}
isLatestConvo={isLatestConvo}
/>
);
},
(prevProps, nextProps) => {
return (
prevProps.conversation.conversationId === nextProps.conversation.conversationId &&
prevProps.conversation.title === nextProps.conversation.title &&
prevProps.isLatestConvo === nextProps.isLatestConvo &&
prevProps.conversation.endpoint === nextProps.conversation.endpoint
);
},
@@ -87,6 +98,13 @@ const Conversations: FC<ConversationsProps> = ({
[filteredConversations],
);
const firstTodayConvoId = useMemo(
() =>
filteredConversations.find((convo) => convo.updatedAt && isToday(parseISO(convo.updatedAt)))
?.conversationId ?? undefined,
[filteredConversations],
);
const flattenedItems = useMemo(() => {
const items: FlattenedItem[] = [];
groupedConversations.forEach(([groupName, convos]) => {
@@ -136,25 +154,26 @@ const Conversations: FC<ConversationsProps> = ({
</CellMeasurer>
);
}
let rendering: JSX.Element;
if (item.type === 'header') {
rendering = <DateLabel groupName={item.groupName} />;
} else if (item.type === 'convo') {
rendering = (
<MemoizedConvo conversation={item.convo} retainView={moveToTop} toggleNav={toggleNav} />
);
}
return (
<CellMeasurer cache={cache} columnIndex={0} key={key} parent={parent} rowIndex={index}>
{({ registerChild }) => (
<div ref={registerChild} style={style}>
{rendering}
{item.type === 'header' ? (
<DateLabel groupName={item.groupName} />
) : item.type === 'convo' ? (
<MemoizedConvo
conversation={item.convo}
retainView={moveToTop}
toggleNav={toggleNav}
isLatestConvo={item.convo.conversationId === firstTodayConvoId}
/>
) : null}
</div>
)}
</CellMeasurer>
);
},
[cache, flattenedItems, moveToTop, toggleNav],
[cache, flattenedItems, firstTodayConvoId, moveToTop, toggleNav],
);
const getRowHeight = useCallback(

View File

@@ -11,17 +11,23 @@ import { useGetEndpointsQuery } from '~/data-provider';
import { NotificationSeverity } from '~/common';
import { ConvoOptions } from './ConvoOptions';
import RenameForm from './RenameForm';
import { cn, logger } from '~/utils';
import ConvoLink from './ConvoLink';
import { cn } from '~/utils';
import store from '~/store';
interface ConversationProps {
conversation: TConversation;
retainView: () => void;
toggleNav: () => void;
isLatestConvo: boolean;
}
export default function Conversation({ conversation, retainView, toggleNav }: ConversationProps) {
export default function Conversation({
conversation,
retainView,
toggleNav,
isLatestConvo,
}: ConversationProps) {
const params = useParams();
const localize = useLocalize();
const { showToast } = useToastContext();
@@ -78,7 +84,6 @@ export default function Conversation({ conversation, retainView, toggleNav }: Co
});
setRenaming(false);
} catch (error) {
logger.error('Error renaming conversation', error);
setTitleInput(title as string);
showToast({
message: localize('com_ui_rename_failed'),

View File

@@ -66,7 +66,7 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
placeholder={
hasValue
? localize('com_ui_mcp_update_var', { 0: config.title })
: `${localize('com_ui_mcp_enter_var', { 0: config.title })} ${localize('com_ui_optional')}`
: localize('com_ui_mcp_enter_var', { 0: config.title })
}
className="w-full rounded border border-border-medium bg-transparent px-2 py-1 text-text-primary placeholder:text-text-secondary focus:outline-none sm:text-sm"
/>

View File

@@ -1,13 +1,14 @@
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import { useCallback, useMemo, memo } from 'react';
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
import type { TMessage, TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import ContentParts from '~/components/Chat/Messages/Content/ContentParts';
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
import { useAttachments, useMessageActions, useLocalize } from '~/hooks';
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
import { useAttachments, useMessageActions } from '~/hooks';
import SubRow from '~/components/Chat/Messages/SubRow';
import { cn, logger } from '~/utils';
import store from '~/store';
@@ -17,6 +18,7 @@ type ContentRenderProps = {
isCard?: boolean;
isMultiMessage?: boolean;
isSubmittingFamily?: boolean;
costs?: TConversationCosts;
} & Pick<
TMessageProps,
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
@@ -33,7 +35,9 @@ const ContentRender = memo(
isMultiMessage = false,
setCurrentEditId,
isSubmittingFamily = false,
costs,
}: ContentRenderProps) => {
const localize = useLocalize();
const { attachments, searchResults } = useAttachments({
messageId: msg?.messageId,
attachments: msg?.attachments,
@@ -62,6 +66,14 @@ const ContentRender = memo(
});
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
const fontSize = useRecoilValue(store.fontSize);
const showCostTracking = useRecoilValue(store.showCostTracking);
const perMessageCost = useMemo(() => {
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
return null;
}
return costs.perMessage.find((p) => p.messageId === msg.messageId) ?? null;
}, [showCostTracking, costs, msg?.messageId]);
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
const isLast = useMemo(
@@ -159,7 +171,26 @@ const ContentRender = memo(
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
)}
>
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
<h2 className={cn('select-none font-semibold', fontSize)}>
{messageLabel}
{perMessageCost && (
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
{perMessageCost.tokenCount > 0 && (
<span className="mr-2">
{perMessageCost.tokenType === 'prompt' ? (
<ArrowIcon direction="up" className="inline" />
) : (
<ArrowIcon direction="down" className="inline" />
)}
{localize('com_ui_token_abbreviation', {
0: perMessageCost.tokenCount,
})}
</span>
)}
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
</span>
)}
</h2>
<div className="flex flex-col gap-1">
<div className="flex max-w-full flex-grow flex-col gap-0">
@@ -173,7 +204,6 @@ const ContentRender = memo(
isSubmitting={isSubmitting}
searchResults={searchResults}
setSiblingIdx={setSiblingIdx}
isLatestMessage={isLatestMessage}
isCreatedByUser={msg.isCreatedByUser}
conversationId={conversation?.conversationId}
content={msg.content as Array<TMessageContentParts | undefined>}

View File

@@ -1,5 +1,6 @@
import React from 'react';
import { useMessageProcess } from '~/hooks';
import type { TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
// eslint-disable-next-line import/no-cycle
import MultiMessage from '~/components/Chat/Messages/MultiMessage';
@@ -25,7 +26,7 @@ const MessageContainer = React.memo(
},
);
export default function MessageContent(props: TMessageProps) {
export default function MessageContent(props: TMessageProps & { costs?: TConversationCosts }) {
const {
showSibling,
conversation,
@@ -34,7 +35,7 @@ export default function MessageContent(props: TMessageProps) {
latestMultiMessage,
isSubmittingFamily,
} = useMessageProcess({ message: props.message });
const { message, currentEditId, setCurrentEditId } = props;
const { message, currentEditId, setCurrentEditId, costs } = props;
if (!message || typeof message !== 'object') {
return null;
@@ -53,6 +54,7 @@ export default function MessageContent(props: TMessageProps) {
message={message}
isSubmittingFamily={isSubmittingFamily}
isCard
costs={costs}
/>
<ContentRender
{...props}
@@ -60,12 +62,13 @@ export default function MessageContent(props: TMessageProps) {
isCard
message={siblingMessage ?? latestMultiMessage ?? undefined}
isSubmittingFamily={isSubmittingFamily}
costs={costs}
/>
</div>
</div>
) : (
<div className="m-auto justify-center p-4 py-2 md:gap-6 ">
<ContentRender {...props} />
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<ContentRender {...props} costs={costs} />
</div>
)}
</MessageContainer>
@@ -76,6 +79,7 @@ export default function MessageContent(props: TMessageProps) {
messagesTree={children ?? []}
currentEditId={currentEditId}
setCurrentEditId={setCurrentEditId}
costs={costs}
/>
</>
);

View File

@@ -24,45 +24,35 @@ const SearchBar = forwardRef((props: SearchBarProps, ref: React.Ref<HTMLDivEleme
const inputRef = useRef<HTMLInputElement>(null);
const [showClearIcon, setShowClearIcon] = useState(false);
const { newConversation: newConvo } = useNewConvo();
const { newConversation } = useNewConvo();
const [search, setSearchState] = useRecoilState(store.search);
const clearSearch = useCallback(
(pathname?: string) => {
if (pathname?.includes('/search') || pathname === '/c/new') {
queryClient.removeQueries([QueryKeys.messages]);
newConvo({ disableFocus: true });
navigate('/c/new');
}
},
[newConvo, navigate, queryClient],
);
const clearSearch = useCallback(() => {
if (location.pathname.includes('/search')) {
newConversation({ disableFocus: true });
navigate('/c/new', { replace: true });
}
}, [newConversation, location.pathname, navigate]);
const clearText = useCallback(
(pathname?: string) => {
setShowClearIcon(false);
setText('');
setSearchState((prev) => ({
...prev,
query: '',
debouncedQuery: '',
isTyping: false,
}));
clearSearch(pathname);
inputRef.current?.focus();
},
[setSearchState, clearSearch],
);
const clearText = useCallback(() => {
setShowClearIcon(false);
setText('');
setSearchState((prev) => ({
...prev,
query: '',
debouncedQuery: '',
isTyping: false,
}));
clearSearch();
inputRef.current?.focus();
}, [setSearchState, clearSearch]);
const handleKeyUp = useCallback(
(e: React.KeyboardEvent<HTMLInputElement>) => {
const { value } = e.target as HTMLInputElement;
if (e.key === 'Backspace' && value === '') {
clearText(location.pathname);
}
},
[clearText, location.pathname],
);
const handleKeyUp = (e: React.KeyboardEvent<HTMLInputElement>) => {
const { value } = e.target as HTMLInputElement;
if (e.key === 'Backspace' && value === '') {
clearText();
}
};
const sendRequest = useCallback(
(value: string) => {
@@ -95,6 +85,8 @@ const SearchBar = forwardRef((props: SearchBarProps, ref: React.Ref<HTMLDivEleme
debouncedSetDebouncedQuery(value);
if (value.length > 0 && location.pathname !== '/search') {
navigate('/search', { replace: true });
} else if (value.length === 0 && location.pathname === '/search') {
navigate('/c/new', { replace: true });
}
};
@@ -140,7 +132,7 @@ const SearchBar = forwardRef((props: SearchBarProps, ref: React.Ref<HTMLDivEleme
showClearIcon ? 'opacity-100' : 'opacity-0',
isSmallScreen === true ? 'right-[16px]' : '',
)}
onClick={() => clearText(location.pathname)}
onClick={clearText}
tabIndex={showClearIcon ? 0 : -1}
disabled={!showClearIcon}
>

View File

@@ -76,6 +76,13 @@ const toggleSwitchConfigs = [
hoverCardText: undefined,
key: 'modularChat',
},
{
stateAtom: store.showCostTracking,
localizationKey: 'com_nav_show_cost_tracking',
switchId: 'showCostTracking',
hoverCardText: 'com_nav_info_show_cost_tracking',
key: 'showCostTracking',
},
];
function Chat() {

View File

@@ -6,7 +6,6 @@ import { LocalStorageKeys } from 'librechat-data-provider';
import { useFormContext, Controller } from 'react-hook-form';
import type { MenuItemProps } from '@librechat/client';
import type { ReactNode } from 'react';
import { usePromptGroupsContext } from '~/Providers';
import { useCategories } from '~/hooks';
import { cn } from '~/utils';
@@ -23,9 +22,8 @@ const CategorySelector: React.FC<CategorySelectorProps> = ({
}) => {
const { t } = useTranslation();
const formContext = useFormContext();
const { categories, emptyCategory } = useCategories();
const [isOpen, setIsOpen] = useState(false);
const { hasAccess } = usePromptGroupsContext();
const { categories, emptyCategory } = useCategories({ hasAccess });
const control = formContext?.control;
const watch = formContext?.watch;

View File

@@ -7,7 +7,6 @@ import CategorySelector from '~/components/Prompts/Groups/CategorySelector';
import VariablesDropdown from '~/components/Prompts/VariablesDropdown';
import PromptVariables from '~/components/Prompts/PromptVariables';
import Description from '~/components/Prompts/Description';
import { usePromptGroupsContext } from '~/Providers';
import { useLocalize, useHasAccess } from '~/hooks';
import Command from '~/components/Prompts/Command';
import { useCreatePrompt } from '~/data-provider';
@@ -38,12 +37,10 @@ const CreatePromptForm = ({
}) => {
const localize = useLocalize();
const navigate = useNavigate();
const { hasAccess: hasUseAccess } = usePromptGroupsContext();
const hasCreateAccess = useHasAccess({
const hasAccess = useHasAccess({
permissionType: PermissionTypes.PROMPTS,
permission: Permissions.CREATE,
});
const hasAccess = hasUseAccess && hasCreateAccess;
useEffect(() => {
let timeoutId: ReturnType<typeof setTimeout>;

View File

@@ -11,8 +11,8 @@ import store from '~/store';
export default function FilterPrompts({ className = '' }: { className?: string }) {
const localize = useLocalize();
const { name, setName, hasAccess } = usePromptGroupsContext();
const { categories } = useCategories({ className: 'h-4 w-4', hasAccess });
const { name, setName } = usePromptGroupsContext();
const { categories } = useCategories('h-4 w-4');
const [displayName, setDisplayName] = useState(name || '');
const [isSearching, setIsSearching] = useState(false);
const [categoryFilter, setCategory] = useRecoilState(store.promptsCategory);

View File

@@ -167,7 +167,6 @@ const PromptForm = () => {
const params = useParams();
const localize = useLocalize();
const { showToast } = useToastContext();
const { hasAccess } = usePromptGroupsContext();
const alwaysMakeProd = useRecoilValue(store.alwaysMakeProd);
const promptId = params.promptId || '';
@@ -180,12 +179,10 @@ const PromptForm = () => {
const [showSidePanel, setShowSidePanel] = useState(false);
const sidePanelWidth = '320px';
const { data: group, isLoading: isLoadingGroup } = useGetPromptGroup(promptId, {
enabled: hasAccess && !!promptId,
});
const { data: group, isLoading: isLoadingGroup } = useGetPromptGroup(promptId);
const { data: prompts = [], isLoading: isLoadingPrompts } = useGetPrompts(
{ groupId: promptId },
{ enabled: hasAccess && !!promptId },
{ enabled: !!promptId },
);
const { hasPermission, isLoading: permissionsLoading } = useResourcePermissions(

View File

@@ -76,8 +76,6 @@ export default function Message(props: TMessageProps) {
messageId,
isExpanded: false,
conversationId: conversation?.conversationId,
isSubmitting: false, // Share view is always read-only
isLatestMessage: false, // No concept of latest message in share view
}}
>
{/* Legacy Plugins */}

View File

@@ -1,26 +1,27 @@
import { useEffect } from 'react';
import { ChevronLeft } from 'lucide-react';
import { useForm, FormProvider } from 'react-hook-form';
import {
AuthTypeEnum,
AuthorizationTypeEnum,
TokenExchangeMethodEnum,
} from 'librechat-data-provider';
import {
Label,
OGDialog,
TrashIcon,
OGDialogTrigger,
useToastContext,
OGDialogTemplate,
TrashIcon,
OGDialog,
OGDialogTrigger,
Label,
useToastContext,
} from '@librechat/client';
import type { ActionAuthForm } from '~/common';
import ActionsAuth from '~/components/SidePanel/Builder/ActionsAuth';
import { useAgentPanelContext } from '~/Providers/AgentPanelContext';
import { useDeleteAgentAction } from '~/data-provider';
import { Panel, isEphemeralAgent } from '~/common';
import type { ActionAuthForm } from '~/common';
import ActionsInput from './ActionsInput';
import { useLocalize } from '~/hooks';
import { Panel } from '~/common';
export default function ActionsPanel() {
const localize = useLocalize();
@@ -108,7 +109,7 @@ export default function ActionsPanel() {
<div className="absolute right-0 top-6">
<button
type="button"
disabled={isEphemeralAgent(agent_id) || !action.action_id}
disabled={!agent_id || !action.action_id}
className="btn btn-neutral border-token-border-light relative h-9 rounded-lg font-medium"
>
<TrashIcon className="text-red-500" />
@@ -126,7 +127,7 @@ export default function ActionsPanel() {
}
selection={{
selectHandler: () => {
if (isEphemeralAgent(agent_id)) {
if (!agent_id) {
return showToast({
message: localize('com_agents_no_agent_id_error'),
status: 'error',
@@ -134,7 +135,7 @@ export default function ActionsPanel() {
}
deleteAgentAction.mutate({
action_id: action.action_id,
agent_id: agent_id || '',
agent_id,
});
},
selectClasses:

View File

@@ -18,7 +18,6 @@ import { useFileMapContext, useAgentPanelContext } from '~/Providers';
import AgentCategorySelector from './AgentCategorySelector';
import Action from '~/components/SidePanel/Builder/Action';
import { useLocalize, useVisibleTools } from '~/hooks';
import { Panel, isEphemeralAgent } from '~/common';
import { useGetAgentFiles } from '~/data-provider';
import { icons } from '~/hooks/Endpoint/Icons';
import Instructions from './Instructions';
@@ -30,6 +29,7 @@ import Artifacts from './Artifacts';
import AgentTool from './AgentTool';
import CodeForm from './Code/Form';
import MCPTools from './MCPTools';
import { Panel } from '~/common';
const labelClass = 'mb-2 text-token-text-primary block font-medium';
const inputClass = cn(
@@ -48,12 +48,12 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
const {
actions,
setAction,
regularTools,
agentsConfig,
startupConfig,
mcpServersMap,
setActivePanel,
endpointsConfig,
groupedTools: allTools,
} = useAgentPanelContext();
const {
@@ -79,9 +79,9 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
}, [fileMap, agentFiles]);
const {
ocrEnabled,
codeEnabled,
toolsEnabled,
contextEnabled,
actionsEnabled,
artifactsEnabled,
webSearchEnabled,
@@ -149,7 +149,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
}, [agent, agent_id, mergedFileMap]);
const handleAddActions = useCallback(() => {
if (isEphemeralAgent(agent_id)) {
if (!agent_id) {
showToast({
message: localize('com_assistants_actions_disabled'),
status: 'warning',
@@ -177,7 +177,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
Icon = icons[iconKey];
}
const { toolIds, mcpServerNames } = useVisibleTools(tools, regularTools, mcpServersMap);
const { toolIds, mcpServerNames } = useVisibleTools(tools, allTools, mcpServersMap);
return (
<>
@@ -291,7 +291,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
{(codeEnabled ||
fileSearchEnabled ||
artifactsEnabled ||
contextEnabled ||
ocrEnabled ||
webSearchEnabled) && (
<div className="mb-4 flex w-full flex-col items-start gap-3">
<label className="text-token-text-primary block font-medium">
@@ -301,8 +301,8 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
{codeEnabled && <CodeForm agent_id={agent_id} files={code_files} />}
{/* Web Search */}
{webSearchEnabled && <SearchForm />}
{/* File Context */}
{contextEnabled && <FileContext agent_id={agent_id} files={context_files} />}
{/* File Context (OCR) */}
{ocrEnabled && <FileContext agent_id={agent_id} files={context_files} />}
{/* Artifacts */}
{artifactsEnabled && <Artifacts />}
{/* File Search */}
@@ -326,15 +326,16 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
</label>
<div>
<div className="mb-1">
{/* Render all visible IDs */}
{/* Render all visible IDs (including groups with subtools selected) */}
{toolIds.map((toolId, i) => {
const tool = regularTools?.find((t) => t.pluginKey === toolId);
if (!allTools) return null;
const tool = allTools[toolId];
if (!tool) return null;
return (
<AgentTool
key={`${toolId}-${i}-${agent_id}`}
tool={toolId}
regularTools={regularTools}
allTools={allTools}
agent_id={agent_id}
/>
);
@@ -370,7 +371,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
{(actionsEnabled ?? false) && (
<button
type="button"
disabled={isEphemeralAgent(agent_id)}
disabled={!agent_id}
onClick={handleAddActions}
className="btn btn-neutral border-token-border-light relative h-9 w-full rounded-lg font-medium"
aria-haspopup="dialog"
@@ -473,15 +474,13 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
setIsOpen={setShowToolDialog}
endpoint={EModelEndpoint.agents}
/>
{startupConfig?.mcpServers != null && (
<MCPToolSelectDialog
agentId={agent_id}
isOpen={showMCPToolDialog}
mcpServerNames={mcpServerNames}
setIsOpen={setShowMCPToolDialog}
endpoint={EModelEndpoint.agents}
/>
)}
<MCPToolSelectDialog
agentId={agent_id}
isOpen={showMCPToolDialog}
mcpServerNames={mcpServerNames}
setIsOpen={setShowMCPToolDialog}
endpoint={EModelEndpoint.agents}
/>
</>
);
}

Some files were not shown because too many files have changed in this diff Show More