Compare commits
20 Commits
rel/v0.8.0
...
feat/price
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b034624690 | ||
|
|
adff605c50 | ||
|
|
465c81adee | ||
|
|
cb8e76e27e | ||
|
|
4d9e17efe1 | ||
|
|
95ebef13df | ||
|
|
4fb9d7bdff | ||
|
|
0edfecf44a | ||
|
|
ba8c09b361 | ||
|
|
794fe6fd11 | ||
|
|
97ac52fc6c | ||
|
|
1a947607a5 | ||
|
|
1745708418 | ||
|
|
14aedac1e1 | ||
|
|
a820d79bfc | ||
|
|
3b1c07ff46 | ||
|
|
c1b0f13360 | ||
|
|
637bbd2e29 | ||
|
|
30e1b421ba | ||
|
|
fb89f60470 |
@@ -233,6 +233,7 @@ class BaseClient {
|
||||
sender: 'User',
|
||||
text,
|
||||
isCreatedByUser: true,
|
||||
targetModel: this.modelOptions?.model ?? this.model,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -68,19 +68,19 @@ const primeFiles = async (options) => {
|
||||
/**
|
||||
*
|
||||
* @param {Object} options
|
||||
* @param {string} options.userId
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Array<{ file_id: string; filename: string }>} options.files
|
||||
* @param {string} [options.entity_id]
|
||||
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
|
||||
* @returns
|
||||
*/
|
||||
const createFileSearchTool = async ({ userId, files, entity_id, fileCitations = false }) => {
|
||||
const createFileSearchTool = async ({ req, files, entity_id, fileCitations = false }) => {
|
||||
return tool(
|
||||
async ({ query }) => {
|
||||
if (files.length === 0) {
|
||||
return 'No files to search. Instruct the user to add files for the search.';
|
||||
}
|
||||
const jwtToken = generateShortLivedToken(userId);
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
if (!jwtToken) {
|
||||
return 'There was an error authenticating the file search request.';
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSe
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
@@ -250,6 +250,7 @@ const loadTools = async ({
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
const toolContextMap = {};
|
||||
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
||||
const requestedMCPTools = {};
|
||||
|
||||
for (const tool of tools) {
|
||||
@@ -306,7 +307,7 @@ const loadTools = async ({
|
||||
}
|
||||
|
||||
return createFileSearchTool({
|
||||
userId: user,
|
||||
req: options.req,
|
||||
files,
|
||||
entity_id: agent?.id,
|
||||
fileCitations,
|
||||
@@ -339,7 +340,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
});
|
||||
};
|
||||
continue;
|
||||
} else if (tool && mcpToolPattern.test(tool)) {
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
|
||||
if (toolName === Constants.mcp_server) {
|
||||
/** Placeholder used for UI purposes */
|
||||
@@ -352,21 +353,33 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
continue;
|
||||
}
|
||||
if (toolName === Constants.mcp_all) {
|
||||
requestedMCPTools[serverName] = [
|
||||
{
|
||||
type: 'all',
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTools({
|
||||
req: options.req,
|
||||
res: options.res,
|
||||
index,
|
||||
serverName,
|
||||
},
|
||||
];
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
});
|
||||
requestedMCPTools[serverName] = [currentMCPGenerator];
|
||||
continue;
|
||||
}
|
||||
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTool({
|
||||
index,
|
||||
req: options.req,
|
||||
res: options.res,
|
||||
toolKey: tool,
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
});
|
||||
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
|
||||
requestedMCPTools[serverName].push({
|
||||
type: 'single',
|
||||
toolKey: tool,
|
||||
serverName,
|
||||
});
|
||||
requestedMCPTools[serverName].push(currentMCPGenerator);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -409,64 +422,24 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
const mcpToolPromises = [];
|
||||
/** MCP server tools are initialized sequentially by server */
|
||||
let index = -1;
|
||||
const failedMCPServers = new Set();
|
||||
for (const [serverName, toolConfigs] of Object.entries(requestedMCPTools)) {
|
||||
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
|
||||
index++;
|
||||
/** @type {LCAvailableTools} */
|
||||
let availableTools;
|
||||
for (const config of toolConfigs) {
|
||||
for (const generator of generators) {
|
||||
try {
|
||||
if (failedMCPServers.has(serverName)) {
|
||||
continue;
|
||||
}
|
||||
const mcpParams = {
|
||||
res: options.res,
|
||||
userId: user,
|
||||
index,
|
||||
serverName: config.serverName,
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
};
|
||||
|
||||
if (config.type === 'all' && toolConfigs.length === 1) {
|
||||
/** Handle async loading for single 'all' tool config */
|
||||
if (generator && generators.length === 1) {
|
||||
mcpToolPromises.push(
|
||||
createMCPTools(mcpParams).catch((error) => {
|
||||
generator(index).catch((error) => {
|
||||
logger.error(`Error loading ${serverName} tools:`, error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (!availableTools) {
|
||||
try {
|
||||
availableTools = await getMCPServerTools(serverName);
|
||||
} catch (error) {
|
||||
logger.error(`Error fetching available tools for MCP server ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
/** Handle synchronous loading */
|
||||
const mcpTool =
|
||||
config.type === 'all'
|
||||
? await createMCPTools(mcpParams)
|
||||
: await createMCPTool({
|
||||
...mcpParams,
|
||||
availableTools,
|
||||
toolKey: config.toolKey,
|
||||
});
|
||||
|
||||
const mcpTool = await generator(index);
|
||||
if (Array.isArray(mcpTool)) {
|
||||
loadedTools.push(...mcpTool);
|
||||
} else if (mcpTool) {
|
||||
loadedTools.push(mcpTool);
|
||||
} else {
|
||||
failedMCPServers.add(serverName);
|
||||
logger.warn(
|
||||
`MCP tool creation failed for "${config.toolKey}", server may be unavailable or unauthenticated.`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error loading MCP tool for server ${serverName}:`, error);
|
||||
|
||||
25
api/cache/cacheConfig.js
vendored
25
api/cache/cacheConfig.js
vendored
@@ -1,5 +1,4 @@
|
||||
const fs = require('fs');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { math, isEnabled } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
|
||||
@@ -35,35 +34,13 @@ if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper function to safely read Redis CA certificate from file
|
||||
* @returns {string|null} The contents of the CA certificate file, or null if not set or on error
|
||||
*/
|
||||
const getRedisCA = () => {
|
||||
const caPath = process.env.REDIS_CA;
|
||||
if (!caPath) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
if (fs.existsSync(caPath)) {
|
||||
return fs.readFileSync(caPath, 'utf8');
|
||||
} else {
|
||||
logger.warn(`Redis CA certificate file not found: ${caPath}`);
|
||||
return null;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to read Redis CA certificate file '${caPath}':`, error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const cacheConfig = {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES,
|
||||
USE_REDIS,
|
||||
REDIS_URI: process.env.REDIS_URI,
|
||||
REDIS_USERNAME: process.env.REDIS_USERNAME,
|
||||
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
|
||||
REDIS_CA: getRedisCA(),
|
||||
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
|
||||
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
|
||||
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
|
||||
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
@@ -26,6 +26,4 @@ module.exports = {
|
||||
createMCPManager: MCPManager.createInstance,
|
||||
getMCPManager: MCPManager.getInstance,
|
||||
getFlowStateManager,
|
||||
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
|
||||
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@ const {
|
||||
getProjectByName,
|
||||
} = require('./Project');
|
||||
const { removeAllPermissions } = require('~/server/services/PermissionService');
|
||||
const { getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { getActions } = require('./Action');
|
||||
const { Agent } = require('~/db/models');
|
||||
|
||||
@@ -49,14 +49,6 @@ const createAgent = async (agentData) => {
|
||||
*/
|
||||
const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean();
|
||||
|
||||
/**
|
||||
* Get multiple agent documents based on the provided search parameters.
|
||||
*
|
||||
* @param {Object} searchParameter - The search parameters to find agents.
|
||||
* @returns {Promise<Agent[]>} Array of agent documents as plain objects.
|
||||
*/
|
||||
const getAgents = async (searchParameter) => await Agent.find(searchParameter).lean();
|
||||
|
||||
/**
|
||||
* Load an agent based on the provided ID
|
||||
*
|
||||
@@ -69,6 +61,8 @@ const getAgents = async (searchParameter) => await Agent.find(searchParameter).l
|
||||
*/
|
||||
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
/** @type {Record<string, FunctionTool>} */
|
||||
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
|
||||
/** @type {TEphemeralAgent | null} */
|
||||
const ephemeralAgent = req.body.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
@@ -86,18 +80,22 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const toolName of Object.keys(availableTools)) {
|
||||
if (!toolName.includes(mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
const mcpServer = toolName.split(mcp_delimiter)?.[1];
|
||||
if (mcpServer && mcpServers.has(mcpServer)) {
|
||||
addedServers.add(mcpServer);
|
||||
tools.push(toolName);
|
||||
}
|
||||
}
|
||||
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
const serverTools = await getMCPServerTools(mcpServer);
|
||||
if (!serverTools) {
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
addedServers.add(mcpServer);
|
||||
continue;
|
||||
}
|
||||
tools.push(...Object.keys(serverTools));
|
||||
addedServers.add(mcpServer);
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -837,7 +835,6 @@ const countPromotedAgents = async () => {
|
||||
|
||||
module.exports = {
|
||||
getAgent,
|
||||
getAgents,
|
||||
loadAgent,
|
||||
createAgent,
|
||||
updateAgent,
|
||||
|
||||
@@ -8,7 +8,6 @@ process.env.CREDS_IV = '0123456789abcdef';
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn(),
|
||||
getMCPServerTools: jest.fn(),
|
||||
}));
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
@@ -31,7 +30,7 @@ const {
|
||||
generateActionMetadataHash,
|
||||
} = require('./Agent');
|
||||
const permissionService = require('~/server/services/PermissionService');
|
||||
const { getCachedTools, getMCPServerTools } = require('~/server/services/Config');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { AclEntry } = require('~/db/models');
|
||||
|
||||
/**
|
||||
@@ -1930,16 +1929,6 @@ describe('models/Agent', () => {
|
||||
another_tool: {},
|
||||
});
|
||||
|
||||
// Mock getMCPServerTools to return tools for each server
|
||||
getMCPServerTools.mockImplementation(async (server) => {
|
||||
if (server === 'server1') {
|
||||
return { tool1_mcp_server1: {} };
|
||||
} else if (server === 'server2') {
|
||||
return { tool2_mcp_server2: {} };
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
@@ -2124,14 +2113,6 @@ describe('models/Agent', () => {
|
||||
|
||||
getCachedTools.mockResolvedValue(availableTools);
|
||||
|
||||
// Mock getMCPServerTools to return all tools for server1
|
||||
getMCPServerTools.mockImplementation(async (server) => {
|
||||
if (server === 'server1') {
|
||||
return availableTools; // All 100 tools belong to server1
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
@@ -2673,17 +2654,6 @@ describe('models/Agent', () => {
|
||||
tool_mcp_server2: {}, // Different server
|
||||
});
|
||||
|
||||
// Mock getMCPServerTools to return only tools matching the server
|
||||
getMCPServerTools.mockImplementation(async (server) => {
|
||||
if (server === 'server1') {
|
||||
// Only return tool that correctly matches server1 format
|
||||
return { tool_mcp_server1: {} };
|
||||
} else if (server === 'server2') {
|
||||
return { tool_mcp_server2: {} };
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
const mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {
|
||||
|
||||
@@ -112,8 +112,17 @@ module.exports = {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
|
||||
/** @type {{ $set: Partial<TConversation>; $addToSet?: Record<string, any>; $unset?: Record<keyof TConversation, number> }} */
|
||||
const updateOperation = { $set: update };
|
||||
|
||||
if (convo.model && convo.endpoint) {
|
||||
updateOperation.$addToSet = {
|
||||
modelHistory: {
|
||||
model: convo.model,
|
||||
endpoint: convo.endpoint,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
|
||||
updateOperation.$unset = metadata.unsetFields;
|
||||
}
|
||||
|
||||
@@ -239,46 +239,10 @@ const updateTagsForConversation = async (user, conversationId, tags) => {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Increments tag counts for existing tags only.
|
||||
* @param {string} user - The user ID.
|
||||
* @param {string[]} tags - Array of tag names to increment
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const bulkIncrementTagCounts = async (user, tags) => {
|
||||
if (!tags || tags.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const uniqueTags = [...new Set(tags.filter(Boolean))];
|
||||
if (uniqueTags.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bulkOps = uniqueTags.map((tag) => ({
|
||||
updateOne: {
|
||||
filter: { user, tag },
|
||||
update: { $inc: { count: 1 } },
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await ConversationTag.bulkWrite(bulkOps);
|
||||
if (result && result.modifiedCount > 0) {
|
||||
logger.debug(
|
||||
`user: ${user} | Incremented tag counts - modified ${result.modifiedCount} tags`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[bulkIncrementTagCounts] Error incrementing tag counts', error);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getConversationTags,
|
||||
createConversationTag,
|
||||
updateConversationTag,
|
||||
deleteConversationTag,
|
||||
bulkIncrementTagCounts,
|
||||
updateTagsForConversation,
|
||||
};
|
||||
|
||||
@@ -42,7 +42,7 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => {
|
||||
$or: [],
|
||||
};
|
||||
|
||||
if (toolResourceSet.has(EToolResources.context)) {
|
||||
if (toolResourceSet.has(EToolResources.ocr)) {
|
||||
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
|
||||
}
|
||||
if (toolResourceSet.has(EToolResources.file_search)) {
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.80",
|
||||
"@librechat/agents": "^2.4.79",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const openIdClient = require('openid-client');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, findOpenIDUser } = require('@librechat/api');
|
||||
const {
|
||||
requestPasswordReset,
|
||||
setOpenIDAuthTokens,
|
||||
@@ -11,9 +11,8 @@ const {
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
@@ -72,14 +71,8 @@ const refreshController = async (req, res) => {
|
||||
const openIdConfig = getOpenIdConfig();
|
||||
const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken);
|
||||
const claims = tokenset.claims();
|
||||
const { user, error } = await findOpenIDUser({
|
||||
findUser,
|
||||
email: claims.email,
|
||||
openidId: claims.sub,
|
||||
idOnTheSource: claims.oid,
|
||||
strategyName: 'refreshController',
|
||||
});
|
||||
if (error || !user) {
|
||||
const user = await findUser({ email: claims.email });
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
|
||||
@@ -103,25 +96,14 @@ const refreshController = async (req, res) => {
|
||||
return res.status(200).send({ token, user });
|
||||
}
|
||||
|
||||
/** Session with the hashed refresh token */
|
||||
const session = await findSession(
|
||||
{
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
},
|
||||
{ lean: false },
|
||||
);
|
||||
// Find the session with the hashed refresh token
|
||||
const session = await findSession({
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
});
|
||||
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session);
|
||||
|
||||
// trigger OAuth MCP server reconnection asynchronously (best effort)
|
||||
void getOAuthReconnectionManager()
|
||||
.reconnectServers(userId)
|
||||
.catch((err) => {
|
||||
logger.error('Error reconnecting OAuth MCP servers:', err);
|
||||
});
|
||||
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
res.status(200).send({ token, user });
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
@@ -132,7 +114,7 @@ const refreshController = async (req, res) => {
|
||||
res.status(401).send('Refresh token expired or not found for this user');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`[refreshController] Invalid refresh token:`, err);
|
||||
logger.error(`[refreshController] Refresh token: ${refreshToken}`, err);
|
||||
res.status(403).send('Invalid refresh token');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getToolkitKey, checkPluginAuth, filterUniquePlugins } = require('@librechat/api');
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
getToolkitKey,
|
||||
checkPluginAuth,
|
||||
filterUniquePlugins,
|
||||
convertMCPToolToPlugin,
|
||||
convertMCPToolsToPlugins,
|
||||
} = require('@librechat/api');
|
||||
const { getCachedTools, setCachedTools, mergeUserTools } = require('~/server/services/Config');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const getAvailablePluginsController = async (req, res) => {
|
||||
@@ -65,27 +72,63 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
|
||||
// Return early if we have cached tools
|
||||
if (cachedToolsArray != null) {
|
||||
res.status(200).json(cachedToolsArray);
|
||||
/** @type {TPlugin[]} */
|
||||
let mcpPlugins;
|
||||
if (appConfig?.mcpConfig) {
|
||||
const mcpManager = getMCPManager();
|
||||
mcpPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
|
||||
: undefined;
|
||||
}
|
||||
|
||||
if (
|
||||
cachedToolsArray != null &&
|
||||
(appConfig?.mcpConfig != null ? mcpPlugins != null && mcpPlugins.length > 0 : true)
|
||||
) {
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {Record<string, FunctionTool> | null} Get tool definitions to filter which tools are actually available */
|
||||
let toolDefinitions = await getCachedTools();
|
||||
|
||||
if (toolDefinitions == null && appConfig?.availableTools != null) {
|
||||
logger.warn('[getAvailableTools] Tool cache was empty, re-initializing from app config');
|
||||
await setCachedTools(appConfig.availableTools);
|
||||
toolDefinitions = appConfig.availableTools;
|
||||
}
|
||||
let toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
let prelimCachedTools;
|
||||
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
let pluginManifest = availableTools;
|
||||
|
||||
if (appConfig?.mcpConfig != null) {
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const mcpTools = await mcpManager.getAllToolFunctions(userId);
|
||||
prelimCachedTools = prelimCachedTools ?? {};
|
||||
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
|
||||
const plugin = convertMCPToolToPlugin({
|
||||
toolKey,
|
||||
toolData,
|
||||
mcpManager,
|
||||
});
|
||||
if (plugin) {
|
||||
pluginManifest.push(plugin);
|
||||
}
|
||||
prelimCachedTools[toolKey] = toolData;
|
||||
}
|
||||
await mergeUserTools({ userId, cachedUserTools, userTools: prelimCachedTools });
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[getAvailableTools] Error loading MCP Tools, servers may still be initializing:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
} else if (prelimCachedTools != null) {
|
||||
await setCachedTools(prelimCachedTools, { isGlobal: true });
|
||||
}
|
||||
|
||||
/** @type {TPlugin[]} Deduplicate and authenticate plugins */
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
const authenticatedPlugins = uniquePlugins.map((plugin) => {
|
||||
@@ -96,13 +139,13 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
/** Filter plugins based on availability */
|
||||
/** Filter plugins based on availability and add MCP-specific auth config */
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
const isToolDefined = toolDefinitions?.[plugin.pluginKey] !== undefined;
|
||||
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
|
||||
const isToolkit =
|
||||
plugin.toolkit === true &&
|
||||
Object.keys(toolDefinitions ?? {}).some(
|
||||
Object.keys(toolDefinitions).some(
|
||||
(key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey,
|
||||
);
|
||||
|
||||
@@ -110,13 +153,39 @@ const getAvailableTools = async (req, res) => {
|
||||
continue;
|
||||
}
|
||||
|
||||
toolsOutput.push(plugin);
|
||||
const toolToAdd = { ...plugin };
|
||||
|
||||
if (plugin.pluginKey.includes(Constants.mcp_delimiter)) {
|
||||
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
const serverConfig = appConfig?.mcpConfig?.[serverName];
|
||||
|
||||
if (serverConfig?.customUserVars) {
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
if (customVarKeys.length === 0) {
|
||||
toolToAdd.authConfig = [];
|
||||
toolToAdd.authenticated = true;
|
||||
} else {
|
||||
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(
|
||||
([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}),
|
||||
);
|
||||
toolToAdd.authenticated = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolsOutput.push(toolToAdd);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
res.status(200).json(finalTools);
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...finalTools]);
|
||||
res.status(200).json(dedupedTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { getCachedTools, getAppConfig } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
@@ -16,10 +17,18 @@ jest.mock('~/server/services/Config', () => ({
|
||||
includedTools: [],
|
||||
}),
|
||||
setCachedTools: jest.fn(),
|
||||
mergeUserTools: jest.fn(),
|
||||
}));
|
||||
|
||||
// loadAndFormatTools mock removed - no longer used in PluginController
|
||||
// getMCPManager mock removed - no longer used in PluginController
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(() => ({
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
})),
|
||||
getFlowStateManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/app/clients/tools', () => ({
|
||||
availableTools: [],
|
||||
@@ -150,6 +159,52 @@ describe('PluginController', () => {
|
||||
});
|
||||
|
||||
describe('getAvailableTools', () => {
|
||||
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}server1`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially (since getAllToolFunctions is called)
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// Mock second call to return tool definitions (includeGlobal: true)
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toBeDefined();
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
expect(responseData.length).toBeGreaterThan(0);
|
||||
const convertedTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}server1`,
|
||||
);
|
||||
expect(convertedTool).toBeDefined();
|
||||
// The real convertMCPToolsToPlugins extracts the name from the delimiter
|
||||
expect(convertedTool.name).toBe('tool1');
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
|
||||
const mockUserTools = {
|
||||
'user-tool': {
|
||||
@@ -174,6 +229,9 @@ describe('PluginController', () => {
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
@@ -196,7 +254,14 @@ describe('PluginController', () => {
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns the tool definitions
|
||||
// First call returns null for user tools
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
tool1: {
|
||||
type: 'function',
|
||||
@@ -207,10 +272,6 @@ describe('PluginController', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
@@ -241,7 +302,14 @@ describe('PluginController', () => {
|
||||
});
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns the tool definitions
|
||||
// First call returns null for user tools
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
toolkit1_function: {
|
||||
type: 'function',
|
||||
@@ -252,10 +320,6 @@ describe('PluginController', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
@@ -267,7 +331,126 @@ describe('PluginController', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('plugin.icon behavior', () => {
|
||||
const callGetAvailableToolsWithMCPServer = async (serverConfig) => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
const functionTools = {
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Mock the MCP manager to return tools and server config
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(functionTools),
|
||||
getRawConfig: jest.fn().mockReturnValue(serverConfig),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// First call returns empty user tools
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock getAppConfig to return the mcpConfig
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
'test-server': serverConfig,
|
||||
},
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(functionTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
return responseData.find(
|
||||
(tool) => tool.pluginKey === `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
};
|
||||
|
||||
it('should set plugin.icon when iconPath is defined', async () => {
|
||||
const serverConfig = {
|
||||
iconPath: '/path/to/icon.png',
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
|
||||
expect(testTool.icon).toBe('/path/to/icon.png');
|
||||
});
|
||||
|
||||
it('should set plugin.icon to undefined when iconPath is not defined', async () => {
|
||||
const serverConfig = {};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
|
||||
expect(testTool.icon).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('helper function integration', () => {
|
||||
it('should properly handle MCP tools with custom user variables', async () => {
|
||||
const appConfig = {
|
||||
mcpConfig: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: { title: 'API Key', description: 'Your API key' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Mock MCP tools returned by getAllToolFunctions
|
||||
const mcpToolFunctions = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Mock the MCP manager to return tools
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(mcpToolFunctions),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: { title: 'API Key', description: 'Your API key' },
|
||||
},
|
||||
}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
mockReq.config = appConfig;
|
||||
|
||||
// First call returns user tools (empty in this case)
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Second call (with includeGlobal: true) returns tool definitions including our MCP tool
|
||||
getCachedTools.mockResolvedValueOnce(mcpToolFunctions);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
|
||||
// Find the MCP tool in the response
|
||||
const mcpTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
|
||||
// The actual implementation adds authConfig and sets authenticated to false when customUserVars exist
|
||||
expect(mcpTool).toBeDefined();
|
||||
expect(mcpTool.authConfig).toEqual([
|
||||
{ authField: 'API_KEY', label: 'API Key', description: 'Your API key' },
|
||||
]);
|
||||
expect(mcpTool.authenticated).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle error cases gracefully', async () => {
|
||||
mockCache.get.mockRejectedValue(new Error('Cache error'));
|
||||
|
||||
@@ -289,13 +472,23 @@ describe('PluginController', () => {
|
||||
|
||||
it('should handle null cachedTools and cachedUserTools', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns empty object instead of null
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
// First call returns null for user tools
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return no tools
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// Second call (with includeGlobal: true) returns empty object instead of null
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle null values gracefully
|
||||
@@ -310,9 +503,9 @@ describe('PluginController', () => {
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock getCachedTools to return undefined
|
||||
// Mock getCachedTools to return undefined for both calls
|
||||
getCachedTools.mockReset();
|
||||
getCachedTools.mockResolvedValueOnce(undefined);
|
||||
getCachedTools.mockResolvedValueOnce(undefined).mockResolvedValueOnce(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
@@ -321,6 +514,51 @@ describe('PluginController', () => {
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle `cachedToolsArray` and `mcpPlugins` both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
|
||||
const userTools = {
|
||||
[`user-tool${Constants.mcp_delimiter}server1`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `user-tool${Constants.mcp_delimiter}server1`,
|
||||
description: 'User tool',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValueOnce(userTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// The controller expects a second call to getCachedTools
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },
|
||||
[`user-tool${Constants.mcp_delimiter}server1`]:
|
||||
userTools[`user-tool${Constants.mcp_delimiter}server1`],
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// Should have both cached and user tools
|
||||
expect(responseData.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it('should handle empty toolDefinitions object', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// Reset getCachedTools to ensure clean state
|
||||
@@ -331,12 +569,76 @@ describe('PluginController', () => {
|
||||
// Ensure no plugins are available
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
// Reset MCP manager to default state
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// With empty tool definitions, no tools should be in the final output
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle MCP tools without customUserVars', async () => {
|
||||
const appConfig = {
|
||||
mcpConfig: {
|
||||
'test-server': {
|
||||
// No customUserVars defined
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Mock the MCP manager to return the tools
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(mockUserTools),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
// No customUserVars defined
|
||||
}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
mockReq.config = appConfig;
|
||||
// First call returns empty user tools
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
// Ensure no plugins in availableTools for clean test
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
expect(responseData.length).toBeGreaterThan(0);
|
||||
|
||||
const mcpTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
|
||||
expect(mcpTool).toBeDefined();
|
||||
expect(mcpTool.authenticated).toBe(true);
|
||||
// The actual implementation sets authConfig to empty array when no customUserVars
|
||||
expect(mcpTool.authConfig).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle undefined filteredTools and includedTools', async () => {
|
||||
mockReq.config = {};
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
@@ -365,129 +667,20 @@ describe('PluginController', () => {
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// getCachedTools returns empty object to avoid null reference error
|
||||
// First call returns empty object
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns empty object to avoid null reference error
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle null toolDefinitions gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
|
||||
it('should handle undefined toolDefinitions when checking isToolDefined (traversaal_search bug)', async () => {
|
||||
// This test reproduces the bug where toolDefinitions is undefined
|
||||
// and accessing toolDefinitions[plugin.pluginKey] causes a TypeError
|
||||
const mockPlugin = {
|
||||
name: 'Traversaal Search',
|
||||
pluginKey: 'traversaal_search',
|
||||
description: 'Search plugin',
|
||||
};
|
||||
|
||||
// Add the plugin to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// CRITICAL: getCachedTools returns undefined
|
||||
// This is what causes the bug when trying to access toolDefinitions[plugin.pluginKey]
|
||||
getCachedTools.mockResolvedValueOnce(undefined);
|
||||
|
||||
// This should not throw an error with the optional chaining fix
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle undefined toolDefinitions gracefully and return empty array
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should re-initialize tools from appConfig when cache returns null', async () => {
|
||||
// Setup: Initial state with tools in appConfig
|
||||
const mockAppTools = {
|
||||
tool1: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'tool1',
|
||||
description: 'Tool 1',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
tool2: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'tool2',
|
||||
description: 'Tool 2',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Add matching plugins to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(
|
||||
{ name: 'Tool 1', pluginKey: 'tool1', description: 'Tool 1' },
|
||||
{ name: 'Tool 2', pluginKey: 'tool2', description: 'Tool 2' },
|
||||
);
|
||||
|
||||
// Simulate cache cleared state (returns null)
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
|
||||
|
||||
mockReq.config = {
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
availableTools: mockAppTools,
|
||||
};
|
||||
|
||||
// Mock setCachedTools to verify it's called to re-initialize
|
||||
const { setCachedTools } = require('~/server/services/Config');
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should have re-initialized the cache with tools from appConfig
|
||||
expect(setCachedTools).toHaveBeenCalledWith(mockAppTools);
|
||||
|
||||
// Should still return tools successfully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(2);
|
||||
expect(responseData.find((t) => t.pluginKey === 'tool1')).toBeDefined();
|
||||
expect(responseData.find((t) => t.pluginKey === 'tool2')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle cache clear without appConfig.availableTools gracefully', async () => {
|
||||
// Setup: appConfig without availableTools
|
||||
getAppConfig.mockResolvedValue({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
// No availableTools property
|
||||
});
|
||||
|
||||
// Clear availableTools array
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
// Cache returns null (cleared state)
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(null); // Global tools (cache cleared)
|
||||
|
||||
mockReq.config = {
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
// No availableTools
|
||||
};
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle gracefully without crashing
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,34 +1,37 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, CacheKeys, Constants, FileSources } = require('librechat-data-provider');
|
||||
const {
|
||||
webSearchKeys,
|
||||
MCPOAuthHandler,
|
||||
MCPTokenStorage,
|
||||
normalizeHttpError,
|
||||
extractWebSearchEnvVars,
|
||||
normalizeHttpError,
|
||||
MCPTokenStorage,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
findToken,
|
||||
updateUser,
|
||||
deleteFiles,
|
||||
deleteConvos,
|
||||
deletePresets,
|
||||
deleteMessages,
|
||||
deleteUserById,
|
||||
deleteAllSharedLinks,
|
||||
deleteAllUserSessions,
|
||||
} = require('~/models');
|
||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { Tools, Constants, FileSources } = require('librechat-data-provider');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User, Token } = require('~/db/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { deleteAllSharedLinks } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { clearMCPServerTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { findToken } = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
@@ -372,6 +375,9 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
// 6. clear the tools cache for the server
|
||||
await clearMCPServerTools({ userId, serverName });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -158,7 +158,7 @@ describe('duplicateAgent', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert `tool_resources.ocr` to `tool_resources.context`', async () => {
|
||||
it('should handle tool_resources.ocr correctly', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
@@ -178,7 +178,7 @@ describe('duplicateAgent', () => {
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tool_resources: {
|
||||
context: { enabled: true, config: 'test' },
|
||||
ocr: { enabled: true, config: 'test' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -2,12 +2,7 @@ const { z } = require('zod');
|
||||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
agentCreateSchema,
|
||||
agentUpdateSchema,
|
||||
mergeAgentOcrConversion,
|
||||
convertOcrToContextInPlace,
|
||||
} = require('@librechat/api');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
@@ -71,7 +66,7 @@ const createAgentHandler = async (req, res) => {
|
||||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = await getCachedTools();
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
for (const tool of tools) {
|
||||
if (availableTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
@@ -203,32 +198,19 @@ const getAgentHandler = async (req, res, expandProperties = false) => {
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @param {AgentUpdateParams} req.body - The Agent update parameters.
|
||||
* @returns {Promise<Agent>} 200 - success response - application/json
|
||||
* @returns {Agent} 200 - success response - application/json
|
||||
*/
|
||||
const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const validatedData = agentUpdateSchema.parse(req.body);
|
||||
const { _id, ...updateData } = removeNullishValues(validatedData);
|
||||
|
||||
// Convert OCR to context in incoming updateData
|
||||
convertOcrToContextInPlace(updateData);
|
||||
|
||||
const existingAgent = await getAgent({ id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
// Convert legacy OCR tool resource to context format in existing agent
|
||||
const ocrConversion = mergeAgentOcrConversion(existingAgent, updateData);
|
||||
if (ocrConversion.tool_resources) {
|
||||
updateData.tool_resources = ocrConversion.tool_resources;
|
||||
}
|
||||
if (ocrConversion.tools) {
|
||||
updateData.tools = ocrConversion.tools;
|
||||
}
|
||||
|
||||
let updatedAgent =
|
||||
Object.keys(updateData).length > 0
|
||||
? await updateAgent({ id }, updateData, {
|
||||
@@ -273,7 +255,7 @@ const updateAgentHandler = async (req, res) => {
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Promise<Agent>} 201 - success response - application/json
|
||||
* @returns {Agent} 201 - success response - application/json
|
||||
*/
|
||||
const duplicateAgentHandler = async (req, res) => {
|
||||
const { id } = req.params;
|
||||
@@ -306,19 +288,9 @@ const duplicateAgentHandler = async (req, res) => {
|
||||
hour12: false,
|
||||
})})`;
|
||||
|
||||
if (_tool_resources?.[EToolResources.context]) {
|
||||
cloneData.tool_resources = {
|
||||
[EToolResources.context]: _tool_resources[EToolResources.context],
|
||||
};
|
||||
}
|
||||
|
||||
if (_tool_resources?.[EToolResources.ocr]) {
|
||||
cloneData.tool_resources = {
|
||||
/** Legacy conversion from `ocr` to `context` */
|
||||
[EToolResources.context]: {
|
||||
...(_tool_resources[EToolResources.context] ?? {}),
|
||||
..._tool_resources[EToolResources.ocr],
|
||||
},
|
||||
[EToolResources.ocr]: _tool_resources[EToolResources.ocr],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -410,7 +382,7 @@ const duplicateAgentHandler = async (req, res) => {
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Promise<Agent>} 200 - success response - application/json
|
||||
* @returns {Agent} 200 - success response - application/json
|
||||
*/
|
||||
const deleteAgentHandler = async (req, res) => {
|
||||
try {
|
||||
@@ -512,7 +484,7 @@ const getListAgentsHandler = async (req, res) => {
|
||||
* @param {Express.Multer.File} req.file - The avatar image file.
|
||||
* @param {object} req.body - Request body
|
||||
* @param {string} [req.body.avatar] - Optional avatar for the agent's avatar.
|
||||
* @returns {Promise<void>} 200 - success response - application/json
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
const uploadAgentAvatarHandler = async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -512,7 +512,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tool_resources: {
|
||||
/** Legacy conversion from `ocr` to `context` */
|
||||
ocr: {
|
||||
file_ids: ['ocr1', 'ocr2'],
|
||||
},
|
||||
@@ -532,8 +531,7 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tool_resources).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.ocr).toBeUndefined();
|
||||
expect(updatedAgent.tool_resources.context).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.ocr).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||
});
|
||||
|
||||
@@ -31,7 +31,7 @@ const createAssistant = async (req, res) => {
|
||||
delete assistantData.conversation_starters;
|
||||
delete assistantData.append_current_datetime;
|
||||
|
||||
const toolDefinitions = await getCachedTools();
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
@@ -136,7 +136,7 @@ const patchAssistant = async (req, res) => {
|
||||
...updateData
|
||||
} = req.body;
|
||||
|
||||
const toolDefinitions = await getCachedTools();
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
updateData.tools = (updateData.tools ?? [])
|
||||
.map((tool) => {
|
||||
|
||||
@@ -28,7 +28,7 @@ const createAssistant = async (req, res) => {
|
||||
delete assistantData.conversation_starters;
|
||||
delete assistantData.append_current_datetime;
|
||||
|
||||
const toolDefinitions = await getCachedTools();
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
@@ -125,7 +125,7 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||
|
||||
let hasFileSearch = false;
|
||||
for (const tool of updateData.tools ?? []) {
|
||||
const toolDefinitions = await getCachedTools();
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool;
|
||||
|
||||
if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
/**
|
||||
* MCP Tools Controller
|
||||
* Handles MCP-specific tool endpoints, decoupled from regular LibreChat tools
|
||||
*/
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
cacheMCPServerTools,
|
||||
getMCPServerTools,
|
||||
getAppConfig,
|
||||
} = require('~/server/services/Config');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
*/
|
||||
const getMCPTools = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user?.id;
|
||||
if (!userId) {
|
||||
logger.warn('[getMCPTools] User ID not found in request');
|
||||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
if (!appConfig?.mcpConfig) {
|
||||
return res.status(200).json({ servers: {} });
|
||||
}
|
||||
|
||||
const mcpManager = getMCPManager();
|
||||
const configuredServers = Object.keys(appConfig.mcpConfig);
|
||||
const mcpServers = {};
|
||||
|
||||
const cachePromises = configuredServers.map((serverName) =>
|
||||
getMCPServerTools(serverName).then((tools) => ({ serverName, tools })),
|
||||
);
|
||||
const cacheResults = await Promise.all(cachePromises);
|
||||
|
||||
const serverToolsMap = new Map();
|
||||
for (const { serverName, tools } of cacheResults) {
|
||||
if (tools) {
|
||||
serverToolsMap.set(serverName, tools);
|
||||
continue;
|
||||
}
|
||||
|
||||
const serverTools = await mcpManager.getServerToolFunctions(userId, serverName);
|
||||
if (!serverTools) {
|
||||
logger.debug(`[getMCPTools] No tools found for server ${serverName}`);
|
||||
continue;
|
||||
}
|
||||
serverToolsMap.set(serverName, serverTools);
|
||||
|
||||
if (Object.keys(serverTools).length > 0) {
|
||||
// Cache asynchronously without blocking
|
||||
cacheMCPServerTools({ serverName, serverTools }).catch((err) =>
|
||||
logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Process each configured server
|
||||
for (const serverName of configuredServers) {
|
||||
try {
|
||||
const serverTools = serverToolsMap.get(serverName);
|
||||
|
||||
// Get server config once
|
||||
const serverConfig = appConfig.mcpConfig[serverName];
|
||||
const rawServerConfig = mcpManager.getRawConfig(serverName);
|
||||
|
||||
// Initialize server object with all server-level data
|
||||
const server = {
|
||||
name: serverName,
|
||||
icon: rawServerConfig?.iconPath || '',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
tools: [],
|
||||
};
|
||||
|
||||
// Set authentication config once for the server
|
||||
if (serverConfig?.customUserVars) {
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
if (customVarKeys.length > 0) {
|
||||
server.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}));
|
||||
server.authenticated = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Process tools efficiently - no need for convertMCPToolToPlugin
|
||||
if (serverTools) {
|
||||
for (const [toolKey, toolData] of Object.entries(serverTools)) {
|
||||
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolName = toolKey.split(Constants.mcp_delimiter)[0];
|
||||
server.tools.push({
|
||||
name: toolName,
|
||||
pluginKey: toolKey,
|
||||
description: toolData.function.description || '',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Only add server if it has tools or is configured
|
||||
if (server.tools.length > 0 || serverConfig) {
|
||||
mcpServers[serverName] = server;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[getMCPTools] Error loading tools for server ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
res.status(200).json({ servers: mcpServers });
|
||||
} catch (error) {
|
||||
logger.error('[getMCPTools]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getMCPTools,
|
||||
};
|
||||
@@ -12,7 +12,6 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
@@ -155,7 +154,7 @@ const startServer = async () => {
|
||||
res.send(updatedIndexHtml);
|
||||
});
|
||||
|
||||
app.listen(port, host, async () => {
|
||||
app.listen(port, host, () => {
|
||||
if (host === '0.0.0.0') {
|
||||
logger.info(
|
||||
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
|
||||
@@ -164,9 +163,7 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
initializeMCPs().then(() => checkMigrations());
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { PermissionBits, hasPermissions, ResourceType } = require('librechat-data-provider');
|
||||
const { getEffectivePermissions } = require('~/server/services/PermissionService');
|
||||
const { getAgents } = require('~/models/Agent');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
/**
|
||||
@@ -10,12 +10,11 @@ const { getFiles } = require('~/models/File');
|
||||
*/
|
||||
const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
|
||||
try {
|
||||
/** Agents that have this file in their tool_resources */
|
||||
const agentsWithFile = await getAgents({
|
||||
// Find agents that have this file in their tool_resources
|
||||
const agentsWithFile = await getAgent({
|
||||
$or: [
|
||||
{ 'tool_resources.execute_code.file_ids': fileId },
|
||||
{ 'tool_resources.file_search.file_ids': fileId },
|
||||
{ 'tool_resources.context.file_ids': fileId },
|
||||
{ 'tool_resources.execute_code.file_ids': fileId },
|
||||
{ 'tool_resources.ocr.file_ids': fileId },
|
||||
],
|
||||
});
|
||||
@@ -25,7 +24,7 @@ const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
|
||||
}
|
||||
|
||||
// Check if user has access to any of these agents
|
||||
for (const agent of agentsWithFile) {
|
||||
for (const agent of Array.isArray(agentsWithFile) ? agentsWithFile : [agentsWithFile]) {
|
||||
// Check if user is the agent author
|
||||
if (agent.author && agent.author.toString() === userId) {
|
||||
logger.debug(`[fileAccess] User is author of agent ${agent.id}`);
|
||||
@@ -84,6 +83,7 @@ const fileAccess = async (req, res, next) => {
|
||||
});
|
||||
}
|
||||
|
||||
// Get the file
|
||||
const [file] = await getFiles({ file_id: fileId });
|
||||
if (!file) {
|
||||
return res.status(404).json({
|
||||
@@ -92,18 +92,20 @@ const fileAccess = async (req, res, next) => {
|
||||
});
|
||||
}
|
||||
|
||||
// Check if user owns the file
|
||||
if (file.user && file.user.toString() === userId) {
|
||||
req.fileAccess = { file };
|
||||
return next();
|
||||
}
|
||||
|
||||
/** Agent-based access (file inherits agent permissions) */
|
||||
// Check agent-based access (file inherits agent permissions)
|
||||
const hasAgentAccess = await checkAgentBasedFileAccess({ userId, role: userRole, fileId });
|
||||
if (hasAgentAccess) {
|
||||
req.fileAccess = { file };
|
||||
return next();
|
||||
}
|
||||
|
||||
// No access
|
||||
logger.warn(`[fileAccess] User ${userId} denied access to file ${fileId}`);
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
|
||||
@@ -1,483 +0,0 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { fileAccess } = require('./fileAccess');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { createFile } = require('~/models/File');
|
||||
|
||||
describe('fileAccess middleware', () => {
|
||||
let mongoServer;
|
||||
let req, res, next;
|
||||
let testUser, otherUser, thirdUser;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
|
||||
// Create test role
|
||||
await Role.create({
|
||||
name: 'test-role',
|
||||
permissions: {
|
||||
AGENTS: {
|
||||
USE: true,
|
||||
CREATE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create test users
|
||||
testUser = await User.create({
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
username: 'testuser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
otherUser = await User.create({
|
||||
email: 'other@example.com',
|
||||
name: 'Other User',
|
||||
username: 'otheruser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
thirdUser = await User.create({
|
||||
email: 'third@example.com',
|
||||
name: 'Third User',
|
||||
username: 'thirduser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
// Setup request/response objects
|
||||
req = {
|
||||
user: { id: testUser._id.toString(), role: testUser.role },
|
||||
params: {},
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('basic file access', () => {
|
||||
test('should allow access when user owns the file', async () => {
|
||||
// Create a file owned by testUser
|
||||
await createFile({
|
||||
user: testUser._id.toString(),
|
||||
file_id: 'file_owned_by_user',
|
||||
filepath: '/test/file.txt',
|
||||
filename: 'file.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
|
||||
req.params.file_id = 'file_owned_by_user';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.fileAccess).toBeDefined();
|
||||
expect(req.fileAccess.file).toBeDefined();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should deny access when user does not own the file and no agent access', async () => {
|
||||
// Create a file owned by otherUser
|
||||
await createFile({
|
||||
user: otherUser._id.toString(),
|
||||
file_id: 'file_owned_by_other',
|
||||
filepath: '/test/file.txt',
|
||||
filename: 'file.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
|
||||
req.params.file_id = 'file_owned_by_other';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to access this file',
|
||||
});
|
||||
});
|
||||
|
||||
test('should return 404 when file does not exist', async () => {
|
||||
req.params.file_id = 'non_existent_file';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Not Found',
|
||||
message: 'File not found',
|
||||
});
|
||||
});
|
||||
|
||||
test('should return 400 when file_id is missing', async () => {
|
||||
// Don't set file_id in params
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(400);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Bad Request',
|
||||
message: 'file_id is required',
|
||||
});
|
||||
});
|
||||
|
||||
test('should return 401 when user is not authenticated', async () => {
|
||||
req.user = null;
|
||||
req.params.file_id = 'some_file';
|
||||
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('agent-based file access', () => {
|
||||
beforeEach(async () => {
|
||||
// Create a file owned by otherUser (not testUser)
|
||||
await createFile({
|
||||
user: otherUser._id.toString(),
|
||||
file_id: 'shared_file_via_agent',
|
||||
filepath: '/test/shared.txt',
|
||||
filename: 'shared.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
});
|
||||
|
||||
test('should allow access when user is author of agent with file', async () => {
|
||||
// Create agent owned by testUser with the file
|
||||
await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: ['shared_file_via_agent'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
req.params.file_id = 'shared_file_via_agent';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.fileAccess).toBeDefined();
|
||||
expect(req.fileAccess.file).toBeDefined();
|
||||
});
|
||||
|
||||
test('should allow access when user has VIEW permission on agent with file', async () => {
|
||||
// Create agent owned by otherUser
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Shared Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
tool_resources: {
|
||||
execute_code: {
|
||||
file_ids: ['shared_file_via_agent'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant VIEW permission to testUser
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 1, // VIEW permission
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.file_id = 'shared_file_via_agent';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.fileAccess).toBeDefined();
|
||||
});
|
||||
|
||||
test('should check file in ocr tool_resources', async () => {
|
||||
await createAgent({
|
||||
id: `agent_ocr_${Date.now()}`,
|
||||
name: 'OCR Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
tool_resources: {
|
||||
ocr: {
|
||||
file_ids: ['shared_file_via_agent'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
req.params.file_id = 'shared_file_via_agent';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.fileAccess).toBeDefined();
|
||||
});
|
||||
|
||||
test('should deny access when user has no permission on agent with file', async () => {
|
||||
// Create agent owned by otherUser without granting permission to testUser
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: ['shared_file_via_agent'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create ACL entry for otherUser only (owner)
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.file_id = 'shared_file_via_agent';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
});
|
||||
|
||||
describe('multiple agents with same file', () => {
|
||||
/**
|
||||
* This test suite verifies that when multiple agents have the same file,
|
||||
* all agents are checked for permissions, not just the first one found.
|
||||
* This ensures users can access files through any agent they have permission for.
|
||||
*/
|
||||
|
||||
test('should check ALL agents with file, not just first one', async () => {
|
||||
// Create a file owned by someone else
|
||||
await createFile({
|
||||
user: otherUser._id.toString(),
|
||||
file_id: 'multi_agent_file',
|
||||
filepath: '/test/multi.txt',
|
||||
filename: 'multi.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
|
||||
// Create first agent (owned by otherUser, no access for testUser)
|
||||
const agent1 = await createAgent({
|
||||
id: 'agent_no_access',
|
||||
name: 'No Access Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: ['multi_agent_file'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create ACL for agent1 - only otherUser has access
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent1._id,
|
||||
permBits: 15,
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
// Create second agent (owned by thirdUser, but testUser has VIEW access)
|
||||
const agent2 = await createAgent({
|
||||
id: 'agent_with_access',
|
||||
name: 'Accessible Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: thirdUser._id,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: ['multi_agent_file'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant testUser VIEW access to agent2
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent2._id,
|
||||
permBits: 1, // VIEW permission
|
||||
grantedBy: thirdUser._id,
|
||||
});
|
||||
|
||||
req.params.file_id = 'multi_agent_file';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
/**
|
||||
* Should succeed because testUser has access to agent2,
|
||||
* even though they don't have access to agent1.
|
||||
* The fix ensures all agents are checked, not just the first one.
|
||||
*/
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.fileAccess).toBeDefined();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should find file in any agent tool_resources type', async () => {
|
||||
// Create a file
|
||||
await createFile({
|
||||
user: otherUser._id.toString(),
|
||||
file_id: 'multi_tool_file',
|
||||
filepath: '/test/tool.txt',
|
||||
filename: 'tool.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
|
||||
// Agent 1: file in file_search (no access for testUser)
|
||||
await createAgent({
|
||||
id: 'agent_file_search',
|
||||
name: 'File Search Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: ['multi_tool_file'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Agent 2: same file in execute_code (testUser has access)
|
||||
await createAgent({
|
||||
id: 'agent_execute_code',
|
||||
name: 'Execute Code Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: thirdUser._id,
|
||||
tool_resources: {
|
||||
execute_code: {
|
||||
file_ids: ['multi_tool_file'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Agent 3: same file in ocr (testUser also has access)
|
||||
await createAgent({
|
||||
id: 'agent_ocr',
|
||||
name: 'OCR Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id, // testUser owns this one
|
||||
tool_resources: {
|
||||
ocr: {
|
||||
file_ids: ['multi_tool_file'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
req.params.file_id = 'multi_tool_file';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
/**
|
||||
* Should succeed because testUser owns agent3,
|
||||
* even if other agents with the file are found first.
|
||||
*/
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(req.fileAccess).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
test('should handle agent with empty tool_resources', async () => {
|
||||
await createFile({
|
||||
user: otherUser._id.toString(),
|
||||
file_id: 'orphan_file',
|
||||
filepath: '/test/orphan.txt',
|
||||
filename: 'orphan.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
|
||||
// Create agent with no files in tool_resources
|
||||
await createAgent({
|
||||
id: `agent_empty_${Date.now()}`,
|
||||
name: 'Empty Resources Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
req.params.file_id = 'orphan_file';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
|
||||
test('should handle agent with null tool_resources', async () => {
|
||||
await createFile({
|
||||
user: otherUser._id.toString(),
|
||||
file_id: 'another_orphan_file',
|
||||
filepath: '/test/orphan2.txt',
|
||||
filename: 'orphan2.txt',
|
||||
type: 'text/plain',
|
||||
size: 100,
|
||||
});
|
||||
|
||||
// Create agent with null tool_resources
|
||||
await createAgent({
|
||||
id: `agent_null_${Date.now()}`,
|
||||
name: 'Null Resources Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
tool_resources: null,
|
||||
});
|
||||
|
||||
req.params.file_id = 'another_orphan_file';
|
||||
await fileAccess(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
});
|
||||
});
|
||||
});
|
||||
680
api/server/routes/__tests__/costs.spec.js
Normal file
680
api/server/routes/__tests__/costs.spec.js
Normal file
@@ -0,0 +1,680 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
createMethods: jest.fn(() => ({})),
|
||||
createModels: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (req, res, next) => next(),
|
||||
validateMessageReq: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getConvo: jest.fn(),
|
||||
saveConvo: jest.fn(),
|
||||
saveMessage: jest.fn(),
|
||||
getMessage: jest.fn(),
|
||||
getMessages: jest.fn(),
|
||||
updateMessage: jest.fn(),
|
||||
deleteMessages: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/db/models', () => {
|
||||
let User, Message, Transaction, Conversation;
|
||||
|
||||
return {
|
||||
get User() {
|
||||
return User;
|
||||
},
|
||||
get Message() {
|
||||
return Message;
|
||||
},
|
||||
get Transaction() {
|
||||
return Transaction;
|
||||
},
|
||||
get Conversation() {
|
||||
return Conversation;
|
||||
},
|
||||
setUser: (model) => {
|
||||
User = model;
|
||||
},
|
||||
setMessage: (model) => {
|
||||
Message = model;
|
||||
},
|
||||
setTransaction: (model) => {
|
||||
Transaction = model;
|
||||
},
|
||||
setConversation: (model) => {
|
||||
Conversation = model;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
describe('Costs Endpoint', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let messagesRouter;
|
||||
let User, Message, Transaction, Conversation;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
const userSchema = new mongoose.Schema({
|
||||
_id: String,
|
||||
name: String,
|
||||
email: String,
|
||||
});
|
||||
|
||||
const conversationSchema = new mongoose.Schema({
|
||||
conversationId: String,
|
||||
user: String,
|
||||
title: String,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
const messageSchema = new mongoose.Schema({
|
||||
messageId: String,
|
||||
conversationId: String,
|
||||
user: String,
|
||||
isCreatedByUser: Boolean,
|
||||
tokenCount: Number,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
const transactionSchema = new mongoose.Schema({
|
||||
conversationId: String,
|
||||
user: String,
|
||||
tokenType: String,
|
||||
tokenValue: Number,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
User = mongoose.model('User', userSchema);
|
||||
Conversation = mongoose.model('Conversation', conversationSchema);
|
||||
Message = mongoose.model('Message', messageSchema);
|
||||
Transaction = mongoose.model('Transaction', transactionSchema);
|
||||
|
||||
const dbModels = require('~/db/models');
|
||||
dbModels.setUser(User);
|
||||
dbModels.setMessage(Message);
|
||||
dbModels.setTransaction(Transaction);
|
||||
dbModels.setConversation(Conversation);
|
||||
|
||||
require('~/db/models');
|
||||
|
||||
try {
|
||||
messagesRouter = require('../messages');
|
||||
} catch (error) {
|
||||
console.error('Error loading messages router:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: 'test-user-id' };
|
||||
next();
|
||||
});
|
||||
app.use('/api/messages', messagesRouter);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await User.deleteMany({});
|
||||
await Conversation.deleteMany({});
|
||||
await Message.deleteMany({});
|
||||
await Transaction.deleteMany({});
|
||||
});
|
||||
|
||||
describe('GET /:conversationId/costs', () => {
|
||||
const conversationId = 'test-conversation-123';
|
||||
const userId = 'test-user-id';
|
||||
|
||||
it('should return cost data for valid conversation', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const aiMessage = new Message({
|
||||
messageId: 'ai-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: 150,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), aiMessage.save()]);
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const completionTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'completion',
|
||||
tokenValue: 750000,
|
||||
createdAt: new Date('2024-01-01T10:01:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction.save(), completionTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toMatchObject({
|
||||
conversationId,
|
||||
totals: {
|
||||
prompt: { usd: 0.5, tokenCount: 100 },
|
||||
completion: { usd: 0.75, tokenCount: 150 },
|
||||
total: { usd: 1.25, tokenCount: 250 },
|
||||
},
|
||||
perMessage: [
|
||||
{ messageId: 'user-msg-1', tokenType: 'prompt', tokenCount: 100, usd: 0.5 },
|
||||
{ messageId: 'ai-msg-1', tokenType: 'completion', tokenCount: 150, usd: 0.75 },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should return empty data for conversation with no messages', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toMatchObject({
|
||||
conversationId,
|
||||
totals: {
|
||||
prompt: { usd: 0, tokenCount: 0 },
|
||||
completion: { usd: 0, tokenCount: 0 },
|
||||
total: { usd: 0, tokenCount: 0 },
|
||||
},
|
||||
perMessage: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle messages without transactions', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const aiMessage = new Message({
|
||||
messageId: 'ai-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: 150,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), aiMessage.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0);
|
||||
expect(response.body.totals.completion.usd).toBe(0);
|
||||
expect(response.body.totals.total.usd).toBe(0);
|
||||
});
|
||||
|
||||
it('should aggregate multiple transactions correctly', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction1 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 300000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const promptTransaction2 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 200000,
|
||||
createdAt: new Date('2024-01-01T10:00:45Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.perMessage[0].usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should handle null tokenCount values', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: null,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.tokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle null tokenValue in transactions', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: null,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await promptTransaction.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle negative tokenValue using Math.abs', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: -500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await promptTransaction.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should filter by user correctly', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const otherUserId = 'other-user-id';
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const otherUserMessage = new Message({
|
||||
messageId: 'other-user-msg-1',
|
||||
conversationId,
|
||||
user: otherUserId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 200,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), otherUserMessage.save()]);
|
||||
|
||||
const userTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const otherUserTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: otherUserId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userTransaction.save(), otherUserTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.perMessage).toHaveLength(1);
|
||||
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
|
||||
});
|
||||
|
||||
it('should filter transactions by tokenType', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const otherTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'other',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction.save(), otherTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.totals.completion.usd).toBe(0);
|
||||
expect(response.body.totals.total.usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should map transactions to messages chronologically', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage1 = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const userMessage2 = new Message({
|
||||
messageId: 'user-msg-2',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 200,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage1.save(), userMessage2.save()]);
|
||||
|
||||
const promptTransaction1 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const promptTransaction2 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:01:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.perMessage).toHaveLength(2);
|
||||
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
|
||||
expect(response.body.perMessage[0].usd).toBe(0.5);
|
||||
expect(response.body.perMessage[1].messageId).toBe('user-msg-2');
|
||||
expect(response.body.perMessage[1].usd).toBe(1.0);
|
||||
});
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
await mongoose.connection.close();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.body).toHaveProperty('error');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -11,9 +11,6 @@ jest.mock('@librechat/api', () => ({
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
MCPTokenStorage: {
|
||||
storeTokens: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
}));
|
||||
|
||||
@@ -50,8 +47,8 @@ jest.mock('~/server/services/Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/mcp', () => ({
|
||||
updateMCPServerTools: jest.fn(),
|
||||
jest.mock('~/server/services/Config/mcpToolsCache', () => ({
|
||||
updateMCPUserTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
@@ -237,7 +234,7 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('GET /:serverName/oauth/callback', () => {
|
||||
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
it('should redirect to error page when OAuth error is received', async () => {
|
||||
@@ -283,7 +280,6 @@ describe('MCP Routes', () => {
|
||||
it('should handle OAuth callback successfully', async () => {
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
@@ -299,7 +295,6 @@ describe('MCP Routes', () => {
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
@@ -337,24 +332,11 @@ describe('MCP Routes', () => {
|
||||
'test-auth-code',
|
||||
mockFlowManager,
|
||||
);
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
userId: 'test-user-id',
|
||||
serverName: 'test-server',
|
||||
tokens: mockTokens,
|
||||
clientInfo: mockFlowState.clientInfo,
|
||||
metadata: mockFlowState.metadata,
|
||||
}),
|
||||
);
|
||||
const storeInvocation = MCPTokenStorage.storeTokens.mock.invocationCallOrder[0];
|
||||
const connectInvocation = mockMcpManager.getUserConnection.mock.invocationCallOrder[0];
|
||||
expect(storeInvocation).toBeLessThan(connectInvocation);
|
||||
expect(mockFlowManager.completeFlow).toHaveBeenCalledWith(
|
||||
'tool-flow-123',
|
||||
'mcp_oauth',
|
||||
mockTokens,
|
||||
);
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
it('should redirect to error page when callback processing fails', async () => {
|
||||
@@ -372,7 +354,6 @@ describe('MCP Routes', () => {
|
||||
it('should handle system-level OAuth completion', async () => {
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
@@ -388,7 +369,6 @@ describe('MCP Routes', () => {
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
@@ -399,13 +379,11 @@ describe('MCP Routes', () => {
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
it('should handle reconnection failure after OAuth', async () => {
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
@@ -421,7 +399,6 @@ describe('MCP Routes', () => {
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
@@ -441,46 +418,6 @@ describe('MCP Routes', () => {
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/success?serverName=test-server');
|
||||
expect(MCPTokenStorage.storeTokens).toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens');
|
||||
});
|
||||
|
||||
it('should redirect to error page if token storage fails', async () => {
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn().mockResolvedValue(),
|
||||
deleteFlow: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const mockFlowState = {
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { toolFlowId: 'tool-flow-123' },
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
};
|
||||
const mockTokens = {
|
||||
access_token: 'test-access-token',
|
||||
refresh_token: 'test-refresh-token',
|
||||
};
|
||||
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState);
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed'));
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn(),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({
|
||||
code: 'test-auth-code',
|
||||
state: 'test-flow-id',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.location).toBe('/oauth/error?error=callback_failed');
|
||||
expect(mockMcpManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -841,10 +778,10 @@ describe('MCP Routes', () => {
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
updateMCPServerTools.mockResolvedValue();
|
||||
updateMCPUserTools.mockResolvedValue();
|
||||
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
@@ -899,10 +836,10 @@ describe('MCP Routes', () => {
|
||||
]);
|
||||
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
updateMCPServerTools.mockResolvedValue();
|
||||
updateMCPUserTools.mockResolvedValue();
|
||||
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
@@ -1206,11 +1143,7 @@ describe('MCP Routes', () => {
|
||||
|
||||
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
||||
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
||||
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||
const mockTokens = {
|
||||
access_token: 'edge-access-token',
|
||||
refresh_token: 'edge-refresh-token',
|
||||
};
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({
|
||||
id: 'test-flow-id',
|
||||
userId: 'test-user-id',
|
||||
@@ -1222,8 +1155,6 @@ describe('MCP Routes', () => {
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
});
|
||||
MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
|
||||
const mockFlowManager = {
|
||||
completeFlow: jest.fn(),
|
||||
@@ -1248,11 +1179,6 @@ describe('MCP Routes', () => {
|
||||
it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => {
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api');
|
||||
const mockTokens = {
|
||||
access_token: 'edge-access-token',
|
||||
refresh_token: 'edge-refresh-token',
|
||||
};
|
||||
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockResolvedValue({
|
||||
@@ -1265,15 +1191,6 @@ describe('MCP Routes', () => {
|
||||
completeFlow: jest.fn(),
|
||||
};
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
MCPOAuthHandler.getFlowState.mockResolvedValue({
|
||||
serverName: 'test-server',
|
||||
userId: 'test-user-id',
|
||||
metadata: { serverUrl: 'https://example.com', oauth: {} },
|
||||
clientInfo: {},
|
||||
codeVerifier: 'test-verifier',
|
||||
});
|
||||
MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens);
|
||||
MCPTokenStorage.storeTokens.mockResolvedValue();
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue({
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config/mcp');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { getMCPTools } = require('~/server/controllers/mcp');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = Router();
|
||||
|
||||
/**
|
||||
* Get all MCP tools available to the user
|
||||
* Returns only MCP tools, completely decoupled from regular LibreChat tools
|
||||
*/
|
||||
router.get('/tools', requireJwtAuth, async (req, res) => {
|
||||
return getMCPTools(req, res);
|
||||
});
|
||||
|
||||
/**
|
||||
* Initiate OAuth flow
|
||||
* This endpoint is called when the user clicks the auth link in the UI
|
||||
@@ -130,41 +121,6 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
/** Persist tokens immediately so reconnection uses fresh credentials */
|
||||
if (flowState?.userId && tokens) {
|
||||
try {
|
||||
await MCPTokenStorage.storeTokens({
|
||||
userId: flowState.userId,
|
||||
serverName,
|
||||
tokens,
|
||||
createToken,
|
||||
updateToken,
|
||||
findToken,
|
||||
clientInfo: flowState.clientInfo,
|
||||
metadata: flowState.metadata,
|
||||
});
|
||||
logger.debug('[MCP OAuth] Stored OAuth tokens prior to reconnection', {
|
||||
serverName,
|
||||
userId: flowState.userId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth] Failed to store OAuth tokens after callback', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear any cached `mcp_get_tokens` flow result so subsequent lookups
|
||||
* re-fetch the freshly stored credentials instead of returning stale nulls.
|
||||
*/
|
||||
if (typeof flowManager?.deleteFlow === 'function') {
|
||||
try {
|
||||
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
|
||||
} catch (error) {
|
||||
logger.warn('[MCP OAuth] Failed to clear cached token flow state', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const mcpManager = getMCPManager(flowState.userId);
|
||||
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
|
||||
@@ -188,12 +144,9 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
// clear any reconnection attempts
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
|
||||
|
||||
const tools = await userConnection.fetchTools();
|
||||
await updateMCPServerTools({
|
||||
await updateMCPUserTools({
|
||||
userId: flowState.userId,
|
||||
serverName,
|
||||
tools,
|
||||
});
|
||||
@@ -335,9 +288,9 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
||||
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const userId = req.user?.id;
|
||||
const user = req.user;
|
||||
|
||||
if (!userId) {
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
@@ -351,7 +304,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
});
|
||||
}
|
||||
|
||||
await mcpManager.disconnectUserConnection(userId, serverName);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
||||
);
|
||||
@@ -360,14 +313,14 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
let userMCPAuthMap;
|
||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||
userMCPAuthMap = await getUserMCPAuthMap({
|
||||
userId,
|
||||
userId: user.id,
|
||||
servers: [serverName],
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
|
||||
const result = await reinitMCPServer({
|
||||
userId,
|
||||
req,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ const {
|
||||
} = require('~/models');
|
||||
const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update');
|
||||
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
|
||||
const { tokenValues, getValueKey, defaultRate } = require('~/models/tx');
|
||||
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
|
||||
const { getConvosQueried } = require('~/models/Conversation');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
@@ -160,6 +161,41 @@ router.post('/artifact/:messageId', async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* POST /costs
|
||||
* Get cost information for models in modelHistory array
|
||||
*/
|
||||
router.post('/costs', async (req, res) => {
|
||||
try {
|
||||
const { modelHistory } = req.body;
|
||||
|
||||
if (!Array.isArray(modelHistory)) {
|
||||
return res.status(400).json({ error: 'modelHistory must be an array' });
|
||||
}
|
||||
|
||||
const modelCostTable = {};
|
||||
|
||||
modelHistory.forEach((modelEntry) => {
|
||||
if (modelEntry && typeof modelEntry === 'object' && modelEntry.model && modelEntry.endpoint) {
|
||||
const { model, endpoint } = modelEntry;
|
||||
|
||||
const valueKey = getValueKey(model, endpoint);
|
||||
const pricing = tokenValues[valueKey];
|
||||
|
||||
modelCostTable[model] = {
|
||||
prompt: pricing?.prompt ?? defaultRate,
|
||||
completion: pricing?.completion ?? defaultRate,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
res.status(200).json({ modelCostTable });
|
||||
} catch (error) {
|
||||
logger.error('Error fetching model costs:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
|
||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -357,18 +357,23 @@ const resetPassword = async (userId, token, password) => {
|
||||
|
||||
/**
|
||||
* Set Auth Tokens
|
||||
*
|
||||
* @param {String | ObjectId} userId
|
||||
* @param {ServerResponse} res
|
||||
* @param {ISession | null} [session=null]
|
||||
* @param {Object} res
|
||||
* @param {String} sessionId
|
||||
* @returns
|
||||
*/
|
||||
const setAuthTokens = async (userId, res, _session = null) => {
|
||||
const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
try {
|
||||
let session = _session;
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
let session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
|
||||
if (session && session._id && session.expiration != null) {
|
||||
if (sessionId) {
|
||||
session = await findSession({ sessionId: sessionId }, { lean: false });
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
@@ -378,9 +383,6 @@ const setAuthTokens = async (userId, res, _session = null) => {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
httpOnly: true,
|
||||
|
||||
@@ -36,7 +36,7 @@ async function getAppConfig(options = {}) {
|
||||
}
|
||||
|
||||
if (baseConfig.availableTools) {
|
||||
await setCachedTools(baseConfig.availableTools);
|
||||
await setCachedTools(baseConfig.availableTools, { isGlobal: true });
|
||||
}
|
||||
|
||||
await cache.set(BASE_CONFIG_KEY, baseConfig);
|
||||
|
||||
@@ -3,32 +3,89 @@ const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* Cache key generators for different tool access patterns
|
||||
* These will support future permission-based caching
|
||||
*/
|
||||
const ToolCacheKeys = {
|
||||
/** Global tools available to all users */
|
||||
GLOBAL: 'tools:global',
|
||||
/** MCP tools cached by server name */
|
||||
MCP_SERVER: (serverName) => `tools:mcp:${serverName}`,
|
||||
/** Tools available to a specific user */
|
||||
USER: (userId) => `tools:user:${userId}`,
|
||||
/** Tools available to a specific role */
|
||||
ROLE: (roleId) => `tools:role:${roleId}`,
|
||||
/** Tools available to a specific group */
|
||||
GROUP: (groupId) => `tools:group:${groupId}`,
|
||||
/** Combined effective tools for a user (computed from all sources) */
|
||||
EFFECTIVE: (userId) => `tools:effective:${userId}`,
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves available tools from cache
|
||||
* @function getCachedTools
|
||||
* @param {Object} options - Options for retrieving tools
|
||||
* @param {string} [options.serverName] - MCP server name to get cached tools for
|
||||
* @param {string} [options.userId] - User ID for user-specific tools
|
||||
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
|
||||
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
|
||||
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
|
||||
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
|
||||
*/
|
||||
async function getCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { serverName } = options;
|
||||
const { userId, roleIds = [], groupIds = [], includeGlobal = true } = options;
|
||||
|
||||
// Return MCP server-specific tools if requested
|
||||
if (serverName) {
|
||||
return await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
|
||||
// For now, return global tools (current behavior)
|
||||
// This will be expanded to merge tools from different sources
|
||||
if (!userId && includeGlobal) {
|
||||
return await cache.get(ToolCacheKeys.GLOBAL);
|
||||
}
|
||||
|
||||
// Default to global tools
|
||||
return await cache.get(ToolCacheKeys.GLOBAL);
|
||||
// Future implementation will merge tools from multiple sources
|
||||
// based on user permissions, roles, and groups
|
||||
if (userId) {
|
||||
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
|
||||
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
|
||||
if (effectiveTools) {
|
||||
return effectiveTools;
|
||||
}
|
||||
|
||||
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
|
||||
const toolSources = [];
|
||||
|
||||
if (includeGlobal) {
|
||||
const globalTools = await cache.get(ToolCacheKeys.GLOBAL);
|
||||
if (globalTools) {
|
||||
toolSources.push(globalTools);
|
||||
}
|
||||
}
|
||||
|
||||
// User-specific tools
|
||||
const userTools = await cache.get(ToolCacheKeys.USER(userId));
|
||||
if (userTools) {
|
||||
toolSources.push(userTools);
|
||||
}
|
||||
|
||||
// Role-based tools
|
||||
for (const roleId of roleIds) {
|
||||
const roleTools = await cache.get(ToolCacheKeys.ROLE(roleId));
|
||||
if (roleTools) {
|
||||
toolSources.push(roleTools);
|
||||
}
|
||||
}
|
||||
|
||||
// Group-based tools
|
||||
for (const groupId of groupIds) {
|
||||
const groupTools = await cache.get(ToolCacheKeys.GROUP(groupId));
|
||||
if (groupTools) {
|
||||
toolSources.push(groupTools);
|
||||
}
|
||||
}
|
||||
|
||||
// Merge all tool sources (for now, simple merge - future will handle conflicts)
|
||||
if (toolSources.length > 0) {
|
||||
return mergeToolSources(toolSources);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,34 +93,49 @@ async function getCachedTools(options = {}) {
|
||||
* @function setCachedTools
|
||||
* @param {Object} tools - The tools object to cache
|
||||
* @param {Object} options - Options for caching tools
|
||||
* @param {string} [options.serverName] - MCP server name for server-specific tools
|
||||
* @param {string} [options.userId] - User ID for user-specific tools
|
||||
* @param {string} [options.roleId] - Role ID for role-based tools
|
||||
* @param {string} [options.groupId] - Group ID for group-based tools
|
||||
* @param {boolean} [options.isGlobal=false] - Whether these are global tools
|
||||
* @param {number} [options.ttl] - Time to live in milliseconds
|
||||
* @returns {Promise<boolean>} Whether the operation was successful
|
||||
*/
|
||||
async function setCachedTools(tools, options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { serverName, ttl } = options;
|
||||
const { userId, roleId, groupId, isGlobal = false, ttl } = options;
|
||||
|
||||
// Cache by MCP server if specified
|
||||
if (serverName) {
|
||||
return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl);
|
||||
let cacheKey;
|
||||
if (isGlobal || (!userId && !roleId && !groupId)) {
|
||||
cacheKey = ToolCacheKeys.GLOBAL;
|
||||
} else if (userId) {
|
||||
cacheKey = ToolCacheKeys.USER(userId);
|
||||
} else if (roleId) {
|
||||
cacheKey = ToolCacheKeys.ROLE(roleId);
|
||||
} else if (groupId) {
|
||||
cacheKey = ToolCacheKeys.GROUP(groupId);
|
||||
}
|
||||
|
||||
// Default to global cache
|
||||
return await cache.set(ToolCacheKeys.GLOBAL, tools, ttl);
|
||||
if (!cacheKey) {
|
||||
throw new Error('Invalid cache key options provided');
|
||||
}
|
||||
|
||||
return await cache.set(cacheKey, tools, ttl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalidates cached tools
|
||||
* @function invalidateCachedTools
|
||||
* @param {Object} options - Options for invalidating tools
|
||||
* @param {string} [options.serverName] - MCP server name to invalidate
|
||||
* @param {string} [options.userId] - User ID to invalidate
|
||||
* @param {string} [options.roleId] - Role ID to invalidate
|
||||
* @param {string} [options.groupId] - Group ID to invalidate
|
||||
* @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools
|
||||
* @param {boolean} [options.invalidateEffective=true] - Whether to invalidate effective tools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function invalidateCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const { serverName, invalidateGlobal = false } = options;
|
||||
const { userId, roleId, groupId, invalidateGlobal = false, invalidateEffective = true } = options;
|
||||
|
||||
const keysToDelete = [];
|
||||
|
||||
@@ -71,34 +143,116 @@ async function invalidateCachedTools(options = {}) {
|
||||
keysToDelete.push(ToolCacheKeys.GLOBAL);
|
||||
}
|
||||
|
||||
if (serverName) {
|
||||
keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName));
|
||||
if (userId) {
|
||||
keysToDelete.push(ToolCacheKeys.USER(userId));
|
||||
if (invalidateEffective) {
|
||||
keysToDelete.push(ToolCacheKeys.EFFECTIVE(userId));
|
||||
}
|
||||
}
|
||||
|
||||
if (roleId) {
|
||||
keysToDelete.push(ToolCacheKeys.ROLE(roleId));
|
||||
// TODO: In future, invalidate all users with this role
|
||||
}
|
||||
|
||||
if (groupId) {
|
||||
keysToDelete.push(ToolCacheKeys.GROUP(groupId));
|
||||
// TODO: In future, invalidate all users in this group
|
||||
}
|
||||
|
||||
await Promise.all(keysToDelete.map((key) => cache.delete(key)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets MCP tools for a specific server from cache or merges with global tools
|
||||
* @function getMCPServerTools
|
||||
* @param {string} serverName - The MCP server name
|
||||
* @returns {Promise<LCAvailableTools|null>} The available tools for the server
|
||||
* Computes and caches effective tools for a user
|
||||
* @function computeEffectiveTools
|
||||
* @param {string} userId - The user ID
|
||||
* @param {Object} context - Context containing user's roles and groups
|
||||
* @param {string[]} [context.roleIds=[]] - User's role IDs
|
||||
* @param {string[]} [context.groupIds=[]] - User's group IDs
|
||||
* @param {number} [ttl] - Time to live for the computed result
|
||||
* @returns {Promise<Object>} The computed effective tools
|
||||
*/
|
||||
async function getMCPServerTools(serverName) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName));
|
||||
async function computeEffectiveTools(userId, context = {}, ttl) {
|
||||
const { roleIds = [], groupIds = [] } = context;
|
||||
|
||||
if (serverTools) {
|
||||
return serverTools;
|
||||
// Get all tool sources
|
||||
const tools = await getCachedTools({
|
||||
userId,
|
||||
roleIds,
|
||||
groupIds,
|
||||
includeGlobal: true,
|
||||
});
|
||||
|
||||
if (tools) {
|
||||
// Cache the computed result
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.set(ToolCacheKeys.EFFECTIVE(userId), tools, ttl);
|
||||
}
|
||||
|
||||
return null;
|
||||
return tools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges multiple tool sources into a single tools object
|
||||
* @function mergeToolSources
|
||||
* @param {Object[]} sources - Array of tool objects to merge
|
||||
* @returns {Object} Merged tools object
|
||||
*/
|
||||
function mergeToolSources(sources) {
|
||||
// For now, simple merge that combines all tools
|
||||
// Future implementation will handle:
|
||||
// - Permission precedence (deny > allow)
|
||||
// - Tool property conflicts
|
||||
// - Metadata merging
|
||||
const merged = {};
|
||||
|
||||
for (const source of sources) {
|
||||
if (!source || typeof source !== 'object') {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const [toolId, toolConfig] of Object.entries(source)) {
|
||||
// Simple last-write-wins for now
|
||||
// Future: merge based on permission levels
|
||||
merged[toolId] = toolConfig;
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
/**
|
||||
* Middleware-friendly function to get tools for a request
|
||||
* @function getToolsForRequest
|
||||
* @param {Object} req - Express request object
|
||||
* @returns {Promise<Object|null>} Available tools for the request
|
||||
*/
|
||||
async function getToolsForRequest(req) {
|
||||
const userId = req.user?.id;
|
||||
|
||||
// For now, return global tools if no user
|
||||
if (!userId) {
|
||||
return getCachedTools({ includeGlobal: true });
|
||||
}
|
||||
|
||||
// Future: Extract roles and groups from req.user
|
||||
const roleIds = req.user?.roles || [];
|
||||
const groupIds = req.user?.groups || [];
|
||||
|
||||
return getCachedTools({
|
||||
userId,
|
||||
roleIds,
|
||||
groupIds,
|
||||
includeGlobal: true,
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
ToolCacheKeys,
|
||||
getCachedTools,
|
||||
setCachedTools,
|
||||
getMCPServerTools,
|
||||
getToolsForRequest,
|
||||
invalidateCachedTools,
|
||||
computeEffectiveTools,
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const appConfig = require('./app');
|
||||
const mcpToolsCache = require('./mcp');
|
||||
const { config } = require('./EndpointService');
|
||||
const getCachedTools = require('./getCachedTools');
|
||||
const mcpToolsCache = require('./mcpToolsCache');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const loadConfigModels = require('./loadConfigModels');
|
||||
const loadDefaultModels = require('./loadDefaultModels');
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getCachedTools, setCachedTools } = require('./getCachedTools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Updates MCP tools in the cache for a specific server
|
||||
* @param {Object} params - Parameters for updating MCP tools
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||
* @returns {Promise<LCAvailableTools>}
|
||||
*/
|
||||
async function updateMCPServerTools({ serverName, tools }) {
|
||||
try {
|
||||
const serverTools = {};
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${mcpDelimiter}${serverName}`;
|
||||
serverTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
await setCachedTools(serverTools, { serverName });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`);
|
||||
return serverTools;
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges app-level tools with global tools
|
||||
* @param {import('@librechat/api').LCAvailableTools} appTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function mergeAppTools(appTools) {
|
||||
try {
|
||||
const count = Object.keys(appTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
const cachedTools = await getCachedTools();
|
||||
const mergedTools = { ...cachedTools, ...appTools };
|
||||
await setCachedTools(mergedTools);
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`Merged ${count} app-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge app-level tools:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Caches MCP server tools (no longer merges with global)
|
||||
* @param {object} params
|
||||
* @param {string} params.serverName
|
||||
* @param {import('@librechat/api').LCAvailableTools} params.serverTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function cacheMCPServerTools({ serverName, serverTools }) {
|
||||
try {
|
||||
const count = Object.keys(serverTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
// Only cache server-specific tools, no merging with global
|
||||
await setCachedTools(serverTools, { serverName });
|
||||
logger.debug(`Cached ${count} MCP server tools for ${serverName}`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to cache MCP server tools for ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
mergeAppTools,
|
||||
cacheMCPServerTools,
|
||||
updateMCPServerTools,
|
||||
};
|
||||
143
api/server/services/Config/mcpToolsCache.js
Normal file
143
api/server/services/Config/mcpToolsCache.js
Normal file
@@ -0,0 +1,143 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getCachedTools, setCachedTools } = require('./getCachedTools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Updates MCP tools in the cache for a specific server and user
|
||||
* @param {Object} params - Parameters for updating MCP tools
|
||||
* @param {string} params.userId - User ID
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||
* @returns {Promise<LCAvailableTools>}
|
||||
*/
|
||||
async function updateMCPUserTools({ userId, serverName, tools }) {
|
||||
try {
|
||||
const userTools = await getCachedTools({ userId });
|
||||
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
userTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
await setCachedTools(userTools, { userId });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
|
||||
return userTools;
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges app-level tools with global tools
|
||||
* @param {import('@librechat/api').LCAvailableTools} appTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function mergeAppTools(appTools) {
|
||||
try {
|
||||
const count = Object.keys(appTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
const cachedTools = await getCachedTools({ includeGlobal: true });
|
||||
const mergedTools = { ...cachedTools, ...appTools };
|
||||
await setCachedTools(mergedTools, { isGlobal: true });
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`Merged ${count} app-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge app-level tools:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges user-level tools with global tools
|
||||
* @param {object} params
|
||||
* @param {string} params.userId
|
||||
* @param {Record<string, FunctionTool>} params.cachedUserTools
|
||||
* @param {import('@librechat/api').LCAvailableTools} params.userTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function mergeUserTools({ userId, cachedUserTools, userTools }) {
|
||||
try {
|
||||
if (!userId) {
|
||||
return;
|
||||
}
|
||||
const count = Object.keys(userTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
const cachedTools = cachedUserTools ?? (await getCachedTools({ userId }));
|
||||
const mergedTools = { ...cachedTools, ...userTools };
|
||||
await setCachedTools(mergedTools, { userId });
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`Merged ${count} user-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge user-level tools:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears all MCP tools for a specific server
|
||||
* @param {Object} params - Parameters for clearing MCP tools
|
||||
* @param {string} [params.userId] - User ID (if clearing user-specific tools)
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function clearMCPServerTools({ userId, serverName }) {
|
||||
try {
|
||||
const tools = await getCachedTools({ userId, includeGlobal: !userId });
|
||||
|
||||
// Remove all tools for this server
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
let removedCount = 0;
|
||||
for (const key of Object.keys(tools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete tools[key];
|
||||
removedCount++;
|
||||
}
|
||||
}
|
||||
|
||||
if (removedCount > 0) {
|
||||
await setCachedTools(tools, userId ? { userId } : { isGlobal: true });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
|
||||
logger.debug(
|
||||
`[MCP Cache] Removed ${removedCount} tools for ${serverName}${userId ? ` user ${userId}` : ' (global)'}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to clear tools for ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
mergeAppTools,
|
||||
mergeUserTools,
|
||||
updateMCPUserTools,
|
||||
clearMCPServerTools,
|
||||
};
|
||||
@@ -552,7 +552,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
throw new Error('File search is not enabled for Agents');
|
||||
}
|
||||
// Note: File search processing continues to dual storage logic below
|
||||
} else if (tool_resource === EToolResources.context) {
|
||||
} else if (tool_resource === EToolResources.ocr) {
|
||||
const { file_id, temp_file_id = null } = metadata;
|
||||
|
||||
/**
|
||||
|
||||
@@ -20,10 +20,10 @@ const {
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, getAppConfig } = require('./Config');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getAppConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
@@ -152,8 +152,8 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.userId - The user ID from the request object.
|
||||
* @param {string} params.serverName
|
||||
* @param {AbortSignal} params.signal
|
||||
* @param {string} params.model
|
||||
@@ -161,9 +161,9 @@ function createOAuthCallback({ runStepEmitter, runStepDeltaEmitter }) {
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap }) {
|
||||
async function reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap }) {
|
||||
const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID;
|
||||
const flowId = `${userId}:${serverName}:${Date.now()}`;
|
||||
const flowId = `${req.user?.id}:${serverName}:${Date.now()}`;
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
const stepId = 'step_oauth_login_' + serverName;
|
||||
const toolCall = {
|
||||
@@ -192,7 +192,7 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
|
||||
flowManager,
|
||||
});
|
||||
return await reinitMCPServer({
|
||||
userId,
|
||||
req,
|
||||
signal,
|
||||
serverName,
|
||||
oauthStart,
|
||||
@@ -211,8 +211,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
|
||||
* i.e. `availableTools`, and will reinitialize the MCP server to ensure all tools are generated.
|
||||
*
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.userId - The user ID from the request object.
|
||||
* @param {string} params.serverName
|
||||
* @param {string} params.model
|
||||
* @param {Providers | EModelEndpoint} params.provider - The provider for the tool.
|
||||
@@ -221,16 +221,8 @@ async function reconnectServer({ res, userId, index, signal, serverName, userMCP
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @returns { Promise<Array<typeof tool | { _call: (toolInput: Object | string) => unknown}>> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTools({
|
||||
res,
|
||||
userId,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
}) {
|
||||
const result = await reconnectServer({ res, userId, index, signal, serverName, userMCPAuthMap });
|
||||
async function createMCPTools({ req, res, index, signal, serverName, provider, userMCPAuthMap }) {
|
||||
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
|
||||
if (!result || !result.tools) {
|
||||
logger.warn(`[MCP][${serverName}] Failed to reinitialize MCP server.`);
|
||||
return;
|
||||
@@ -239,8 +231,8 @@ async function createMCPTools({
|
||||
const serverTools = [];
|
||||
for (const tool of result.tools) {
|
||||
const toolInstance = await createMCPTool({
|
||||
req,
|
||||
res,
|
||||
userId,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
availableTools: result.availableTools,
|
||||
@@ -257,8 +249,8 @@ async function createMCPTools({
|
||||
/**
|
||||
* Creates a single tool from the specified MCP Server via `toolKey`.
|
||||
* @param {Object} params
|
||||
* @param {ServerRequest} params.req - The Express request object, containing user/request info.
|
||||
* @param {ServerResponse} params.res - The Express response object for sending events.
|
||||
* @param {string} params.userId - The user ID from the request object.
|
||||
* @param {string} params.toolKey - The toolKey for the tool.
|
||||
* @param {string} params.model - The model for the tool.
|
||||
* @param {number} [params.index]
|
||||
@@ -269,31 +261,26 @@ async function createMCPTools({
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({
|
||||
req,
|
||||
res,
|
||||
userId,
|
||||
index,
|
||||
signal,
|
||||
toolKey,
|
||||
provider,
|
||||
userMCPAuthMap,
|
||||
availableTools,
|
||||
availableTools: tools,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
const availableTools =
|
||||
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
|
||||
/** @type {LCTool | undefined} */
|
||||
let toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
logger.warn(
|
||||
`[MCP][${serverName}][${toolName}] Requested tool not found in available tools, re-initializing MCP server.`,
|
||||
);
|
||||
const result = await reconnectServer({
|
||||
res,
|
||||
userId,
|
||||
index,
|
||||
signal,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
const result = await reconnectServer({ req, res, index, signal, serverName, userMCPAuthMap });
|
||||
toolDefinition = result?.availableTools?.[toolKey]?.function;
|
||||
}
|
||||
|
||||
@@ -551,20 +538,13 @@ async function getServerConnectionStatus(
|
||||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
// connection state overrides specific to OAuth servers
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
// check if server is actively being reconnected
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
|
||||
finalConnectionState = 'connecting';
|
||||
} else {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ jest.mock('./Config', () => ({
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getOAuthReconnectionManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
@@ -49,7 +48,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
let mockGetMCPManager;
|
||||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
let mockGetOAuthReconnectionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -58,7 +56,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
mockGetMCPManager = require('~/config').getMCPManager;
|
||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
@@ -357,12 +354,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
@@ -379,12 +370,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return failed flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -416,12 +401,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return active flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -453,12 +432,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return no flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => null),
|
||||
@@ -481,35 +454,6 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
});
|
||||
});
|
||||
|
||||
it('should return connecting state when OAuth server is reconnecting', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager to return true for isReconnecting
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => true),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'connecting',
|
||||
});
|
||||
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not check OAuth flow status when server is connected', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
|
||||
@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
};
|
||||
|
||||
const userId = await createUser(userData, true, true);
|
||||
const userId = await createUser(userData, true, false);
|
||||
return userId.toString();
|
||||
}
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ async function saveUserMessage(req, params) {
|
||||
parentMessageId: params.parentMessageId ?? Constants.NO_PARENT,
|
||||
/* For messages, use the assistant_id instead of model */
|
||||
model: params.assistant_id,
|
||||
targetModel: params.model,
|
||||
thread_id: params.thread_id,
|
||||
sender: 'User',
|
||||
text: params.text,
|
||||
|
||||
@@ -74,7 +74,7 @@ async function processRequiredActions(client, requiredActions) {
|
||||
requiredActions,
|
||||
);
|
||||
const appConfig = client.req.config;
|
||||
const toolDefinitions = await getCachedTools();
|
||||
const toolDefinitions = await getCachedTools({ userId: client.req.user.id, includeGlobal: true });
|
||||
const seenToolkits = new Set();
|
||||
const tools = requiredActions
|
||||
.map((action) => {
|
||||
@@ -353,12 +353,7 @@ async function processRequiredActions(client, requiredActions) {
|
||||
async function loadAgentTools({ req, res, agent, signal, tool_resources, openAIApiKey }) {
|
||||
if (!agent.tools || agent.tools.length === 0) {
|
||||
return {};
|
||||
} else if (
|
||||
agent.tools &&
|
||||
agent.tools.length === 1 &&
|
||||
/** Legacy handling for `ocr` as may still exist in existing Agents */
|
||||
(agent.tools[0] === AgentCapabilities.context || agent.tools[0] === AgentCapabilities.ocr)
|
||||
) {
|
||||
} else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) {
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,12 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { findToken, createToken, updateToken, deleteTokens } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { updateMCPServerTools } = require('~/server/services/Config');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {string} params.serverName - The name of the MCP server
|
||||
* @param {boolean} params.returnOnOAuth - Whether to initiate OAuth and return, or wait for OAuth flow to finish
|
||||
* @param {AbortSignal} [params.signal] - The abort signal to handle cancellation.
|
||||
@@ -18,7 +18,7 @@ const { getLogStores } = require('~/cache');
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
*/
|
||||
async function reinitMCPServer({
|
||||
userId,
|
||||
req,
|
||||
signal,
|
||||
forceNew,
|
||||
serverName,
|
||||
@@ -44,14 +44,14 @@ async function reinitMCPServer({
|
||||
const oauthStart =
|
||||
_oauthStart ??
|
||||
(async (authURL) => {
|
||||
logger.info(`[MCP Reinitialize] OAuth URL received for ${serverName}`);
|
||||
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
|
||||
oauthUrl = authURL;
|
||||
oauthRequired = true;
|
||||
});
|
||||
|
||||
try {
|
||||
userConnection = await mcpManager.getUserConnection({
|
||||
user: { id: userId },
|
||||
user: req.user,
|
||||
signal,
|
||||
forceNew,
|
||||
oauthStart,
|
||||
@@ -97,7 +97,8 @@ async function reinitMCPServer({
|
||||
|
||||
if (userConnection && !oauthRequired) {
|
||||
tools = await userConnection.fetchTools();
|
||||
availableTools = await updateMCPServerTools({
|
||||
availableTools = await updateMCPUserTools({
|
||||
userId: req.user.id,
|
||||
serverName,
|
||||
tools,
|
||||
});
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Initialize OAuth reconnect manager
|
||||
*/
|
||||
async function initializeOAuthReconnectManager() {
|
||||
try {
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
const tokenMethods = {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
};
|
||||
await createOAuthReconnectionManager(flowManager, tokenMethods);
|
||||
logger.info(`OAuth reconnect manager initialized successfully.`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize OAuth reconnect manager:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = initializeOAuthReconnectManager;
|
||||
@@ -10,10 +10,6 @@ jest.mock('~/models/Message', () => ({
|
||||
bulkSaveMessages: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/ConversationTag', () => ({
|
||||
bulkIncrementTagCounts: jest.fn(),
|
||||
}));
|
||||
|
||||
let mockIdCounter = 0;
|
||||
jest.mock('uuid', () => {
|
||||
return {
|
||||
@@ -26,13 +22,11 @@ jest.mock('uuid', () => {
|
||||
|
||||
const {
|
||||
forkConversation,
|
||||
duplicateConversation,
|
||||
splitAtTargetLevel,
|
||||
getAllMessagesUpToParent,
|
||||
getMessagesUpToTargetLevel,
|
||||
cloneMessagesWithTimestamps,
|
||||
} = require('./fork');
|
||||
const { bulkIncrementTagCounts } = require('~/models/ConversationTag');
|
||||
const { getConvo, bulkSaveConvos } = require('~/models/Conversation');
|
||||
const { getMessages, bulkSaveMessages } = require('~/models/Message');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
@@ -187,120 +181,6 @@ describe('forkConversation', () => {
|
||||
}),
|
||||
).rejects.toThrow('Failed to fetch messages');
|
||||
});
|
||||
|
||||
test('should increment tag counts when forking conversation with tags', async () => {
|
||||
const mockConvoWithTags = {
|
||||
...mockConversation,
|
||||
tags: ['bookmark1', 'bookmark2'],
|
||||
};
|
||||
getConvo.mockResolvedValue(mockConvoWithTags);
|
||||
|
||||
await forkConversation({
|
||||
originalConvoId: 'abc123',
|
||||
targetMessageId: '3',
|
||||
requestUserId: 'user1',
|
||||
option: ForkOptions.DIRECT_PATH,
|
||||
});
|
||||
|
||||
// Verify that bulkIncrementTagCounts was called with correct tags
|
||||
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['bookmark1', 'bookmark2']);
|
||||
});
|
||||
|
||||
test('should handle conversation without tags when forking', async () => {
|
||||
const mockConvoWithoutTags = {
|
||||
...mockConversation,
|
||||
// No tags field
|
||||
};
|
||||
getConvo.mockResolvedValue(mockConvoWithoutTags);
|
||||
|
||||
await forkConversation({
|
||||
originalConvoId: 'abc123',
|
||||
targetMessageId: '3',
|
||||
requestUserId: 'user1',
|
||||
option: ForkOptions.DIRECT_PATH,
|
||||
});
|
||||
|
||||
// bulkIncrementTagCounts will be called with array containing undefined
|
||||
expect(bulkIncrementTagCounts).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle empty tags array when forking', async () => {
|
||||
const mockConvoWithEmptyTags = {
|
||||
...mockConversation,
|
||||
tags: [],
|
||||
};
|
||||
getConvo.mockResolvedValue(mockConvoWithEmptyTags);
|
||||
|
||||
await forkConversation({
|
||||
originalConvoId: 'abc123',
|
||||
targetMessageId: '3',
|
||||
requestUserId: 'user1',
|
||||
option: ForkOptions.DIRECT_PATH,
|
||||
});
|
||||
|
||||
// bulkIncrementTagCounts will be called with empty array
|
||||
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []);
|
||||
});
|
||||
});
|
||||
|
||||
describe('duplicateConversation', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockIdCounter = 0;
|
||||
getConvo.mockResolvedValue(mockConversation);
|
||||
getMessages.mockResolvedValue(mockMessages);
|
||||
bulkSaveConvos.mockResolvedValue(null);
|
||||
bulkSaveMessages.mockResolvedValue(null);
|
||||
bulkIncrementTagCounts.mockResolvedValue(null);
|
||||
});
|
||||
|
||||
test('should duplicate conversation and increment tag counts', async () => {
|
||||
const mockConvoWithTags = {
|
||||
...mockConversation,
|
||||
tags: ['important', 'work', 'project'],
|
||||
};
|
||||
getConvo.mockResolvedValue(mockConvoWithTags);
|
||||
|
||||
await duplicateConversation({
|
||||
userId: 'user1',
|
||||
conversationId: 'abc123',
|
||||
});
|
||||
|
||||
// Verify that bulkIncrementTagCounts was called with correct tags
|
||||
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', ['important', 'work', 'project']);
|
||||
});
|
||||
|
||||
test('should duplicate conversation without tags', async () => {
|
||||
const mockConvoWithoutTags = {
|
||||
...mockConversation,
|
||||
// No tags field
|
||||
};
|
||||
getConvo.mockResolvedValue(mockConvoWithoutTags);
|
||||
|
||||
await duplicateConversation({
|
||||
userId: 'user1',
|
||||
conversationId: 'abc123',
|
||||
});
|
||||
|
||||
// bulkIncrementTagCounts will be called with array containing undefined
|
||||
expect(bulkIncrementTagCounts).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle empty tags array when duplicating', async () => {
|
||||
const mockConvoWithEmptyTags = {
|
||||
...mockConversation,
|
||||
tags: [],
|
||||
};
|
||||
getConvo.mockResolvedValue(mockConvoWithEmptyTags);
|
||||
|
||||
await duplicateConversation({
|
||||
userId: 'user1',
|
||||
conversationId: 'abc123',
|
||||
});
|
||||
|
||||
// bulkIncrementTagCounts will be called with empty array
|
||||
expect(bulkIncrementTagCounts).toHaveBeenCalledWith('user1', []);
|
||||
});
|
||||
});
|
||||
|
||||
const mockMessagesComplex = [
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
|
||||
const { bulkIncrementTagCounts } = require('~/models/ConversationTag');
|
||||
const { bulkSaveConvos } = require('~/models/Conversation');
|
||||
const { bulkSaveMessages } = require('~/models/Message');
|
||||
const { logger } = require('~/config');
|
||||
@@ -94,22 +93,13 @@ class ImportBatchBuilder {
|
||||
|
||||
/**
|
||||
* Saves the batch of conversations and messages to the DB.
|
||||
* Also increments tag counts for any existing tags.
|
||||
* @returns {Promise<void>} A promise that resolves when the batch is saved.
|
||||
* @throws {Error} If there is an error saving the batch.
|
||||
*/
|
||||
async saveBatch() {
|
||||
try {
|
||||
const promises = [];
|
||||
promises.push(bulkSaveConvos(this.conversations));
|
||||
promises.push(bulkSaveMessages(this.messages, true));
|
||||
promises.push(
|
||||
bulkIncrementTagCounts(
|
||||
this.requestUserId,
|
||||
this.conversations.flatMap((convo) => convo.tags),
|
||||
),
|
||||
);
|
||||
await Promise.all(promises);
|
||||
await bulkSaveConvos(this.conversations);
|
||||
await bulkSaveMessages(this.messages, true);
|
||||
logger.debug(
|
||||
`user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`,
|
||||
);
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const nodemailer = require('nodemailer');
|
||||
const handlebars = require('handlebars');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logAxiosError, isEnabled, readFileAsString } = require('@librechat/api');
|
||||
const { logAxiosError, isEnabled } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Sends an email using Mailgun API.
|
||||
@@ -92,7 +93,8 @@ const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => {
|
||||
*/
|
||||
const sendEmail = async ({ email, subject, payload, template, throwError = true }) => {
|
||||
try {
|
||||
const { content: source } = await readFileAsString(path.join(__dirname, 'emails', template));
|
||||
// Read and compile the email template
|
||||
const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8');
|
||||
const compiledTemplate = handlebars.compile(source);
|
||||
const html = compiledTemplate(payload);
|
||||
|
||||
|
||||
@@ -109,8 +109,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
||||
const username =
|
||||
(LDAP_USERNAME && userinfo[LDAP_USERNAME]) || userinfo.givenName || userinfo.mail;
|
||||
|
||||
let mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local';
|
||||
mail = Array.isArray(mail) ? mail[0] : mail;
|
||||
const mail = (LDAP_EMAIL && userinfo[LDAP_EMAIL]) || userinfo.mail || username + '@ldap.local';
|
||||
|
||||
if (!userinfo.mail && !(LDAP_EMAIL && userinfo[LDAP_EMAIL])) {
|
||||
logger.warn(
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
// --- Mocks ---
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
// isEnabled used for TLS flags
|
||||
isEnabled: jest.fn(() => false),
|
||||
getBalanceConfig: jest.fn(() => ({ enabled: false })),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
createUser: jest.fn(),
|
||||
updateUser: jest.fn(),
|
||||
countUsers: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/domains', () => ({
|
||||
isEmailDomainAllowed: jest.fn(() => true),
|
||||
}));
|
||||
|
||||
// Mock passport-ldapauth to capture verify callback
|
||||
let verifyCallback;
|
||||
jest.mock('passport-ldapauth', () => {
|
||||
return jest.fn().mockImplementation((options, verify) => {
|
||||
verifyCallback = verify; // capture the strategy verify function
|
||||
return { name: 'ldap', options, verify };
|
||||
});
|
||||
});
|
||||
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { findUser, createUser, updateUser, countUsers } = require('~/models');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
|
||||
// Helper to call the verify callback and wrap in a Promise for convenience
|
||||
const callVerify = (userinfo) =>
|
||||
new Promise((resolve, reject) => {
|
||||
verifyCallback(userinfo, (err, user, info) => {
|
||||
if (err) return reject(err);
|
||||
resolve({ user, info });
|
||||
});
|
||||
});
|
||||
|
||||
describe('ldapStrategy', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// minimal required env for ldapStrategy module to export
|
||||
process.env.LDAP_URL = 'ldap://example.com';
|
||||
process.env.LDAP_USER_SEARCH_BASE = 'ou=users,dc=example,dc=com';
|
||||
|
||||
// Unset optional envs to exercise defaults
|
||||
delete process.env.LDAP_CA_CERT_PATH;
|
||||
delete process.env.LDAP_FULL_NAME;
|
||||
delete process.env.LDAP_ID;
|
||||
delete process.env.LDAP_USERNAME;
|
||||
delete process.env.LDAP_EMAIL;
|
||||
delete process.env.LDAP_TLS_REJECT_UNAUTHORIZED;
|
||||
delete process.env.LDAP_STARTTLS;
|
||||
|
||||
// Default model/domain mocks
|
||||
findUser.mockReset().mockResolvedValue(null);
|
||||
createUser.mockReset().mockResolvedValue('newUserId');
|
||||
updateUser.mockReset().mockImplementation(async (id, user) => ({ _id: id, ...user }));
|
||||
countUsers.mockReset().mockResolvedValue(0);
|
||||
isEmailDomainAllowed.mockReset().mockReturnValue(true);
|
||||
|
||||
// Ensure requiring the strategy sets up the verify callback
|
||||
jest.isolateModules(() => {
|
||||
require('./ldapStrategy');
|
||||
});
|
||||
});
|
||||
|
||||
it('uses the first email when LDAP returns multiple emails (array)', async () => {
|
||||
const userinfo = {
|
||||
uid: 'uid123',
|
||||
givenName: 'Alice',
|
||||
cn: 'Alice Doe',
|
||||
mail: ['first@example.com', 'second@example.com'],
|
||||
};
|
||||
|
||||
const { user } = await callVerify(userinfo);
|
||||
|
||||
expect(user.email).toBe('first@example.com');
|
||||
expect(createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: 'ldap',
|
||||
ldapId: 'uid123',
|
||||
username: 'Alice',
|
||||
email: 'first@example.com',
|
||||
emailVerified: true,
|
||||
name: 'Alice Doe',
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('blocks login if an existing user has a different provider', async () => {
|
||||
findUser.mockResolvedValue({ _id: 'u1', email: 'first@example.com', provider: 'google' });
|
||||
|
||||
const userinfo = {
|
||||
uid: 'uid123',
|
||||
mail: 'first@example.com',
|
||||
givenName: 'Alice',
|
||||
cn: 'Alice Doe',
|
||||
};
|
||||
|
||||
const { user, info } = await callVerify(userinfo);
|
||||
|
||||
expect(user).toBe(false);
|
||||
expect(info).toEqual({ message: ErrorTypes.AUTH_FAILED });
|
||||
expect(createUser).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('updates an existing ldap user with current LDAP info', async () => {
|
||||
const existing = {
|
||||
_id: 'u2',
|
||||
provider: 'ldap',
|
||||
email: 'old@example.com',
|
||||
ldapId: 'uid123',
|
||||
username: 'olduser',
|
||||
name: 'Old Name',
|
||||
};
|
||||
findUser.mockResolvedValue(existing);
|
||||
|
||||
const userinfo = {
|
||||
uid: 'uid123',
|
||||
mail: 'new@example.com',
|
||||
givenName: 'NewFirst',
|
||||
cn: 'NewFirst NewLast',
|
||||
};
|
||||
|
||||
const { user } = await callVerify(userinfo);
|
||||
|
||||
expect(createUser).not.toHaveBeenCalled();
|
||||
expect(updateUser).toHaveBeenCalledWith(
|
||||
'u2',
|
||||
expect.objectContaining({
|
||||
provider: 'ldap',
|
||||
ldapId: 'uid123',
|
||||
email: 'new@example.com',
|
||||
username: 'NewFirst',
|
||||
name: 'NewFirst NewLast',
|
||||
}),
|
||||
);
|
||||
expect(user.email).toBe('new@example.com');
|
||||
});
|
||||
|
||||
it('falls back to username@ldap.local when no email attributes are present', async () => {
|
||||
const userinfo = {
|
||||
uid: 'uid999',
|
||||
givenName: 'John',
|
||||
cn: 'John Doe',
|
||||
// no mail and no custom LDAP_EMAIL
|
||||
};
|
||||
|
||||
const { user } = await callVerify(userinfo);
|
||||
|
||||
expect(user.email).toBe('John@ldap.local');
|
||||
});
|
||||
|
||||
it('denies login if email domain is not allowed', async () => {
|
||||
isEmailDomainAllowed.mockReturnValue(false);
|
||||
|
||||
const userinfo = {
|
||||
uid: 'uid123',
|
||||
mail: 'notallowed@blocked.com',
|
||||
givenName: 'Alice',
|
||||
cn: 'Alice Doe',
|
||||
};
|
||||
|
||||
const { user, info } = await callVerify(userinfo);
|
||||
expect(user).toBe(false);
|
||||
expect(info).toEqual({ message: 'Email domain not allowed' });
|
||||
});
|
||||
});
|
||||
@@ -41,18 +41,13 @@ const openIdJwtLogin = (openIdConfig) => {
|
||||
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
|
||||
secretOrKeyProvider: jwksRsa.passportJwtSecret(jwksRsaOptions),
|
||||
},
|
||||
/**
|
||||
* @param {import('openid-client').IDToken} payload
|
||||
* @param {import('passport-jwt').VerifyCallback} done
|
||||
*/
|
||||
async (payload, done) => {
|
||||
try {
|
||||
const { user, error, migration } = await findOpenIDUser({
|
||||
findUser,
|
||||
email: payload?.email,
|
||||
openidId: payload?.sub,
|
||||
idOnTheSource: payload?.oid,
|
||||
email: payload?.email,
|
||||
strategyName: 'openIdJwtLogin',
|
||||
findUser,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
|
||||
@@ -337,10 +337,6 @@ async function setupOpenId() {
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
usePKCE,
|
||||
},
|
||||
/**
|
||||
* @param {import('openid-client').TokenEndpointResponseHelpers} tokenset
|
||||
* @param {import('passport-jwt').VerifyCallback} done
|
||||
*/
|
||||
async (tokenset, done) => {
|
||||
try {
|
||||
const claims = tokenset.claims();
|
||||
@@ -358,11 +354,10 @@ async function setupOpenId() {
|
||||
}
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
findUser,
|
||||
email: claims.email,
|
||||
openidId: claims.sub,
|
||||
idOnTheSource: claims.oid,
|
||||
email: claims.email,
|
||||
strategyName: 'openidStrategy',
|
||||
findUser,
|
||||
});
|
||||
let user = result.user;
|
||||
const error = result.error;
|
||||
@@ -376,10 +371,6 @@ async function setupOpenId() {
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
if (requiredRole) {
|
||||
const requiredRoles = requiredRole
|
||||
.split(',')
|
||||
.map((role) => role.trim())
|
||||
.filter(Boolean);
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access') {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
@@ -402,13 +393,9 @@ async function setupOpenId() {
|
||||
);
|
||||
}
|
||||
|
||||
if (!requiredRoles.some((role) => roles.includes(role))) {
|
||||
const rolesList =
|
||||
requiredRoles.length === 1
|
||||
? `"${requiredRoles[0]}"`
|
||||
: `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`;
|
||||
if (!roles.includes(requiredRole)) {
|
||||
return done(null, false, {
|
||||
message: `You must have ${rolesList} role to log in.`,
|
||||
message: `You must have the "${requiredRole}" role to log in.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -441,10 +428,6 @@ async function setupOpenId() {
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
if (userinfo.email && userinfo.email !== user.email) {
|
||||
user.email = userinfo.email;
|
||||
user.emailVerified = userinfo.email_verified || false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
|
||||
@@ -274,7 +274,10 @@ describe('setupOpenId', () => {
|
||||
name: '',
|
||||
};
|
||||
findUser.mockImplementation(async (query) => {
|
||||
if (query.openidId === tokenset.claims().sub || query.email === tokenset.claims().email) {
|
||||
if (
|
||||
query.openidId === tokenset.claims().sub ||
|
||||
(query.email === tokenset.claims().email && query.provider === 'openid')
|
||||
) {
|
||||
return existingUser;
|
||||
}
|
||||
return null;
|
||||
@@ -335,25 +338,7 @@ describe('setupOpenId', () => {
|
||||
|
||||
// Assert – verify that the strategy rejects login
|
||||
expect(user).toBe(false);
|
||||
expect(details.message).toBe('You must have "requiredRole" role to log in.');
|
||||
});
|
||||
|
||||
it('should allow login when single required role is present (backward compatibility)', async () => {
|
||||
// Arrange – ensure single role configuration (as set in beforeEach)
|
||||
// OPENID_REQUIRED_ROLE = 'requiredRole'
|
||||
// Default jwtDecode mock in beforeEach already returns this role
|
||||
jwtDecode.mockReturnValue({
|
||||
roles: ['requiredRole', 'anotherRole'],
|
||||
});
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert – verify that login succeeds with single role configuration
|
||||
expect(user).toBeTruthy();
|
||||
expect(user.email).toBe(tokenset.claims().email);
|
||||
expect(user.username).toBe(tokenset.claims().preferred_username);
|
||||
expect(createUser).toHaveBeenCalled();
|
||||
expect(details.message).toBe('You must have the "requiredRole" role to log in.');
|
||||
});
|
||||
|
||||
it('should attempt to download and save the avatar if picture is provided', async () => {
|
||||
@@ -379,58 +364,6 @@ describe('setupOpenId', () => {
|
||||
// Depending on your implementation, user.avatar may be undefined or an empty string.
|
||||
});
|
||||
|
||||
it('should support comma-separated multiple roles', async () => {
|
||||
// Arrange
|
||||
process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin';
|
||||
await setupOpenId(); // Re-initialize the strategy
|
||||
verifyCallback = require('openid-client/passport').__getVerifyCallback();
|
||||
jwtDecode.mockReturnValue({
|
||||
roles: ['anotherRole', 'aThirdRole'],
|
||||
});
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert
|
||||
expect(user).toBeTruthy();
|
||||
expect(user.email).toBe(tokenset.claims().email);
|
||||
});
|
||||
|
||||
it('should reject login when user has none of the required multiple roles', async () => {
|
||||
// Arrange
|
||||
process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin';
|
||||
await setupOpenId(); // Re-initialize the strategy
|
||||
verifyCallback = require('openid-client/passport').__getVerifyCallback();
|
||||
jwtDecode.mockReturnValue({
|
||||
roles: ['aThirdRole', 'aFourthRole'],
|
||||
});
|
||||
|
||||
// Act
|
||||
const { user, details } = await validate(tokenset);
|
||||
|
||||
// Assert
|
||||
expect(user).toBe(false);
|
||||
expect(details.message).toBe(
|
||||
'You must have one of: "someRole", "anotherRole", "admin" role to log in.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle spaces in comma-separated roles', async () => {
|
||||
// Arrange
|
||||
process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin ';
|
||||
await setupOpenId(); // Re-initialize the strategy
|
||||
verifyCallback = require('openid-client/passport').__getVerifyCallback();
|
||||
jwtDecode.mockReturnValue({
|
||||
roles: ['someRole'],
|
||||
});
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert
|
||||
expect(user).toBeTruthy();
|
||||
});
|
||||
|
||||
it('should default to usePKCE false when OPENID_USE_PKCE is not defined', async () => {
|
||||
const OpenIDStrategy = require('openid-client/passport').Strategy;
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ describe('fileSearch.js - test only new file_id and page additions', () => {
|
||||
queryVectors.mockResolvedValue(mockResults);
|
||||
|
||||
const fileSearchTool = await createFileSearchTool({
|
||||
userId: 'user1',
|
||||
req: { user: { id: 'user1' } },
|
||||
files: mockFiles,
|
||||
entity_id: 'agent-123',
|
||||
});
|
||||
|
||||
@@ -873,13 +873,6 @@
|
||||
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ISession
|
||||
* @typedef {import('@librechat/data-schemas').ISession} ISession
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports IBalance
|
||||
* @typedef {import('@librechat/data-schemas').IBalance} IBalance
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { MCP, Action, TPlugin } from 'librechat-data-provider';
|
||||
import { Constants, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
|
||||
import type { AgentPanelContextType, MCPServerInfo } from '~/common';
|
||||
import {
|
||||
useAvailableToolsQuery,
|
||||
useGetActionsQuery,
|
||||
useGetStartupConfig,
|
||||
useMCPToolsQuery,
|
||||
} from '~/data-provider';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery, useGetStartupConfig } from '~/data-provider';
|
||||
import { useLocalize, useGetAgentsConfig, useMCPConnectionStatus } from '~/hooks';
|
||||
import { Panel, isEphemeralAgent } from '~/common';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
type GroupedToolType = AgentToolType & { tools?: AgentToolType[] };
|
||||
type GroupedToolsRecord = Record<string, GroupedToolType>;
|
||||
|
||||
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
|
||||
|
||||
@@ -30,67 +28,79 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
const [activePanel, setActivePanel] = useState<Panel>(Panel.builder);
|
||||
const [agent_id, setCurrentAgentId] = useState<string | undefined>(undefined);
|
||||
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const { data: actions } = useGetActionsQuery(EModelEndpoint.agents, {
|
||||
enabled: !isEphemeralAgent(agent_id),
|
||||
enabled: !!agent_id,
|
||||
});
|
||||
|
||||
const { data: regularTools } = useAvailableToolsQuery(EModelEndpoint.agents, {
|
||||
enabled: !isEphemeralAgent(agent_id),
|
||||
const { data: pluginTools } = useAvailableToolsQuery(EModelEndpoint.agents, {
|
||||
enabled: !!agent_id,
|
||||
});
|
||||
|
||||
const { data: mcpData } = useMCPToolsQuery({
|
||||
enabled: !isEphemeralAgent(agent_id) && startupConfig?.mcpServers != null,
|
||||
});
|
||||
|
||||
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const mcpServerNames = useMemo(
|
||||
() => Object.keys(startupConfig?.mcpServers ?? {}),
|
||||
[startupConfig],
|
||||
);
|
||||
|
||||
const { connectionStatus } = useMCPConnectionStatus({
|
||||
enabled: !isEphemeralAgent(agent_id) && mcpServerNames.length > 0,
|
||||
enabled: !!agent_id && mcpServerNames.length > 0,
|
||||
});
|
||||
|
||||
const mcpServersMap = useMemo(() => {
|
||||
const processedData = useMemo(() => {
|
||||
if (!pluginTools) {
|
||||
return {
|
||||
tools: [],
|
||||
groupedTools: {},
|
||||
mcpServersMap: new Map<string, MCPServerInfo>(),
|
||||
};
|
||||
}
|
||||
|
||||
const tools: AgentToolType[] = [];
|
||||
const groupedTools: GroupedToolsRecord = {};
|
||||
|
||||
const configuredServers = new Set(mcpServerNames);
|
||||
const serversMap = new Map<string, MCPServerInfo>();
|
||||
const mcpServersMap = new Map<string, MCPServerInfo>();
|
||||
|
||||
if (mcpData?.servers) {
|
||||
for (const [serverName, serverData] of Object.entries(mcpData.servers)) {
|
||||
const metadata = {
|
||||
name: serverName,
|
||||
pluginKey: serverName,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: serverData.icon || '',
|
||||
authConfig: serverData.authConfig,
|
||||
authenticated: serverData.authenticated,
|
||||
} as TPlugin;
|
||||
for (const pluginTool of pluginTools) {
|
||||
const tool: AgentToolType = {
|
||||
tool_id: pluginTool.pluginKey,
|
||||
metadata: pluginTool as TPlugin,
|
||||
};
|
||||
|
||||
const tools = serverData.tools.map((tool) => ({
|
||||
tool_id: tool.pluginKey,
|
||||
metadata: {
|
||||
...tool,
|
||||
icon: serverData.icon,
|
||||
authConfig: serverData.authConfig,
|
||||
authenticated: serverData.authenticated,
|
||||
} as TPlugin,
|
||||
}));
|
||||
tools.push(tool);
|
||||
|
||||
serversMap.set(serverName, {
|
||||
serverName,
|
||||
tools,
|
||||
isConfigured: configuredServers.has(serverName),
|
||||
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
|
||||
metadata,
|
||||
});
|
||||
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
|
||||
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
|
||||
|
||||
if (!mcpServersMap.has(serverName)) {
|
||||
const metadata = {
|
||||
name: serverName,
|
||||
pluginKey: serverName,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: pluginTool.icon || '',
|
||||
} as TPlugin;
|
||||
|
||||
mcpServersMap.set(serverName, {
|
||||
serverName,
|
||||
tools: [],
|
||||
isConfigured: configuredServers.has(serverName),
|
||||
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
mcpServersMap.get(serverName)!.tools.push(tool);
|
||||
} else {
|
||||
// Non-MCP tool
|
||||
groupedTools[tool.tool_id] = {
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Add configured servers that don't have tools yet
|
||||
for (const mcpServerName of mcpServerNames) {
|
||||
if (serversMap.has(mcpServerName)) {
|
||||
if (mcpServersMap.has(mcpServerName)) {
|
||||
continue;
|
||||
}
|
||||
const metadata = {
|
||||
@@ -100,7 +110,7 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${mcpServerName}`,
|
||||
} as TPlugin;
|
||||
|
||||
serversMap.set(mcpServerName, {
|
||||
mcpServersMap.set(mcpServerName, {
|
||||
tools: [],
|
||||
metadata,
|
||||
isConfigured: true,
|
||||
@@ -109,8 +119,14 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
});
|
||||
}
|
||||
|
||||
return serversMap;
|
||||
}, [mcpData, localize, mcpServerNames, connectionStatus]);
|
||||
return {
|
||||
tools,
|
||||
groupedTools,
|
||||
mcpServersMap,
|
||||
};
|
||||
}, [pluginTools, localize, mcpServerNames, connectionStatus]);
|
||||
|
||||
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
|
||||
|
||||
const value: AgentPanelContextType = {
|
||||
mcp,
|
||||
@@ -121,14 +137,16 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
setMcps,
|
||||
agent_id,
|
||||
setAction,
|
||||
pluginTools,
|
||||
activePanel,
|
||||
regularTools,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
mcpServersMap,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
setCurrentAgentId,
|
||||
tools: processedData.tools,
|
||||
groupedTools: processedData.groupedTools,
|
||||
mcpServersMap: processedData.mcpServersMap,
|
||||
};
|
||||
|
||||
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface DragDropContextValue {
|
||||
conversationId: string | null | undefined;
|
||||
agentId: string | null | undefined;
|
||||
}
|
||||
|
||||
const DragDropContext = createContext<DragDropContextValue | undefined>(undefined);
|
||||
|
||||
export function DragDropProvider({ children }: { children: React.ReactNode }) {
|
||||
const { conversation } = useChatContext();
|
||||
|
||||
/** Context value only created when conversation fields change */
|
||||
const contextValue = useMemo<DragDropContextValue>(
|
||||
() => ({
|
||||
conversationId: conversation?.conversationId,
|
||||
agentId: conversation?.agent_id,
|
||||
}),
|
||||
[conversation?.conversationId, conversation?.agent_id],
|
||||
);
|
||||
|
||||
return <DragDropContext.Provider value={contextValue}>{children}</DragDropContext.Provider>;
|
||||
}
|
||||
|
||||
export function useDragDropContext() {
|
||||
const context = useContext(DragDropContext);
|
||||
if (!context) {
|
||||
throw new Error('useDragDropContext must be used within DragDropProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -1,15 +1,10 @@
|
||||
import { createContext, useContext } from 'react';
|
||||
|
||||
type MessageContext = {
|
||||
messageId: string;
|
||||
nextType?: string;
|
||||
partIndex?: number;
|
||||
isExpanded: boolean;
|
||||
conversationId?: string | null;
|
||||
/** Submission state for cursor display - only true for latest message when submitting */
|
||||
isSubmitting?: boolean;
|
||||
/** Whether this is the latest message in the conversation */
|
||||
isLatestMessage?: boolean;
|
||||
};
|
||||
|
||||
export const MessageContext = createContext<MessageContext>({} as MessageContext);
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { useAddedChatContext } from './AddedChatContext';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface MessagesViewContextValue {
|
||||
/** Core conversation data */
|
||||
conversation: ReturnType<typeof useChatContext>['conversation'];
|
||||
conversationId: string | null | undefined;
|
||||
|
||||
/** Submission and control states */
|
||||
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
|
||||
isSubmittingFamily: boolean;
|
||||
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
|
||||
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
|
||||
|
||||
/** Message operations */
|
||||
ask: ReturnType<typeof useChatContext>['ask'];
|
||||
regenerate: ReturnType<typeof useChatContext>['regenerate'];
|
||||
handleContinue: ReturnType<typeof useChatContext>['handleContinue'];
|
||||
|
||||
/** Message state management */
|
||||
index: ReturnType<typeof useChatContext>['index'];
|
||||
latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
|
||||
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
|
||||
getMessages: ReturnType<typeof useChatContext>['getMessages'];
|
||||
setMessages: ReturnType<typeof useChatContext>['setMessages'];
|
||||
}
|
||||
|
||||
const MessagesViewContext = createContext<MessagesViewContextValue | undefined>(undefined);
|
||||
|
||||
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
|
||||
const chatContext = useChatContext();
|
||||
const addedChatContext = useAddedChatContext();
|
||||
|
||||
const {
|
||||
ask,
|
||||
index,
|
||||
regenerate,
|
||||
isSubmitting: isSubmittingRoot,
|
||||
conversation,
|
||||
latestMessage,
|
||||
setAbortScroll,
|
||||
handleContinue,
|
||||
setLatestMessage,
|
||||
abortScroll,
|
||||
getMessages,
|
||||
setMessages,
|
||||
} = chatContext;
|
||||
|
||||
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
|
||||
|
||||
/** Memoize conversation-related values */
|
||||
const conversationValues = useMemo(
|
||||
() => ({
|
||||
conversation,
|
||||
conversationId: conversation?.conversationId,
|
||||
}),
|
||||
[conversation],
|
||||
);
|
||||
|
||||
/** Memoize submission states */
|
||||
const submissionStates = useMemo(
|
||||
() => ({
|
||||
isSubmitting: isSubmittingRoot,
|
||||
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
|
||||
abortScroll,
|
||||
setAbortScroll,
|
||||
}),
|
||||
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
|
||||
);
|
||||
|
||||
/** Memoize message operations (these are typically stable references) */
|
||||
const messageOperations = useMemo(
|
||||
() => ({
|
||||
ask,
|
||||
regenerate,
|
||||
getMessages,
|
||||
setMessages,
|
||||
handleContinue,
|
||||
}),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
|
||||
/** Memoize message state values */
|
||||
const messageState = useMemo(
|
||||
() => ({
|
||||
index,
|
||||
latestMessage,
|
||||
setLatestMessage,
|
||||
}),
|
||||
[index, latestMessage, setLatestMessage],
|
||||
);
|
||||
|
||||
/** Combine all values into final context value */
|
||||
const contextValue = useMemo<MessagesViewContextValue>(
|
||||
() => ({
|
||||
...conversationValues,
|
||||
...submissionStates,
|
||||
...messageOperations,
|
||||
...messageState,
|
||||
}),
|
||||
[conversationValues, submissionStates, messageOperations, messageState],
|
||||
);
|
||||
|
||||
return (
|
||||
<MessagesViewContext.Provider value={contextValue}>{children}</MessagesViewContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useMessagesViewContext() {
|
||||
const context = useContext(MessagesViewContext);
|
||||
if (!context) {
|
||||
throw new Error('useMessagesViewContext must be used within MessagesViewProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
/** Hook for components that only need conversation data */
|
||||
export function useMessagesConversation() {
|
||||
const { conversation, conversationId } = useMessagesViewContext();
|
||||
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
|
||||
}
|
||||
|
||||
/** Hook for components that only need submission states */
|
||||
export function useMessagesSubmission() {
|
||||
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
|
||||
useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
|
||||
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
|
||||
);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message operations */
|
||||
export function useMessagesOperations() {
|
||||
const { ask, regenerate, handleContinue, getMessages, setMessages } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ ask, regenerate, handleContinue, getMessages, setMessages }),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message state */
|
||||
export function useMessagesState() {
|
||||
const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ index, latestMessage, setLatestMessage }),
|
||||
[index, latestMessage, setLatestMessage],
|
||||
);
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
import React, { createContext, useContext, ReactNode, useMemo } from 'react';
|
||||
import { PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import type { TPromptGroup } from 'librechat-data-provider';
|
||||
import type { PromptOption } from '~/common';
|
||||
import CategoryIcon from '~/components/Prompts/Groups/CategoryIcon';
|
||||
import { usePromptGroupsNav, useHasAccess } from '~/hooks';
|
||||
import { useGetAllPromptGroups } from '~/data-provider';
|
||||
import { usePromptGroupsNav } from '~/hooks';
|
||||
import { mapPromptGroups } from '~/utils';
|
||||
|
||||
type AllPromptGroupsData =
|
||||
@@ -20,21 +19,14 @@ type PromptGroupsContextType =
|
||||
data: AllPromptGroupsData;
|
||||
isLoading: boolean;
|
||||
};
|
||||
hasAccess: boolean;
|
||||
})
|
||||
| null;
|
||||
|
||||
const PromptGroupsContext = createContext<PromptGroupsContextType>(null);
|
||||
|
||||
export const PromptGroupsProvider = ({ children }: { children: ReactNode }) => {
|
||||
const hasAccess = useHasAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
|
||||
const promptGroupsNav = usePromptGroupsNav(hasAccess);
|
||||
const promptGroupsNav = usePromptGroupsNav();
|
||||
const { data: allGroupsData, isLoading: isLoadingAll } = useGetAllPromptGroups(undefined, {
|
||||
enabled: hasAccess,
|
||||
select: (data) => {
|
||||
const mappedArray: PromptOption[] = data.map((group) => ({
|
||||
id: group._id ?? '',
|
||||
@@ -63,12 +55,11 @@ export const PromptGroupsProvider = ({ children }: { children: ReactNode }) => {
|
||||
() => ({
|
||||
...promptGroupsNav,
|
||||
allPromptGroups: {
|
||||
data: hasAccess ? allGroupsData : undefined,
|
||||
isLoading: hasAccess ? isLoadingAll : false,
|
||||
data: allGroupsData,
|
||||
isLoading: isLoadingAll,
|
||||
},
|
||||
hasAccess,
|
||||
}),
|
||||
[promptGroupsNav, allGroupsData, isLoadingAll, hasAccess],
|
||||
[promptGroupsNav, allGroupsData, isLoadingAll],
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -23,9 +23,7 @@ export * from './SetConvoContext';
|
||||
export * from './SearchContext';
|
||||
export * from './BadgeRowContext';
|
||||
export * from './SidePanelContext';
|
||||
export * from './DragDropContext';
|
||||
export * from './MCPPanelContext';
|
||||
export * from './ArtifactsContext';
|
||||
export * from './PromptGroupsContext';
|
||||
export * from './MessagesViewContext';
|
||||
export { default as BadgeRowProvider } from './BadgeRowContext';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React from 'react';
|
||||
import { TStartupConfig } from 'librechat-data-provider';
|
||||
import { TModelSpec, TStartupConfig } from 'librechat-data-provider';
|
||||
|
||||
export interface Endpoint {
|
||||
value: string;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { RefObject } from 'react';
|
||||
import { Constants, FileSources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { FileSources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { UseMutationResult } from '@tanstack/react-query';
|
||||
import type * as InputNumberPrimitive from 'rc-input-number';
|
||||
import type { SetterOrUpdater, RecoilState } from 'recoil';
|
||||
@@ -8,10 +8,6 @@ import type * as t from 'librechat-data-provider';
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
import type { TranslationKeys } from '~/hooks';
|
||||
|
||||
export function isEphemeralAgent(agentId: string | null | undefined): boolean {
|
||||
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
|
||||
}
|
||||
|
||||
export interface ConfigFieldDetail {
|
||||
title: string;
|
||||
description: string;
|
||||
@@ -236,8 +232,10 @@ export type AgentPanelContextType = {
|
||||
mcps?: t.MCP[];
|
||||
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
|
||||
setMcps: React.Dispatch<React.SetStateAction<t.MCP[] | undefined>>;
|
||||
groupedTools: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
|
||||
activePanel?: string;
|
||||
regularTools?: t.TPlugin[];
|
||||
tools: t.AgentToolType[];
|
||||
pluginTools?: t.TPlugin[];
|
||||
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
|
||||
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
agent_id?: string;
|
||||
|
||||
@@ -11,9 +11,9 @@ import {
|
||||
AgentListResponse,
|
||||
} from 'librechat-data-provider';
|
||||
import type t from 'librechat-data-provider';
|
||||
import { useLocalize, useDefaultConvo } from '~/hooks';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { renderAgentAvatar } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface SupportContact {
|
||||
name?: string;
|
||||
@@ -34,11 +34,11 @@ interface AgentDetailProps {
|
||||
*/
|
||||
const AgentDetail: React.FC<AgentDetailProps> = ({ agent, isOpen, onClose }) => {
|
||||
const localize = useLocalize();
|
||||
const queryClient = useQueryClient();
|
||||
// const navigate = useNavigate();
|
||||
const { conversation, newConversation } = useChatContext();
|
||||
const { showToast } = useToastContext();
|
||||
const dialogRef = useRef<HTMLDivElement>(null);
|
||||
const getDefaultConversation = useDefaultConvo();
|
||||
const { conversation, newConversation } = useChatContext();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
/**
|
||||
* Navigate to chat with the selected agent
|
||||
@@ -62,22 +62,13 @@ const AgentDetail: React.FC<AgentDetailProps> = ({ agent, isOpen, onClose }) =>
|
||||
);
|
||||
queryClient.invalidateQueries([QueryKeys.messages]);
|
||||
|
||||
/** Template with agent configuration */
|
||||
const template = {
|
||||
conversationId: Constants.NEW_CONVO as string,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
agent_id: agent.id,
|
||||
title: localize('com_agents_chat_with', { name: agent.name || localize('com_ui_agent') }),
|
||||
};
|
||||
|
||||
const currentConvo = getDefaultConversation({
|
||||
conversation: { ...(conversation ?? {}), ...template },
|
||||
preset: template,
|
||||
});
|
||||
|
||||
newConversation({
|
||||
template: currentConvo,
|
||||
preset: template,
|
||||
template: {
|
||||
conversationId: Constants.NEW_CONVO as string,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
agent_id: agent.id,
|
||||
title: `Chat with ${agent.name || 'Agent'}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -20,7 +20,6 @@ jest.mock('react-router-dom', () => ({
|
||||
jest.mock('~/hooks', () => ({
|
||||
useMediaQuery: jest.fn(() => false), // Mock as desktop by default
|
||||
useLocalize: jest.fn(),
|
||||
useDefaultConvo: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/client', () => ({
|
||||
@@ -48,12 +47,7 @@ const mockWriteText = jest.fn();
|
||||
|
||||
const mockNavigate = jest.fn();
|
||||
const mockShowToast = jest.fn();
|
||||
const mockLocalize = jest.fn((key: string, values?: Record<string, any>) => {
|
||||
if (key === 'com_agents_chat_with' && values?.name) {
|
||||
return `Chat with ${values.name}`;
|
||||
}
|
||||
return key;
|
||||
});
|
||||
const mockLocalize = jest.fn((key: string) => key);
|
||||
|
||||
const mockAgent: t.Agent = {
|
||||
id: 'test-agent-id',
|
||||
@@ -112,12 +106,8 @@ describe('AgentDetail', () => {
|
||||
(useNavigate as jest.Mock).mockReturnValue(mockNavigate);
|
||||
const { useToastContext } = require('@librechat/client');
|
||||
(useToastContext as jest.Mock).mockReturnValue({ showToast: mockShowToast });
|
||||
const { useLocalize, useDefaultConvo } = require('~/hooks');
|
||||
const { useLocalize } = require('~/hooks');
|
||||
(useLocalize as jest.Mock).mockReturnValue(mockLocalize);
|
||||
(useDefaultConvo as jest.Mock).mockReturnValue(() => ({
|
||||
conversationId: Constants.NEW_CONVO,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
}));
|
||||
|
||||
// Mock useChatContext
|
||||
const { useChatContext } = require('~/Providers');
|
||||
@@ -237,10 +227,6 @@ describe('AgentDetail', () => {
|
||||
template: {
|
||||
conversationId: Constants.NEW_CONVO,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
},
|
||||
preset: {
|
||||
conversationId: Constants.NEW_CONVO,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
agent_id: 'test-agent-id',
|
||||
title: 'Chat with Test Agent',
|
||||
},
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
SandpackProvider,
|
||||
SandpackProviderProps,
|
||||
} from '@codesandbox/sandpack-react/unstyled';
|
||||
import type { SandpackPreviewRef, PreviewProps } from '@codesandbox/sandpack-react/unstyled';
|
||||
import type { SandpackPreviewRef } from '@codesandbox/sandpack-react/unstyled';
|
||||
import type { TStartupConfig } from 'librechat-data-provider';
|
||||
import type { ArtifactFiles } from '~/common';
|
||||
import { sharedFiles, sharedOptions } from '~/utils/artifacts';
|
||||
@@ -13,7 +13,6 @@ export const ArtifactPreview = memo(function ({
|
||||
files,
|
||||
fileKey,
|
||||
template,
|
||||
isMermaid,
|
||||
sharedProps,
|
||||
previewRef,
|
||||
currentCode,
|
||||
@@ -21,7 +20,6 @@ export const ArtifactPreview = memo(function ({
|
||||
}: {
|
||||
files: ArtifactFiles;
|
||||
fileKey: string;
|
||||
isMermaid: boolean;
|
||||
template: SandpackProviderProps['template'];
|
||||
sharedProps: Partial<SandpackProviderProps>;
|
||||
previewRef: React.MutableRefObject<SandpackPreviewRef>;
|
||||
@@ -56,15 +54,6 @@ export const ArtifactPreview = memo(function ({
|
||||
return _options;
|
||||
}, [startupConfig, template]);
|
||||
|
||||
const style: PreviewProps['style'] | undefined = useMemo(() => {
|
||||
if (isMermaid) {
|
||||
return {
|
||||
backgroundColor: '#282C34',
|
||||
};
|
||||
}
|
||||
return;
|
||||
}, [isMermaid]);
|
||||
|
||||
if (Object.keys(artifactFiles).length === 0) {
|
||||
return null;
|
||||
}
|
||||
@@ -84,7 +73,6 @@ export const ArtifactPreview = memo(function ({
|
||||
showRefreshButton={false}
|
||||
tabIndex={0}
|
||||
ref={previewRef}
|
||||
style={style}
|
||||
/>
|
||||
</SandpackProvider>
|
||||
);
|
||||
|
||||
@@ -8,7 +8,6 @@ import { useAutoScroll } from '~/hooks/Artifacts/useAutoScroll';
|
||||
import { ArtifactCodeEditor } from './ArtifactCodeEditor';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import { ArtifactPreview } from './ArtifactPreview';
|
||||
import { MermaidMarkdown } from './MermaidMarkdown';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function ArtifactTabs({
|
||||
@@ -45,25 +44,23 @@ export default function ArtifactTabs({
|
||||
id="artifacts-code"
|
||||
className={cn('flex-grow overflow-auto')}
|
||||
>
|
||||
{isMermaid ? (
|
||||
<MermaidMarkdown content={content} isSubmitting={isSubmitting} />
|
||||
) : (
|
||||
<ArtifactCodeEditor
|
||||
files={files}
|
||||
fileKey={fileKey}
|
||||
template={template}
|
||||
artifact={artifact}
|
||||
editorRef={editorRef}
|
||||
sharedProps={sharedProps}
|
||||
/>
|
||||
)}
|
||||
<ArtifactCodeEditor
|
||||
files={files}
|
||||
fileKey={fileKey}
|
||||
template={template}
|
||||
artifact={artifact}
|
||||
editorRef={editorRef}
|
||||
sharedProps={sharedProps}
|
||||
/>
|
||||
</Tabs.Content>
|
||||
<Tabs.Content value="preview" className="flex-grow overflow-auto">
|
||||
<Tabs.Content
|
||||
value="preview"
|
||||
className={cn('flex-grow overflow-auto', isMermaid ? 'bg-[#282C34]' : 'bg-white')}
|
||||
>
|
||||
<ArtifactPreview
|
||||
files={files}
|
||||
fileKey={fileKey}
|
||||
template={template}
|
||||
isMermaid={isMermaid}
|
||||
previewRef={previewRef}
|
||||
sharedProps={sharedProps}
|
||||
currentCode={currentCode}
|
||||
|
||||
@@ -9,7 +9,6 @@ import { useEditorContext } from '~/Providers';
|
||||
import ArtifactTabs from './ArtifactTabs';
|
||||
import { CopyCodeButton } from './Code';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
export default function Artifacts() {
|
||||
@@ -59,10 +58,9 @@ export default function Artifacts() {
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
{/* Main Container */}
|
||||
<div
|
||||
className={cn(
|
||||
`flex h-full w-full flex-col overflow-hidden border border-border-medium bg-surface-primary text-xl text-text-primary shadow-xl transition-all duration-500 ease-in-out`,
|
||||
isVisible ? 'scale-100 opacity-100 blur-0' : 'scale-105 opacity-0 blur-sm',
|
||||
)}
|
||||
className={`flex h-full w-full flex-col overflow-hidden border border-border-medium bg-surface-primary text-xl text-text-primary shadow-xl transition-all duration-500 ease-in-out ${
|
||||
isVisible ? 'scale-100 opacity-100 blur-0' : 'scale-105 opacity-0 blur-sm'
|
||||
}`}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between border-b border-border-medium bg-surface-primary-alt p-2">
|
||||
@@ -76,17 +74,16 @@ export default function Artifacts() {
|
||||
{/* Refresh button */}
|
||||
{activeTab === 'preview' && (
|
||||
<button
|
||||
className={cn(
|
||||
'mr-2 text-text-secondary transition-transform duration-500 ease-in-out',
|
||||
isRefreshing ? 'rotate-180' : '',
|
||||
)}
|
||||
className={`mr-2 text-text-secondary transition-transform duration-500 ease-in-out ${
|
||||
isRefreshing ? 'rotate-180' : ''
|
||||
}`}
|
||||
onClick={handleRefresh}
|
||||
disabled={isRefreshing}
|
||||
aria-label="Refresh"
|
||||
>
|
||||
<RefreshCw
|
||||
size={16}
|
||||
className={cn('transform', isRefreshing ? 'animate-spin' : '')}
|
||||
className={`transform ${isRefreshing ? 'animate-spin' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
import { CodeMarkdown } from './Code';
|
||||
|
||||
export function MermaidMarkdown({
|
||||
content,
|
||||
isSubmitting,
|
||||
}: {
|
||||
content: string;
|
||||
isSubmitting: boolean;
|
||||
}) {
|
||||
return <CodeMarkdown content={`\`\`\`mermaid\n${content}\`\`\``} isSubmitting={isSubmitting} />;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useState, useEffect, useRef } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { Spinner } from '@librechat/client';
|
||||
@@ -13,6 +13,7 @@ import { useGetMessagesByConvoId } from '~/data-provider';
|
||||
import MessagesView from './Messages/MessagesView';
|
||||
import Presentation from './Presentation';
|
||||
import ChatForm from './Input/ChatForm';
|
||||
import CostBar from './CostBar';
|
||||
import Landing from './Landing';
|
||||
import Header from './Header';
|
||||
import Footer from './Footer';
|
||||
@@ -29,7 +30,13 @@ function LoadingSpinner() {
|
||||
);
|
||||
}
|
||||
|
||||
function ChatView({ index = 0 }: { index?: number }) {
|
||||
function ChatView({
|
||||
index = 0,
|
||||
modelCosts,
|
||||
}: {
|
||||
index?: number;
|
||||
modelCosts?: { modelCostTable: Record<string, { prompt: number; completion: number }> };
|
||||
}) {
|
||||
const { conversationId } = useParams();
|
||||
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
|
||||
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
|
||||
@@ -37,6 +44,9 @@ function ChatView({ index = 0 }: { index?: number }) {
|
||||
|
||||
const fileMap = useFileMapContext();
|
||||
|
||||
const [showCostBar, setShowCostBar] = useState(false);
|
||||
const lastScrollY = useRef(0);
|
||||
|
||||
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
|
||||
select: useCallback(
|
||||
(data: TMessage[]) => {
|
||||
@@ -54,6 +64,58 @@ function ChatView({ index = 0 }: { index?: number }) {
|
||||
useSSE(rootSubmission, chatHelpers, false);
|
||||
useSSE(addedSubmission, addedChatHelpers, true);
|
||||
|
||||
const checkIfAtBottom = useCallback(
|
||||
(container: HTMLElement) => {
|
||||
const currentScrollY = container.scrollTop;
|
||||
const scrollHeight = container.scrollHeight;
|
||||
const clientHeight = container.clientHeight;
|
||||
|
||||
const distanceFromBottom = scrollHeight - currentScrollY - clientHeight;
|
||||
const isAtBottom = distanceFromBottom < 10;
|
||||
|
||||
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
|
||||
setShowCostBar(isAtBottom && !isStreaming);
|
||||
lastScrollY.current = currentScrollY;
|
||||
},
|
||||
[chatHelpers.isSubmitting, addedChatHelpers.isSubmitting],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const handleScroll = (event: Event) => {
|
||||
const target = event.target as HTMLElement;
|
||||
checkIfAtBottom(target);
|
||||
};
|
||||
|
||||
const findAndAttachScrollListener = () => {
|
||||
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
|
||||
if (messagesContainer) {
|
||||
checkIfAtBottom(messagesContainer as HTMLElement);
|
||||
|
||||
messagesContainer.addEventListener('scroll', handleScroll, { passive: true });
|
||||
return () => {
|
||||
messagesContainer.removeEventListener('scroll', handleScroll);
|
||||
};
|
||||
}
|
||||
setTimeout(findAndAttachScrollListener, 100);
|
||||
};
|
||||
|
||||
const cleanup = findAndAttachScrollListener();
|
||||
|
||||
return cleanup;
|
||||
}, [messagesTree, checkIfAtBottom]);
|
||||
|
||||
useEffect(() => {
|
||||
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
|
||||
if (isStreaming) {
|
||||
setShowCostBar(false);
|
||||
} else {
|
||||
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
|
||||
if (messagesContainer) {
|
||||
checkIfAtBottom(messagesContainer as HTMLElement);
|
||||
}
|
||||
}
|
||||
}, [chatHelpers.isSubmitting, addedChatHelpers.isSubmitting, checkIfAtBottom]);
|
||||
|
||||
const methods = useForm<ChatFormValues>({
|
||||
defaultValues: { text: '' },
|
||||
});
|
||||
@@ -69,7 +131,22 @@ function ChatView({ index = 0 }: { index?: number }) {
|
||||
} else if ((isLoading || isNavigating) && !isLandingPage) {
|
||||
content = <LoadingSpinner />;
|
||||
} else if (!isLandingPage) {
|
||||
content = <MessagesView messagesTree={messagesTree} />;
|
||||
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
|
||||
content = (
|
||||
<MessagesView
|
||||
messagesTree={messagesTree}
|
||||
costBar={
|
||||
!isLandingPage &&
|
||||
modelCosts && (
|
||||
<CostBar
|
||||
messagesTree={messagesTree}
|
||||
modelCosts={modelCosts}
|
||||
showCostBar={showCostBar && !isStreaming}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
content = <Landing centerFormOnLanding={centerFormOnLanding} />;
|
||||
}
|
||||
|
||||
112
client/src/components/Chat/CostBar.tsx
Normal file
112
client/src/components/Chat/CostBar.tsx
Normal file
@@ -0,0 +1,112 @@
|
||||
import { useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { ArrowIcon } from '@librechat/client';
|
||||
import { TModelCosts, TMessage } from 'librechat-data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
interface CostBarProps {
|
||||
messagesTree: TMessage[];
|
||||
modelCosts: TModelCosts;
|
||||
showCostBar: boolean;
|
||||
}
|
||||
|
||||
export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) {
|
||||
const localize = useLocalize();
|
||||
const showCostTracking = useRecoilValue(store.showCostTracking);
|
||||
|
||||
const conversationCosts = useMemo(() => {
|
||||
if (!modelCosts?.modelCostTable || !messagesTree) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let totalPromptTokens = 0;
|
||||
let totalCompletionTokens = 0;
|
||||
let totalPromptUSD = 0;
|
||||
let totalCompletionUSD = 0;
|
||||
|
||||
const flattenMessages = (messages: TMessage[]) => {
|
||||
const flattened: TMessage[] = [];
|
||||
messages.forEach((message: TMessage) => {
|
||||
flattened.push(message);
|
||||
if (message.children && message.children.length > 0) {
|
||||
flattened.push(...flattenMessages(message.children));
|
||||
}
|
||||
});
|
||||
return flattened;
|
||||
};
|
||||
|
||||
const allMessages = flattenMessages(messagesTree);
|
||||
|
||||
allMessages.forEach((message) => {
|
||||
if (!message.tokenCount) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const modelToUse = message.isCreatedByUser ? message.targetModel : message.model;
|
||||
|
||||
const modelPricing = modelCosts.modelCostTable[modelToUse];
|
||||
if (message.isCreatedByUser) {
|
||||
totalPromptTokens += message.tokenCount;
|
||||
totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt;
|
||||
} else {
|
||||
totalCompletionTokens += message.tokenCount;
|
||||
totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion;
|
||||
}
|
||||
});
|
||||
|
||||
const totalTokens = totalPromptTokens + totalCompletionTokens;
|
||||
const totalUSD = totalPromptUSD + totalCompletionUSD;
|
||||
|
||||
return {
|
||||
totals: {
|
||||
prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD },
|
||||
completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD },
|
||||
total: { tokenCount: totalTokens, usd: totalUSD },
|
||||
},
|
||||
};
|
||||
}, [modelCosts, messagesTree]);
|
||||
|
||||
if (!showCostTracking || !conversationCosts || !conversationCosts.totals) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'mx-auto w-full max-w-md px-4 text-xs text-muted-foreground transition-all duration-300 ease-in-out',
|
||||
showCostBar ? 'opacity-100' : 'opacity-0',
|
||||
)}
|
||||
>
|
||||
<div className="grid grid-cols-3 gap-2 text-center">
|
||||
<div>
|
||||
<div>
|
||||
<ArrowIcon direction="up" />
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: conversationCosts.totals.prompt.tokenCount,
|
||||
})}
|
||||
</div>
|
||||
<div>${Math.abs(conversationCosts.totals.prompt.usd).toFixed(6)}</div>
|
||||
</div>
|
||||
<div>
|
||||
<div>
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: conversationCosts.totals.total.tokenCount,
|
||||
})}
|
||||
</div>
|
||||
<div>${Math.abs(conversationCosts.totals.total.usd).toFixed(6)}</div>
|
||||
</div>
|
||||
<div>
|
||||
<div>
|
||||
<ArrowIcon direction="down" />
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: conversationCosts.totals.completion.tokenCount,
|
||||
})}
|
||||
</div>
|
||||
<div>${Math.abs(conversationCosts.totals.completion.usd).toFixed(6)}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useRef, useState, useMemo } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
|
||||
import { EToolResources, EModelEndpoint, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import {
|
||||
@@ -42,9 +42,7 @@ const AttachFileMenu = ({
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(
|
||||
ephemeralAgentByConvoId(conversationId),
|
||||
);
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { handleFileChange } = useFileHandling({
|
||||
overrideEndpoint: EModelEndpoint.agents,
|
||||
@@ -66,10 +64,7 @@ const AttachFileMenu = ({
|
||||
* */
|
||||
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
|
||||
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(
|
||||
agentId,
|
||||
ephemeralAgent,
|
||||
);
|
||||
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(agentId);
|
||||
|
||||
const handleUploadClick = (isImage?: boolean) => {
|
||||
if (!inputRef.current) {
|
||||
@@ -94,11 +89,11 @@ const AttachFileMenu = ({
|
||||
},
|
||||
];
|
||||
|
||||
if (capabilities.contextEnabled) {
|
||||
if (capabilities.ocrEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_ocr_text'),
|
||||
onClick: () => {
|
||||
setToolResource(EToolResources.context);
|
||||
setToolResource(EToolResources.ocr);
|
||||
onAction();
|
||||
},
|
||||
icon: <FileType2Icon className="icon-md" />,
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { OGDialog, OGDialogTemplate } from '@librechat/client';
|
||||
import { EToolResources, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import { ImageUpIcon, FileSearch, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
|
||||
import { EToolResources, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import {
|
||||
useAgentToolPermissions,
|
||||
useAgentCapabilities,
|
||||
useGetAgentsConfig,
|
||||
useLocalize,
|
||||
} from '~/hooks';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { useDragDropContext } from '~/Providers';
|
||||
import { useChatContext } from '~/Providers';
|
||||
|
||||
interface DragDropModalProps {
|
||||
onOptionSelect: (option: EToolResources | undefined) => void;
|
||||
@@ -34,11 +32,9 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
||||
* Use definition for agents endpoint for ephemeral agents
|
||||
* */
|
||||
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
const { conversationId, agentId } = useDragDropContext();
|
||||
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(conversationId ?? ''));
|
||||
const { conversation } = useChatContext();
|
||||
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(
|
||||
agentId,
|
||||
ephemeralAgent,
|
||||
conversation?.agent_id,
|
||||
);
|
||||
|
||||
const options = useMemo(() => {
|
||||
@@ -64,10 +60,10 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
||||
icon: <TerminalSquareIcon className="icon-md" />,
|
||||
});
|
||||
}
|
||||
if (capabilities.contextEnabled) {
|
||||
if (capabilities.ocrEnabled) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_ocr_text'),
|
||||
value: EToolResources.context,
|
||||
value: EToolResources.ocr,
|
||||
icon: <FileType2Icon className="icon-md" />,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useDragHelpers } from '~/hooks';
|
||||
import DragDropOverlay from '~/components/Chat/Input/Files/DragDropOverlay';
|
||||
import DragDropModal from '~/components/Chat/Input/Files/DragDropModal';
|
||||
import { DragDropProvider } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface DragDropWrapperProps {
|
||||
@@ -20,14 +19,12 @@ export default function DragDropWrapper({ children, className }: DragDropWrapper
|
||||
{children}
|
||||
{/** Always render overlay to avoid mount/unmount overhead */}
|
||||
<DragDropOverlay isActive={isActive} />
|
||||
<DragDropProvider>
|
||||
<DragDropModal
|
||||
files={draggedFiles}
|
||||
isVisible={showModal}
|
||||
setShowModal={setShowModal}
|
||||
onOptionSelect={handleOptionSelect}
|
||||
/>
|
||||
</DragDropProvider>
|
||||
<DragDropModal
|
||||
files={draggedFiles}
|
||||
isVisible={showModal}
|
||||
setShowModal={setShowModal}
|
||||
onOptionSelect={handleOptionSelect}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,13 +2,14 @@ import { useState, useRef, useEffect, useMemo, memo, useCallback } from 'react';
|
||||
import { AutoSizer, List } from 'react-virtualized';
|
||||
import { Spinner, useCombobox } from '@librechat/client';
|
||||
import { useSetRecoilState, useRecoilValue } from 'recoil';
|
||||
import { PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import type { TPromptGroup } from 'librechat-data-provider';
|
||||
import type { PromptOption } from '~/common';
|
||||
import { removeCharIfLast, detectVariables } from '~/utils';
|
||||
import VariableDialog from '~/components/Prompts/Groups/VariableDialog';
|
||||
import { usePromptGroupsContext } from '~/Providers';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import MentionItem from './MentionItem';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
|
||||
const commandChar = '/';
|
||||
@@ -53,7 +54,12 @@ function PromptsCommand({
|
||||
submitPrompt: (textPrompt: string) => void;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { allPromptGroups, hasAccess } = usePromptGroupsContext();
|
||||
const hasAccess = useHasAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
|
||||
const { allPromptGroups } = usePromptGroupsContext();
|
||||
const { data, isLoading } = allPromptGroups;
|
||||
|
||||
const [activeIndex, setActiveIndex] = useState(0);
|
||||
|
||||
@@ -26,7 +26,6 @@ type ContentPartsProps = {
|
||||
isCreatedByUser: boolean;
|
||||
isLast: boolean;
|
||||
isSubmitting: boolean;
|
||||
isLatestMessage?: boolean;
|
||||
edit?: boolean;
|
||||
enterEdit?: (cancel?: boolean) => void | null | undefined;
|
||||
siblingIdx?: number;
|
||||
@@ -46,7 +45,6 @@ const ContentParts = memo(
|
||||
isCreatedByUser,
|
||||
isLast,
|
||||
isSubmitting,
|
||||
isLatestMessage,
|
||||
edit,
|
||||
enterEdit,
|
||||
siblingIdx,
|
||||
@@ -57,8 +55,6 @@ const ContentParts = memo(
|
||||
const [isExpanded, setIsExpanded] = useState(showThinking);
|
||||
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
|
||||
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const hasReasoningParts = useMemo(() => {
|
||||
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
|
||||
const allThinkPartsHaveContent =
|
||||
@@ -138,9 +134,7 @@ const ContentParts = memo(
|
||||
})
|
||||
}
|
||||
label={
|
||||
effectiveIsSubmitting && isLast
|
||||
? localize('com_ui_thinking')
|
||||
: localize('com_ui_thoughts')
|
||||
isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts')
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
@@ -161,14 +155,12 @@ const ContentParts = memo(
|
||||
conversationId,
|
||||
partIndex: idx,
|
||||
nextType: content[idx + 1]?.type,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isLatestMessage,
|
||||
}}
|
||||
>
|
||||
<Part
|
||||
part={part}
|
||||
attachments={attachments}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
isSubmitting={isSubmitting}
|
||||
key={`part-${messageId}-${idx}`}
|
||||
isCreatedByUser={isCreatedByUser}
|
||||
isLast={idx === content.length - 1}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { TextareaAutosize, TooltipAnchor } from '@librechat/client';
|
||||
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TEditProps } from '~/common';
|
||||
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import Container from './Container';
|
||||
@@ -22,8 +22,7 @@ const EditMessage = ({
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const saveButtonRef = useRef<HTMLButtonElement | null>(null);
|
||||
const submitButtonRef = useRef<HTMLButtonElement | null>(null);
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { getMessages, setMessages } = useMessagesOperations();
|
||||
const { getMessages, setMessages, conversation } = useChatContext();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageContentProps, TDisplayProps } from '~/common';
|
||||
import Error from '~/components/Messages/Content/Error';
|
||||
import Thinking from '~/components/Artifacts/Thinking';
|
||||
import { useMessageContext } from '~/Providers';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import MarkdownLite from './MarkdownLite';
|
||||
import EditMessage from './EditMessage';
|
||||
import { useLocalize } from '~/hooks';
|
||||
@@ -70,12 +70,16 @@ export const ErrorMessage = ({
|
||||
};
|
||||
|
||||
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
|
||||
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
|
||||
const { isSubmitting, latestMessage } = useChatContext();
|
||||
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
|
||||
const showCursorState = useMemo(
|
||||
() => showCursor === true && isSubmitting,
|
||||
[showCursor, isSubmitting],
|
||||
);
|
||||
const isLatestMessage = useMemo(
|
||||
() => message.messageId === latestMessage?.messageId,
|
||||
[message.messageId, latestMessage?.messageId],
|
||||
);
|
||||
|
||||
let content: React.ReactElement;
|
||||
if (!isCreatedByUser) {
|
||||
|
||||
@@ -85,14 +85,13 @@ const Part = memo(
|
||||
|
||||
const isToolCall =
|
||||
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
|
||||
if (isToolCall && toolCall.name === Tools.execute_code) {
|
||||
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
|
||||
return (
|
||||
<ExecuteCode
|
||||
attachments={attachments}
|
||||
isSubmitting={isSubmitting}
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
output={toolCall.output ?? ''}
|
||||
initialProgress={toolCall.progress ?? 0.1}
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
attachments={attachments}
|
||||
/>
|
||||
);
|
||||
} else if (
|
||||
|
||||
@@ -70,7 +70,7 @@ const ImageAttachment = memo(({ attachment }: { attachment: TAttachment }) => {
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
altText={attachment.filename || 'attachment image'}
|
||||
altText={attachment.filename}
|
||||
imagePath={filepath ?? ''}
|
||||
height={height ?? 0}
|
||||
width={width ?? 0}
|
||||
@@ -89,9 +89,8 @@ export default function Attachment({ attachment }: { attachment?: TAttachment })
|
||||
}
|
||||
|
||||
const { width, height, filepath = null } = attachment as TFile & TAttachmentMetadata;
|
||||
const isImage = attachment.filename
|
||||
? imageExtRegex.test(attachment.filename) && width != null && height != null && filepath != null
|
||||
: false;
|
||||
const isImage =
|
||||
imageExtRegex.test(attachment.filename) && width != null && height != null && filepath != null;
|
||||
|
||||
if (isImage) {
|
||||
return <ImageAttachment attachment={attachment} />;
|
||||
@@ -111,12 +110,11 @@ export function AttachmentGroup({ attachments }: { attachments?: TAttachment[] }
|
||||
|
||||
attachments.forEach((attachment) => {
|
||||
const { width, height, filepath = null } = attachment as TFile & TAttachmentMetadata;
|
||||
const isImage = attachment.filename
|
||||
? imageExtRegex.test(attachment.filename) &&
|
||||
width != null &&
|
||||
height != null &&
|
||||
filepath != null
|
||||
: false;
|
||||
const isImage =
|
||||
imageExtRegex.test(attachment.filename) &&
|
||||
width != null &&
|
||||
height != null &&
|
||||
filepath != null;
|
||||
|
||||
if (isImage) {
|
||||
imageAttachments.push(attachment);
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
|
||||
import type { Agents } from 'librechat-data-provider';
|
||||
import type { TEditProps } from '~/common';
|
||||
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
|
||||
import Container from '~/components/Chat/Messages/Content/Container';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
@@ -25,8 +25,7 @@ const EditTextPart = ({
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { ask, getMessages, setMessages } = useMessagesOperations();
|
||||
const { ask, getMessages, setMessages, conversation } = useChatContext();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
|
||||
@@ -45,28 +45,26 @@ export function useParseArgs(args?: string): ParsedArgs | null {
|
||||
}
|
||||
|
||||
export default function ExecuteCode({
|
||||
isSubmitting,
|
||||
initialProgress = 0.1,
|
||||
args,
|
||||
output = '',
|
||||
attachments,
|
||||
}: {
|
||||
initialProgress: number;
|
||||
isSubmitting: boolean;
|
||||
args?: string;
|
||||
output?: string;
|
||||
attachments?: TAttachment[];
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const hasOutput = output.length > 0;
|
||||
const outputRef = useRef<string>(output);
|
||||
const codeContentRef = useRef<HTMLDivElement>(null);
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const showAnalysisCode = useRecoilValue(store.showCode);
|
||||
const [showCode, setShowCode] = useState(showAnalysisCode);
|
||||
const codeContentRef = useRef<HTMLDivElement>(null);
|
||||
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
|
||||
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const hasOutput = output.length > 0;
|
||||
const outputRef = useRef<string>(output);
|
||||
const prevShowCodeRef = useRef<boolean>(showCode);
|
||||
|
||||
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
|
||||
const progress = useProgress(initialProgress);
|
||||
|
||||
@@ -138,8 +136,6 @@ export default function ExecuteCode({
|
||||
};
|
||||
}, [showCode, isAnimating]);
|
||||
|
||||
const cancelled = !isSubmitting && progress < 1;
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="relative my-2.5 flex size-5 shrink-0 items-center gap-2.5">
|
||||
@@ -147,12 +143,9 @@ export default function ExecuteCode({
|
||||
progress={progress}
|
||||
onClick={() => setShowCode((prev) => !prev)}
|
||||
inProgressText={localize('com_ui_analyzing')}
|
||||
finishedText={
|
||||
cancelled ? localize('com_ui_cancelled') : localize('com_ui_analyzing_finished')
|
||||
}
|
||||
finishedText={localize('com_ui_analyzing_finished')}
|
||||
hasInput={!!code?.length}
|
||||
isExpanded={showCode}
|
||||
error={cancelled}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
|
||||
@@ -2,7 +2,7 @@ import { memo, useMemo, ReactElement } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import MarkdownLite from '~/components/Chat/Messages/Content/MarkdownLite';
|
||||
import Markdown from '~/components/Chat/Messages/Content/Markdown';
|
||||
import { useMessageContext } from '~/Providers';
|
||||
import { useChatContext, useMessageContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -18,9 +18,14 @@ type ContentType =
|
||||
| ReactElement;
|
||||
|
||||
const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => {
|
||||
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
|
||||
const { messageId } = useMessageContext();
|
||||
const { isSubmitting, latestMessage } = useChatContext();
|
||||
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
|
||||
const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]);
|
||||
const isLatestMessage = useMemo(
|
||||
() => messageId === latestMessage?.messageId,
|
||||
[messageId, latestMessage?.messageId],
|
||||
);
|
||||
|
||||
const content: ContentType = useMemo(() => {
|
||||
if (!isCreatedByUser) {
|
||||
|
||||
@@ -21,7 +21,7 @@ type THoverButtons = {
|
||||
latestMessage: TMessage | null;
|
||||
isLast: boolean;
|
||||
index: number;
|
||||
handleFeedback?: ({ feedback }: { feedback: TFeedback | undefined }) => void;
|
||||
handleFeedback: ({ feedback }: { feedback: TFeedback | undefined }) => void;
|
||||
};
|
||||
|
||||
type HoverButtonProps = {
|
||||
@@ -238,7 +238,7 @@ const HoverButtons = ({
|
||||
/>
|
||||
|
||||
{/* Feedback Buttons */}
|
||||
{!isCreatedByUser && handleFeedback != null && (
|
||||
{!isCreatedByUser && (
|
||||
<Feedback handleFeedback={handleFeedback} feedback={message.feedback} isLast={isLast} />
|
||||
)}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import React from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useMessageProcess } from '~/hooks';
|
||||
import type { TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import MessageRender from './ui/MessageRender';
|
||||
// eslint-disable-next-line import/no-cycle
|
||||
@@ -28,7 +29,7 @@ const MessageContainer = React.memo(
|
||||
},
|
||||
);
|
||||
|
||||
export default function Message(props: TMessageProps) {
|
||||
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
|
||||
const {
|
||||
showSibling,
|
||||
conversation,
|
||||
@@ -37,7 +38,7 @@ export default function Message(props: TMessageProps) {
|
||||
latestMultiMessage,
|
||||
isSubmittingFamily,
|
||||
} = useMessageProcess({ message: props.message });
|
||||
const { message, currentEditId, setCurrentEditId } = props;
|
||||
const { message, currentEditId, setCurrentEditId, costs } = props;
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
|
||||
if (!message || typeof message !== 'object') {
|
||||
@@ -62,6 +63,7 @@ export default function Message(props: TMessageProps) {
|
||||
message={message}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
isCard
|
||||
costs={costs}
|
||||
/>
|
||||
<MessageRender
|
||||
{...props}
|
||||
@@ -69,12 +71,13 @@ export default function Message(props: TMessageProps) {
|
||||
isCard
|
||||
message={siblingMessage ?? latestMultiMessage ?? undefined}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
costs={costs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<MessageRender {...props} />
|
||||
<MessageRender {...props} costs={costs} />
|
||||
</div>
|
||||
)}
|
||||
</MessageContainer>
|
||||
@@ -85,6 +88,7 @@ export default function Message(props: TMessageProps) {
|
||||
messagesTree={children ?? []}
|
||||
currentEditId={currentEditId}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
costs={costs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TMessageContentParts } from 'librechat-data-provider';
|
||||
import type { TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import { useMessageHelpers, useLocalize, useAttachments } from '~/hooks';
|
||||
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
|
||||
@@ -12,10 +12,17 @@ import SubRow from './SubRow';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
export default function Message(props: TMessageProps) {
|
||||
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
|
||||
const localize = useLocalize();
|
||||
const { message, siblingIdx, siblingCount, setSiblingIdx, currentEditId, setCurrentEditId } =
|
||||
props;
|
||||
const {
|
||||
message,
|
||||
siblingIdx,
|
||||
siblingCount,
|
||||
setSiblingIdx,
|
||||
currentEditId,
|
||||
setCurrentEditId,
|
||||
costs,
|
||||
} = props;
|
||||
const { attachments, searchResults } = useAttachments({
|
||||
messageId: message?.messageId,
|
||||
attachments: message?.attachments,
|
||||
@@ -125,7 +132,6 @@ export default function Message(props: TMessageProps) {
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isCreatedByUser={message.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
isLatestMessage={messageId === latestMessage?.messageId}
|
||||
content={message.content as Array<TMessageContentParts | undefined>}
|
||||
/>
|
||||
</div>
|
||||
@@ -165,6 +171,7 @@ export default function Message(props: TMessageProps) {
|
||||
messagesTree={children ?? []}
|
||||
currentEditId={currentEditId}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
costs={costs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import { useState } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { CSSTransition } from 'react-transition-group';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
|
||||
import { useScreenshot, useMessageScrolling, useLocalize } from '~/hooks';
|
||||
import ScrollToBottom from '~/components/Messages/ScrollToBottom';
|
||||
import { MessagesViewProvider } from '~/Providers';
|
||||
import MultiMessage from './MultiMessage';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
function MessagesViewContent({
|
||||
export default function MessagesView({
|
||||
messagesTree: _messagesTree,
|
||||
costBar,
|
||||
costs,
|
||||
}: {
|
||||
messagesTree?: TMessage[] | null;
|
||||
costBar?: React.ReactNode;
|
||||
costs?: TConversationCosts;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const fontSize = useRecoilValue(store.fontSize);
|
||||
@@ -45,7 +48,7 @@ function MessagesViewContent({
|
||||
width: '100%',
|
||||
}}
|
||||
>
|
||||
<div className="flex flex-col pb-9 dark:bg-transparent">
|
||||
<div className="flex flex-col dark:bg-transparent">
|
||||
{(_messagesTree && _messagesTree.length == 0) || _messagesTree === null ? (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -64,18 +67,25 @@ function MessagesViewContent({
|
||||
messageId={conversationId ?? null}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
currentEditId={currentEditId ?? null}
|
||||
costs={costs}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div
|
||||
id="messages-end"
|
||||
className="group h-0 w-full flex-shrink-0"
|
||||
className="group h-1 w-full flex-shrink-0 pb-7"
|
||||
ref={messagesEndRef}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{costBar && (
|
||||
<div className="pointer-events-none absolute bottom-2 left-1/2 z-10 -translate-x-1/2">
|
||||
{costBar}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<CSSTransition
|
||||
in={showScrollButton && scrollButtonPreference}
|
||||
timeout={{
|
||||
@@ -93,11 +103,3 @@ function MessagesViewContent({
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function MessagesView({ messagesTree }: { messagesTree?: TMessage[] | null }) {
|
||||
return (
|
||||
<MessagesViewProvider>
|
||||
<MessagesViewContent messagesTree={messagesTree} />
|
||||
</MessagesViewProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { useEffect, useCallback } from 'react';
|
||||
import { isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import MessageContent from '~/components/Messages/MessageContent';
|
||||
import MessageParts from './MessageParts';
|
||||
@@ -14,7 +14,8 @@ export default function MultiMessage({
|
||||
messagesTree,
|
||||
currentEditId,
|
||||
setCurrentEditId,
|
||||
}: TMessageProps) {
|
||||
costs,
|
||||
}: TMessageProps & { costs?: TConversationCosts }) {
|
||||
const [siblingIdx, setSiblingIdx] = useRecoilState(store.messagesSiblingIdxFamily(messageId));
|
||||
|
||||
const setSiblingIdxRev = useCallback(
|
||||
@@ -27,7 +28,7 @@ export default function MultiMessage({
|
||||
useEffect(() => {
|
||||
// reset siblingIdx when the tree changes, mostly when a new message is submitting.
|
||||
setSiblingIdx(0);
|
||||
}, [messagesTree?.length, setSiblingIdx]);
|
||||
}, [messagesTree?.length]);
|
||||
|
||||
useEffect(() => {
|
||||
if (messagesTree?.length && siblingIdx >= messagesTree.length) {
|
||||
@@ -55,6 +56,7 @@ export default function MultiMessage({
|
||||
siblingIdx={messagesTree.length - siblingIdx - 1}
|
||||
siblingCount={messagesTree.length}
|
||||
setSiblingIdx={setSiblingIdxRev}
|
||||
costs={costs}
|
||||
/>
|
||||
);
|
||||
} else if (message.content) {
|
||||
@@ -67,6 +69,7 @@ export default function MultiMessage({
|
||||
siblingIdx={messagesTree.length - siblingIdx - 1}
|
||||
siblingCount={messagesTree.length}
|
||||
setSiblingIdx={setSiblingIdxRev}
|
||||
costs={costs}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -80,6 +83,7 @@ export default function MultiMessage({
|
||||
siblingIdx={messagesTree.length - siblingIdx - 1}
|
||||
siblingCount={messagesTree.length}
|
||||
setSiblingIdx={setSiblingIdxRev}
|
||||
costs={costs}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import React, { useCallback, useMemo, memo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { type TMessage } from 'librechat-data-provider';
|
||||
import { ArrowIcon } from '@librechat/client';
|
||||
import { type TMessage, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import MessageContent from '~/components/Chat/Messages/Content/MessageContent';
|
||||
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
|
||||
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
|
||||
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
|
||||
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
|
||||
import { useMessageActions, useLocalize } from '~/hooks';
|
||||
import { Plugin } from '~/components/Messages/Content';
|
||||
import SubRow from '~/components/Chat/Messages/SubRow';
|
||||
import { MessageContext } from '~/Providers';
|
||||
import { useMessageActions } from '~/hooks';
|
||||
import { cn, logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -19,6 +20,7 @@ type MessageRenderProps = {
|
||||
isCard?: boolean;
|
||||
isMultiMessage?: boolean;
|
||||
isSubmittingFamily?: boolean;
|
||||
costs?: TConversationCosts;
|
||||
} & Pick<
|
||||
TMessageProps,
|
||||
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
|
||||
@@ -35,7 +37,9 @@ const MessageRender = memo(
|
||||
isMultiMessage = false,
|
||||
setCurrentEditId,
|
||||
isSubmittingFamily = false,
|
||||
costs,
|
||||
}: MessageRenderProps) => {
|
||||
const localize = useLocalize();
|
||||
const {
|
||||
ask,
|
||||
edit,
|
||||
@@ -60,6 +64,18 @@ const MessageRender = memo(
|
||||
});
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
const fontSize = useRecoilValue(store.fontSize);
|
||||
const showCostTracking = useRecoilValue(store.showCostTracking);
|
||||
|
||||
const perMessageCost = useMemo(() => {
|
||||
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
|
||||
return null;
|
||||
}
|
||||
const entry = costs.perMessage.find((p) => p.messageId === msg.messageId);
|
||||
if (!entry) {
|
||||
return null;
|
||||
}
|
||||
return entry;
|
||||
}, [showCostTracking, costs, msg?.messageId]);
|
||||
|
||||
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
|
||||
const hasNoChildren = !(msg?.children?.length ?? 0);
|
||||
@@ -71,9 +87,6 @@ const MessageRender = memo(
|
||||
const showCardRender = isLast && !isSubmittingFamily && isCard;
|
||||
const isLatestCard = isCard && !isSubmittingFamily && isLatestMessage;
|
||||
|
||||
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const iconData: TMessageIcon = useMemo(
|
||||
() => ({
|
||||
endpoint: msg?.endpoint ?? conversation?.endpoint,
|
||||
@@ -160,7 +173,26 @@ const MessageRender = memo(
|
||||
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
|
||||
)}
|
||||
>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>
|
||||
{messageLabel}
|
||||
{perMessageCost && (
|
||||
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
|
||||
{perMessageCost.tokenCount > 0 && (
|
||||
<span>
|
||||
{perMessageCost.tokenType === 'prompt' ? (
|
||||
<ArrowIcon direction="up" className="inline" />
|
||||
) : (
|
||||
<ArrowIcon direction="down" className="inline" />
|
||||
)}
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: perMessageCost.tokenCount,
|
||||
})}
|
||||
</span>
|
||||
)}
|
||||
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
|
||||
</span>
|
||||
)}
|
||||
</h2>
|
||||
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex max-w-full flex-grow flex-col gap-0">
|
||||
@@ -169,8 +201,6 @@ const MessageRender = memo(
|
||||
messageId: msg.messageId,
|
||||
conversationId: conversation?.conversationId,
|
||||
isExpanded: false,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isLatestMessage,
|
||||
}}
|
||||
>
|
||||
{msg.plugin && <Plugin plugin={msg.plugin} />}
|
||||
@@ -182,7 +212,7 @@ const MessageRender = memo(
|
||||
message={msg}
|
||||
enterEdit={enterEdit}
|
||||
error={!!(msg.error ?? false)}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
isSubmitting={isSubmitting}
|
||||
unfinished={msg.unfinished ?? false}
|
||||
isCreatedByUser={msg.isCreatedByUser ?? true}
|
||||
siblingIdx={siblingIdx ?? 0}
|
||||
@@ -191,7 +221,7 @@ const MessageRender = memo(
|
||||
</MessageContext.Provider>
|
||||
</div>
|
||||
|
||||
{hasNoChildren && (isSubmittingFamily === true || effectiveIsSubmitting) ? (
|
||||
{hasNoChildren && (isSubmittingFamily === true || isSubmitting) ? (
|
||||
<PlaceholderRow isCard={isCard} />
|
||||
) : (
|
||||
<SubRow classes="text-xs">
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useMemo, memo, type FC, useCallback } from 'react';
|
||||
import throttle from 'lodash/throttle';
|
||||
import { parseISO, isToday } from 'date-fns';
|
||||
import { Spinner, useMediaQuery } from '@librechat/client';
|
||||
import { List, AutoSizer, CellMeasurer, CellMeasurerCache } from 'react-virtualized';
|
||||
import { TConversation } from 'librechat-data-provider';
|
||||
@@ -49,17 +50,27 @@ const MemoizedConvo = memo(
|
||||
conversation,
|
||||
retainView,
|
||||
toggleNav,
|
||||
isLatestConvo,
|
||||
}: {
|
||||
conversation: TConversation;
|
||||
retainView: () => void;
|
||||
toggleNav: () => void;
|
||||
isLatestConvo: boolean;
|
||||
}) => {
|
||||
return <Convo conversation={conversation} retainView={retainView} toggleNav={toggleNav} />;
|
||||
return (
|
||||
<Convo
|
||||
conversation={conversation}
|
||||
retainView={retainView}
|
||||
toggleNav={toggleNav}
|
||||
isLatestConvo={isLatestConvo}
|
||||
/>
|
||||
);
|
||||
},
|
||||
(prevProps, nextProps) => {
|
||||
return (
|
||||
prevProps.conversation.conversationId === nextProps.conversation.conversationId &&
|
||||
prevProps.conversation.title === nextProps.conversation.title &&
|
||||
prevProps.isLatestConvo === nextProps.isLatestConvo &&
|
||||
prevProps.conversation.endpoint === nextProps.conversation.endpoint
|
||||
);
|
||||
},
|
||||
@@ -87,6 +98,13 @@ const Conversations: FC<ConversationsProps> = ({
|
||||
[filteredConversations],
|
||||
);
|
||||
|
||||
const firstTodayConvoId = useMemo(
|
||||
() =>
|
||||
filteredConversations.find((convo) => convo.updatedAt && isToday(parseISO(convo.updatedAt)))
|
||||
?.conversationId ?? undefined,
|
||||
[filteredConversations],
|
||||
);
|
||||
|
||||
const flattenedItems = useMemo(() => {
|
||||
const items: FlattenedItem[] = [];
|
||||
groupedConversations.forEach(([groupName, convos]) => {
|
||||
@@ -136,25 +154,26 @@ const Conversations: FC<ConversationsProps> = ({
|
||||
</CellMeasurer>
|
||||
);
|
||||
}
|
||||
let rendering: JSX.Element;
|
||||
if (item.type === 'header') {
|
||||
rendering = <DateLabel groupName={item.groupName} />;
|
||||
} else if (item.type === 'convo') {
|
||||
rendering = (
|
||||
<MemoizedConvo conversation={item.convo} retainView={moveToTop} toggleNav={toggleNav} />
|
||||
);
|
||||
}
|
||||
return (
|
||||
<CellMeasurer cache={cache} columnIndex={0} key={key} parent={parent} rowIndex={index}>
|
||||
{({ registerChild }) => (
|
||||
<div ref={registerChild} style={style}>
|
||||
{rendering}
|
||||
{item.type === 'header' ? (
|
||||
<DateLabel groupName={item.groupName} />
|
||||
) : item.type === 'convo' ? (
|
||||
<MemoizedConvo
|
||||
conversation={item.convo}
|
||||
retainView={moveToTop}
|
||||
toggleNav={toggleNav}
|
||||
isLatestConvo={item.convo.conversationId === firstTodayConvoId}
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
)}
|
||||
</CellMeasurer>
|
||||
);
|
||||
},
|
||||
[cache, flattenedItems, moveToTop, toggleNav],
|
||||
[cache, flattenedItems, firstTodayConvoId, moveToTop, toggleNav],
|
||||
);
|
||||
|
||||
const getRowHeight = useCallback(
|
||||
|
||||
@@ -11,17 +11,23 @@ import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { NotificationSeverity } from '~/common';
|
||||
import { ConvoOptions } from './ConvoOptions';
|
||||
import RenameForm from './RenameForm';
|
||||
import { cn, logger } from '~/utils';
|
||||
import ConvoLink from './ConvoLink';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
interface ConversationProps {
|
||||
conversation: TConversation;
|
||||
retainView: () => void;
|
||||
toggleNav: () => void;
|
||||
isLatestConvo: boolean;
|
||||
}
|
||||
|
||||
export default function Conversation({ conversation, retainView, toggleNav }: ConversationProps) {
|
||||
export default function Conversation({
|
||||
conversation,
|
||||
retainView,
|
||||
toggleNav,
|
||||
isLatestConvo,
|
||||
}: ConversationProps) {
|
||||
const params = useParams();
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
@@ -78,7 +84,6 @@ export default function Conversation({ conversation, retainView, toggleNav }: Co
|
||||
});
|
||||
setRenaming(false);
|
||||
} catch (error) {
|
||||
logger.error('Error renaming conversation', error);
|
||||
setTitleInput(title as string);
|
||||
showToast({
|
||||
message: localize('com_ui_rename_failed'),
|
||||
|
||||
@@ -66,7 +66,7 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
|
||||
placeholder={
|
||||
hasValue
|
||||
? localize('com_ui_mcp_update_var', { 0: config.title })
|
||||
: `${localize('com_ui_mcp_enter_var', { 0: config.title })} ${localize('com_ui_optional')}`
|
||||
: localize('com_ui_mcp_enter_var', { 0: config.title })
|
||||
}
|
||||
className="w-full rounded border border-border-medium bg-transparent px-2 py-1 text-text-primary placeholder:text-text-secondary focus:outline-none sm:text-sm"
|
||||
/>
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { ArrowIcon } from '@librechat/client';
|
||||
import { useCallback, useMemo, memo } from 'react';
|
||||
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
|
||||
import type { TMessage, TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import ContentParts from '~/components/Chat/Messages/Content/ContentParts';
|
||||
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
|
||||
import { useAttachments, useMessageActions, useLocalize } from '~/hooks';
|
||||
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
|
||||
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
|
||||
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
|
||||
import { useAttachments, useMessageActions } from '~/hooks';
|
||||
import SubRow from '~/components/Chat/Messages/SubRow';
|
||||
import { cn, logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
@@ -17,6 +18,7 @@ type ContentRenderProps = {
|
||||
isCard?: boolean;
|
||||
isMultiMessage?: boolean;
|
||||
isSubmittingFamily?: boolean;
|
||||
costs?: TConversationCosts;
|
||||
} & Pick<
|
||||
TMessageProps,
|
||||
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
|
||||
@@ -33,7 +35,9 @@ const ContentRender = memo(
|
||||
isMultiMessage = false,
|
||||
setCurrentEditId,
|
||||
isSubmittingFamily = false,
|
||||
costs,
|
||||
}: ContentRenderProps) => {
|
||||
const localize = useLocalize();
|
||||
const { attachments, searchResults } = useAttachments({
|
||||
messageId: msg?.messageId,
|
||||
attachments: msg?.attachments,
|
||||
@@ -62,6 +66,14 @@ const ContentRender = memo(
|
||||
});
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
const fontSize = useRecoilValue(store.fontSize);
|
||||
const showCostTracking = useRecoilValue(store.showCostTracking);
|
||||
|
||||
const perMessageCost = useMemo(() => {
|
||||
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
|
||||
return null;
|
||||
}
|
||||
return costs.perMessage.find((p) => p.messageId === msg.messageId) ?? null;
|
||||
}, [showCostTracking, costs, msg?.messageId]);
|
||||
|
||||
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
|
||||
const isLast = useMemo(
|
||||
@@ -159,7 +171,26 @@ const ContentRender = memo(
|
||||
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
|
||||
)}
|
||||
>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>
|
||||
{messageLabel}
|
||||
{perMessageCost && (
|
||||
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
|
||||
{perMessageCost.tokenCount > 0 && (
|
||||
<span className="mr-2">
|
||||
{perMessageCost.tokenType === 'prompt' ? (
|
||||
<ArrowIcon direction="up" className="inline" />
|
||||
) : (
|
||||
<ArrowIcon direction="down" className="inline" />
|
||||
)}
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: perMessageCost.tokenCount,
|
||||
})}
|
||||
</span>
|
||||
)}
|
||||
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
|
||||
</span>
|
||||
)}
|
||||
</h2>
|
||||
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex max-w-full flex-grow flex-col gap-0">
|
||||
@@ -173,7 +204,6 @@ const ContentRender = memo(
|
||||
isSubmitting={isSubmitting}
|
||||
searchResults={searchResults}
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isLatestMessage={isLatestMessage}
|
||||
isCreatedByUser={msg.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
content={msg.content as Array<TMessageContentParts | undefined>}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import React from 'react';
|
||||
import { useMessageProcess } from '~/hooks';
|
||||
import type { TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
// eslint-disable-next-line import/no-cycle
|
||||
import MultiMessage from '~/components/Chat/Messages/MultiMessage';
|
||||
@@ -25,7 +26,7 @@ const MessageContainer = React.memo(
|
||||
},
|
||||
);
|
||||
|
||||
export default function MessageContent(props: TMessageProps) {
|
||||
export default function MessageContent(props: TMessageProps & { costs?: TConversationCosts }) {
|
||||
const {
|
||||
showSibling,
|
||||
conversation,
|
||||
@@ -34,7 +35,7 @@ export default function MessageContent(props: TMessageProps) {
|
||||
latestMultiMessage,
|
||||
isSubmittingFamily,
|
||||
} = useMessageProcess({ message: props.message });
|
||||
const { message, currentEditId, setCurrentEditId } = props;
|
||||
const { message, currentEditId, setCurrentEditId, costs } = props;
|
||||
|
||||
if (!message || typeof message !== 'object') {
|
||||
return null;
|
||||
@@ -53,6 +54,7 @@ export default function MessageContent(props: TMessageProps) {
|
||||
message={message}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
isCard
|
||||
costs={costs}
|
||||
/>
|
||||
<ContentRender
|
||||
{...props}
|
||||
@@ -60,12 +62,13 @@ export default function MessageContent(props: TMessageProps) {
|
||||
isCard
|
||||
message={siblingMessage ?? latestMultiMessage ?? undefined}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
costs={costs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6 ">
|
||||
<ContentRender {...props} />
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<ContentRender {...props} costs={costs} />
|
||||
</div>
|
||||
)}
|
||||
</MessageContainer>
|
||||
@@ -76,6 +79,7 @@ export default function MessageContent(props: TMessageProps) {
|
||||
messagesTree={children ?? []}
|
||||
currentEditId={currentEditId}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
costs={costs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -24,45 +24,35 @@ const SearchBar = forwardRef((props: SearchBarProps, ref: React.Ref<HTMLDivEleme
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [showClearIcon, setShowClearIcon] = useState(false);
|
||||
|
||||
const { newConversation: newConvo } = useNewConvo();
|
||||
const { newConversation } = useNewConvo();
|
||||
const [search, setSearchState] = useRecoilState(store.search);
|
||||
|
||||
const clearSearch = useCallback(
|
||||
(pathname?: string) => {
|
||||
if (pathname?.includes('/search') || pathname === '/c/new') {
|
||||
queryClient.removeQueries([QueryKeys.messages]);
|
||||
newConvo({ disableFocus: true });
|
||||
navigate('/c/new');
|
||||
}
|
||||
},
|
||||
[newConvo, navigate, queryClient],
|
||||
);
|
||||
const clearSearch = useCallback(() => {
|
||||
if (location.pathname.includes('/search')) {
|
||||
newConversation({ disableFocus: true });
|
||||
navigate('/c/new', { replace: true });
|
||||
}
|
||||
}, [newConversation, location.pathname, navigate]);
|
||||
|
||||
const clearText = useCallback(
|
||||
(pathname?: string) => {
|
||||
setShowClearIcon(false);
|
||||
setText('');
|
||||
setSearchState((prev) => ({
|
||||
...prev,
|
||||
query: '',
|
||||
debouncedQuery: '',
|
||||
isTyping: false,
|
||||
}));
|
||||
clearSearch(pathname);
|
||||
inputRef.current?.focus();
|
||||
},
|
||||
[setSearchState, clearSearch],
|
||||
);
|
||||
const clearText = useCallback(() => {
|
||||
setShowClearIcon(false);
|
||||
setText('');
|
||||
setSearchState((prev) => ({
|
||||
...prev,
|
||||
query: '',
|
||||
debouncedQuery: '',
|
||||
isTyping: false,
|
||||
}));
|
||||
clearSearch();
|
||||
inputRef.current?.focus();
|
||||
}, [setSearchState, clearSearch]);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
const { value } = e.target as HTMLInputElement;
|
||||
if (e.key === 'Backspace' && value === '') {
|
||||
clearText(location.pathname);
|
||||
}
|
||||
},
|
||||
[clearText, location.pathname],
|
||||
);
|
||||
const handleKeyUp = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
const { value } = e.target as HTMLInputElement;
|
||||
if (e.key === 'Backspace' && value === '') {
|
||||
clearText();
|
||||
}
|
||||
};
|
||||
|
||||
const sendRequest = useCallback(
|
||||
(value: string) => {
|
||||
@@ -95,6 +85,8 @@ const SearchBar = forwardRef((props: SearchBarProps, ref: React.Ref<HTMLDivEleme
|
||||
debouncedSetDebouncedQuery(value);
|
||||
if (value.length > 0 && location.pathname !== '/search') {
|
||||
navigate('/search', { replace: true });
|
||||
} else if (value.length === 0 && location.pathname === '/search') {
|
||||
navigate('/c/new', { replace: true });
|
||||
}
|
||||
};
|
||||
|
||||
@@ -140,7 +132,7 @@ const SearchBar = forwardRef((props: SearchBarProps, ref: React.Ref<HTMLDivEleme
|
||||
showClearIcon ? 'opacity-100' : 'opacity-0',
|
||||
isSmallScreen === true ? 'right-[16px]' : '',
|
||||
)}
|
||||
onClick={() => clearText(location.pathname)}
|
||||
onClick={clearText}
|
||||
tabIndex={showClearIcon ? 0 : -1}
|
||||
disabled={!showClearIcon}
|
||||
>
|
||||
|
||||
@@ -76,6 +76,13 @@ const toggleSwitchConfigs = [
|
||||
hoverCardText: undefined,
|
||||
key: 'modularChat',
|
||||
},
|
||||
{
|
||||
stateAtom: store.showCostTracking,
|
||||
localizationKey: 'com_nav_show_cost_tracking',
|
||||
switchId: 'showCostTracking',
|
||||
hoverCardText: 'com_nav_info_show_cost_tracking',
|
||||
key: 'showCostTracking',
|
||||
},
|
||||
];
|
||||
|
||||
function Chat() {
|
||||
|
||||
@@ -6,7 +6,6 @@ import { LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { useFormContext, Controller } from 'react-hook-form';
|
||||
import type { MenuItemProps } from '@librechat/client';
|
||||
import type { ReactNode } from 'react';
|
||||
import { usePromptGroupsContext } from '~/Providers';
|
||||
import { useCategories } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
@@ -23,9 +22,8 @@ const CategorySelector: React.FC<CategorySelectorProps> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const formContext = useFormContext();
|
||||
const { categories, emptyCategory } = useCategories();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { hasAccess } = usePromptGroupsContext();
|
||||
const { categories, emptyCategory } = useCategories({ hasAccess });
|
||||
|
||||
const control = formContext?.control;
|
||||
const watch = formContext?.watch;
|
||||
|
||||
@@ -7,7 +7,6 @@ import CategorySelector from '~/components/Prompts/Groups/CategorySelector';
|
||||
import VariablesDropdown from '~/components/Prompts/VariablesDropdown';
|
||||
import PromptVariables from '~/components/Prompts/PromptVariables';
|
||||
import Description from '~/components/Prompts/Description';
|
||||
import { usePromptGroupsContext } from '~/Providers';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import Command from '~/components/Prompts/Command';
|
||||
import { useCreatePrompt } from '~/data-provider';
|
||||
@@ -38,12 +37,10 @@ const CreatePromptForm = ({
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const navigate = useNavigate();
|
||||
const { hasAccess: hasUseAccess } = usePromptGroupsContext();
|
||||
const hasCreateAccess = useHasAccess({
|
||||
const hasAccess = useHasAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permission: Permissions.CREATE,
|
||||
});
|
||||
const hasAccess = hasUseAccess && hasCreateAccess;
|
||||
|
||||
useEffect(() => {
|
||||
let timeoutId: ReturnType<typeof setTimeout>;
|
||||
|
||||
@@ -11,8 +11,8 @@ import store from '~/store';
|
||||
|
||||
export default function FilterPrompts({ className = '' }: { className?: string }) {
|
||||
const localize = useLocalize();
|
||||
const { name, setName, hasAccess } = usePromptGroupsContext();
|
||||
const { categories } = useCategories({ className: 'h-4 w-4', hasAccess });
|
||||
const { name, setName } = usePromptGroupsContext();
|
||||
const { categories } = useCategories('h-4 w-4');
|
||||
const [displayName, setDisplayName] = useState(name || '');
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const [categoryFilter, setCategory] = useRecoilState(store.promptsCategory);
|
||||
|
||||
@@ -167,7 +167,6 @@ const PromptForm = () => {
|
||||
const params = useParams();
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const { hasAccess } = usePromptGroupsContext();
|
||||
const alwaysMakeProd = useRecoilValue(store.alwaysMakeProd);
|
||||
const promptId = params.promptId || '';
|
||||
|
||||
@@ -180,12 +179,10 @@ const PromptForm = () => {
|
||||
const [showSidePanel, setShowSidePanel] = useState(false);
|
||||
const sidePanelWidth = '320px';
|
||||
|
||||
const { data: group, isLoading: isLoadingGroup } = useGetPromptGroup(promptId, {
|
||||
enabled: hasAccess && !!promptId,
|
||||
});
|
||||
const { data: group, isLoading: isLoadingGroup } = useGetPromptGroup(promptId);
|
||||
const { data: prompts = [], isLoading: isLoadingPrompts } = useGetPrompts(
|
||||
{ groupId: promptId },
|
||||
{ enabled: hasAccess && !!promptId },
|
||||
{ enabled: !!promptId },
|
||||
);
|
||||
|
||||
const { hasPermission, isLoading: permissionsLoading } = useResourcePermissions(
|
||||
|
||||
@@ -76,8 +76,6 @@ export default function Message(props: TMessageProps) {
|
||||
messageId,
|
||||
isExpanded: false,
|
||||
conversationId: conversation?.conversationId,
|
||||
isSubmitting: false, // Share view is always read-only
|
||||
isLatestMessage: false, // No concept of latest message in share view
|
||||
}}
|
||||
>
|
||||
{/* Legacy Plugins */}
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import { useEffect } from 'react';
|
||||
import { ChevronLeft } from 'lucide-react';
|
||||
import { useForm, FormProvider } from 'react-hook-form';
|
||||
|
||||
import {
|
||||
AuthTypeEnum,
|
||||
AuthorizationTypeEnum,
|
||||
TokenExchangeMethodEnum,
|
||||
} from 'librechat-data-provider';
|
||||
import {
|
||||
Label,
|
||||
OGDialog,
|
||||
TrashIcon,
|
||||
OGDialogTrigger,
|
||||
useToastContext,
|
||||
OGDialogTemplate,
|
||||
TrashIcon,
|
||||
OGDialog,
|
||||
OGDialogTrigger,
|
||||
Label,
|
||||
useToastContext,
|
||||
} from '@librechat/client';
|
||||
import type { ActionAuthForm } from '~/common';
|
||||
import ActionsAuth from '~/components/SidePanel/Builder/ActionsAuth';
|
||||
import { useAgentPanelContext } from '~/Providers/AgentPanelContext';
|
||||
import { useDeleteAgentAction } from '~/data-provider';
|
||||
import { Panel, isEphemeralAgent } from '~/common';
|
||||
import type { ActionAuthForm } from '~/common';
|
||||
import ActionsInput from './ActionsInput';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
export default function ActionsPanel() {
|
||||
const localize = useLocalize();
|
||||
@@ -108,7 +109,7 @@ export default function ActionsPanel() {
|
||||
<div className="absolute right-0 top-6">
|
||||
<button
|
||||
type="button"
|
||||
disabled={isEphemeralAgent(agent_id) || !action.action_id}
|
||||
disabled={!agent_id || !action.action_id}
|
||||
className="btn btn-neutral border-token-border-light relative h-9 rounded-lg font-medium"
|
||||
>
|
||||
<TrashIcon className="text-red-500" />
|
||||
@@ -126,7 +127,7 @@ export default function ActionsPanel() {
|
||||
}
|
||||
selection={{
|
||||
selectHandler: () => {
|
||||
if (isEphemeralAgent(agent_id)) {
|
||||
if (!agent_id) {
|
||||
return showToast({
|
||||
message: localize('com_agents_no_agent_id_error'),
|
||||
status: 'error',
|
||||
@@ -134,7 +135,7 @@ export default function ActionsPanel() {
|
||||
}
|
||||
deleteAgentAction.mutate({
|
||||
action_id: action.action_id,
|
||||
agent_id: agent_id || '',
|
||||
agent_id,
|
||||
});
|
||||
},
|
||||
selectClasses:
|
||||
|
||||
@@ -18,7 +18,6 @@ import { useFileMapContext, useAgentPanelContext } from '~/Providers';
|
||||
import AgentCategorySelector from './AgentCategorySelector';
|
||||
import Action from '~/components/SidePanel/Builder/Action';
|
||||
import { useLocalize, useVisibleTools } from '~/hooks';
|
||||
import { Panel, isEphemeralAgent } from '~/common';
|
||||
import { useGetAgentFiles } from '~/data-provider';
|
||||
import { icons } from '~/hooks/Endpoint/Icons';
|
||||
import Instructions from './Instructions';
|
||||
@@ -30,6 +29,7 @@ import Artifacts from './Artifacts';
|
||||
import AgentTool from './AgentTool';
|
||||
import CodeForm from './Code/Form';
|
||||
import MCPTools from './MCPTools';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
const labelClass = 'mb-2 text-token-text-primary block font-medium';
|
||||
const inputClass = cn(
|
||||
@@ -48,12 +48,12 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
const {
|
||||
actions,
|
||||
setAction,
|
||||
regularTools,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
mcpServersMap,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
groupedTools: allTools,
|
||||
} = useAgentPanelContext();
|
||||
|
||||
const {
|
||||
@@ -79,9 +79,9 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
}, [fileMap, agentFiles]);
|
||||
|
||||
const {
|
||||
ocrEnabled,
|
||||
codeEnabled,
|
||||
toolsEnabled,
|
||||
contextEnabled,
|
||||
actionsEnabled,
|
||||
artifactsEnabled,
|
||||
webSearchEnabled,
|
||||
@@ -149,7 +149,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
}, [agent, agent_id, mergedFileMap]);
|
||||
|
||||
const handleAddActions = useCallback(() => {
|
||||
if (isEphemeralAgent(agent_id)) {
|
||||
if (!agent_id) {
|
||||
showToast({
|
||||
message: localize('com_assistants_actions_disabled'),
|
||||
status: 'warning',
|
||||
@@ -177,7 +177,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
Icon = icons[iconKey];
|
||||
}
|
||||
|
||||
const { toolIds, mcpServerNames } = useVisibleTools(tools, regularTools, mcpServersMap);
|
||||
const { toolIds, mcpServerNames } = useVisibleTools(tools, allTools, mcpServersMap);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -291,7 +291,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
{(codeEnabled ||
|
||||
fileSearchEnabled ||
|
||||
artifactsEnabled ||
|
||||
contextEnabled ||
|
||||
ocrEnabled ||
|
||||
webSearchEnabled) && (
|
||||
<div className="mb-4 flex w-full flex-col items-start gap-3">
|
||||
<label className="text-token-text-primary block font-medium">
|
||||
@@ -301,8 +301,8 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
{codeEnabled && <CodeForm agent_id={agent_id} files={code_files} />}
|
||||
{/* Web Search */}
|
||||
{webSearchEnabled && <SearchForm />}
|
||||
{/* File Context */}
|
||||
{contextEnabled && <FileContext agent_id={agent_id} files={context_files} />}
|
||||
{/* File Context (OCR) */}
|
||||
{ocrEnabled && <FileContext agent_id={agent_id} files={context_files} />}
|
||||
{/* Artifacts */}
|
||||
{artifactsEnabled && <Artifacts />}
|
||||
{/* File Search */}
|
||||
@@ -326,15 +326,16 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
</label>
|
||||
<div>
|
||||
<div className="mb-1">
|
||||
{/* Render all visible IDs */}
|
||||
{/* Render all visible IDs (including groups with subtools selected) */}
|
||||
{toolIds.map((toolId, i) => {
|
||||
const tool = regularTools?.find((t) => t.pluginKey === toolId);
|
||||
if (!allTools) return null;
|
||||
const tool = allTools[toolId];
|
||||
if (!tool) return null;
|
||||
return (
|
||||
<AgentTool
|
||||
key={`${toolId}-${i}-${agent_id}`}
|
||||
tool={toolId}
|
||||
regularTools={regularTools}
|
||||
allTools={allTools}
|
||||
agent_id={agent_id}
|
||||
/>
|
||||
);
|
||||
@@ -370,7 +371,7 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
{(actionsEnabled ?? false) && (
|
||||
<button
|
||||
type="button"
|
||||
disabled={isEphemeralAgent(agent_id)}
|
||||
disabled={!agent_id}
|
||||
onClick={handleAddActions}
|
||||
className="btn btn-neutral border-token-border-light relative h-9 w-full rounded-lg font-medium"
|
||||
aria-haspopup="dialog"
|
||||
@@ -473,15 +474,13 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
setIsOpen={setShowToolDialog}
|
||||
endpoint={EModelEndpoint.agents}
|
||||
/>
|
||||
{startupConfig?.mcpServers != null && (
|
||||
<MCPToolSelectDialog
|
||||
agentId={agent_id}
|
||||
isOpen={showMCPToolDialog}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsOpen={setShowMCPToolDialog}
|
||||
endpoint={EModelEndpoint.agents}
|
||||
/>
|
||||
)}
|
||||
<MCPToolSelectDialog
|
||||
agentId={agent_id}
|
||||
isOpen={showMCPToolDialog}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsOpen={setShowMCPToolDialog}
|
||||
endpoint={EModelEndpoint.agents}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user