Compare commits
1 Commits
refactor/m
...
fix/stt-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1abf37af5a |
@@ -10,17 +10,7 @@ const {
|
||||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
|
||||
const {
|
||||
Tokenizer,
|
||||
createFetch,
|
||||
matchModelName,
|
||||
getClaudeHeaders,
|
||||
getModelMaxTokens,
|
||||
configureReasoning,
|
||||
checkPromptCacheSupport,
|
||||
getModelMaxOutputTokens,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
|
||||
const {
|
||||
truncateText,
|
||||
formatMessage,
|
||||
@@ -29,6 +19,12 @@ const {
|
||||
parseParamFromPrompt,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const {
|
||||
getClaudeHeaders,
|
||||
configureReasoning,
|
||||
checkPromptCacheSupport,
|
||||
} = require('~/server/services/Endpoints/anthropic/helpers');
|
||||
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { sleep } = require('~/server/utils');
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const { google } = require('googleapis');
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { Tokenizer, getSafetySettings } = require('@librechat/api');
|
||||
@@ -22,6 +21,7 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const {
|
||||
|
||||
@@ -7,9 +7,7 @@ const {
|
||||
createFetch,
|
||||
resolveHeaders,
|
||||
constructAzureURL,
|
||||
getModelMaxTokens,
|
||||
genAzureChatCompletion,
|
||||
getModelMaxOutputTokens,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
@@ -33,13 +31,13 @@ const {
|
||||
titleInstruction,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const { summaryBuffer } = require('./memory');
|
||||
const { runTitleChain } = require('./chains');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { tokenSplit } = require('./document');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { createLLM } = require('./llm');
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const BaseClient = require('../BaseClient');
|
||||
const { getModelMaxTokens } = require('../../../utils');
|
||||
|
||||
class FakeClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
|
||||
@@ -71,10 +71,9 @@ const primeFiles = async (options) => {
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Array<{ file_id: string; filename: string }>} options.files
|
||||
* @param {string} [options.entity_id]
|
||||
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
|
||||
* @returns
|
||||
*/
|
||||
const createFileSearchTool = async ({ req, files, entity_id, fileCitations = false }) => {
|
||||
const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
return tool(
|
||||
async ({ query }) => {
|
||||
if (files.length === 0) {
|
||||
@@ -143,9 +142,9 @@ const createFileSearchTool = async ({ req, files, entity_id, fileCitations = fal
|
||||
const formattedString = formattedResults
|
||||
.map(
|
||||
(result, index) =>
|
||||
`File: ${result.filename}${
|
||||
fileCitations ? `\nAnchor: \\ue202turn0file${index} (${result.filename})` : ''
|
||||
}\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${result.content}\n`,
|
||||
`File: ${result.filename}\nAnchor: \\ue202turn0file${index} (${result.filename})\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${
|
||||
result.content
|
||||
}\n`,
|
||||
)
|
||||
.join('\n---\n');
|
||||
|
||||
@@ -159,14 +158,12 @@ const createFileSearchTool = async ({ req, files, entity_id, fileCitations = fal
|
||||
pageRelevance: result.page ? { [result.page]: 1.0 - result.distance } : {},
|
||||
}));
|
||||
|
||||
return [formattedString, { [Tools.file_search]: { sources, fileCitations } }];
|
||||
return [formattedString, { [Tools.file_search]: { sources } }];
|
||||
},
|
||||
{
|
||||
name: Tools.file_search,
|
||||
responseFormat: 'content_and_artifact',
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.${
|
||||
fileCitations
|
||||
? `
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.
|
||||
|
||||
**CITE FILE SEARCH RESULTS:**
|
||||
Use anchor markers immediately after statements derived from file content. Reference the filename in your text:
|
||||
@@ -174,9 +171,7 @@ Use anchor markers immediately after statements derived from file content. Refer
|
||||
- Page reference: "According to report.docx... \\ue202turn0file1"
|
||||
- Multi-file: "Multiple sources confirm... \\ue200\\ue202turn0file0\\ue202turn0file1\\ue201"
|
||||
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`
|
||||
: ''
|
||||
}`,
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`,
|
||||
schema: z.object({
|
||||
query: z
|
||||
.string()
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api');
|
||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
Permissions,
|
||||
EToolResources,
|
||||
PermissionTypes,
|
||||
replaceSpecialVars,
|
||||
} = require('librechat-data-provider');
|
||||
const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const {
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
@@ -34,7 +27,6 @@ const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
||||
@@ -289,29 +281,7 @@ const loadTools = async ({
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
|
||||
/** @type {boolean | undefined} Check if user has FILE_CITATIONS permission */
|
||||
let fileCitations;
|
||||
if (fileCitations == null && options.req?.user != null) {
|
||||
try {
|
||||
fileCitations = await checkAccess({
|
||||
user: options.req.user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[handleTools] FILE_CITATIONS permission check failed:', error);
|
||||
fileCitations = false;
|
||||
}
|
||||
}
|
||||
|
||||
return createFileSearchTool({
|
||||
req: options.req,
|
||||
files,
|
||||
entity_id: agent?.id,
|
||||
fileCitations,
|
||||
});
|
||||
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
|
||||
};
|
||||
continue;
|
||||
} else if (tool === Tools.web_search) {
|
||||
@@ -342,16 +312,6 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
continue;
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
|
||||
if (toolName === Constants.mcp_server) {
|
||||
/** Placeholder used for UI purposes */
|
||||
continue;
|
||||
}
|
||||
if (serverName && options.req?.config?.mcpConfig?.[serverName] == null) {
|
||||
logger.warn(
|
||||
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (toolName === Constants.mcp_all) {
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTools({
|
||||
|
||||
@@ -211,7 +211,7 @@ describe('File Access Control', () => {
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when user only has VIEW permission and needs access for deletion', async () => {
|
||||
it('should deny access when user only has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
@@ -263,71 +263,12 @@ describe('File Access Control', () => {
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
isDelete: true,
|
||||
});
|
||||
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access when user has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
});
|
||||
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFiles with agent access control', () => {
|
||||
|
||||
@@ -269,7 +269,7 @@ async function getListPromptGroupsByAccess({
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
@@ -189,15 +189,11 @@ async function createAutoRefillTransaction(txData) {
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createTransaction(_txData) {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
const { balance, ...txData } = _txData;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
calculateTokenValue(transaction);
|
||||
@@ -226,11 +222,7 @@ async function createTransaction(_txData) {
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createStructuredTransaction(_txData) {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { balance, ...txData } = _txData;
|
||||
const transaction = new Transaction({
|
||||
...txData,
|
||||
endpointTokenConfig: txData.endpointTokenConfig,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
||||
const { Balance, Transaction } = require('~/db/models');
|
||||
const { createTransaction } = require('./Transaction');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
let mongoServer;
|
||||
beforeAll(async () => {
|
||||
@@ -379,188 +380,3 @@ describe('NaN Handling Tests', () => {
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Transactions Config Tests', () => {
|
||||
test('createTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createTransaction should save when transactions.enabled is true', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created
|
||||
expect(result).toBeDefined();
|
||||
expect(result.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
});
|
||||
|
||||
test('createTransaction should save when balance.enabled is true even if transactions config is missing', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
balance: { enabled: true },
|
||||
// No transactions config provided
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created (backward compatibility)
|
||||
expect(result).toBeDefined();
|
||||
expect(result.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
});
|
||||
|
||||
test('createTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].inputTokens).toBe(-10);
|
||||
expect(transactions[0].writeTokens).toBe(-100);
|
||||
expect(transactions[0].readTokens).toBe(-5);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { matchModelName } = require('@librechat/api');
|
||||
const { matchModelName } = require('../utils/tokens');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
|
||||
@@ -75,7 +75,7 @@ const refreshController = async (req, res) => {
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
|
||||
const token = setOpenIDAuthTokens(tokenset, res);
|
||||
return res.status(200).send({ token, user });
|
||||
} catch (error) {
|
||||
logger.error('[refreshController] OpenID token refresh error', error);
|
||||
|
||||
@@ -74,23 +74,14 @@ const getAvailableTools = async (req, res) => {
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
const mcpManager = getMCPManager();
|
||||
const userPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
|
||||
: undefined;
|
||||
|
||||
/** @type {TPlugin[]} */
|
||||
let mcpPlugins;
|
||||
if (appConfig?.mcpConfig) {
|
||||
const mcpManager = getMCPManager();
|
||||
mcpPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
|
||||
: undefined;
|
||||
}
|
||||
|
||||
if (
|
||||
cachedToolsArray != null &&
|
||||
(appConfig?.mcpConfig != null ? mcpPlugins != null && mcpPlugins.length > 0 : true)
|
||||
) {
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...cachedToolsArray]);
|
||||
if (cachedToolsArray != null && userPlugins != null) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
return;
|
||||
}
|
||||
@@ -102,9 +93,9 @@ const getAvailableTools = async (req, res) => {
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
let pluginManifest = availableTools;
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
if (appConfig?.mcpConfig != null) {
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const mcpTools = await mcpManager.getAllToolFunctions(userId);
|
||||
prelimCachedTools = prelimCachedTools ?? {};
|
||||
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
|
||||
@@ -184,7 +175,7 @@ const getAvailableTools = async (req, res) => {
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...finalTools]);
|
||||
const dedupedTools = filterUniquePlugins([...(userPlugins ?? []), ...finalTools]);
|
||||
res.status(200).json(dedupedTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
|
||||
@@ -174,19 +174,10 @@ describe('PluginController', () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially (since getAllToolFunctions is called)
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// Mock second call to return tool definitions (includeGlobal: true)
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
@@ -514,7 +505,7 @@ describe('PluginController', () => {
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle `cachedToolsArray` and `mcpPlugins` both being defined', async () => {
|
||||
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
|
||||
const userTools = {
|
||||
@@ -531,19 +522,10 @@ describe('PluginController', () => {
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValueOnce(userTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// The controller expects a second call to getCachedTools
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },
|
||||
|
||||
@@ -187,7 +187,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
// Extract server name from pluginKey (format: "mcp_<serverName>")
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
logger.info(
|
||||
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
|
||||
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
}
|
||||
|
||||
@@ -7,12 +7,10 @@ const {
|
||||
createRun,
|
||||
Tokenizer,
|
||||
checkAccess,
|
||||
logAxiosError,
|
||||
resolveHeaders,
|
||||
getBalanceConfig,
|
||||
memoryInstructions,
|
||||
formatContentStrings,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
@@ -89,10 +87,11 @@ function createTokenCounter(encoding) {
|
||||
}
|
||||
|
||||
function logToolError(graph, error, toolId) {
|
||||
logAxiosError({
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
|
||||
error,
|
||||
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
|
||||
});
|
||||
toolId,
|
||||
);
|
||||
}
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
@@ -624,13 +623,11 @@ class AgentClient extends BaseClient {
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.context='message']
|
||||
* @param {AppConfig['balance']} [params.balance]
|
||||
* @param {AppConfig['transactions']} [params.transactions]
|
||||
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
|
||||
*/
|
||||
async recordCollectedUsage({
|
||||
model,
|
||||
balance,
|
||||
transactions,
|
||||
context = 'message',
|
||||
collectedUsage = this.collectedUsage,
|
||||
}) {
|
||||
@@ -656,7 +653,6 @@ class AgentClient extends BaseClient {
|
||||
const txMetadata = {
|
||||
context,
|
||||
balance,
|
||||
transactions,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
@@ -1055,12 +1051,7 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
await this.recordCollectedUsage({
|
||||
context: 'message',
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
});
|
||||
await this.recordCollectedUsage({ context: 'message', balance: balanceConfig });
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
|
||||
@@ -1254,13 +1245,11 @@ class AgentClient extends BaseClient {
|
||||
});
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
await this.recordCollectedUsage({
|
||||
collectedUsage,
|
||||
context: 'title',
|
||||
model: clientOptions.model,
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
}).catch((err) => {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
|
||||
|
||||
@@ -237,9 +237,6 @@ describe('AgentClient - titleConvo', () => {
|
||||
balance: {
|
||||
enabled: false,
|
||||
},
|
||||
transactions: {
|
||||
enabled: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
SystemRoles,
|
||||
FileSources,
|
||||
ResourceType,
|
||||
@@ -70,9 +69,9 @@ const createAgentHandler = async (req, res) => {
|
||||
for (const tool of tools) {
|
||||
if (availableTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
} else if (systemTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
} else if (tool.includes(Constants.mcp_delimiter)) {
|
||||
}
|
||||
|
||||
if (systemTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const { sendEvent, getBalanceConfig } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -34,6 +34,7 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const { sendEvent, getBalanceConfig } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -31,6 +31,7 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
||||
@@ -12,7 +12,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
const { checkMigrations } = require('./services/start/migration');
|
||||
@@ -126,7 +126,7 @@ const startServer = async () => {
|
||||
app.use('/api/config', routes.config);
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
|
||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const validatePasswordReset = require('./validatePasswordReset');
|
||||
const validateRegistration = require('./validateRegistration');
|
||||
const validateImageRequest = require('./validateImageRequest');
|
||||
const buildEndpointOption = require('./buildEndpointOption');
|
||||
const validateMessageReq = require('./validateMessageReq');
|
||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||
@@ -49,5 +50,6 @@ module.exports = {
|
||||
validateMessageReq,
|
||||
buildEndpointOption,
|
||||
validateRegistration,
|
||||
validateImageRequest,
|
||||
validatePasswordReset,
|
||||
};
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const createValidateImageRequest = require('~/server/middleware/validateImageRequest');
|
||||
const validateImageRequest = require('~/server/middleware/validateImageRequest');
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
isEnabled: jest.fn(),
|
||||
jest.mock('~/server/services/Config/app', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('validateImageRequest middleware', () => {
|
||||
let req, res, next, validateImageRequest;
|
||||
let req, res, next;
|
||||
const validObjectId = '65cfb246f7ecadb8b1e8036b';
|
||||
const { getAppConfig } = require('~/server/services/Config/app');
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -22,278 +22,116 @@ describe('validateImageRequest middleware', () => {
|
||||
};
|
||||
next = jest.fn();
|
||||
process.env.JWT_REFRESH_SECRET = 'test-secret';
|
||||
process.env.OPENID_REUSE_TOKENS = 'false';
|
||||
|
||||
// Default: OpenID token reuse disabled
|
||||
isEnabled.mockReturnValue(false);
|
||||
// Mock getAppConfig to return secureImageLinks: true by default
|
||||
getAppConfig.mockResolvedValue({
|
||||
secureImageLinks: true,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Factory function', () => {
|
||||
test('should return a pass-through middleware if secureImageLinks is false', async () => {
|
||||
const middleware = createValidateImageRequest(false);
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return validation middleware if secureImageLinks is true', async () => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
test('should call next() if secureImageLinks is false', async () => {
|
||||
getAppConfig.mockResolvedValue({
|
||||
secureImageLinks: false,
|
||||
});
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('Standard LibreChat token flow', () => {
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
});
|
||||
|
||||
test('should return 401 if refresh token is not provided', async () => {
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is invalid', async () => {
|
||||
req.headers.cookie = 'refreshToken=invalid-token';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is expired', async () => {
|
||||
const expiredToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${expiredToken}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should call next() for valid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should allow agent avatar pattern for any valid ObjectId', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should prevent file traversal attempts', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
|
||||
const traversalAttempts = [
|
||||
`/images/${validObjectId}/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
|
||||
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
|
||||
];
|
||||
|
||||
for (const attempt of traversalAttempts) {
|
||||
req.originalUrl = attempt;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
jest.clearAllMocks();
|
||||
// Reset mocks for next iteration
|
||||
res.status = jest.fn().mockReturnThis();
|
||||
res.send = jest.fn();
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle URL encoded characters in valid paths', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
test('should return 401 if refresh token is not provided', async () => {
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
|
||||
describe('OpenID token flow', () => {
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
// Enable OpenID token reuse
|
||||
isEnabled.mockReturnValue(true);
|
||||
process.env.OPENID_REUSE_TOKENS = 'true';
|
||||
});
|
||||
|
||||
test('should return 403 if no OpenID user ID cookie when token_provider is openid', async () => {
|
||||
req.headers.cookie = 'refreshToken=dummy-token; token_provider=openid';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should validate JWT-signed user ID for OpenID flow', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid JWT-signed user ID', async () => {
|
||||
req.headers.cookie =
|
||||
'refreshToken=dummy-token; token_provider=openid; openid_user_id=invalid-jwt';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 for expired JWT-signed user ID', async () => {
|
||||
const expiredSignedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${expiredSignedUserId}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should validate image path against JWT-signed user ID', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
const differentObjectId = '65cfb246f7ecadb8b1e8036c';
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = `/images/${differentObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should allow agent avatars in OpenID flow', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
test('should return 403 if refresh token is invalid', async () => {
|
||||
req.headers.cookie = 'refreshToken=invalid-token';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
describe('Security edge cases', () => {
|
||||
let validToken;
|
||||
test('should return 403 if refresh token is expired', async () => {
|
||||
const expiredToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${expiredToken}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
});
|
||||
test('should call next() for valid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle very long image filenames', async () => {
|
||||
const longFilename = 'a'.repeat(1000) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
test('should return 403 for invalid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should handle URLs with maximum practical length', async () => {
|
||||
// Most browsers support URLs up to ~2000 characters
|
||||
const longFilename = 'x'.repeat(1900) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
test('should return 403 for invalid ObjectId format', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/123/example.jpg'; // Invalid ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should accept URLs just under the 2048 limit', async () => {
|
||||
// Create a URL exactly 2047 characters long
|
||||
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
|
||||
const filenameLength = 2047 - baseLength;
|
||||
const filename = 'a'.repeat(filenameLength) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${filename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
// File traversal tests
|
||||
test('should prevent file traversal attempts', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
|
||||
test('should handle malformed URL encoding gracefully', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/test%ZZinvalid.jpg`;
|
||||
const traversalAttempts = [
|
||||
`/images/${validObjectId}/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
|
||||
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
|
||||
];
|
||||
|
||||
for (const attempt of traversalAttempts) {
|
||||
req.originalUrl = attempt;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
jest.clearAllMocks();
|
||||
}
|
||||
});
|
||||
|
||||
test('should reject URLs with null bytes', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/test\x00.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should handle URLs with repeated slashes', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}//test.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should reject extremely long URLs as potential DoS', async () => {
|
||||
// Create a URL longer than 2048 characters
|
||||
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
|
||||
const filenameLength = 2049 - baseLength; // Ensure total length exceeds 2048
|
||||
const extremelyLongFilename = 'x'.repeat(filenameLength) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${extremelyLongFilename}`;
|
||||
// Verify our test URL is actually too long
|
||||
expect(req.originalUrl.length).toBeGreaterThan(2048);
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
test('should handle URL encoded characters in valid paths', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getAppConfig } = require('~/server/services/Config/app');
|
||||
|
||||
const OBJECT_ID_LENGTH = 24;
|
||||
const OBJECT_ID_PATTERN = /^[0-9a-f]{24}$/i;
|
||||
@@ -22,129 +22,50 @@ function isValidObjectId(id) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a LibreChat refresh token
|
||||
* @param {string} refreshToken - The refresh token to validate
|
||||
* @returns {{valid: boolean, userId?: string, error?: string}} - Validation result
|
||||
* Middleware to validate image request.
|
||||
* Must be set by `secureImageLinks` via custom config file.
|
||||
*/
|
||||
function validateToken(refreshToken) {
|
||||
async function validateImageRequest(req, res, next) {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
if (!appConfig.secureImageLinks) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Refresh token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
let payload;
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
|
||||
if (!isValidObjectId(payload.id)) {
|
||||
return { valid: false, error: 'Invalid User ID' };
|
||||
}
|
||||
|
||||
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
|
||||
if (payload.exp < currentTimeInSeconds) {
|
||||
return { valid: false, error: 'Refresh token expired' };
|
||||
}
|
||||
|
||||
return { valid: true, userId: payload.id };
|
||||
payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
} catch (err) {
|
||||
logger.warn('[validateToken]', err);
|
||||
return { valid: false, error: 'Invalid token' };
|
||||
logger.warn('[validateImageRequest]', err);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (!isValidObjectId(payload.id)) {
|
||||
logger.warn('[validateImageRequest] Invalid User ID');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
|
||||
if (payload.exp < currentTimeInSeconds) {
|
||||
logger.warn('[validateImageRequest] Refresh token expired');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const fullPath = decodeURIComponent(req.originalUrl);
|
||||
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory to create the `validateImageRequest` middleware with configured secureImageLinks
|
||||
* @param {boolean} [secureImageLinks] - Whether secure image links are enabled
|
||||
*/
|
||||
function createValidateImageRequest(secureImageLinks) {
|
||||
if (!secureImageLinks) {
|
||||
return (_req, _res, next) => next();
|
||||
}
|
||||
/**
|
||||
* Middleware to validate image request.
|
||||
* Supports both LibreChat refresh tokens and OpenID JWT tokens.
|
||||
* Must be set by `secureImageLinks` via custom config file.
|
||||
*/
|
||||
return async function validateImageRequest(req, res, next) {
|
||||
try {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
if (!cookieHeader) {
|
||||
logger.warn('[validateImageRequest] No cookies provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const parsedCookies = cookies.parse(cookieHeader);
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const tokenProvider = parsedCookies.token_provider;
|
||||
let userIdForPath;
|
||||
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
const openidUserId = parsedCookies.openid_user_id;
|
||||
if (!openidUserId) {
|
||||
logger.warn('[validateImageRequest] No OpenID user ID cookie found');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const validationResult = validateToken(openidUserId);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
} else {
|
||||
const validationResult = validateToken(refreshToken);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
}
|
||||
|
||||
if (!userIdForPath) {
|
||||
logger.warn('[validateImageRequest] No user ID available for path validation');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const MAX_URL_LENGTH = 2048;
|
||||
if (req.originalUrl.length > MAX_URL_LENGTH) {
|
||||
logger.warn('[validateImageRequest] URL too long');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (req.originalUrl.includes('\x00')) {
|
||||
logger.warn('[validateImageRequest] URL contains null byte');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
let fullPath;
|
||||
try {
|
||||
fullPath = decodeURIComponent(req.originalUrl);
|
||||
} catch {
|
||||
logger.warn('[validateImageRequest] Invalid URL encoding');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const agentAvatarPattern = /^\/images\/[a-f0-9]{24}\/agent-[^/]*$/;
|
||||
if (agentAvatarPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
return next();
|
||||
}
|
||||
|
||||
const escapedUserId = userIdForPath.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const pathPattern = new RegExp(`^/images/${escapedUserId}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[validateImageRequest] Error:', error);
|
||||
res.status(500).send('Internal Server Error');
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = createValidateImageRequest;
|
||||
module.exports = validateImageRequest;
|
||||
|
||||
@@ -122,11 +122,9 @@ router.get('/', async function (req, res) {
|
||||
payload.minPasswordLength = minPasswordLength;
|
||||
}
|
||||
|
||||
payload.mcpServers = {};
|
||||
const getMCPServers = () => {
|
||||
try {
|
||||
if (appConfig?.mcpConfig == null) {
|
||||
return;
|
||||
}
|
||||
const mcpManager = getMCPManager();
|
||||
if (!mcpManager) {
|
||||
return;
|
||||
@@ -135,9 +133,6 @@ router.get('/', async function (req, res) {
|
||||
if (!mcpServers) return;
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
for (const serverName in mcpServers) {
|
||||
if (!payload.mcpServers) {
|
||||
payload.mcpServers = {};
|
||||
}
|
||||
const serverConfig = mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = removeNullishValues({
|
||||
startup: serverConfig?.startup,
|
||||
|
||||
@@ -4,13 +4,9 @@ const { sleep } = require('@librechat/agents');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
createImportLimiters,
|
||||
createForkLimiters,
|
||||
configMiddleware,
|
||||
} = require('~/server/middleware');
|
||||
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
||||
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
|
||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { importConversations } = require('~/server/utils/import');
|
||||
@@ -175,7 +171,6 @@ router.post(
|
||||
'/import',
|
||||
importIpLimiter,
|
||||
importUserLimiter,
|
||||
configMiddleware,
|
||||
upload.single('file'),
|
||||
async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -185,7 +185,6 @@ router.delete('/', async (req, res) => {
|
||||
role: req.user.role,
|
||||
fileIds: nonOwnedFileIds,
|
||||
agentId: req.body.agent_id,
|
||||
isDelete: true,
|
||||
});
|
||||
|
||||
for (const file of nonOwnedFiles) {
|
||||
|
||||
@@ -39,7 +39,7 @@ const oauthHandler = async (req, res) => {
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
setOpenIDAuthTokens(req.user.tokenset, res);
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ router.get('/all', async (req, res) => {
|
||||
router.get('/groups', async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { pageSize, limit, cursor, name, category, ...otherFilters } = req.query;
|
||||
const { pageSize, pageNumber, limit, cursor, name, category, ...otherFilters } = req.query;
|
||||
|
||||
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
|
||||
name,
|
||||
@@ -171,13 +171,6 @@ router.get('/groups', async (req, res) => {
|
||||
actualLimit = parseInt(pageSize, 10);
|
||||
}
|
||||
|
||||
if (
|
||||
actualCursor &&
|
||||
(actualCursor === 'undefined' || actualCursor === 'null' || actualCursor.length === 0)
|
||||
) {
|
||||
actualCursor = null;
|
||||
}
|
||||
|
||||
let accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
@@ -197,7 +190,6 @@ router.get('/groups', async (req, res) => {
|
||||
publicPromptGroupIds: publiclyAccessibleIds,
|
||||
});
|
||||
|
||||
// Cursor-based pagination only
|
||||
const result = await getListPromptGroupsByAccess({
|
||||
accessibleIds: filteredAccessibleIds,
|
||||
otherParams: filter,
|
||||
@@ -206,21 +198,19 @@ router.get('/groups', async (req, res) => {
|
||||
});
|
||||
|
||||
if (!result) {
|
||||
const emptyResponse = createEmptyPromptGroupsResponse({
|
||||
pageNumber: '1',
|
||||
pageSize: actualLimit,
|
||||
actualLimit,
|
||||
});
|
||||
const emptyResponse = createEmptyPromptGroupsResponse({ pageNumber, pageSize, actualLimit });
|
||||
return res.status(200).send(emptyResponse);
|
||||
}
|
||||
|
||||
const { data: promptGroups = [], has_more = false, after = null } = result;
|
||||
|
||||
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
|
||||
|
||||
const response = formatPromptGroupsResponse({
|
||||
promptGroups: groupsWithPublicFlag,
|
||||
pageNumber: '1', // Always 1 for cursor-based pagination
|
||||
pageSize: actualLimit.toString(),
|
||||
pageNumber,
|
||||
pageSize,
|
||||
actualLimit,
|
||||
hasMore: has_more,
|
||||
after,
|
||||
});
|
||||
|
||||
@@ -33,11 +33,22 @@ let promptRoutes;
|
||||
let Prompt, PromptGroup, AclEntry, AccessRole, User;
|
||||
let testUsers, testRoles;
|
||||
let grantPermission;
|
||||
let currentTestUser; // Track current user for middleware
|
||||
|
||||
// Helper function to set user in middleware
|
||||
function setTestUser(app, user) {
|
||||
currentTestUser = user;
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...(user.toObject ? user.toObject() : user),
|
||||
id: user.id || user._id.toString(),
|
||||
_id: user._id,
|
||||
name: user.name,
|
||||
role: user.role,
|
||||
};
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
console.log('Setting admin user with role:', req.user.role);
|
||||
}
|
||||
next();
|
||||
});
|
||||
}
|
||||
|
||||
beforeAll(async () => {
|
||||
@@ -64,35 +75,14 @@ beforeAll(async () => {
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Add user middleware before routes
|
||||
app.use((req, res, next) => {
|
||||
if (currentTestUser) {
|
||||
req.user = {
|
||||
...(currentTestUser.toObject ? currentTestUser.toObject() : currentTestUser),
|
||||
id: currentTestUser._id.toString(),
|
||||
_id: currentTestUser._id,
|
||||
name: currentTestUser.name,
|
||||
role: currentTestUser.role,
|
||||
};
|
||||
}
|
||||
next();
|
||||
});
|
||||
// Mock authentication middleware - default to owner
|
||||
setTestUser(app, testUsers.owner);
|
||||
|
||||
// Set default user
|
||||
currentTestUser = testUsers.owner;
|
||||
|
||||
// Import routes after middleware is set up
|
||||
// Import routes after mocks are set up
|
||||
promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Always reset to owner user after each test for isolation
|
||||
if (currentTestUser !== testUsers.owner) {
|
||||
currentTestUser = testUsers.owner;
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
@@ -126,26 +116,36 @@ async function setupTestData() {
|
||||
// Create test users
|
||||
testUsers = {
|
||||
owner: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Owner',
|
||||
email: 'owner@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
viewer: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Viewer',
|
||||
email: 'viewer@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
editor: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Editor',
|
||||
email: 'editor@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
noAccess: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'No Access',
|
||||
email: 'noaccess@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
admin: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Admin',
|
||||
email: 'admin@example.com',
|
||||
role: SystemRoles.ADMIN,
|
||||
@@ -181,7 +181,8 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
it('should have routes loaded', async () => {
|
||||
// This should at least not crash
|
||||
const response = await request(app).get('/api/prompts/test-404');
|
||||
|
||||
console.log('Test 404 response status:', response.status);
|
||||
console.log('Test 404 response body:', response.body);
|
||||
// We expect a 401 or 404, not 500
|
||||
expect(response.status).not.toBe(500);
|
||||
});
|
||||
@@ -206,6 +207,12 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
|
||||
const response = await request(app).post('/api/prompts').send(promptData);
|
||||
|
||||
if (response.status !== 200) {
|
||||
console.log('POST /api/prompts error status:', response.status);
|
||||
console.log('POST /api/prompts error body:', response.body);
|
||||
console.log('Console errors:', consoleErrorSpy.mock.calls);
|
||||
}
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.prompt).toBeDefined();
|
||||
expect(response.body.prompt.prompt).toBe(promptData.prompt.prompt);
|
||||
@@ -311,8 +318,29 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
});
|
||||
|
||||
it('should allow admin access without explicit permissions', async () => {
|
||||
// Set admin user
|
||||
setTestUser(app, testUsers.admin);
|
||||
// First, reset the app to remove previous middleware
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Set admin user BEFORE adding routes
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.admin.toObject(),
|
||||
id: testUsers.admin._id.toString(),
|
||||
_id: testUsers.admin._id,
|
||||
name: testUsers.admin.name,
|
||||
role: testUsers.admin.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
|
||||
// Now add the routes
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
console.log('Admin user:', testUsers.admin);
|
||||
console.log('Admin role:', testUsers.admin.role);
|
||||
console.log('SystemRoles.ADMIN:', SystemRoles.ADMIN);
|
||||
|
||||
const response = await request(app).get(`/api/prompts/${testPrompt._id}`).expect(200);
|
||||
|
||||
@@ -404,8 +432,21 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
grantedBy: testUsers.editor._id,
|
||||
});
|
||||
|
||||
// Set viewer user
|
||||
setTestUser(app, testUsers.viewer);
|
||||
// Recreate app with viewer user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.viewer.toObject(),
|
||||
id: testUsers.viewer._id.toString(),
|
||||
_id: testUsers.viewer._id,
|
||||
name: testUsers.viewer.name,
|
||||
role: testUsers.viewer.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
await request(app)
|
||||
.delete(`/api/prompts/${authorPrompt._id}`)
|
||||
@@ -458,8 +499,21 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Ensure owner user
|
||||
setTestUser(app, testUsers.owner);
|
||||
// Recreate app to ensure fresh middleware
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.owner.toObject(),
|
||||
id: testUsers.owner._id.toString(),
|
||||
_id: testUsers.owner._id,
|
||||
name: testUsers.owner.name,
|
||||
role: testUsers.owner.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/${testPrompt._id}/tags/production`)
|
||||
@@ -483,8 +537,21 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Set viewer user
|
||||
setTestUser(app, testUsers.viewer);
|
||||
// Recreate app with viewer user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.viewer.toObject(),
|
||||
id: testUsers.viewer._id.toString(),
|
||||
_id: testUsers.viewer._id,
|
||||
name: testUsers.viewer.name,
|
||||
role: testUsers.viewer.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
await request(app).patch(`/api/prompts/${testPrompt._id}/tags/production`).expect(403);
|
||||
|
||||
@@ -543,305 +610,4 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
expect(response.body._id).toBe(publicPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
|
||||
describe('Pagination', () => {
|
||||
beforeEach(async () => {
|
||||
// Create multiple prompt groups for pagination testing
|
||||
const groups = [];
|
||||
for (let i = 0; i < 15; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Test Group ${i + 1}`,
|
||||
category: 'pagination-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000), // Stagger updatedAt for consistent ordering
|
||||
});
|
||||
groups.push(group);
|
||||
|
||||
// Grant owner permissions on each group
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should correctly indicate hasMore when there are more pages', async () => {
|
||||
const response = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10' })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.promptGroups).toHaveLength(10);
|
||||
expect(response.body.has_more).toBe(true);
|
||||
expect(response.body.after).toBeTruthy();
|
||||
// Since has_more is true, pages should be a high number (9999 in our fix)
|
||||
expect(parseInt(response.body.pages)).toBeGreaterThan(1);
|
||||
});
|
||||
|
||||
it('should correctly indicate no more pages on the last page', async () => {
|
||||
// First get the cursor for page 2
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
// Now fetch the second page using the cursor
|
||||
const response = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10', cursor: firstPage.body.after })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.promptGroups).toHaveLength(5); // 15 total, 10 on page 1, 5 on page 2
|
||||
expect(response.body.has_more).toBe(false);
|
||||
});
|
||||
|
||||
it('should support cursor-based pagination', async () => {
|
||||
// First page
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.promptGroups).toHaveLength(5);
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
// Second page using cursor
|
||||
const secondPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', cursor: firstPage.body.after })
|
||||
.expect(200);
|
||||
|
||||
expect(secondPage.body.promptGroups).toHaveLength(5);
|
||||
expect(secondPage.body.has_more).toBe(true);
|
||||
expect(secondPage.body.after).toBeTruthy();
|
||||
|
||||
// Verify different groups
|
||||
const firstPageIds = firstPage.body.promptGroups.map((g) => g._id);
|
||||
const secondPageIds = secondPage.body.promptGroups.map((g) => g._id);
|
||||
expect(firstPageIds).not.toEqual(secondPageIds);
|
||||
});
|
||||
|
||||
it('should paginate correctly with category filtering', async () => {
|
||||
// Create groups with different categories
|
||||
await PromptGroup.deleteMany({}); // Clear existing groups
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create 8 groups with category 'test-cat-1'
|
||||
for (let i = 0; i < 8; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Category 1 Group ${i + 1}`,
|
||||
category: 'test-cat-1',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Create 7 groups with category 'test-cat-2'
|
||||
for (let i = 0; i < 7; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Category 2 Group ${i + 1}`,
|
||||
category: 'test-cat-2',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - (i + 8) * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Test pagination with category filter
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', category: 'test-cat-1' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.promptGroups).toHaveLength(5);
|
||||
expect(firstPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
const secondPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', cursor: firstPage.body.after, category: 'test-cat-1' })
|
||||
.expect(200);
|
||||
|
||||
expect(secondPage.body.promptGroups).toHaveLength(3); // 8 total, 5 on page 1, 3 on page 2
|
||||
expect(secondPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
|
||||
expect(secondPage.body.has_more).toBe(false);
|
||||
});
|
||||
|
||||
it('should paginate correctly with name/keyword filtering', async () => {
|
||||
// Create groups with specific names
|
||||
await PromptGroup.deleteMany({}); // Clear existing groups
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create 12 groups with 'Search' in the name
|
||||
for (let i = 0; i < 12; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Search Test Group ${i + 1}`,
|
||||
category: 'search-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Create 5 groups without 'Search' in the name
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Other Group ${i + 1}`,
|
||||
category: 'other-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - (i + 12) * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Test pagination with name filter
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10', name: 'Search' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.promptGroups).toHaveLength(10);
|
||||
expect(firstPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
const secondPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10', cursor: firstPage.body.after, name: 'Search' })
|
||||
.expect(200);
|
||||
|
||||
expect(secondPage.body.promptGroups).toHaveLength(2); // 12 total, 10 on page 1, 2 on page 2
|
||||
expect(secondPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
|
||||
expect(secondPage.body.has_more).toBe(false);
|
||||
});
|
||||
|
||||
it('should paginate correctly with combined filters', async () => {
|
||||
// Create groups with various combinations
|
||||
await PromptGroup.deleteMany({}); // Clear existing groups
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create 6 groups matching both category and name filters
|
||||
for (let i = 0; i < 6; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `API Test Group ${i + 1}`,
|
||||
category: 'api-category',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Create groups that only match one filter
|
||||
for (let i = 0; i < 4; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `API Other Group ${i + 1}`,
|
||||
category: 'other-category',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - (i + 6) * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Test pagination with both filters
|
||||
const response = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', name: 'API', category: 'api-category' })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.promptGroups).toHaveLength(5);
|
||||
expect(
|
||||
response.body.promptGroups.every(
|
||||
(g) => g.name.includes('API') && g.category === 'api-category',
|
||||
),
|
||||
).toBe(true);
|
||||
expect(response.body.has_more).toBe(true);
|
||||
expect(response.body.after).toBeTruthy();
|
||||
|
||||
// Page 2
|
||||
const page2 = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', cursor: response.body.after, name: 'API', category: 'api-category' })
|
||||
.expect(200);
|
||||
|
||||
expect(page2.body.promptGroups).toHaveLength(1); // 6 total, 5 on page 1, 1 on page 2
|
||||
expect(page2.body.has_more).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -49,7 +49,6 @@ const AppService = async () => {
|
||||
enabled: isEnabled(process.env.CHECK_BALANCE),
|
||||
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
|
||||
};
|
||||
const transactions = config.transactions ?? configDefaults.transactions;
|
||||
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
|
||||
|
||||
process.env.CDN_PROVIDER = fileStrategy;
|
||||
@@ -85,7 +84,6 @@ const AppService = async () => {
|
||||
memory,
|
||||
speech,
|
||||
balance,
|
||||
transactions,
|
||||
mcpConfig,
|
||||
webSearch,
|
||||
fileStrategy,
|
||||
|
||||
@@ -402,10 +402,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
* @param {import('openid-client').TokenEndpointResponse & import('openid-client').TokenEndpointResponseHelpers} tokenset
|
||||
* - The tokenset object containing access and refresh tokens
|
||||
* @param {Object} res - response object
|
||||
* @param {string} [userId] - Optional MongoDB user ID for image path validation
|
||||
* @returns {String} - access token
|
||||
*/
|
||||
const setOpenIDAuthTokens = (tokenset, res, userId) => {
|
||||
const setOpenIDAuthTokens = (tokenset, res) => {
|
||||
try {
|
||||
if (!tokenset) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
@@ -436,18 +435,6 @@ const setOpenIDAuthTokens = (tokenset, res, userId) => {
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
if (userId && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
/** JWT-signed user ID cookie for image path validation when OPENID_REUSE_TOKENS is enabled */
|
||||
const signedUserId = jwt.sign({ id: userId }, process.env.JWT_REFRESH_SECRET, {
|
||||
expiresIn: expiryInMilliseconds / 1000,
|
||||
});
|
||||
res.cookie('openid_user_id', signedUserId, {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
}
|
||||
return tokenset.access_token;
|
||||
} catch (error) {
|
||||
logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error);
|
||||
@@ -465,7 +452,7 @@ const setOpenIDAuthTokens = (tokenset, res, userId) => {
|
||||
const resendVerificationEmail = async (req) => {
|
||||
try {
|
||||
const { email } = req.body;
|
||||
await deleteTokens({ email });
|
||||
await deleteTokens(email);
|
||||
const user = await findUser({ email }, 'email _id name');
|
||||
|
||||
if (!user) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const {
|
||||
primeResources,
|
||||
getModelMaxTokens,
|
||||
extractLibreChatParams,
|
||||
optionalChainWithEmptyCheck,
|
||||
} = require('@librechat/api');
|
||||
@@ -18,6 +17,7 @@ const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getFiles, getToolFilesByIds } = require('~/models/File');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { AnthropicClientOptions } from '@librechat/agents';
|
||||
import { EModelEndpoint, anthropicSettings } from 'librechat-data-provider';
|
||||
import { matchModelName } from '~/utils/tokens';
|
||||
const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider');
|
||||
const { matchModelName } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @param {string} modelName
|
||||
* @returns {boolean}
|
||||
*/
|
||||
function checkPromptCacheSupport(modelName: string): boolean {
|
||||
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic) ?? '';
|
||||
function checkPromptCacheSupport(modelName) {
|
||||
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic);
|
||||
if (
|
||||
modelMatch.includes('claude-3-5-sonnet-latest') ||
|
||||
modelMatch.includes('claude-3.5-sonnet-latest')
|
||||
@@ -32,10 +31,7 @@ function checkPromptCacheSupport(modelName: string): boolean {
|
||||
* @param {boolean} supportsCacheControl Whether the model supports cache control
|
||||
* @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable
|
||||
*/
|
||||
function getClaudeHeaders(
|
||||
model: string,
|
||||
supportsCacheControl: boolean,
|
||||
): Record<string, string> | undefined {
|
||||
function getClaudeHeaders(model, supportsCacheControl) {
|
||||
if (!supportsCacheControl) {
|
||||
return undefined;
|
||||
}
|
||||
@@ -76,13 +72,9 @@ function getClaudeHeaders(
|
||||
* @param {number|null} extendedOptions.thinkingBudget The token budget for thinking
|
||||
* @returns {Object} Updated request options
|
||||
*/
|
||||
function configureReasoning(
|
||||
anthropicInput: AnthropicClientOptions & { max_tokens?: number },
|
||||
extendedOptions: { thinking?: boolean; thinkingBudget?: number | null } = {},
|
||||
): AnthropicClientOptions & { max_tokens?: number } {
|
||||
function configureReasoning(anthropicInput, extendedOptions = {}) {
|
||||
const updatedOptions = { ...anthropicInput };
|
||||
const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens;
|
||||
|
||||
if (
|
||||
extendedOptions.thinking &&
|
||||
updatedOptions?.model &&
|
||||
@@ -90,16 +82,11 @@ function configureReasoning(
|
||||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(updatedOptions.model))
|
||||
) {
|
||||
updatedOptions.thinking = {
|
||||
...updatedOptions.thinking,
|
||||
type: 'enabled',
|
||||
} as { type: 'enabled'; budget_tokens: number };
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
updatedOptions.thinking != null &&
|
||||
extendedOptions.thinkingBudget != null &&
|
||||
updatedOptions.thinking.type === 'enabled'
|
||||
) {
|
||||
if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) {
|
||||
updatedOptions.thinking = {
|
||||
...updatedOptions.thinking,
|
||||
budget_tokens: extendedOptions.thinkingBudget,
|
||||
@@ -108,10 +95,9 @@ function configureReasoning(
|
||||
|
||||
if (
|
||||
updatedOptions.thinking != null &&
|
||||
updatedOptions.thinking.type === 'enabled' &&
|
||||
(currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens)
|
||||
) {
|
||||
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model ?? '');
|
||||
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model);
|
||||
updatedOptions.max_tokens = currentMaxTokens ?? maxTokens;
|
||||
|
||||
logger.warn(
|
||||
@@ -129,4 +115,4 @@ function configureReasoning(
|
||||
return updatedOptions;
|
||||
}
|
||||
|
||||
export { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };
|
||||
module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };
|
||||
@@ -1,6 +1,6 @@
|
||||
const { getLLMConfig } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||
const AnthropicClient = require('~/app/clients/AnthropicClient');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
import { ProxyAgent } from 'undici';
|
||||
import { AnthropicClientOptions } from '@librechat/agents';
|
||||
import { anthropicSettings, removeNullishValues } from 'librechat-data-provider';
|
||||
import type {
|
||||
AnthropicConfigOptions,
|
||||
AnthropicLLMConfigResult,
|
||||
AnthropicParameters,
|
||||
} from '~/types/anthropic';
|
||||
import { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } from './helpers';
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
|
||||
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating an Anthropic language model (LLM) instance.
|
||||
@@ -27,42 +21,25 @@ import { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } from '.
|
||||
*
|
||||
* @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed.
|
||||
*/
|
||||
function getLLMConfig(
|
||||
apiKey?: string,
|
||||
options: AnthropicConfigOptions = {} as AnthropicConfigOptions,
|
||||
): AnthropicLLMConfigResult {
|
||||
function getLLMConfig(apiKey, options = {}) {
|
||||
const systemOptions = {
|
||||
thinking: options.modelOptions?.thinking ?? anthropicSettings.thinking.default,
|
||||
promptCache: options.modelOptions?.promptCache ?? anthropicSettings.promptCache.default,
|
||||
thinkingBudget:
|
||||
options.modelOptions?.thinkingBudget ?? anthropicSettings.thinkingBudget.default,
|
||||
thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default,
|
||||
promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default,
|
||||
thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default,
|
||||
};
|
||||
|
||||
/** Couldn't figure out a way to still loop through the object while deleting the overlapping keys when porting this
|
||||
* over from javascript, so for now they are being deleted manually until a better way presents itself.
|
||||
*/
|
||||
if (options.modelOptions) {
|
||||
delete options.modelOptions.thinking;
|
||||
delete options.modelOptions.promptCache;
|
||||
delete options.modelOptions.thinkingBudget;
|
||||
} else {
|
||||
throw new Error('No modelOptions provided');
|
||||
for (let key in systemOptions) {
|
||||
delete options.modelOptions[key];
|
||||
}
|
||||
|
||||
const defaultOptions = {
|
||||
model: anthropicSettings.model.default,
|
||||
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const mergedOptions = Object.assign(
|
||||
defaultOptions,
|
||||
options.modelOptions,
|
||||
) as typeof defaultOptions &
|
||||
Partial<AnthropicParameters> & { stop?: string[]; web_search?: boolean };
|
||||
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
|
||||
|
||||
/** @type {AnthropicClientOptions} */
|
||||
let requestOptions: AnthropicClientOptions & { stream?: boolean } = {
|
||||
let requestOptions = {
|
||||
apiKey,
|
||||
model: mergedOptions.model,
|
||||
stream: mergedOptions.stream,
|
||||
@@ -89,20 +66,20 @@ function getLLMConfig(
|
||||
}
|
||||
|
||||
const supportsCacheControl =
|
||||
systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model ?? '');
|
||||
const headers = getClaudeHeaders(requestOptions.model ?? '', supportsCacheControl);
|
||||
if (headers && requestOptions.clientOptions) {
|
||||
systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model);
|
||||
const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl);
|
||||
if (headers) {
|
||||
requestOptions.clientOptions.defaultHeaders = headers;
|
||||
}
|
||||
|
||||
if (options.proxy && requestOptions.clientOptions) {
|
||||
if (options.proxy) {
|
||||
const proxyAgent = new ProxyAgent(options.proxy);
|
||||
requestOptions.clientOptions.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (options.reverseProxyUrl && requestOptions.clientOptions) {
|
||||
if (options.reverseProxyUrl) {
|
||||
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
|
||||
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||
}
|
||||
@@ -119,10 +96,8 @@ function getLLMConfig(
|
||||
return {
|
||||
tools,
|
||||
/** @type {AnthropicClientOptions} */
|
||||
llmConfig: removeNullishValues(
|
||||
requestOptions as Record<string, unknown>,
|
||||
) as AnthropicClientOptions,
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
};
|
||||
}
|
||||
|
||||
export { getLLMConfig };
|
||||
module.exports = { getLLMConfig };
|
||||
@@ -1,4 +1,4 @@
|
||||
import { getLLMConfig } from './llm';
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||
|
||||
jest.mock('https-proxy-agent', () => ({
|
||||
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
|
||||
@@ -25,9 +25,9 @@ describe('getLLMConfig', () => {
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
|
||||
expect(result.llmConfig.clientOptions?.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions?.fetchOptions?.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions?.fetchOptions?.dispatcher.constructor.name).toBe(
|
||||
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
|
||||
'ProxyAgent',
|
||||
);
|
||||
});
|
||||
@@ -93,10 +93,9 @@ describe('getLLMConfig', () => {
|
||||
};
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions?.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions?.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
const defaultHeaders = clientOptions?.defaultHeaders as Record<string, string>;
|
||||
expect(defaultHeaders['anthropic-beta']).toBe(
|
||||
expect(clientOptions.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
@@ -112,10 +111,9 @@ describe('getLLMConfig', () => {
|
||||
const modelOptions = { model, promptCache: true };
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions?.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions?.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
const defaultHeaders = clientOptions?.defaultHeaders as Record<string, string>;
|
||||
expect(defaultHeaders['anthropic-beta']).toBe(
|
||||
expect(clientOptions.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
@@ -213,13 +211,13 @@ describe('getLLMConfig', () => {
|
||||
it('should handle empty modelOptions', () => {
|
||||
expect(() => {
|
||||
getLLMConfig('test-api-key', {});
|
||||
}).toThrow('No modelOptions provided');
|
||||
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
|
||||
});
|
||||
|
||||
it('should handle no options parameter', () => {
|
||||
expect(() => {
|
||||
getLLMConfig('test-api-key');
|
||||
}).toThrow('No modelOptions provided');
|
||||
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
|
||||
});
|
||||
|
||||
it('should handle temperature, stop sequences, and stream settings', () => {
|
||||
@@ -256,9 +254,9 @@ describe('getLLMConfig', () => {
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
|
||||
expect(result.llmConfig.clientOptions?.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions?.fetchOptions?.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions?.fetchOptions?.dispatcher.constructor.name).toBe(
|
||||
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
|
||||
'ProxyAgent',
|
||||
);
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com');
|
||||
@@ -274,7 +272,7 @@ describe('getLLMConfig', () => {
|
||||
});
|
||||
|
||||
// claude-3-5-sonnet supports prompt caching and should get the appropriate headers
|
||||
expect(result.llmConfig.clientOptions?.defaultHeaders).toEqual({
|
||||
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
|
||||
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31',
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,3 @@
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
@@ -8,6 +7,7 @@ const {
|
||||
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
|
||||
const getOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
if (!endpointOption) {
|
||||
|
||||
@@ -110,7 +110,7 @@ class STTService {
|
||||
*/
|
||||
async getProviderSchema(req) {
|
||||
const appConfig =
|
||||
req.config ??
|
||||
req?.config ??
|
||||
(await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
}));
|
||||
@@ -159,11 +159,9 @@ class STTService {
|
||||
* Prepares the request for the OpenAI STT provider.
|
||||
* @param {Object} sttSchema - The STT schema for OpenAI.
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed.
|
||||
* @param {Object} audioFile - The audio file object (unused in OpenAI provider).
|
||||
* @param {string} language - The language code for the transcription.
|
||||
* @returns {Array} An array containing the URL, data, and headers for the request.
|
||||
*/
|
||||
openAIProvider(sttSchema, audioReadStream, audioFile, language) {
|
||||
openAIProvider(sttSchema, audioReadStream) {
|
||||
const url = sttSchema?.url || 'https://api.openai.com/v1/audio/transcriptions';
|
||||
const apiKey = extractEnvVariable(sttSchema.apiKey) || '';
|
||||
|
||||
@@ -172,12 +170,6 @@ class STTService {
|
||||
model: sttSchema.model,
|
||||
};
|
||||
|
||||
if (language) {
|
||||
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
|
||||
const isoLanguage = language.split('-')[0];
|
||||
data.language = isoLanguage;
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
...(apiKey && { Authorization: `Bearer ${apiKey}` }),
|
||||
@@ -192,11 +184,10 @@ class STTService {
|
||||
* @param {Object} sttSchema - The STT schema for Azure OpenAI.
|
||||
* @param {Buffer} audioBuffer - The audio data to be transcribed.
|
||||
* @param {Object} audioFile - The audio file object containing originalname, mimetype, and size.
|
||||
* @param {string} language - The language code for the transcription.
|
||||
* @returns {Array} An array containing the URL, data, and headers for the request.
|
||||
* @throws {Error} If the audio file size exceeds 25MB or the audio file format is not accepted.
|
||||
*/
|
||||
azureOpenAIProvider(sttSchema, audioBuffer, audioFile, language) {
|
||||
azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
|
||||
const url = `${genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: extractEnvVariable(sttSchema?.instanceName),
|
||||
azureOpenAIApiDeploymentName: extractEnvVariable(sttSchema?.deploymentName),
|
||||
@@ -220,12 +211,6 @@ class STTService {
|
||||
contentType: audioFile.mimetype,
|
||||
});
|
||||
|
||||
if (language) {
|
||||
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
|
||||
const isoLanguage = language.split('-')[0];
|
||||
formData.append('language', isoLanguage);
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
...(apiKey && { 'api-key': apiKey }),
|
||||
@@ -244,11 +229,10 @@ class STTService {
|
||||
* @param {Object} requestData - The data required for the STT request.
|
||||
* @param {Buffer} requestData.audioBuffer - The audio data to be transcribed.
|
||||
* @param {Object} requestData.audioFile - The audio file object containing originalname, mimetype, and size.
|
||||
* @param {string} requestData.language - The language code for the transcription.
|
||||
* @returns {Promise<string>} A promise that resolves to the transcribed text.
|
||||
* @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing.
|
||||
*/
|
||||
async sttRequest(provider, sttSchema, { audioBuffer, audioFile, language }) {
|
||||
async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) {
|
||||
const strategy = this.providerStrategies[provider];
|
||||
if (!strategy) {
|
||||
throw new Error('Invalid provider');
|
||||
@@ -259,13 +243,7 @@ class STTService {
|
||||
const audioReadStream = Readable.from(audioBuffer);
|
||||
audioReadStream.path = `audio.${fileExtension}`;
|
||||
|
||||
const [url, data, headers] = strategy.call(
|
||||
this,
|
||||
sttSchema,
|
||||
audioReadStream,
|
||||
audioFile,
|
||||
language,
|
||||
);
|
||||
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
|
||||
|
||||
try {
|
||||
const response = await axios.post(url, data, { headers });
|
||||
@@ -306,8 +284,7 @@ class STTService {
|
||||
|
||||
try {
|
||||
const [provider, sttSchema] = await this.getProviderSchema(req);
|
||||
const language = req.body?.language || '';
|
||||
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile, language });
|
||||
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile });
|
||||
res.json({ text });
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while processing the audio:', error);
|
||||
|
||||
@@ -17,7 +17,7 @@ const { Files } = require('~/models');
|
||||
* @param {IUser} options.user - The user object
|
||||
* @param {AppConfig} options.appConfig - The app configuration object
|
||||
* @param {GraphRunnableConfig['configurable']} options.metadata - The metadata
|
||||
* @param {{ [Tools.file_search]: { sources: Object[]; fileCitations: boolean } }} options.toolArtifact - The tool artifact containing structured data
|
||||
* @param {any} options.toolArtifact - The tool artifact containing structured data
|
||||
* @param {string} options.toolCallId - The tool call ID
|
||||
* @returns {Promise<Object|null>} The file search attachment or null
|
||||
*/
|
||||
@@ -29,14 +29,12 @@ async function processFileCitations({ user, appConfig, toolArtifact, toolCallId,
|
||||
|
||||
if (user) {
|
||||
try {
|
||||
const hasFileCitationsAccess =
|
||||
toolArtifact?.[Tools.file_search]?.fileCitations ??
|
||||
(await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
}));
|
||||
const hasFileCitationsAccess = await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
if (!hasFileCitationsAccess) {
|
||||
logger.debug(
|
||||
|
||||
@@ -10,10 +10,9 @@ const { getAgent } = require('~/models/Agent');
|
||||
* @param {string} [params.role] - Optional user role to avoid DB query
|
||||
* @param {string[]} params.fileIds - Array of file IDs to check
|
||||
* @param {string} params.agentId - The agent ID that might grant access
|
||||
* @param {boolean} [params.isDelete] - Whether the operation is a delete operation
|
||||
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
|
||||
*/
|
||||
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
|
||||
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId }) => {
|
||||
const accessMap = new Map();
|
||||
|
||||
// Initialize all files as no access
|
||||
@@ -45,23 +44,22 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
if (isDelete) {
|
||||
// Check if user has EDIT permission (which would indicate collaborative access)
|
||||
const hasEditPermission = await checkPermission({
|
||||
userId,
|
||||
role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
// Check if user has EDIT permission (which would indicate collaborative access)
|
||||
const hasEditPermission = await checkPermission({
|
||||
userId,
|
||||
role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
// If user only has VIEW permission, they can't access files
|
||||
// Only users with EDIT permission or higher can access agent files
|
||||
if (!hasEditPermission) {
|
||||
return accessMap;
|
||||
}
|
||||
// If user only has VIEW permission, they can't access files
|
||||
// Only users with EDIT permission or higher can access agent files
|
||||
if (!hasEditPermission) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// User has edit permissions - check which files are actually attached
|
||||
const attachedFileIds = new Set();
|
||||
if (agent.tool_resources) {
|
||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
||||
|
||||
@@ -616,7 +616,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
|
||||
if (shouldUseSTT) {
|
||||
const sttService = await STTService.getInstance();
|
||||
const { text, bytes } = await processAudioFile({ req, file, sttService });
|
||||
const { text, bytes } = await processAudioFile({ file, sttService });
|
||||
return await createTextFile({ text, bytes });
|
||||
}
|
||||
|
||||
@@ -646,8 +646,8 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
basePath,
|
||||
entity_id,
|
||||
basePath,
|
||||
});
|
||||
|
||||
// SECOND: Upload to Vector DB
|
||||
@@ -670,18 +670,17 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
basePath,
|
||||
entity_id,
|
||||
basePath,
|
||||
});
|
||||
}
|
||||
|
||||
let { bytes, filename, filepath: _filepath, height, width } = storageResult;
|
||||
const { bytes, filename, filepath: _filepath, height, width } = storageResult;
|
||||
// For RAG files, use embedding result; for others, use storage result
|
||||
let embedded = storageResult.embedded;
|
||||
if (tool_resource === EToolResources.file_search) {
|
||||
embedded = embeddingResult?.embedded;
|
||||
filename = embeddingResult?.filename || filename;
|
||||
}
|
||||
const embedded =
|
||||
tool_resource === EToolResources.file_search
|
||||
? embeddingResult?.embedded
|
||||
: storageResult.embedded;
|
||||
|
||||
let filepath = _filepath;
|
||||
|
||||
@@ -930,7 +929,6 @@ async function saveBase64Image(
|
||||
url,
|
||||
{ req, file_id: _file_id, filename: _filename, endpoint, context, resolution },
|
||||
) {
|
||||
const appConfig = req.config;
|
||||
const effectiveResolution = resolution ?? appConfig.fileConfig?.imageGeneration ?? 'high';
|
||||
const file_id = _file_id ?? v4();
|
||||
let filename = `${file_id}-${_filename}`;
|
||||
@@ -945,6 +943,7 @@ async function saveBase64Image(
|
||||
}
|
||||
|
||||
const image = await resizeImageBuffer(inputBuffer, effectiveResolution, endpoint);
|
||||
const appConfig = req.config;
|
||||
const source = getFileStrategy(appConfig, { isImage: true });
|
||||
const { saveBuffer } = getStrategyFunctions(source);
|
||||
const filepath = await saveBuffer({
|
||||
|
||||
@@ -271,7 +271,6 @@ async function createMCPTool({
|
||||
availableTools: tools,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
const availableTools =
|
||||
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
|
||||
/** @type {LCTool | undefined} */
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
const axios = require('axios');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { logAxiosError, inputSchema, processModelData } = require('@librechat/api');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { inputSchema, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { OllamaClient } = require('~/app/clients/OllamaClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
|
||||
/**
|
||||
* Splits a string by commas and trims each resulting value.
|
||||
|
||||
@@ -11,8 +11,8 @@ const {
|
||||
getAnthropicModels,
|
||||
} = require('./ModelService');
|
||||
|
||||
jest.mock('@librechat/api', () => {
|
||||
const originalUtils = jest.requireActual('@librechat/api');
|
||||
jest.mock('~/utils', () => {
|
||||
const originalUtils = jest.requireActual('~/utils');
|
||||
return {
|
||||
...originalUtils,
|
||||
processModelData: jest.fn((...args) => {
|
||||
@@ -108,7 +108,7 @@ describe('fetchModels with createTokenConfig true', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
// Clears the mock's history before each test
|
||||
const _utils = require('@librechat/api');
|
||||
const _utils = require('~/utils');
|
||||
axios.get.mockResolvedValue({ data });
|
||||
});
|
||||
|
||||
@@ -120,7 +120,7 @@ describe('fetchModels with createTokenConfig true', () => {
|
||||
createTokenConfig: true,
|
||||
});
|
||||
|
||||
const { processModelData } = require('@librechat/api');
|
||||
const { processModelData } = require('~/utils');
|
||||
expect(processModelData).toHaveBeenCalled();
|
||||
expect(processModelData).toHaveBeenCalledWith(data);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const fs = require('fs').promises;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getImporter } = require('./importers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Job definition for importing a conversation.
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
const jwt = require('jsonwebtoken');
|
||||
const mongoose = require('mongoose');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Strategy: AppleStrategy } = require('passport-apple');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const socialLogin = require('./socialLogin');
|
||||
const { findUser } = require('~/models');
|
||||
const { User } = require('~/db/models');
|
||||
@@ -17,8 +17,6 @@ jest.mock('@librechat/data-schemas', () => {
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
@@ -26,19 +24,12 @@ jest.mock('./process', () => ({
|
||||
createSocialUser: jest.fn(),
|
||||
handleExistingUser: jest.fn(),
|
||||
}));
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isEnabled: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('Apple Login Strategy', () => {
|
||||
let mongoServer;
|
||||
@@ -297,14 +288,7 @@ describe('Apple Login Strategy', () => {
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(null, existingUser);
|
||||
expect(existingUser.avatarUrl).toBeNull(); // As per getProfileDetails
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(
|
||||
existingUser,
|
||||
null,
|
||||
expect.objectContaining({
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
);
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(existingUser, null);
|
||||
});
|
||||
|
||||
it('should handle missing idToken gracefully', async () => {
|
||||
|
||||
@@ -183,7 +183,7 @@ const getUserInfo = async (config, accessToken, sub) => {
|
||||
const exchangedAccessToken = await exchangeAccessTokenIfNeeded(config, accessToken, sub);
|
||||
return await client.fetchUserInfo(config, exchangedAccessToken, sub);
|
||||
} catch (error) {
|
||||
logger.error('[openidStrategy] getUserInfo: Error fetching user info:', error);
|
||||
logger.warn(`[openidStrategy] getUserInfo: Error fetching user info: ${error}`);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
@@ -398,7 +398,6 @@ async function setupOpenId() {
|
||||
);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
@@ -410,6 +409,7 @@ async function setupOpenId() {
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
@@ -438,9 +438,7 @@ async function setupOpenId() {
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
|
||||
@@ -3,6 +3,7 @@ const { FileSources } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { updateUser, createUser, getUserById } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
/**
|
||||
* Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter
|
||||
@@ -11,15 +12,14 @@ const { updateUser, createUser, getUserById } = require('~/models');
|
||||
*
|
||||
* @param {IUser} oldUser - The existing user object that needs to be updated.
|
||||
* @param {string} avatarUrl - The new avatar URL to be set for the user.
|
||||
* @param {AppConfig} appConfig - The application configuration object.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
* The function updates the user's avatar and saves the user object. It does not return any value.
|
||||
*
|
||||
* @throws {Error} Throws an error if there's an issue saving the updated user object.
|
||||
*/
|
||||
const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
|
||||
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
|
||||
const handleExistingUser = async (oldUser, avatarUrl) => {
|
||||
const fileStrategy = process.env.CDN_PROVIDER;
|
||||
const isLocal = fileStrategy === FileSources.local;
|
||||
|
||||
let updatedAvatar = false;
|
||||
@@ -56,7 +56,6 @@ const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
|
||||
* @param {string} params.providerId - The provider-specific ID of the user.
|
||||
* @param {string} params.username - The username of the new user.
|
||||
* @param {string} params.name - The name of the new user.
|
||||
* @param {AppConfig} appConfig - The application configuration object.
|
||||
* @param {boolean} [params.emailVerified=false] - Optional. Indicates whether the user's email is verified. Defaults to false.
|
||||
*
|
||||
* @returns {Promise<User>}
|
||||
@@ -72,7 +71,6 @@ const createSocialUser = async ({
|
||||
providerId,
|
||||
username,
|
||||
name,
|
||||
appConfig,
|
||||
emailVerified,
|
||||
}) => {
|
||||
const update = {
|
||||
@@ -85,9 +83,10 @@ const createSocialUser = async ({
|
||||
emailVerified,
|
||||
};
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const newUserId = await createUser(update, balanceConfig);
|
||||
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
|
||||
const fileStrategy = process.env.CDN_PROVIDER;
|
||||
const isLocal = fileStrategy === FileSources.local;
|
||||
|
||||
if (!isLocal) {
|
||||
|
||||
@@ -220,7 +220,6 @@ async function setupSaml() {
|
||||
getUserName(profile) || getGivenName(profile) || getEmail(profile),
|
||||
);
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'saml',
|
||||
@@ -230,6 +229,7 @@ async function setupSaml() {
|
||||
emailVerified: true,
|
||||
name: fullName,
|
||||
};
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
@@ -250,9 +250,7 @@ async function setupSaml() {
|
||||
fileName = profile.nameID + '.png';
|
||||
}
|
||||
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
|
||||
@@ -2,7 +2,6 @@ const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { findUser } = require('~/models');
|
||||
|
||||
const socialLogin =
|
||||
@@ -13,12 +12,11 @@ const socialLogin =
|
||||
profile,
|
||||
});
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
const existingUser = await findUser({ email: email.trim() });
|
||||
const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION);
|
||||
|
||||
if (existingUser?.provider === provider) {
|
||||
await handleExistingUser(existingUser, avatarUrl, appConfig);
|
||||
await handleExistingUser(existingUser, avatarUrl);
|
||||
return cb(null, existingUser);
|
||||
} else if (existingUser) {
|
||||
logger.info(
|
||||
@@ -40,7 +38,6 @@ const socialLogin =
|
||||
username,
|
||||
name,
|
||||
emailVerified,
|
||||
appConfig,
|
||||
});
|
||||
return cb(null, newUser);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const axios = require('axios');
|
||||
const deriveBaseURL = require('./deriveBaseURL');
|
||||
jest.mock('@librechat/api', () => {
|
||||
const originalUtils = jest.requireActual('@librechat/api');
|
||||
jest.mock('~/utils', () => {
|
||||
const originalUtils = jest.requireActual('~/utils');
|
||||
return {
|
||||
...originalUtils,
|
||||
processModelData: jest.fn((...args) => {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const tokenHelpers = require('./tokens');
|
||||
const deriveBaseURL = require('./deriveBaseURL');
|
||||
const extractBaseURL = require('./extractBaseURL');
|
||||
const findMessageContent = require('./findMessageContent');
|
||||
@@ -5,5 +6,6 @@ const findMessageContent = require('./findMessageContent');
|
||||
module.exports = {
|
||||
deriveBaseURL,
|
||||
extractBaseURL,
|
||||
...tokenHelpers,
|
||||
findMessageContent,
|
||||
};
|
||||
|
||||
@@ -1,23 +1,5 @@
|
||||
import z from 'zod';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
|
||||
/** Configuration object mapping model keys to their respective prompt, completion rates, and context limit
|
||||
*
|
||||
* Note: the [key: string]: unknown is not in the original JSDoc typedef in /api/typedefs.js, but I've included it since
|
||||
* getModelMaxOutputTokens calls getModelTokenValue with a key of 'output', which was not in the original JSDoc typedef,
|
||||
* but would be referenced in a TokenConfig in the if(matchedPattern) portion of getModelTokenValue.
|
||||
* So in order to preserve functionality for that case and any others which might reference an additional key I'm unaware of,
|
||||
* I've included it here until the interface can be typed more tightly.
|
||||
*/
|
||||
export interface TokenConfig {
|
||||
prompt: number;
|
||||
completion: number;
|
||||
context: number;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
/** An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit */
|
||||
export type EndpointTokenConfig = Record<string, TokenConfig>;
|
||||
const z = require('zod');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
|
||||
const openAIModels = {
|
||||
'o4-mini': 200000,
|
||||
@@ -260,7 +242,7 @@ const aggregateModels = {
|
||||
'gpt-oss-120b': 131000,
|
||||
};
|
||||
|
||||
export const maxTokensMap = {
|
||||
const maxTokensMap = {
|
||||
[EModelEndpoint.azureOpenAI]: openAIModels,
|
||||
[EModelEndpoint.openAI]: aggregateModels,
|
||||
[EModelEndpoint.agents]: aggregateModels,
|
||||
@@ -270,7 +252,7 @@ export const maxTokensMap = {
|
||||
[EModelEndpoint.bedrock]: bedrockModels,
|
||||
};
|
||||
|
||||
export const modelMaxOutputs = {
|
||||
const modelMaxOutputs = {
|
||||
o1: 32268, // -500 from max: 32,768
|
||||
'o1-mini': 65136, // -500 from max: 65,536
|
||||
'o1-preview': 32268, // -500 from max: 32,768
|
||||
@@ -279,7 +261,7 @@ export const modelMaxOutputs = {
|
||||
'gpt-5-nano': 128000,
|
||||
'gpt-oss-20b': 131000,
|
||||
'gpt-oss-120b': 131000,
|
||||
system_default: 32000,
|
||||
system_default: 1024,
|
||||
};
|
||||
|
||||
/** Outputs from https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-names */
|
||||
@@ -295,7 +277,7 @@ const anthropicMaxOutputs = {
|
||||
'claude-3-7-sonnet': 128000,
|
||||
};
|
||||
|
||||
export const maxOutputTokensMap = {
|
||||
const maxOutputTokensMap = {
|
||||
[EModelEndpoint.anthropic]: anthropicMaxOutputs,
|
||||
[EModelEndpoint.azureOpenAI]: modelMaxOutputs,
|
||||
[EModelEndpoint.openAI]: modelMaxOutputs,
|
||||
@@ -305,13 +287,10 @@ export const maxOutputTokensMap = {
|
||||
/**
|
||||
* Finds the first matching pattern in the tokens map.
|
||||
* @param {string} modelName
|
||||
* @param {Record<string, number> | EndpointTokenConfig} tokensMap
|
||||
* @param {Record<string, number>} tokensMap
|
||||
* @returns {string|null}
|
||||
*/
|
||||
export function findMatchingPattern(
|
||||
modelName: string,
|
||||
tokensMap: Record<string, number> | EndpointTokenConfig,
|
||||
): string | null {
|
||||
function findMatchingPattern(modelName, tokensMap) {
|
||||
const keys = Object.keys(tokensMap);
|
||||
for (let i = keys.length - 1; i >= 0; i--) {
|
||||
const modelKey = keys[i];
|
||||
@@ -326,79 +305,57 @@ export function findMatchingPattern(
|
||||
/**
|
||||
* Retrieves a token value for a given model name from a tokens map.
|
||||
*
|
||||
* @param modelName - The name of the model to look up.
|
||||
* @param tokensMap - The map of model names to token values.
|
||||
* @param [key='context'] - The key to look up in the tokens map.
|
||||
* @returns The token value for the given model or undefined if no match is found.
|
||||
* @param {string} modelName - The name of the model to look up.
|
||||
* @param {EndpointTokenConfig | Record<string, number>} tokensMap - The map of model names to token values.
|
||||
* @param {string} [key='context'] - The key to look up in the tokens map.
|
||||
* @returns {number|undefined} The token value for the given model or undefined if no match is found.
|
||||
*/
|
||||
export function getModelTokenValue(
|
||||
modelName: string,
|
||||
tokensMap?: EndpointTokenConfig | Record<string, number>,
|
||||
key = 'context' as keyof TokenConfig,
|
||||
): number | undefined {
|
||||
function getModelTokenValue(modelName, tokensMap, key = 'context') {
|
||||
if (typeof modelName !== 'string' || !tokensMap) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const value = tokensMap[modelName];
|
||||
if (typeof value === 'number') {
|
||||
return value;
|
||||
if (tokensMap[modelName]?.context) {
|
||||
return tokensMap[modelName].context;
|
||||
}
|
||||
|
||||
if (value?.context) {
|
||||
return value.context;
|
||||
if (tokensMap[modelName]) {
|
||||
return tokensMap[modelName];
|
||||
}
|
||||
|
||||
const matchedPattern = findMatchingPattern(modelName, tokensMap);
|
||||
|
||||
if (matchedPattern) {
|
||||
const result = tokensMap[matchedPattern];
|
||||
if (typeof result === 'number') {
|
||||
return result;
|
||||
}
|
||||
|
||||
const tokenValue = result?.[key];
|
||||
if (typeof tokenValue === 'number') {
|
||||
return tokenValue;
|
||||
}
|
||||
return tokensMap.system_default as number | undefined;
|
||||
return result?.[key] ?? result ?? tokensMap.system_default;
|
||||
}
|
||||
|
||||
return tokensMap.system_default as number | undefined;
|
||||
return tokensMap.system_default;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the maximum tokens for a given model name.
|
||||
*
|
||||
* @param modelName - The name of the model to look up.
|
||||
* @param endpoint - The endpoint (default is 'openAI').
|
||||
* @param [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
|
||||
* @returns The maximum tokens for the given model or undefined if no match is found.
|
||||
* @param {string} modelName - The name of the model to look up.
|
||||
* @param {string} endpoint - The endpoint (default is 'openAI').
|
||||
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
|
||||
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
|
||||
*/
|
||||
export function getModelMaxTokens(
|
||||
modelName: string,
|
||||
endpoint = EModelEndpoint.openAI,
|
||||
endpointTokenConfig?: EndpointTokenConfig,
|
||||
): number | undefined {
|
||||
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint as keyof typeof maxTokensMap];
|
||||
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
|
||||
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint];
|
||||
return getModelTokenValue(modelName, tokensMap);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the maximum output tokens for a given model name.
|
||||
*
|
||||
* @param modelName - The name of the model to look up.
|
||||
* @param endpoint - The endpoint (default is 'openAI').
|
||||
* @param [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
|
||||
* @returns The maximum output tokens for the given model or undefined if no match is found.
|
||||
* @param {string} modelName - The name of the model to look up.
|
||||
* @param {string} endpoint - The endpoint (default is 'openAI').
|
||||
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
|
||||
* @returns {number|undefined} The maximum output tokens for the given model or undefined if no match is found.
|
||||
*/
|
||||
export function getModelMaxOutputTokens(
|
||||
modelName: string,
|
||||
endpoint = EModelEndpoint.openAI,
|
||||
endpointTokenConfig?: EndpointTokenConfig,
|
||||
): number | undefined {
|
||||
const tokensMap =
|
||||
endpointTokenConfig ?? maxOutputTokensMap[endpoint as keyof typeof maxOutputTokensMap];
|
||||
function getModelMaxOutputTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
|
||||
const tokensMap = endpointTokenConfig ?? maxOutputTokensMap[endpoint];
|
||||
return getModelTokenValue(modelName, tokensMap, 'output');
|
||||
}
|
||||
|
||||
@@ -406,24 +363,21 @@ export function getModelMaxOutputTokens(
|
||||
* Retrieves the model name key for a given model name input. If the exact model name isn't found,
|
||||
* it searches for partial matches within the model name, checking keys in reverse order.
|
||||
*
|
||||
* @param modelName - The name of the model to look up.
|
||||
* @param endpoint - The endpoint (default is 'openAI').
|
||||
* @returns The model name key for the given model; returns input if no match is found and is string.
|
||||
* @param {string} modelName - The name of the model to look up.
|
||||
* @param {string} endpoint - The endpoint (default is 'openAI').
|
||||
* @returns {string|undefined} The model name key for the given model; returns input if no match is found and is string.
|
||||
*
|
||||
* @example
|
||||
* matchModelName('gpt-4-32k-0613'); // Returns 'gpt-4-32k-0613'
|
||||
* matchModelName('gpt-4-32k-unknown'); // Returns 'gpt-4-32k'
|
||||
* matchModelName('unknown-model'); // Returns undefined
|
||||
*/
|
||||
export function matchModelName(
|
||||
modelName: string,
|
||||
endpoint = EModelEndpoint.openAI,
|
||||
): string | undefined {
|
||||
function matchModelName(modelName, endpoint = EModelEndpoint.openAI) {
|
||||
if (typeof modelName !== 'string') {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const tokensMap: Record<string, number> = maxTokensMap[endpoint as keyof typeof maxTokensMap];
|
||||
const tokensMap = maxTokensMap[endpoint];
|
||||
if (!tokensMap) {
|
||||
return modelName;
|
||||
}
|
||||
@@ -436,7 +390,7 @@ export function matchModelName(
|
||||
return matchedPattern || modelName;
|
||||
}
|
||||
|
||||
export const modelSchema = z.object({
|
||||
const modelSchema = z.object({
|
||||
id: z.string(),
|
||||
pricing: z.object({
|
||||
prompt: z.string(),
|
||||
@@ -445,7 +399,7 @@ export const modelSchema = z.object({
|
||||
context_length: z.number(),
|
||||
});
|
||||
|
||||
export const inputSchema = z.object({
|
||||
const inputSchema = z.object({
|
||||
data: z.array(modelSchema),
|
||||
});
|
||||
|
||||
@@ -454,7 +408,7 @@ export const inputSchema = z.object({
|
||||
* @param {{ data: Array<z.infer<typeof modelSchema>> }} input The input object containing base URL and data fetched from the API.
|
||||
* @returns {EndpointTokenConfig} The processed model data.
|
||||
*/
|
||||
export function processModelData(input: z.infer<typeof inputSchema>): EndpointTokenConfig {
|
||||
function processModelData(input) {
|
||||
const validationResult = inputSchema.safeParse(input);
|
||||
if (!validationResult.success) {
|
||||
throw new Error('Invalid input data');
|
||||
@@ -462,7 +416,7 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
|
||||
const { data } = validationResult.data;
|
||||
|
||||
/** @type {EndpointTokenConfig} */
|
||||
const tokenConfig: EndpointTokenConfig = {};
|
||||
const tokenConfig = {};
|
||||
|
||||
for (const model of data) {
|
||||
const modelKey = model.id;
|
||||
@@ -485,7 +439,7 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
|
||||
return tokenConfig;
|
||||
}
|
||||
|
||||
export const tiktokenModels = new Set([
|
||||
const tiktokenModels = new Set([
|
||||
'text-davinci-003',
|
||||
'text-davinci-002',
|
||||
'text-davinci-001',
|
||||
@@ -523,3 +477,17 @@ export const tiktokenModels = new Set([
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-0301',
|
||||
]);
|
||||
|
||||
module.exports = {
|
||||
inputSchema,
|
||||
modelSchema,
|
||||
maxTokensMap,
|
||||
tiktokenModels,
|
||||
maxOutputTokensMap,
|
||||
matchModelName,
|
||||
processModelData,
|
||||
getModelMaxTokens,
|
||||
getModelTokenValue,
|
||||
findMatchingPattern,
|
||||
getModelMaxOutputTokens,
|
||||
};
|
||||
@@ -1,12 +1,12 @@
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
maxTokensMap,
|
||||
matchModelName,
|
||||
processModelData,
|
||||
getModelMaxTokens,
|
||||
maxOutputTokensMap,
|
||||
findMatchingPattern,
|
||||
} = require('@librechat/api');
|
||||
getModelMaxTokens,
|
||||
processModelData,
|
||||
matchModelName,
|
||||
maxTokensMap,
|
||||
} = require('./tokens');
|
||||
|
||||
describe('getModelMaxTokens', () => {
|
||||
test('should return correct tokens for exact match', () => {
|
||||
@@ -394,7 +394,7 @@ describe('getModelMaxTokens', () => {
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for GPT-5 models', () => {
|
||||
const { getModelMaxOutputTokens } = require('@librechat/api');
|
||||
const { getModelMaxOutputTokens } = require('./tokens');
|
||||
['gpt-5', 'gpt-5-mini', 'gpt-5-nano'].forEach((model) => {
|
||||
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
|
||||
@@ -407,7 +407,7 @@ describe('getModelMaxTokens', () => {
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for GPT-OSS models', () => {
|
||||
const { getModelMaxOutputTokens } = require('@librechat/api');
|
||||
const { getModelMaxOutputTokens } = require('./tokens');
|
||||
['gpt-oss-20b', 'gpt-oss-120b'].forEach((model) => {
|
||||
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
|
||||
|
||||
@@ -37,7 +37,6 @@
|
||||
"@headlessui/react": "^2.1.2",
|
||||
"@librechat/client": "*",
|
||||
"@marsidev/react-turnstile": "^1.1.0",
|
||||
"@mcp-ui/client": "^5.7.0",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.0.2",
|
||||
"@radix-ui/react-checkbox": "^1.0.3",
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import React, { createContext, useContext, useState } from 'react';
|
||||
import { Constants, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
|
||||
import type { AgentPanelContextType, MCPServerInfo } from '~/common';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery, useGetStartupConfig } from '~/data-provider';
|
||||
import { useLocalize, useGetAgentsConfig, useMCPConnectionStatus } from '~/hooks';
|
||||
import type { AgentPanelContextType } from '~/common';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider';
|
||||
import { useLocalize, useGetAgentsConfig } from '~/hooks';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
type GroupedToolType = AgentToolType & { tools?: AgentToolType[] };
|
||||
type GroupedToolsRecord = Record<string, GroupedToolType>;
|
||||
|
||||
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
|
||||
|
||||
export function useAgentPanelContext() {
|
||||
@@ -36,117 +33,67 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
enabled: !!agent_id,
|
||||
});
|
||||
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const mcpServerNames = useMemo(
|
||||
() => Object.keys(startupConfig?.mcpServers ?? {}),
|
||||
[startupConfig],
|
||||
);
|
||||
|
||||
const { connectionStatus } = useMCPConnectionStatus({
|
||||
enabled: !!agent_id && mcpServerNames.length > 0,
|
||||
});
|
||||
|
||||
const processedData = useMemo(() => {
|
||||
if (!pluginTools) {
|
||||
return {
|
||||
tools: [],
|
||||
groupedTools: {},
|
||||
mcpServersMap: new Map<string, MCPServerInfo>(),
|
||||
};
|
||||
}
|
||||
|
||||
const tools: AgentToolType[] = [];
|
||||
const groupedTools: GroupedToolsRecord = {};
|
||||
|
||||
const configuredServers = new Set(mcpServerNames);
|
||||
const mcpServersMap = new Map<string, MCPServerInfo>();
|
||||
|
||||
for (const pluginTool of pluginTools) {
|
||||
const tool: AgentToolType = {
|
||||
tool_id: pluginTool.pluginKey,
|
||||
metadata: pluginTool as TPlugin,
|
||||
};
|
||||
|
||||
tools.push(tool);
|
||||
const tools =
|
||||
pluginTools?.map((tool) => ({
|
||||
tool_id: tool.pluginKey,
|
||||
metadata: tool as TPlugin,
|
||||
agent_id: agent_id || '',
|
||||
})) || [];
|
||||
|
||||
const groupedTools = tools?.reduce(
|
||||
(acc, tool) => {
|
||||
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
|
||||
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
|
||||
|
||||
if (!mcpServersMap.has(serverName)) {
|
||||
const metadata = {
|
||||
name: serverName,
|
||||
pluginKey: serverName,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: pluginTool.icon || '',
|
||||
} as TPlugin;
|
||||
|
||||
mcpServersMap.set(serverName, {
|
||||
serverName,
|
||||
const groupKey = `${serverName.toLowerCase()}`;
|
||||
if (!acc[groupKey]) {
|
||||
acc[groupKey] = {
|
||||
tool_id: groupKey,
|
||||
metadata: {
|
||||
name: `${serverName}`,
|
||||
pluginKey: groupKey,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: tool.metadata.icon || '',
|
||||
} as TPlugin,
|
||||
agent_id: agent_id || '',
|
||||
tools: [],
|
||||
isConfigured: configuredServers.has(serverName),
|
||||
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
|
||||
metadata,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
mcpServersMap.get(serverName)!.tools.push(tool);
|
||||
} else {
|
||||
// Non-MCP tool
|
||||
groupedTools[tool.tool_id] = {
|
||||
acc[groupKey].tools?.push({
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
});
|
||||
} else {
|
||||
acc[tool.tool_id] = {
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
for (const mcpServerName of mcpServerNames) {
|
||||
if (mcpServersMap.has(mcpServerName)) {
|
||||
continue;
|
||||
}
|
||||
const metadata = {
|
||||
icon: '',
|
||||
name: mcpServerName,
|
||||
pluginKey: mcpServerName,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${mcpServerName}`,
|
||||
} as TPlugin;
|
||||
|
||||
mcpServersMap.set(mcpServerName, {
|
||||
tools: [],
|
||||
metadata,
|
||||
isConfigured: true,
|
||||
serverName: mcpServerName,
|
||||
isConnected: connectionStatus?.[mcpServerName]?.connectionState === 'connected',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
tools,
|
||||
groupedTools,
|
||||
mcpServersMap,
|
||||
};
|
||||
}, [pluginTools, localize, mcpServerNames, connectionStatus]);
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, AgentToolType & { tools?: AgentToolType[] }>,
|
||||
);
|
||||
|
||||
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
|
||||
|
||||
const value: AgentPanelContextType = {
|
||||
mcp,
|
||||
mcps,
|
||||
/** Query data for actions and tools */
|
||||
tools,
|
||||
action,
|
||||
setMcp,
|
||||
actions,
|
||||
setMcps,
|
||||
agent_id,
|
||||
setAction,
|
||||
pluginTools,
|
||||
activePanel,
|
||||
groupedTools,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
setCurrentAgentId,
|
||||
tools: processedData.tools,
|
||||
groupedTools: processedData.groupedTools,
|
||||
mcpServersMap: processedData.mcpServersMap,
|
||||
};
|
||||
|
||||
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;
|
||||
|
||||
@@ -216,14 +216,6 @@ export type AgentPanelProps = {
|
||||
agentsConfig?: t.TAgentsEndpoint | null;
|
||||
};
|
||||
|
||||
export interface MCPServerInfo {
|
||||
serverName: string;
|
||||
tools: t.AgentToolType[];
|
||||
isConfigured: boolean;
|
||||
isConnected: boolean;
|
||||
metadata: t.TPlugin;
|
||||
}
|
||||
|
||||
export type AgentPanelContextType = {
|
||||
action?: t.Action;
|
||||
actions?: t.Action[];
|
||||
@@ -233,17 +225,13 @@ export type AgentPanelContextType = {
|
||||
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
|
||||
setMcps: React.Dispatch<React.SetStateAction<t.MCP[] | undefined>>;
|
||||
groupedTools: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
|
||||
activePanel?: string;
|
||||
tools: t.AgentToolType[];
|
||||
pluginTools?: t.TPlugin[];
|
||||
activePanel?: string;
|
||||
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
|
||||
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
agent_id?: string;
|
||||
startupConfig?: t.TStartupConfig | null;
|
||||
agentsConfig?: t.TAgentsEndpoint | null;
|
||||
endpointsConfig?: t.TEndpointsConfig | null;
|
||||
/** Pre-computed MCP server information indexed by server key */
|
||||
mcpServersMap: Map<string, MCPServerInfo>;
|
||||
};
|
||||
|
||||
export type AgentModelPanelProps = {
|
||||
@@ -642,10 +630,3 @@ declare global {
|
||||
google_tag_manager?: unknown;
|
||||
}
|
||||
}
|
||||
|
||||
export type UIResource = {
|
||||
uri: string;
|
||||
mimeType: string;
|
||||
text: string;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React, { useState, useEffect, useContext } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { Turnstile } from '@marsidev/react-turnstile';
|
||||
import { ThemeContext, Spinner, Button, isDark } from '@librechat/client';
|
||||
import { ThemeContext, Spinner, Button } from '@librechat/client';
|
||||
import type { TLoginUser, TStartupConfig } from 'librechat-data-provider';
|
||||
import type { TAuthContext } from '~/common';
|
||||
import { useResendVerificationEmail, useGetStartupConfig } from '~/data-provider';
|
||||
@@ -28,7 +28,7 @@ const LoginForm: React.FC<TLoginFormProps> = ({ onSubmit, startupConfig, error,
|
||||
|
||||
const { data: config } = useGetStartupConfig();
|
||||
const useUsernameLogin = config?.ldap?.username;
|
||||
const validTheme = isDark(theme) ? 'dark' : 'light';
|
||||
const validTheme = theme === 'dark' ? 'dark' : 'light';
|
||||
const requireCaptcha = Boolean(startupConfig.turnstile?.siteKey);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useForm } from 'react-hook-form';
|
||||
import React, { useContext, useState } from 'react';
|
||||
import { Turnstile } from '@marsidev/react-turnstile';
|
||||
import { ThemeContext, Spinner, Button, isDark } from '@librechat/client';
|
||||
import { ThemeContext, Spinner, Button } from '@librechat/client';
|
||||
import { useNavigate, useOutletContext, useLocation } from 'react-router-dom';
|
||||
import { useRegisterUserMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TRegisterUser, TError } from 'librechat-data-provider';
|
||||
@@ -31,7 +31,7 @@ const Registration: React.FC = () => {
|
||||
const location = useLocation();
|
||||
const queryParams = new URLSearchParams(location.search);
|
||||
const token = queryParams.get('token');
|
||||
const validTheme = isDark(theme) ? 'dark' : 'light';
|
||||
const validTheme = theme === 'dark' ? 'dark' : 'light';
|
||||
|
||||
// only require captcha if we have a siteKey
|
||||
const requireCaptcha = Boolean(startupConfig?.turnstile?.siteKey);
|
||||
|
||||
@@ -253,7 +253,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
||||
handleSaveBadges={handleSaveBadges}
|
||||
setBadges={setBadges}
|
||||
/>
|
||||
<FileFormChat conversation={conversation} />
|
||||
<FileFormChat disableInputs={disableInputs} />
|
||||
{endpoint && (
|
||||
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
|
||||
<TextareaAutosize
|
||||
@@ -301,7 +301,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
||||
)}
|
||||
>
|
||||
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
|
||||
<AttachFileChat conversation={conversation} disableInputs={disableInputs} />
|
||||
<AttachFileChat disableInputs={disableInputs} />
|
||||
</div>
|
||||
<BadgeRow
|
||||
showEphemeralBadges={!isAgentsEndpoint(endpoint) && !isAssistantsEndpoint(endpoint)}
|
||||
|
||||
@@ -7,18 +7,14 @@ import {
|
||||
isAssistantsEndpoint,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig, TConversation } from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig } from 'librechat-data-provider';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import AttachFile from './AttachFile';
|
||||
|
||||
function AttachFileChat({
|
||||
disableInputs,
|
||||
conversation,
|
||||
}: {
|
||||
disableInputs: boolean;
|
||||
conversation: TConversation | null;
|
||||
}) {
|
||||
function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const { conversation } = useChatContext();
|
||||
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const { endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(endpoint), [endpoint]);
|
||||
|
||||
@@ -91,10 +91,6 @@ const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: Attach
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
onClick: () => {
|
||||
setToolResource(EToolResources.file_search);
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
[EToolResources.file_search]: true,
|
||||
}));
|
||||
onAction();
|
||||
},
|
||||
icon: <FileSearch className="icon-md" />,
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import { memo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TConversation } from 'librechat-data-provider';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useFileHandling } from '~/hooks';
|
||||
import FileRow from './FileRow';
|
||||
import store from '~/store';
|
||||
|
||||
function FileFormChat({ conversation }: { conversation: TConversation | null }) {
|
||||
const { files, setFiles, setFilesLoading } = useChatContext();
|
||||
function FileFormChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const { files, setFiles, conversation, setFilesLoading } = useChatContext();
|
||||
const { endpoint: _endpoint } = conversation ?? { endpoint: null };
|
||||
const { abortUpload } = useFileHandling();
|
||||
|
||||
|
||||
@@ -59,12 +59,10 @@ export default function FileRow({
|
||||
|
||||
useEffect(() => {
|
||||
if (files.length === 0) {
|
||||
setFilesLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (files.some((file) => file.progress < 1)) {
|
||||
setFilesLoading(true);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import React, { memo, useCallback } from 'react';
|
||||
import { MultiSelect, MCPIcon } from '@librechat/client';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { useMCPServerManager } from '~/hooks';
|
||||
|
||||
type MCPSelectProps = { conversationId?: string | null };
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronRight } from 'lucide-react';
|
||||
import { PinIcon, MCPIcon } from '@librechat/client';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useMCPServerManager } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface MCPSubMenuProps {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import debounce from 'lodash/debounce';
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import { EModelEndpoint, isAgentsEndpoint, isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import { isAgentsEndpoint, isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import type { Endpoint, SelectedValues } from '~/common';
|
||||
import {
|
||||
@@ -59,25 +59,7 @@ export function ModelSelectorProvider({ children, startupConfig }: ModelSelector
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { endpoint, model, spec, agent_id, assistant_id, newConversation } =
|
||||
useModelSelectorChatContext();
|
||||
const modelSpecs = useMemo(() => {
|
||||
const specs = startupConfig?.modelSpecs?.list ?? [];
|
||||
if (!agentsMap) {
|
||||
return specs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter modelSpecs to only include agents the user has access to.
|
||||
* Use agentsMap which already contains permission-filtered agents (consistent with other components).
|
||||
*/
|
||||
return specs.filter((spec) => {
|
||||
if (spec.preset?.endpoint === EModelEndpoint.agents && spec.preset?.agent_id) {
|
||||
return spec.preset.agent_id in agentsMap;
|
||||
}
|
||||
/** Keep non-agent modelSpecs */
|
||||
return true;
|
||||
});
|
||||
}, [startupConfig, agentsMap]);
|
||||
|
||||
const modelSpecs = useMemo(() => startupConfig?.modelSpecs?.list ?? [], [startupConfig]);
|
||||
const permissionLevel = useAgentDefaultPermissionLevel();
|
||||
const { data: agents = null } = useListAgentsQuery(
|
||||
{ requiredPermission: permissionLevel },
|
||||
|
||||
@@ -88,10 +88,6 @@ export default function ToolCall({
|
||||
const url = new URL(authURL);
|
||||
return url.hostname;
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
'client/src/components/Chat/Messages/Content/ToolCall.tsx - Failed to parse auth URL',
|
||||
e,
|
||||
);
|
||||
return '';
|
||||
}
|
||||
}, [auth]);
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import React from 'react';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { UIResourceRenderer } from '@mcp-ui/client';
|
||||
import UIResourceCarousel from './UIResourceCarousel';
|
||||
import type { UIResource } from '~/common';
|
||||
|
||||
function OptimizedCodeBlock({ text, maxHeight = 320 }: { text: string; maxHeight?: number }) {
|
||||
return (
|
||||
@@ -54,26 +51,6 @@ export default function ToolCallInfo({
|
||||
: localize('com_assistants_attempt_info');
|
||||
}
|
||||
|
||||
// Extract ui_resources from the output to display them in the UI
|
||||
let uiResources: UIResource[] = [];
|
||||
if (output?.includes('ui_resources')) {
|
||||
try {
|
||||
const parsedOutput = JSON.parse(output);
|
||||
const uiResourcesItem = parsedOutput.find(
|
||||
(contentItem) => contentItem.metadata?.type === 'ui_resources',
|
||||
);
|
||||
if (uiResourcesItem?.metadata?.data) {
|
||||
uiResources = uiResourcesItem.metadata.data;
|
||||
output = JSON.stringify(
|
||||
parsedOutput.filter((contentItem) => contentItem.metadata?.type !== 'ui_resources'),
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
// If JSON parsing fails, keep original output
|
||||
console.error('Failed to parse output:', error);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full p-2">
|
||||
<div style={{ opacity: 1 }}>
|
||||
@@ -89,26 +66,6 @@ export default function ToolCallInfo({
|
||||
<div>
|
||||
<OptimizedCodeBlock text={formatText(output)} maxHeight={250} />
|
||||
</div>
|
||||
{uiResources.length > 0 && (
|
||||
<div className="my-2 text-sm font-medium text-text-primary">
|
||||
{localize('com_ui_ui_resources')}
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
{uiResources.length > 1 && <UIResourceCarousel uiResources={uiResources} />}
|
||||
|
||||
{uiResources.length === 1 && (
|
||||
<UIResourceRenderer
|
||||
resource={uiResources[0]}
|
||||
onUIAction={async (result) => {
|
||||
console.log('Action:', result);
|
||||
}}
|
||||
htmlProps={{
|
||||
autoResizeIframe: { width: true, height: true },
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
import { UIResourceRenderer } from '@mcp-ui/client';
|
||||
import type { UIResource } from '~/common';
|
||||
import React, { useState } from 'react';
|
||||
|
||||
interface UIResourceCarouselProps {
|
||||
uiResources: UIResource[];
|
||||
}
|
||||
|
||||
const UIResourceCarousel: React.FC<UIResourceCarouselProps> = React.memo(({ uiResources }) => {
|
||||
const [showLeftArrow, setShowLeftArrow] = useState(false);
|
||||
const [showRightArrow, setShowRightArrow] = useState(true);
|
||||
const [isContainerHovered, setIsContainerHovered] = useState(false);
|
||||
const scrollContainerRef = React.useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleScroll = React.useCallback(() => {
|
||||
if (!scrollContainerRef.current) return;
|
||||
|
||||
const { scrollLeft, scrollWidth, clientWidth } = scrollContainerRef.current;
|
||||
setShowLeftArrow(scrollLeft > 0);
|
||||
setShowRightArrow(scrollLeft < scrollWidth - clientWidth - 10);
|
||||
}, []);
|
||||
|
||||
const scroll = React.useCallback((direction: 'left' | 'right') => {
|
||||
if (!scrollContainerRef.current) return;
|
||||
|
||||
const viewportWidth = scrollContainerRef.current.clientWidth;
|
||||
const scrollAmount = Math.floor(viewportWidth * 0.9);
|
||||
const currentScroll = scrollContainerRef.current.scrollLeft;
|
||||
const newScroll =
|
||||
direction === 'left' ? currentScroll - scrollAmount : currentScroll + scrollAmount;
|
||||
|
||||
scrollContainerRef.current.scrollTo({
|
||||
left: newScroll,
|
||||
behavior: 'smooth',
|
||||
});
|
||||
}, []);
|
||||
|
||||
React.useEffect(() => {
|
||||
const container = scrollContainerRef.current;
|
||||
if (container) {
|
||||
container.addEventListener('scroll', handleScroll);
|
||||
handleScroll();
|
||||
return () => container.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
}, [handleScroll]);
|
||||
|
||||
if (uiResources.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="relative mb-4 pt-3"
|
||||
onMouseEnter={() => setIsContainerHovered(true)}
|
||||
onMouseLeave={() => setIsContainerHovered(false)}
|
||||
>
|
||||
<div
|
||||
className={`pointer-events-none absolute left-0 top-0 z-10 h-full w-24 bg-gradient-to-r from-surface-primary to-transparent transition-opacity duration-500 ease-in-out ${
|
||||
showLeftArrow ? 'opacity-100' : 'opacity-0'
|
||||
}`}
|
||||
/>
|
||||
|
||||
<div
|
||||
className={`pointer-events-none absolute right-0 top-0 z-10 h-full w-24 bg-gradient-to-l from-surface-primary to-transparent transition-opacity duration-500 ease-in-out ${
|
||||
showRightArrow ? 'opacity-100' : 'opacity-0'
|
||||
}`}
|
||||
/>
|
||||
|
||||
{showLeftArrow && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => scroll('left')}
|
||||
className={`absolute left-2 top-1/2 z-20 -translate-y-1/2 rounded-xl bg-white p-2 text-gray-800 shadow-lg transition-all duration-200 hover:scale-110 hover:bg-gray-100 hover:shadow-xl active:scale-95 dark:bg-gray-200 dark:text-gray-800 dark:hover:bg-gray-300 ${
|
||||
isContainerHovered ? 'opacity-100' : 'pointer-events-none opacity-0'
|
||||
}`}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<svg className="h-5 w-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M15 19l-7-7 7-7"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
)}
|
||||
|
||||
<div
|
||||
ref={scrollContainerRef}
|
||||
className="hide-scrollbar flex gap-4 overflow-x-auto scroll-smooth"
|
||||
>
|
||||
{uiResources.map((uiResource, index) => {
|
||||
const height = 360;
|
||||
const width = 230;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className="flex-shrink-0 transform-gpu transition-all duration-300 ease-out animate-in fade-in-0 slide-in-from-bottom-5"
|
||||
style={{
|
||||
width: `${width}px`,
|
||||
minHeight: `${height}px`,
|
||||
animationDelay: `${index * 100}ms`,
|
||||
}}
|
||||
>
|
||||
<div className="flex h-full flex-col">
|
||||
<UIResourceRenderer
|
||||
resource={{
|
||||
uri: uiResource.uri,
|
||||
mimeType: uiResource.mimeType,
|
||||
text: uiResource.text,
|
||||
}}
|
||||
onUIAction={async (result) => {
|
||||
console.log('Action:', result);
|
||||
}}
|
||||
htmlProps={{
|
||||
autoResizeIframe: { width: true, height: true },
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{showRightArrow && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => scroll('right')}
|
||||
className={`absolute right-2 top-1/2 z-20 -translate-y-1/2 rounded-xl bg-white p-2 text-gray-800 shadow-lg transition-all duration-200 hover:scale-110 hover:bg-gray-100 hover:shadow-xl active:scale-95 dark:bg-gray-200 dark:text-gray-800 dark:hover:bg-gray-300 ${
|
||||
isContainerHovered ? 'opacity-100' : 'pointer-events-none opacity-0'
|
||||
}`}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<svg className="h-5 w-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 5l7 7-7 7" />
|
||||
</svg>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
export default UIResourceCarousel;
|
||||
@@ -1,273 +0,0 @@
|
||||
import React from 'react';
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import ToolCallInfo from '../ToolCallInfo';
|
||||
import { UIResourceRenderer } from '@mcp-ui/client';
|
||||
import UIResourceCarousel from '../UIResourceCarousel';
|
||||
|
||||
// Mock the dependencies
|
||||
jest.mock('~/hooks', () => ({
|
||||
useLocalize: () => (key: string, values?: any) => {
|
||||
const translations: Record<string, string> = {
|
||||
com_assistants_domain_info: `Used ${values?.[0]}`,
|
||||
com_assistants_function_use: `Used ${values?.[0]}`,
|
||||
com_assistants_action_attempt: `Attempted to use ${values?.[0]}`,
|
||||
com_assistants_attempt_info: 'Attempted to use function',
|
||||
com_ui_result: 'Result',
|
||||
com_ui_ui_resources: 'UI Resources',
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@mcp-ui/client', () => ({
|
||||
UIResourceRenderer: jest.fn(() => null),
|
||||
}));
|
||||
|
||||
jest.mock('../UIResourceCarousel', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(() => null),
|
||||
}));
|
||||
|
||||
// Add TextEncoder/TextDecoder polyfill for Jest environment
|
||||
import { TextEncoder, TextDecoder } from 'util';
|
||||
|
||||
if (typeof global.TextEncoder === 'undefined') {
|
||||
global.TextEncoder = TextEncoder as any;
|
||||
global.TextDecoder = TextDecoder as any;
|
||||
}
|
||||
|
||||
describe('ToolCallInfo', () => {
|
||||
const mockProps = {
|
||||
input: '{"test": "input"}',
|
||||
function_name: 'testFunction',
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('ui_resources extraction', () => {
|
||||
it('should extract single ui_resource from output', () => {
|
||||
const uiResource = {
|
||||
type: 'text',
|
||||
data: 'Test resource',
|
||||
};
|
||||
|
||||
const output = JSON.stringify([
|
||||
{ type: 'text', text: 'Regular output' },
|
||||
{
|
||||
metadata: {
|
||||
type: 'ui_resources',
|
||||
data: [uiResource],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
// Should render UIResourceRenderer for single resource
|
||||
expect(UIResourceRenderer).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
resource: uiResource,
|
||||
onUIAction: expect.any(Function),
|
||||
htmlProps: {
|
||||
autoResizeIframe: { width: true, height: true },
|
||||
},
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Should not render carousel for single resource
|
||||
expect(UIResourceCarousel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should extract multiple ui_resources from output', () => {
|
||||
const uiResources = [
|
||||
{ type: 'text', data: 'Resource 1' },
|
||||
{ type: 'text', data: 'Resource 2' },
|
||||
{ type: 'text', data: 'Resource 3' },
|
||||
];
|
||||
|
||||
const output = JSON.stringify([
|
||||
{ type: 'text', text: 'Regular output' },
|
||||
{
|
||||
metadata: {
|
||||
type: 'ui_resources',
|
||||
data: uiResources,
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
// Should render carousel for multiple resources
|
||||
expect(UIResourceCarousel).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
uiResources,
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Should not render individual UIResourceRenderer
|
||||
expect(UIResourceRenderer).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should filter out ui_resources from displayed output', () => {
|
||||
const regularContent = [
|
||||
{ type: 'text', text: 'Regular output 1' },
|
||||
{ type: 'text', text: 'Regular output 2' },
|
||||
];
|
||||
|
||||
const output = JSON.stringify([
|
||||
...regularContent,
|
||||
{
|
||||
metadata: {
|
||||
type: 'ui_resources',
|
||||
data: [{ type: 'text', data: 'UI Resource' }],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const { container } = render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
// Check that the displayed output doesn't contain ui_resources
|
||||
const codeBlocks = container.querySelectorAll('code');
|
||||
const outputCode = codeBlocks[1]?.textContent; // Second code block is the output
|
||||
|
||||
expect(outputCode).toContain('Regular output 1');
|
||||
expect(outputCode).toContain('Regular output 2');
|
||||
expect(outputCode).not.toContain('ui_resources');
|
||||
});
|
||||
|
||||
it('should handle output without ui_resources', () => {
|
||||
const output = JSON.stringify([{ type: 'text', text: 'Regular output' }]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
expect(UIResourceRenderer).not.toHaveBeenCalled();
|
||||
expect(UIResourceCarousel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle malformed ui_resources gracefully', () => {
|
||||
const output = JSON.stringify([
|
||||
{
|
||||
metadata: 'ui_resources', // metadata should be an object, not a string
|
||||
text: 'some text content',
|
||||
},
|
||||
]);
|
||||
|
||||
// Component should not throw error and should render without UI resources
|
||||
const { container } = render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
// Should render the component without crashing
|
||||
expect(container).toBeTruthy();
|
||||
|
||||
// UIResourceCarousel should not be called since the metadata structure is invalid
|
||||
expect(UIResourceCarousel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle ui_resources as plain text without breaking', () => {
|
||||
const outputWithTextOnly =
|
||||
'This output contains ui_resources as plain text but not as a proper structure';
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={outputWithTextOnly} />);
|
||||
|
||||
// Should render normally without errors
|
||||
expect(screen.getByText(`Used ${mockProps.function_name}`)).toBeInTheDocument();
|
||||
expect(screen.getByText('Result')).toBeInTheDocument();
|
||||
|
||||
// The output text should be displayed in a code block
|
||||
const codeBlocks = screen.getAllByText((content, element) => {
|
||||
return element?.tagName === 'CODE' && content.includes(outputWithTextOnly);
|
||||
});
|
||||
expect(codeBlocks.length).toBeGreaterThan(0);
|
||||
|
||||
// Should not render UI resources components
|
||||
expect(UIResourceRenderer).not.toHaveBeenCalled();
|
||||
expect(UIResourceCarousel).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('rendering logic', () => {
|
||||
it('should render UI Resources heading when ui_resources exist', () => {
|
||||
const output = JSON.stringify([
|
||||
{
|
||||
metadata: {
|
||||
type: 'ui_resources',
|
||||
data: [{ type: 'text', data: 'Test' }],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
expect(screen.getByText('UI Resources')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should not render UI Resources heading when no ui_resources', () => {
|
||||
const output = JSON.stringify([{ type: 'text', text: 'Regular output' }]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
expect(screen.queryByText('UI Resources')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('should pass correct props to UIResourceRenderer', () => {
|
||||
const uiResource = {
|
||||
type: 'form',
|
||||
data: { fields: [{ name: 'test', type: 'text' }] },
|
||||
};
|
||||
|
||||
const output = JSON.stringify([
|
||||
{
|
||||
metadata: {
|
||||
type: 'ui_resources',
|
||||
data: [uiResource],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
expect(UIResourceRenderer).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
resource: uiResource,
|
||||
onUIAction: expect.any(Function),
|
||||
htmlProps: {
|
||||
autoResizeIframe: { width: true, height: true },
|
||||
},
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should console.log when UIAction is triggered', async () => {
|
||||
const consoleSpy = jest.spyOn(console, 'log').mockImplementation();
|
||||
|
||||
const output = JSON.stringify([
|
||||
{
|
||||
metadata: {
|
||||
type: 'ui_resources',
|
||||
data: [{ type: 'text', data: 'Test' }],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
render(<ToolCallInfo {...mockProps} output={output} />);
|
||||
|
||||
const mockUIResourceRenderer = UIResourceRenderer as jest.MockedFunction<
|
||||
typeof UIResourceRenderer
|
||||
>;
|
||||
const onUIAction = mockUIResourceRenderer.mock.calls[0]?.[0]?.onUIAction;
|
||||
const testResult = { action: 'submit', data: { test: 'value' } };
|
||||
|
||||
if (onUIAction) {
|
||||
await onUIAction(testResult as any);
|
||||
}
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Action:', testResult);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,219 +0,0 @@
|
||||
import React from 'react';
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom';
|
||||
import UIResourceCarousel from '../UIResourceCarousel';
|
||||
import type { UIResource } from '~/common';
|
||||
|
||||
// Mock the UIResourceRenderer component
|
||||
jest.mock('@mcp-ui/client', () => ({
|
||||
UIResourceRenderer: ({ resource, onUIAction }: any) => (
|
||||
<div data-testid="ui-resource-renderer" onClick={() => onUIAction({ action: 'test' })}>
|
||||
{resource.text || 'UI Resource'}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
// Mock scrollTo
|
||||
const mockScrollTo = jest.fn();
|
||||
Object.defineProperty(HTMLElement.prototype, 'scrollTo', {
|
||||
configurable: true,
|
||||
value: mockScrollTo,
|
||||
});
|
||||
|
||||
describe('UIResourceCarousel', () => {
|
||||
const mockUIResources: UIResource[] = [
|
||||
{ uri: 'resource1', mimeType: 'text/html', text: 'Resource 1' },
|
||||
{ uri: 'resource2', mimeType: 'text/html', text: 'Resource 2' },
|
||||
{ uri: 'resource3', mimeType: 'text/html', text: 'Resource 3' },
|
||||
{ uri: 'resource4', mimeType: 'text/html', text: 'Resource 4' },
|
||||
{ uri: 'resource5', mimeType: 'text/html', text: 'Resource 5' },
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
// Reset scroll properties
|
||||
Object.defineProperty(HTMLElement.prototype, 'scrollLeft', {
|
||||
configurable: true,
|
||||
value: 0,
|
||||
});
|
||||
Object.defineProperty(HTMLElement.prototype, 'scrollWidth', {
|
||||
configurable: true,
|
||||
value: 1000,
|
||||
});
|
||||
Object.defineProperty(HTMLElement.prototype, 'clientWidth', {
|
||||
configurable: true,
|
||||
value: 500,
|
||||
});
|
||||
});
|
||||
|
||||
it('renders nothing when no resources provided', () => {
|
||||
const { container } = render(<UIResourceCarousel uiResources={[]} />);
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it('renders all UI resources', () => {
|
||||
render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const renderers = screen.getAllByTestId('ui-resource-renderer');
|
||||
expect(renderers).toHaveLength(5);
|
||||
expect(screen.getByText('Resource 1')).toBeInTheDocument();
|
||||
expect(screen.getByText('Resource 5')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows/hides navigation arrows on hover', async () => {
|
||||
const { container } = render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const carouselContainer = container.querySelector('.relative.mb-4.pt-3');
|
||||
|
||||
// Initially arrows should be hidden (opacity-0)
|
||||
const leftArrow = screen.queryByLabelText('Scroll left');
|
||||
const rightArrow = screen.queryByLabelText('Scroll right');
|
||||
|
||||
// Right arrow should exist but left should not (at start)
|
||||
expect(leftArrow).not.toBeInTheDocument();
|
||||
expect(rightArrow).toBeInTheDocument();
|
||||
expect(rightArrow).toHaveClass('opacity-0');
|
||||
|
||||
// Hover over container
|
||||
fireEvent.mouseEnter(carouselContainer!);
|
||||
await waitFor(() => {
|
||||
expect(rightArrow).toHaveClass('opacity-100');
|
||||
});
|
||||
|
||||
// Leave hover
|
||||
fireEvent.mouseLeave(carouselContainer!);
|
||||
await waitFor(() => {
|
||||
expect(rightArrow).toHaveClass('opacity-0');
|
||||
});
|
||||
});
|
||||
|
||||
it('handles scroll navigation', async () => {
|
||||
const { container } = render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const scrollContainer = container.querySelector('.hide-scrollbar');
|
||||
|
||||
// Simulate being scrolled to show left arrow
|
||||
Object.defineProperty(scrollContainer, 'scrollLeft', {
|
||||
configurable: true,
|
||||
value: 200,
|
||||
});
|
||||
|
||||
// Trigger scroll event
|
||||
fireEvent.scroll(scrollContainer!);
|
||||
|
||||
// Both arrows should now be visible
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText('Scroll left')).toBeInTheDocument();
|
||||
expect(screen.getByLabelText('Scroll right')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Hover to make arrows interactive
|
||||
const carouselContainer = container.querySelector('.relative.mb-4.pt-3');
|
||||
fireEvent.mouseEnter(carouselContainer!);
|
||||
|
||||
// Click right arrow
|
||||
fireEvent.click(screen.getByLabelText('Scroll right'));
|
||||
expect(mockScrollTo).toHaveBeenCalledWith({
|
||||
left: 650, // 200 + (500 * 0.9)
|
||||
behavior: 'smooth',
|
||||
});
|
||||
|
||||
// Click left arrow
|
||||
fireEvent.click(screen.getByLabelText('Scroll left'));
|
||||
expect(mockScrollTo).toHaveBeenCalledWith({
|
||||
left: -250, // 200 - (500 * 0.9)
|
||||
behavior: 'smooth',
|
||||
});
|
||||
});
|
||||
|
||||
it('hides right arrow when scrolled to end', async () => {
|
||||
const { container } = render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const scrollContainer = container.querySelector('.hide-scrollbar');
|
||||
|
||||
// Simulate scrolled to end
|
||||
Object.defineProperty(scrollContainer, 'scrollLeft', {
|
||||
configurable: true,
|
||||
value: 490, // scrollWidth - clientWidth - 10
|
||||
});
|
||||
|
||||
fireEvent.scroll(scrollContainer!);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText('Scroll left')).toBeInTheDocument();
|
||||
expect(screen.queryByLabelText('Scroll right')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it('handles UIResource actions', async () => {
|
||||
const consoleSpy = jest.spyOn(console, 'log').mockImplementation();
|
||||
render(<UIResourceCarousel uiResources={mockUIResources.slice(0, 1)} />);
|
||||
|
||||
const renderer = screen.getByTestId('ui-resource-renderer');
|
||||
fireEvent.click(renderer);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Action:', { action: 'test' });
|
||||
});
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('applies correct dimensions to resource containers', () => {
|
||||
render(<UIResourceCarousel uiResources={mockUIResources.slice(0, 2)} />);
|
||||
const containers = screen
|
||||
.getAllByTestId('ui-resource-renderer')
|
||||
.map((el) => el.parentElement?.parentElement);
|
||||
|
||||
containers.forEach((container, index) => {
|
||||
expect(container).toHaveStyle({
|
||||
width: '230px',
|
||||
minHeight: '360px',
|
||||
animationDelay: `${index * 100}ms`,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('shows correct gradient overlays based on scroll position', () => {
|
||||
const { container } = render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
|
||||
// At start, left gradient should be hidden, right should be visible
|
||||
const leftGradient = container.querySelector('.bg-gradient-to-r');
|
||||
const rightGradient = container.querySelector('.bg-gradient-to-l');
|
||||
|
||||
expect(leftGradient).toHaveClass('opacity-0');
|
||||
expect(rightGradient).toHaveClass('opacity-100');
|
||||
});
|
||||
|
||||
it('cleans up event listeners on unmount', () => {
|
||||
const { container, unmount } = render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const scrollContainer = container.querySelector('.hide-scrollbar');
|
||||
|
||||
const removeEventListenerSpy = jest.spyOn(scrollContainer!, 'removeEventListener');
|
||||
|
||||
unmount();
|
||||
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith('scroll', expect.any(Function));
|
||||
});
|
||||
|
||||
it('renders with animation delays for each resource', () => {
|
||||
render(<UIResourceCarousel uiResources={mockUIResources.slice(0, 3)} />);
|
||||
const resourceContainers = screen
|
||||
.getAllByTestId('ui-resource-renderer')
|
||||
.map((el) => el.parentElement?.parentElement);
|
||||
|
||||
resourceContainers.forEach((container, index) => {
|
||||
expect(container).toHaveStyle({
|
||||
animationDelay: `${index * 100}ms`,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('memoizes component properly', () => {
|
||||
const { rerender } = render(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const firstRender = screen.getAllByTestId('ui-resource-renderer');
|
||||
|
||||
// Re-render with same props
|
||||
rerender(<UIResourceCarousel uiResources={mockUIResources} />);
|
||||
const secondRender = screen.getAllByTestId('ui-resource-renderer');
|
||||
|
||||
// Component should not re-render with same props (React.memo)
|
||||
expect(firstRender.length).toBe(secondRender.length);
|
||||
});
|
||||
});
|
||||
@@ -16,6 +16,7 @@ interface CustomUserVarsSectionProps {
|
||||
onRevoke: () => void;
|
||||
isSubmitting?: boolean;
|
||||
}
|
||||
|
||||
interface AuthFieldProps {
|
||||
name: string;
|
||||
config: CustomUserVarConfig;
|
||||
@@ -68,7 +69,7 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
|
||||
? localize('com_ui_mcp_update_var', { 0: config.title })
|
||||
: localize('com_ui_mcp_enter_var', { 0: config.title })
|
||||
}
|
||||
className="w-full rounded border border-border-medium bg-transparent px-2 py-1 text-text-primary placeholder:text-text-secondary focus:outline-none sm:text-sm"
|
||||
className="w-full shadow-sm sm:text-sm"
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
@@ -78,22 +79,23 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
|
||||
}
|
||||
|
||||
export default function CustomUserVarsSection({
|
||||
serverName,
|
||||
fields,
|
||||
onSave,
|
||||
onRevoke,
|
||||
serverName,
|
||||
isSubmitting = false,
|
||||
}: CustomUserVarsSectionProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
// Fetch auth value flags for the server
|
||||
const { data: authValuesData } = useMCPAuthValuesQuery(serverName, {
|
||||
enabled: !!serverName,
|
||||
});
|
||||
|
||||
const {
|
||||
reset,
|
||||
control,
|
||||
handleSubmit,
|
||||
reset,
|
||||
formState: { errors },
|
||||
} = useForm<Record<string, string>>({
|
||||
defaultValues: useMemo(() => {
|
||||
@@ -138,20 +140,10 @@ export default function CustomUserVarsSection({
|
||||
</form>
|
||||
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
disabled={isSubmitting}
|
||||
onClick={handleRevokeClick}
|
||||
>
|
||||
<Button onClick={handleRevokeClick} variant="destructive" disabled={isSubmitting}>
|
||||
{localize('com_ui_revoke')}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="submit"
|
||||
disabled={isSubmitting}
|
||||
onClick={handleSubmit(onFormSubmit)}
|
||||
>
|
||||
<Button onClick={handleSubmit(onFormSubmit)} variant="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? localize('com_ui_saving') : localize('com_ui_save')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React from 'react';
|
||||
import { RefreshCw } from 'lucide-react';
|
||||
import { Button, Spinner } from '@librechat/client';
|
||||
import { useLocalize, useMCPServerManager, useMCPConnectionStatus } from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface ServerInitializationSectionProps {
|
||||
sidePanel?: boolean;
|
||||
@@ -21,15 +21,16 @@ export default function ServerInitializationSection({
|
||||
}: ServerInitializationSectionProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const { initializeServer, cancelOAuthFlow, isInitializing, isCancellable, getOAuthUrl } =
|
||||
useMCPServerManager({ conversationId });
|
||||
const {
|
||||
initializeServer,
|
||||
connectionStatus,
|
||||
cancelOAuthFlow,
|
||||
isInitializing,
|
||||
isCancellable,
|
||||
getOAuthUrl,
|
||||
} = useMCPServerManager({ conversationId });
|
||||
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const { connectionStatus } = useMCPConnectionStatus({
|
||||
enabled: !!startupConfig?.mcpServers && Object.keys(startupConfig.mcpServers).length > 0,
|
||||
});
|
||||
|
||||
const serverStatus = connectionStatus?.[serverName];
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
const isConnected = serverStatus?.connectionState === 'connected';
|
||||
const canCancel = isCancellable(serverName);
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
|
||||
@@ -11,9 +11,9 @@ import store from '~/store';
|
||||
|
||||
export default function FilterPrompts({ className = '' }: { className?: string }) {
|
||||
const localize = useLocalize();
|
||||
const { name, setName } = usePromptGroupsContext();
|
||||
const { setName } = usePromptGroupsContext();
|
||||
const { categories } = useCategories('h-4 w-4');
|
||||
const [displayName, setDisplayName] = useState(name || '');
|
||||
const [displayName, setDisplayName] = useState('');
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const [categoryFilter, setCategory] = useRecoilState(store.promptsCategory);
|
||||
|
||||
@@ -60,26 +60,13 @@ export default function FilterPrompts({ className = '' }: { className?: string }
|
||||
[setCategory],
|
||||
);
|
||||
|
||||
// Sync displayName with name prop when it changes externally
|
||||
useEffect(() => {
|
||||
setDisplayName(name || '');
|
||||
}, [name]);
|
||||
|
||||
useEffect(() => {
|
||||
if (displayName === '') {
|
||||
// Clear immediately when empty
|
||||
setName('');
|
||||
setIsSearching(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSearching(true);
|
||||
const timeout = setTimeout(() => {
|
||||
setIsSearching(false);
|
||||
setName(displayName); // Debounced setName call
|
||||
}, 500);
|
||||
return () => clearTimeout(timeout);
|
||||
}, [displayName, setName]);
|
||||
}, [displayName]);
|
||||
|
||||
return (
|
||||
<div className={cn('flex w-full gap-2 text-text-primary', className)}>
|
||||
@@ -97,6 +84,7 @@ export default function FilterPrompts({ className = '' }: { className?: string }
|
||||
value={displayName}
|
||||
onChange={(e) => {
|
||||
setDisplayName(e.target.value);
|
||||
setName(e.target.value);
|
||||
}}
|
||||
isSearching={isSearching}
|
||||
placeholder={localize('com_ui_filter_prompts_name')}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { useMemo } from 'react';
|
||||
import { useLocation } from 'react-router-dom';
|
||||
import { useMediaQuery } from '@librechat/client';
|
||||
import PanelNavigation from '~/components/Prompts/Groups/PanelNavigation';
|
||||
import ManagePrompts from '~/components/Prompts/ManagePrompts';
|
||||
import { usePromptGroupsContext } from '~/Providers';
|
||||
import List from '~/components/Prompts/Groups/List';
|
||||
import PanelNavigation from './PanelNavigation';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function GroupSidePanel({
|
||||
@@ -19,33 +19,38 @@ export default function GroupSidePanel({
|
||||
const location = useLocation();
|
||||
const isSmallerScreen = useMediaQuery('(max-width: 1024px)');
|
||||
const isChatRoute = useMemo(() => location.pathname?.startsWith('/c/'), [location.pathname]);
|
||||
|
||||
const { promptGroups, groupsQuery, nextPage, prevPage, hasNextPage, hasPreviousPage } =
|
||||
usePromptGroupsContext();
|
||||
const {
|
||||
nextPage,
|
||||
prevPage,
|
||||
isFetching,
|
||||
hasNextPage,
|
||||
groupsQuery,
|
||||
promptGroups,
|
||||
hasPreviousPage,
|
||||
} = usePromptGroupsContext();
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex h-full w-full flex-col gap-2 md:mr-2 md:w-auto md:min-w-72 lg:w-1/4 xl:w-1/4',
|
||||
'mr-2 flex h-auto w-auto min-w-72 flex-col gap-2 lg:w-1/4 xl:w-1/4',
|
||||
isDetailView === true && isSmallerScreen ? 'hidden' : '',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
<div className={cn('flex-grow overflow-y-auto', isChatRoute ? '' : 'px-2 md:px-0')}>
|
||||
<div className="flex-grow overflow-y-auto">
|
||||
<List groups={promptGroups} isChatRoute={isChatRoute} isLoading={!!groupsQuery.isLoading} />
|
||||
</div>
|
||||
<div className={cn(isChatRoute ? '' : 'px-2 pb-3 pt-2 md:px-0')}>
|
||||
<div className="flex items-center justify-between">
|
||||
{isChatRoute && <ManagePrompts className="select-none" />}
|
||||
<PanelNavigation
|
||||
onPrevious={prevPage}
|
||||
onNext={nextPage}
|
||||
nextPage={nextPage}
|
||||
prevPage={prevPage}
|
||||
isFetching={isFetching}
|
||||
hasNextPage={hasNextPage}
|
||||
hasPreviousPage={hasPreviousPage}
|
||||
isLoading={groupsQuery.isFetching}
|
||||
isChatRoute={isChatRoute}
|
||||
>
|
||||
{isChatRoute && <ManagePrompts className="select-none" />}
|
||||
</PanelNavigation>
|
||||
hasPreviousPage={hasPreviousPage}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -3,51 +3,42 @@ import { Button, ThemeSelector } from '@librechat/client';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function PanelNavigation({
|
||||
onPrevious,
|
||||
onNext,
|
||||
hasNextPage,
|
||||
prevPage,
|
||||
nextPage,
|
||||
hasPreviousPage,
|
||||
isLoading,
|
||||
hasNextPage,
|
||||
isFetching,
|
||||
isChatRoute,
|
||||
children,
|
||||
}: {
|
||||
onPrevious: () => void;
|
||||
onNext: () => void;
|
||||
prevPage: () => void;
|
||||
nextPage: () => void;
|
||||
hasNextPage: boolean;
|
||||
hasPreviousPage: boolean;
|
||||
isLoading?: boolean;
|
||||
isFetching: boolean;
|
||||
isChatRoute: boolean;
|
||||
children?: React.ReactNode;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex gap-2">
|
||||
{!isChatRoute && <ThemeSelector returnThemeOnly={true} />}
|
||||
{children}
|
||||
</div>
|
||||
<div className="flex items-center gap-2" role="navigation" aria-label="Pagination">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={onPrevious}
|
||||
disabled={!hasPreviousPage || isLoading}
|
||||
aria-label={localize('com_ui_prev')}
|
||||
>
|
||||
<>
|
||||
<div className="flex gap-2">{!isChatRoute && <ThemeSelector returnThemeOnly={true} />}</div>
|
||||
<div
|
||||
className="flex items-center justify-between gap-2"
|
||||
role="navigation"
|
||||
aria-label="Pagination"
|
||||
>
|
||||
<Button variant="outline" size="sm" onClick={() => prevPage()} disabled={!hasPreviousPage}>
|
||||
{localize('com_ui_prev')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={onNext}
|
||||
disabled={!hasNextPage || isLoading}
|
||||
aria-label={localize('com_ui_next')}
|
||||
onClick={() => nextPage()}
|
||||
disabled={!hasNextPage || isFetching}
|
||||
>
|
||||
{localize('com_ui_next')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ export default function PromptsAccordion() {
|
||||
return (
|
||||
<div className="flex h-full w-full flex-col">
|
||||
<PromptSidePanel className="mt-2 space-y-2 lg:w-full xl:w-full" {...groupsNav}>
|
||||
<FilterPrompts className="items-center justify-center" />
|
||||
<FilterPrompts setName={groupsNav.setName} className="items-center justify-center" />
|
||||
<div className="flex w-full flex-row items-center justify-end">
|
||||
<AutoSendPrompt className="text-xs dark:text-white" />
|
||||
</div>
|
||||
|
||||
@@ -39,7 +39,7 @@ export default function PromptsView() {
|
||||
<DashBreadcrumb />
|
||||
<div className="flex w-full flex-grow flex-row divide-x overflow-hidden dark:divide-gray-600">
|
||||
<GroupSidePanel isDetailView={isDetailView}>
|
||||
<div className="mt-1 flex flex-row items-center justify-between px-2 md:px-2">
|
||||
<div className="mx-2 mt-1 flex flex-row items-center justify-between">
|
||||
<FilterPrompts />
|
||||
</div>
|
||||
</GroupSidePanel>
|
||||
|
||||
@@ -12,23 +12,22 @@ import {
|
||||
getIconKey,
|
||||
cn,
|
||||
} from '~/utils';
|
||||
import { ToolSelectDialog, MCPToolSelectDialog } from '~/components/Tools';
|
||||
import useAgentCapabilities from '~/hooks/Agents/useAgentCapabilities';
|
||||
import { useFileMapContext, useAgentPanelContext } from '~/Providers';
|
||||
import useAgentCapabilities from '~/hooks/Agents/useAgentCapabilities';
|
||||
import AgentCategorySelector from './AgentCategorySelector';
|
||||
import Action from '~/components/SidePanel/Builder/Action';
|
||||
import { useLocalize, useVisibleTools } from '~/hooks';
|
||||
import { ToolSelectDialog } from '~/components/Tools';
|
||||
import { useGetAgentFiles } from '~/data-provider';
|
||||
import { icons } from '~/hooks/Endpoint/Icons';
|
||||
import Instructions from './Instructions';
|
||||
import AgentAvatar from './AgentAvatar';
|
||||
import FileContext from './FileContext';
|
||||
import SearchForm from './Search/Form';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import FileSearch from './FileSearch';
|
||||
import Artifacts from './Artifacts';
|
||||
import AgentTool from './AgentTool';
|
||||
import CodeForm from './Code/Form';
|
||||
import MCPTools from './MCPTools';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
const labelClass = 'mb-2 text-token-text-primary block font-medium';
|
||||
@@ -44,13 +43,10 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
const { showToast } = useToastContext();
|
||||
const methods = useFormContext<AgentForm>();
|
||||
const [showToolDialog, setShowToolDialog] = useState(false);
|
||||
const [showMCPToolDialog, setShowMCPToolDialog] = useState(false);
|
||||
const {
|
||||
actions,
|
||||
setAction,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
mcpServersMap,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
groupedTools: allTools,
|
||||
@@ -177,7 +173,19 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
Icon = icons[iconKey];
|
||||
}
|
||||
|
||||
const { toolIds, mcpServerNames } = useVisibleTools(tools, allTools, mcpServersMap);
|
||||
// Determine what to show
|
||||
const selectedToolIds = tools ?? [];
|
||||
const visibleToolIds = new Set(selectedToolIds);
|
||||
|
||||
// Check what group parent tools should be shown if any subtool is present
|
||||
Object.entries(allTools ?? {}).forEach(([toolId, toolObj]) => {
|
||||
if (toolObj.tools?.length) {
|
||||
// if any subtool of this group is selected, ensure group parent tool rendered
|
||||
if (toolObj.tools.some((st) => selectedToolIds.includes(st.tool_id))) {
|
||||
visibleToolIds.add(toolId);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -309,14 +317,6 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
{fileSearchEnabled && <FileSearch agent_id={agent_id} files={knowledge_files} />}
|
||||
</div>
|
||||
)}
|
||||
{/* MCP Section */}
|
||||
{startupConfig?.mcpServers != null && (
|
||||
<MCPTools
|
||||
agentId={agent_id}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setShowMCPToolDialog={setShowMCPToolDialog}
|
||||
/>
|
||||
)}
|
||||
{/* Agent Tools & Actions */}
|
||||
<div className="mb-4">
|
||||
<label className={labelClass}>
|
||||
@@ -326,8 +326,8 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
</label>
|
||||
<div>
|
||||
<div className="mb-1">
|
||||
{/* Render all visible IDs (including groups with subtools selected) */}
|
||||
{toolIds.map((toolId, i) => {
|
||||
{/* // Render all visible IDs (including groups with subtools selected) */}
|
||||
{[...visibleToolIds].map((toolId, i) => {
|
||||
if (!allTools) return null;
|
||||
const tool = allTools[toolId];
|
||||
if (!tool) return null;
|
||||
@@ -384,6 +384,9 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* MCP Section */}
|
||||
{/* <MCPSection /> */}
|
||||
|
||||
{/* Support Contact (Optional) */}
|
||||
<div className="mb-4">
|
||||
<div className="mb-1.5 flex items-center gap-2">
|
||||
@@ -474,13 +477,6 @@ export default function AgentConfig({ createMutation }: Pick<AgentPanelProps, 'c
|
||||
setIsOpen={setShowToolDialog}
|
||||
endpoint={EModelEndpoint.agents}
|
||||
/>
|
||||
<MCPToolSelectDialog
|
||||
agentId={agent_id}
|
||||
isOpen={showMCPToolDialog}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsOpen={setShowMCPToolDialog}
|
||||
endpoint={EModelEndpoint.agents}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
Tools,
|
||||
Constants,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
EModelEndpoint,
|
||||
PermissionBits,
|
||||
isAssistantsEndpoint,
|
||||
@@ -54,7 +53,7 @@ export default function AgentPanel() {
|
||||
});
|
||||
|
||||
const { hasPermission, isLoading: permissionsLoading } = useResourcePermissions(
|
||||
ResourceType.AGENT,
|
||||
'agent',
|
||||
basicAgentQuery.data?._id || '',
|
||||
);
|
||||
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
import React, { useState } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import * as AccordionPrimitive from '@radix-ui/react-accordion';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Label,
|
||||
Checkbox,
|
||||
OGDialog,
|
||||
Accordion,
|
||||
TrashIcon,
|
||||
AccordionItem,
|
||||
CircleHelpIcon,
|
||||
OGDialogTrigger,
|
||||
useToastContext,
|
||||
AccordionContent,
|
||||
OGDialogTemplate,
|
||||
} from '@librechat/client';
|
||||
import type { AgentForm, MCPServerInfo } from '~/common';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useLocalize, useMCPServerManager } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function MCPTool({ serverInfo }: { serverInfo?: MCPServerInfo }) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { getServerStatusIconProps, getConfigDialogProps } = useMCPServerManager();
|
||||
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
const [accordionValue, setAccordionValue] = useState<string>('');
|
||||
const [hoveredToolId, setHoveredToolId] = useState<string | null>(null);
|
||||
|
||||
if (!serverInfo) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const currentServerName = serverInfo.serverName;
|
||||
|
||||
const getSelectedTools = () => {
|
||||
if (!serverInfo?.tools) return [];
|
||||
const formTools = getValues('tools') || [];
|
||||
return serverInfo.tools.filter((t) => formTools.includes(t.tool_id)).map((t) => t.tool_id);
|
||||
};
|
||||
|
||||
const updateFormTools = (newSelectedTools: string[]) => {
|
||||
const currentTools = getValues('tools') || [];
|
||||
const otherTools = currentTools.filter(
|
||||
(t: string) => !serverInfo?.tools?.some((st) => st.tool_id === t),
|
||||
);
|
||||
setValue('tools', [...otherTools, ...newSelectedTools]);
|
||||
};
|
||||
|
||||
const removeTool = (serverName: string) => {
|
||||
if (!serverName) {
|
||||
return;
|
||||
}
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({ message: `Error while deleting the tool: ${error}`, status: 'error' });
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds);
|
||||
showToast({ message: 'Tool deleted successfully', status: 'success' });
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const selectedTools = getSelectedTools();
|
||||
const isExpanded = accordionValue === currentServerName;
|
||||
|
||||
const statusIconProps = getServerStatusIconProps(currentServerName);
|
||||
const configDialogProps = getConfigDialogProps();
|
||||
|
||||
const statusIcon = statusIconProps && (
|
||||
<div
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
}}
|
||||
className="cursor-pointer rounded p-0.5 hover:bg-surface-secondary"
|
||||
>
|
||||
<MCPServerStatusIcon {...statusIconProps} />
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<OGDialog>
|
||||
<Accordion type="single" value={accordionValue} onValueChange={setAccordionValue} collapsible>
|
||||
<AccordionItem value={currentServerName} className="group relative w-full border-none">
|
||||
<div
|
||||
className="relative flex w-full items-center gap-1 rounded-lg p-1 hover:bg-surface-primary-alt"
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
onBlur={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget)) {
|
||||
setIsFocused(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<AccordionPrimitive.Header asChild>
|
||||
<div
|
||||
className="flex grow cursor-pointer select-none items-center gap-1 rounded bg-transparent p-0 text-left transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-1"
|
||||
onClick={() =>
|
||||
setAccordionValue((prev) => {
|
||||
if (prev) {
|
||||
return '';
|
||||
}
|
||||
return currentServerName;
|
||||
})
|
||||
}
|
||||
>
|
||||
{statusIcon && <div className="flex items-center">{statusIcon}</div>}
|
||||
|
||||
{serverInfo.metadata.icon && (
|
||||
<div className="flex h-8 w-8 items-center justify-center overflow-hidden rounded-full">
|
||||
<div
|
||||
className="flex h-6 w-6 items-center justify-center overflow-hidden rounded-full bg-center bg-no-repeat dark:bg-white/20"
|
||||
style={{
|
||||
backgroundImage: `url(${serverInfo.metadata.icon})`,
|
||||
backgroundSize: 'cover',
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className="grow px-2 py-1.5"
|
||||
style={{ textOverflow: 'ellipsis', wordBreak: 'break-all', overflow: 'hidden' }}
|
||||
>
|
||||
{currentServerName}
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex items-center">
|
||||
<div
|
||||
className={cn(
|
||||
'absolute right-0 transition-all duration-300',
|
||||
isHovering || isFocused
|
||||
? 'translate-x-0 opacity-100'
|
||||
: 'translate-x-8 opacity-0',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
data-checkbox-container
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="mt-1"
|
||||
>
|
||||
<Checkbox
|
||||
id={`select-all-${currentServerName}`}
|
||||
checked={
|
||||
selectedTools.length === serverInfo.tools?.length &&
|
||||
selectedTools.length > 0
|
||||
}
|
||||
onCheckedChange={(checked) => {
|
||||
if (serverInfo.tools) {
|
||||
const newSelectedTools = checked
|
||||
? serverInfo.tools.map((t) => t.tool_id)
|
||||
: [
|
||||
`${Constants.mcp_server}${Constants.mcp_delimiter}${currentServerName}`,
|
||||
];
|
||||
updateFormTools(newSelectedTools);
|
||||
}
|
||||
}}
|
||||
className={cn(
|
||||
'h-4 w-4 rounded border border-border-medium transition-all duration-200 hover:border-border-heavy',
|
||||
isExpanded ? 'visible' : 'pointer-events-none invisible',
|
||||
)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
const checkbox = e.currentTarget as HTMLButtonElement;
|
||||
checkbox.click();
|
||||
}
|
||||
}}
|
||||
tabIndex={isExpanded ? 0 : -1}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-1">
|
||||
{/* Caret button for accordion */}
|
||||
<AccordionPrimitive.Trigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
}}
|
||||
className={cn(
|
||||
'flex h-7 w-7 items-center justify-center rounded transition-colors duration-200 hover:bg-surface-active-alt focus:translate-x-0 focus:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-1',
|
||||
isExpanded && 'bg-surface-active-alt',
|
||||
)}
|
||||
aria-hidden="true"
|
||||
tabIndex={0}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
>
|
||||
<ChevronDown
|
||||
className={cn(
|
||||
'h-4 w-4 transition-transform duration-200',
|
||||
isExpanded && 'rotate-180',
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
</AccordionPrimitive.Trigger>
|
||||
|
||||
<OGDialogTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
'flex h-7 w-7 items-center justify-center rounded transition-colors duration-200',
|
||||
'hover:bg-surface-active-alt focus:translate-x-0 focus:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-1',
|
||||
)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
aria-label={`Delete ${currentServerName}`}
|
||||
tabIndex={0}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
>
|
||||
<TrashIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</OGDialogTrigger>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</AccordionPrimitive.Header>
|
||||
</div>
|
||||
|
||||
<AccordionContent className="relative ml-1 pt-1 before:absolute before:bottom-2 before:left-0 before:top-0 before:w-0.5 before:bg-border-medium">
|
||||
<div className="space-y-1">
|
||||
{serverInfo.tools?.map((subTool) => (
|
||||
<label
|
||||
key={subTool.tool_id}
|
||||
htmlFor={subTool.tool_id}
|
||||
className={cn(
|
||||
'border-token-border-light hover:bg-token-surface-secondary flex cursor-pointer items-center rounded-lg border p-2',
|
||||
'ml-2 mr-1 focus-within:ring-2 focus-within:ring-ring focus-within:ring-offset-2 focus-within:ring-offset-background',
|
||||
)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
onKeyDown={(e) => {
|
||||
e.stopPropagation();
|
||||
}}
|
||||
onMouseEnter={() => setHoveredToolId(subTool.tool_id)}
|
||||
onMouseLeave={() => setHoveredToolId(null)}
|
||||
>
|
||||
<Checkbox
|
||||
id={subTool.tool_id}
|
||||
checked={selectedTools.includes(subTool.tool_id)}
|
||||
onCheckedChange={(_checked) => {
|
||||
const newSelectedTools = selectedTools.includes(subTool.tool_id)
|
||||
? selectedTools.filter((t) => t !== subTool.tool_id)
|
||||
: [...selectedTools, subTool.tool_id];
|
||||
updateFormTools(newSelectedTools);
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
e.stopPropagation();
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
const checkbox = e.currentTarget as HTMLButtonElement;
|
||||
checkbox.click();
|
||||
}
|
||||
}}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className={cn(
|
||||
'relative float-left mr-2 inline-flex h-4 w-4 cursor-pointer rounded border border-border-medium transition-[border-color] duration-200 hover:border-border-heavy focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 focus:ring-offset-background',
|
||||
)}
|
||||
/>
|
||||
<span className="text-token-text-primary select-none">
|
||||
{subTool.metadata.name}
|
||||
</span>
|
||||
{subTool.metadata.description && (
|
||||
<Ariakit.HovercardProvider placement="left-start">
|
||||
<div className="ml-auto flex h-6 w-6 items-center justify-center">
|
||||
<Ariakit.HovercardAnchor
|
||||
render={
|
||||
<Ariakit.Button
|
||||
className={cn(
|
||||
'flex h-5 w-5 cursor-help items-center rounded-full text-text-secondary transition-opacity duration-200',
|
||||
hoveredToolId === subTool.tool_id ? 'opacity-100' : 'opacity-0',
|
||||
)}
|
||||
aria-label={localize('com_ui_tool_info')}
|
||||
>
|
||||
<CircleHelpIcon className="h-4 w-4" />
|
||||
<Ariakit.VisuallyHidden>
|
||||
{localize('com_ui_tool_info')}
|
||||
</Ariakit.VisuallyHidden>
|
||||
</Ariakit.Button>
|
||||
}
|
||||
/>
|
||||
<Ariakit.HovercardDisclosure
|
||||
className="rounded-full text-text-secondary focus:outline-none focus:ring-2 focus:ring-ring"
|
||||
aria-label={localize('com_ui_tool_more_info')}
|
||||
aria-expanded={hoveredToolId === subTool.tool_id}
|
||||
aria-controls={`tool-description-${subTool.tool_id}`}
|
||||
>
|
||||
<Ariakit.VisuallyHidden>
|
||||
{localize('com_ui_tool_more_info')}
|
||||
</Ariakit.VisuallyHidden>
|
||||
<ChevronDown className="h-4 w-4" />
|
||||
</Ariakit.HovercardDisclosure>
|
||||
</div>
|
||||
<Ariakit.Hovercard
|
||||
id={`tool-description-${subTool.tool_id}`}
|
||||
gutter={14}
|
||||
shift={40}
|
||||
flip={false}
|
||||
className="z-[999] w-80 scale-95 rounded-2xl border border-border-medium bg-surface-secondary p-4 text-text-primary opacity-0 shadow-md transition-all duration-200 data-[enter]:scale-100 data-[leave]:scale-95 data-[enter]:opacity-100 data-[leave]:opacity-0"
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
role="tooltip"
|
||||
aria-label={subTool.metadata.description}
|
||||
>
|
||||
<div className="space-y-2">
|
||||
<p className="text-sm text-text-secondary">
|
||||
{subTool.metadata.description}
|
||||
</p>
|
||||
</div>
|
||||
</Ariakit.Hovercard>
|
||||
</Ariakit.HovercardProvider>
|
||||
)}
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
</Accordion>
|
||||
<OGDialogTemplate
|
||||
showCloseButton={false}
|
||||
title={localize('com_ui_delete_tool')}
|
||||
mainClassName="px-0"
|
||||
className="max-w-[450px]"
|
||||
main={
|
||||
<Label className="text-left text-sm font-medium">
|
||||
{localize('com_ui_delete_tool_confirm')}
|
||||
</Label>
|
||||
}
|
||||
selection={{
|
||||
selectHandler: () => removeTool(currentServerName),
|
||||
selectClasses:
|
||||
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white',
|
||||
selectText: localize('com_ui_delete'),
|
||||
}}
|
||||
/>
|
||||
{configDialogProps && <MCPConfigDialog {...configDialogProps} />}
|
||||
</OGDialog>
|
||||
);
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
import React from 'react';
|
||||
import UninitializedMCPTool from './UninitializedMCPTool';
|
||||
import UnconfiguredMCPTool from './UnconfiguredMCPTool';
|
||||
import { useAgentPanelContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import MCPTool from './MCPTool';
|
||||
|
||||
export default function MCPTools({
|
||||
agentId,
|
||||
mcpServerNames,
|
||||
setShowMCPToolDialog,
|
||||
}: {
|
||||
agentId: string;
|
||||
mcpServerNames?: string[];
|
||||
setShowMCPToolDialog: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { mcpServersMap } = useAgentPanelContext();
|
||||
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<label className="text-token-text-primary mb-2 block font-medium">
|
||||
{localize('com_ui_mcp_servers')}
|
||||
</label>
|
||||
<div>
|
||||
<div className="mb-1">
|
||||
{/* Render servers with selected tools */}
|
||||
{mcpServerNames?.map((mcpServerName) => {
|
||||
const serverInfo = mcpServersMap.get(mcpServerName);
|
||||
if (!serverInfo?.isConfigured) {
|
||||
return (
|
||||
<UnconfiguredMCPTool
|
||||
key={`${mcpServerName}-${agentId}`}
|
||||
serverName={mcpServerName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (!serverInfo) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (serverInfo.isConnected) {
|
||||
return (
|
||||
<MCPTool key={`${serverInfo.serverName}-${agentId}`} serverInfo={serverInfo} />
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<UninitializedMCPTool
|
||||
key={`${serverInfo.serverName}-${agentId}`}
|
||||
serverInfo={serverInfo}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowMCPToolDialog(true)}
|
||||
className="btn btn-neutral border-token-border-light relative h-9 w-full rounded-lg font-medium"
|
||||
aria-haspopup="dialog"
|
||||
>
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
{localize('com_assistants_add_mcp_server_tools')}
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
import React, { useState } from 'react';
|
||||
import { CircleX } from 'lucide-react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Label,
|
||||
OGDialog,
|
||||
TrashIcon,
|
||||
useToastContext,
|
||||
OGDialogTrigger,
|
||||
OGDialogTemplate,
|
||||
} from '@librechat/client';
|
||||
import type { AgentForm } from '~/common';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function UnconfiguredMCPTool({ serverName }: { serverName?: string }) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
if (!serverName) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const removeTool = () => {
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({
|
||||
message: localize('com_ui_delete_tool_error', { error: String(error) }),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds);
|
||||
showToast({ message: localize('com_ui_delete_tool_success'), status: 'success' });
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog>
|
||||
<div
|
||||
className="group relative flex w-full items-center gap-1 rounded-lg p-1 text-sm hover:bg-surface-primary-alt"
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
onBlur={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget)) {
|
||||
setIsFocused(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center">
|
||||
<div className="flex h-6 w-6 items-center justify-center rounded p-1">
|
||||
<CircleX className="h-4 w-4 text-red-500" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex grow cursor-not-allowed items-center gap-1 rounded bg-transparent p-0 text-left transition-colors">
|
||||
<div
|
||||
className="grow select-none px-2 py-1.5"
|
||||
style={{ textOverflow: 'ellipsis', wordBreak: 'break-all', overflow: 'hidden' }}
|
||||
>
|
||||
{serverName}
|
||||
<span className="ml-2 text-xs text-text-secondary">
|
||||
{' - '}
|
||||
{localize('com_ui_unavailable')}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<OGDialogTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
'flex h-7 w-7 items-center justify-center rounded transition-all duration-200 hover:bg-surface-active-alt focus:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-1',
|
||||
isHovering || isFocused ? 'opacity-100' : 'pointer-events-none opacity-0',
|
||||
)}
|
||||
aria-label={`Delete ${serverName}`}
|
||||
tabIndex={0}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
>
|
||||
<TrashIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</OGDialogTrigger>
|
||||
</div>
|
||||
<OGDialogTemplate
|
||||
showCloseButton={false}
|
||||
title={localize('com_ui_delete_tool')}
|
||||
mainClassName="px-0"
|
||||
className="max-w-[450px]"
|
||||
main={
|
||||
<Label className="text-left text-sm font-medium">
|
||||
{localize('com_ui_delete_tool_confirm')}
|
||||
</Label>
|
||||
}
|
||||
selection={{
|
||||
selectHandler: () => removeTool(),
|
||||
selectClasses:
|
||||
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white',
|
||||
selectText: localize('com_ui_delete'),
|
||||
}}
|
||||
/>
|
||||
</OGDialog>
|
||||
);
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
import React, { useState } from 'react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Label,
|
||||
OGDialog,
|
||||
TrashIcon,
|
||||
OGDialogTrigger,
|
||||
useToastContext,
|
||||
OGDialogTemplate,
|
||||
} from '@librechat/client';
|
||||
import type { AgentForm, MCPServerInfo } from '~/common';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import { useLocalize, useMCPServerManager } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function UninitializedMCPTool({ serverInfo }: { serverInfo?: MCPServerInfo }) {
|
||||
const [isFocused, setIsFocused] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { initializeServer, isInitializing, getServerStatusIconProps, getConfigDialogProps } =
|
||||
useMCPServerManager();
|
||||
|
||||
if (!serverInfo) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const removeTool = (serverName: string) => {
|
||||
if (!serverName) {
|
||||
return;
|
||||
}
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
showToast({
|
||||
message: localize('com_ui_delete_tool_error', { error: String(error) }),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools');
|
||||
const remainingToolIds =
|
||||
currentTools?.filter(
|
||||
(currentToolId) =>
|
||||
currentToolId !== serverName &&
|
||||
!currentToolId.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
) || [];
|
||||
setValue('tools', remainingToolIds);
|
||||
showToast({ message: localize('com_ui_delete_tool_success'), status: 'success' });
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const serverName = serverInfo.serverName;
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
const statusIconProps = getServerStatusIconProps(serverName);
|
||||
const configDialogProps = getConfigDialogProps();
|
||||
|
||||
const statusIcon = statusIconProps && (
|
||||
<div
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
}}
|
||||
className="cursor-pointer rounded p-0.5 hover:bg-surface-secondary"
|
||||
>
|
||||
<MCPServerStatusIcon {...statusIconProps} />
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<OGDialog>
|
||||
<div
|
||||
className="group relative flex w-full items-center gap-1 rounded-lg p-1 text-sm hover:bg-surface-primary-alt"
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
onBlur={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget)) {
|
||||
setIsFocused(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="flex grow cursor-pointer items-center gap-1 rounded bg-transparent p-0 text-left transition-colors"
|
||||
onClick={(e) => {
|
||||
if ((e.target as HTMLElement).closest('[data-status-icon]')) {
|
||||
return;
|
||||
}
|
||||
if (!isServerInitializing) {
|
||||
initializeServer(serverName);
|
||||
}
|
||||
}}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
if (!isServerInitializing) {
|
||||
initializeServer(serverName);
|
||||
}
|
||||
}
|
||||
}}
|
||||
aria-disabled={isServerInitializing}
|
||||
>
|
||||
{statusIcon && (
|
||||
<div className="flex items-center" data-status-icon>
|
||||
{statusIcon}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{serverInfo.metadata.icon && (
|
||||
<div className="flex h-8 w-8 items-center justify-center overflow-hidden rounded-full">
|
||||
<div
|
||||
className="flex h-6 w-6 items-center justify-center overflow-hidden rounded-full bg-center bg-no-repeat dark:bg-white/20"
|
||||
style={{
|
||||
backgroundImage: `url(${serverInfo.metadata.icon})`,
|
||||
backgroundSize: 'cover',
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className="grow px-2 py-1.5"
|
||||
style={{ textOverflow: 'ellipsis', wordBreak: 'break-all', overflow: 'hidden' }}
|
||||
>
|
||||
{serverName}
|
||||
{isServerInitializing && (
|
||||
<span className="ml-2 text-xs text-text-secondary">
|
||||
{localize('com_ui_initializing')}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<OGDialogTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
'flex h-7 w-7 items-center justify-center rounded transition-all duration-200 hover:bg-surface-active-alt focus:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-1',
|
||||
isHovering || isFocused ? 'opacity-100' : 'pointer-events-none opacity-0',
|
||||
)}
|
||||
aria-label={`Delete ${serverName}`}
|
||||
tabIndex={0}
|
||||
onFocus={() => setIsFocused(true)}
|
||||
>
|
||||
<TrashIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</OGDialogTrigger>
|
||||
</div>
|
||||
<OGDialogTemplate
|
||||
showCloseButton={false}
|
||||
title={localize('com_ui_delete_tool')}
|
||||
mainClassName="px-0"
|
||||
className="max-w-[450px]"
|
||||
main={
|
||||
<Label className="text-left text-sm font-medium">
|
||||
{localize('com_ui_delete_tool_confirm')}
|
||||
</Label>
|
||||
}
|
||||
selection={{
|
||||
selectHandler: () => removeTool(serverName),
|
||||
selectClasses:
|
||||
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white',
|
||||
selectText: localize('com_ui_delete'),
|
||||
}}
|
||||
/>
|
||||
{configDialogProps && <MCPConfigDialog {...configDialogProps} />}
|
||||
</OGDialog>
|
||||
);
|
||||
}
|
||||
@@ -6,11 +6,12 @@ import { Constants, QueryKeys } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TUpdateUserPlugins } from 'librechat-data-provider';
|
||||
import ServerInitializationSection from '~/components/MCP/ServerInitializationSection';
|
||||
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
|
||||
import CustomUserVarsSection from '~/components/MCP/CustomUserVarsSection';
|
||||
import { MCPPanelProvider, useMCPPanelContext } from '~/Providers';
|
||||
import { useLocalize, useMCPConnectionStatus } from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import MCPPanelSkeleton from './MCPPanelSkeleton';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function MCPPanelContent() {
|
||||
const localize = useLocalize();
|
||||
@@ -18,10 +19,7 @@ function MCPPanelContent() {
|
||||
const { showToast } = useToastContext();
|
||||
const { conversationId } = useMCPPanelContext();
|
||||
const { data: startupConfig, isLoading: startupConfigLoading } = useGetStartupConfig();
|
||||
const { connectionStatus } = useMCPConnectionStatus({
|
||||
enabled: !!startupConfig?.mcpServers && Object.keys(startupConfig.mcpServers).length > 0,
|
||||
});
|
||||
|
||||
const { data: connectionStatusData } = useMCPConnectionStatusQuery();
|
||||
const [selectedServerNameForEditing, setSelectedServerNameForEditing] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
@@ -59,6 +57,11 @@ function MCPPanelContent() {
|
||||
}));
|
||||
}, [startupConfig?.mcpServers]);
|
||||
|
||||
const connectionStatus = useMemo(
|
||||
() => connectionStatusData?.connectionStatus || {},
|
||||
[connectionStatusData?.connectionStatus],
|
||||
);
|
||||
|
||||
const handleServerClickToEdit = (serverName: string) => {
|
||||
setSelectedServerNameForEditing(serverName);
|
||||
};
|
||||
@@ -122,7 +125,7 @@ function MCPPanelContent() {
|
||||
);
|
||||
}
|
||||
|
||||
const serverStatus = connectionStatus?.[selectedServerNameForEditing];
|
||||
const serverStatus = connectionStatus[selectedServerNameForEditing];
|
||||
|
||||
return (
|
||||
<div className="h-auto max-w-full space-y-4 overflow-x-hidden py-2">
|
||||
@@ -167,7 +170,7 @@ function MCPPanelContent() {
|
||||
<div className="h-auto max-w-full overflow-x-hidden py-2">
|
||||
<div className="space-y-2">
|
||||
{mcpServerDefinitions.map((server) => {
|
||||
const serverStatus = connectionStatus?.[server.serverName];
|
||||
const serverStatus = connectionStatus[server.serverName];
|
||||
const isConnected = serverStatus?.connectionState === 'connected';
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import { XCircle, PlusCircleIcon, Wrench } from 'lucide-react';
|
||||
import type { AgentToolType } from 'librechat-data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
type MCPToolItemProps = {
|
||||
tool: AgentToolType;
|
||||
onAddTool: () => void;
|
||||
onRemoveTool: () => void;
|
||||
isInstalled?: boolean;
|
||||
isConfiguring?: boolean;
|
||||
isInitializing?: boolean;
|
||||
};
|
||||
|
||||
function MCPToolItem({
|
||||
tool,
|
||||
onAddTool,
|
||||
onRemoveTool,
|
||||
isInstalled = false,
|
||||
isConfiguring = false,
|
||||
isInitializing = false,
|
||||
}: MCPToolItemProps) {
|
||||
const localize = useLocalize();
|
||||
const handleClick = () => {
|
||||
if (isInstalled) {
|
||||
onRemoveTool();
|
||||
} else {
|
||||
onAddTool();
|
||||
}
|
||||
};
|
||||
|
||||
const name = tool.metadata?.name || tool.tool_id;
|
||||
const description = tool.metadata?.description || '';
|
||||
const icon = tool.metadata?.icon;
|
||||
|
||||
// Determine button state and text
|
||||
const getButtonState = () => {
|
||||
if (isInstalled) {
|
||||
return {
|
||||
text: localize('com_nav_tool_remove'),
|
||||
icon: <XCircle className="flex h-4 w-4 items-center stroke-2" />,
|
||||
className:
|
||||
'btn relative bg-gray-300 hover:bg-gray-400 dark:bg-gray-50 dark:hover:bg-gray-200',
|
||||
disabled: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (isConfiguring) {
|
||||
return {
|
||||
text: localize('com_ui_confirm'),
|
||||
icon: <PlusCircleIcon className="flex h-4 w-4 items-center stroke-2" />,
|
||||
className: 'btn btn-primary relative',
|
||||
disabled: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (isInitializing) {
|
||||
return {
|
||||
text: localize('com_ui_initializing'),
|
||||
icon: <Wrench className="flex h-4 w-4 items-center stroke-2" />,
|
||||
className: 'btn btn-primary relative opacity-75 cursor-not-allowed',
|
||||
disabled: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
text: localize('com_ui_add'),
|
||||
icon: <PlusCircleIcon className="flex h-4 w-4 items-center stroke-2" />,
|
||||
className: 'btn btn-primary relative',
|
||||
disabled: false,
|
||||
};
|
||||
};
|
||||
|
||||
const buttonState = getButtonState();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 rounded border border-border-medium bg-transparent p-6">
|
||||
<div className="flex gap-4">
|
||||
<div className="h-[70px] w-[70px] shrink-0">
|
||||
<div className="relative h-full w-full">
|
||||
{icon ? (
|
||||
<img
|
||||
src={icon}
|
||||
alt={localize('com_ui_logo', { 0: name })}
|
||||
className="h-full w-full rounded-[5px] bg-white"
|
||||
/>
|
||||
) : (
|
||||
<div className="flex h-full w-full items-center justify-center rounded-[5px] border border-border-medium bg-transparent">
|
||||
<Wrench className="h-8 w-8 text-text-secondary" />
|
||||
</div>
|
||||
)}
|
||||
<div className="absolute inset-0 rounded-[5px] ring-1 ring-inset ring-black/10"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex min-w-0 flex-col items-start justify-between">
|
||||
<div className="mb-2 line-clamp-1 max-w-full text-lg leading-5 text-text-primary">
|
||||
{name}
|
||||
</div>
|
||||
<button
|
||||
className={buttonState.className}
|
||||
aria-label={`${buttonState.text} ${name}`}
|
||||
onClick={handleClick}
|
||||
disabled={buttonState.disabled}
|
||||
>
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
{buttonState.text}
|
||||
{buttonState.icon}
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="line-clamp-3 h-[60px] text-sm text-text-secondary">{description}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default MCPToolItem;
|
||||
@@ -1,370 +0,0 @@
|
||||
import { useEffect, useState, useMemo } from 'react';
|
||||
import { Search, X } from 'lucide-react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { Constants, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { Dialog, DialogPanel, DialogTitle, Description } from '@headlessui/react';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TError, AgentToolType } from 'librechat-data-provider';
|
||||
import type { AgentForm, TPluginStoreDialogProps } from '~/common';
|
||||
import { useLocalize, usePluginDialogHelpers, useMCPServerManager } from '~/hooks';
|
||||
import { useGetStartupConfig, useAvailableToolsQuery } from '~/data-provider';
|
||||
import CustomUserVarsSection from '~/components/MCP/CustomUserVarsSection';
|
||||
import { PluginPagination } from '~/components/Plugins/Store';
|
||||
import { useAgentPanelContext } from '~/Providers';
|
||||
import MCPToolItem from './MCPToolItem';
|
||||
|
||||
function MCPToolSelectDialog({
|
||||
isOpen,
|
||||
agentId,
|
||||
setIsOpen,
|
||||
mcpServerNames,
|
||||
}: TPluginStoreDialogProps & {
|
||||
agentId: string;
|
||||
mcpServerNames?: string[];
|
||||
endpoint: EModelEndpoint.agents;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { mcpServersMap } = useAgentPanelContext();
|
||||
const { initializeServer } = useMCPServerManager();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { refetch: refetchAvailableTools } = useAvailableToolsQuery(EModelEndpoint.agents);
|
||||
|
||||
const [isInitializing, setIsInitializing] = useState<string | null>(null);
|
||||
const [configuringServer, setConfiguringServer] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
maxPage,
|
||||
setMaxPage,
|
||||
currentPage,
|
||||
setCurrentPage,
|
||||
itemsPerPage,
|
||||
searchChanged,
|
||||
setSearchChanged,
|
||||
searchValue,
|
||||
setSearchValue,
|
||||
gridRef,
|
||||
handleSearch,
|
||||
handleChangePage,
|
||||
error,
|
||||
setError,
|
||||
errorMessage,
|
||||
setErrorMessage,
|
||||
} = usePluginDialogHelpers();
|
||||
|
||||
const updateUserPlugins = useUpdateUserPluginsMutation();
|
||||
|
||||
const handleInstallError = (error: TError) => {
|
||||
setError(true);
|
||||
const errorMessage = error.response?.data?.message ?? '';
|
||||
if (errorMessage) {
|
||||
setErrorMessage(errorMessage);
|
||||
}
|
||||
setTimeout(() => {
|
||||
setError(false);
|
||||
setErrorMessage('');
|
||||
}, 5000);
|
||||
};
|
||||
|
||||
const handleDirectAdd = async (serverName: string) => {
|
||||
try {
|
||||
setIsInitializing(serverName);
|
||||
const serverInfo = mcpServersMap.get(serverName);
|
||||
if (!serverInfo?.isConnected) {
|
||||
const result = await initializeServer(serverName);
|
||||
if (result?.success && result.oauthRequired && result.oauthUrl) {
|
||||
setIsInitializing(null);
|
||||
return;
|
||||
}
|
||||
}
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'install',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => {
|
||||
handleInstallError(error as TError);
|
||||
setIsInitializing(null);
|
||||
},
|
||||
onSuccess: async () => {
|
||||
const { data: updatedAvailableTools } = await refetchAvailableTools();
|
||||
|
||||
const currentTools = getValues('tools') || [];
|
||||
const toolsToAdd: string[] = [
|
||||
`${Constants.mcp_server}${Constants.mcp_delimiter}${serverName}`,
|
||||
];
|
||||
|
||||
if (updatedAvailableTools) {
|
||||
updatedAvailableTools.forEach((tool) => {
|
||||
if (tool.pluginKey.endsWith(`${Constants.mcp_delimiter}${serverName}`)) {
|
||||
toolsToAdd.push(tool.pluginKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const newTools = toolsToAdd.filter((tool) => !currentTools.includes(tool));
|
||||
if (newTools.length > 0) {
|
||||
setValue('tools', [...currentTools, ...newTools]);
|
||||
}
|
||||
setIsInitializing(null);
|
||||
},
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Error adding MCP server:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveCustomVars = async (serverName: string, authData: Record<string, string>) => {
|
||||
try {
|
||||
await updateUserPlugins.mutateAsync({
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'install',
|
||||
auth: authData,
|
||||
isEntityTool: true,
|
||||
});
|
||||
|
||||
await handleDirectAdd(serverName);
|
||||
|
||||
setConfiguringServer(null);
|
||||
} catch (error) {
|
||||
console.error('Error saving custom vars:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRevokeCustomVars = (serverName: string) => {
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => handleInstallError(error as TError),
|
||||
onSuccess: () => {
|
||||
setConfiguringServer(null);
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const onAddTool = async (serverName: string) => {
|
||||
if (configuringServer === serverName) {
|
||||
setConfiguringServer(null);
|
||||
await handleDirectAdd(serverName);
|
||||
return;
|
||||
}
|
||||
|
||||
const serverConfig = startupConfig?.mcpServers?.[serverName];
|
||||
const hasCustomUserVars =
|
||||
serverConfig?.customUserVars && Object.keys(serverConfig.customUserVars).length > 0;
|
||||
|
||||
if (hasCustomUserVars) {
|
||||
setConfiguringServer(serverName);
|
||||
} else {
|
||||
await handleDirectAdd(serverName);
|
||||
}
|
||||
};
|
||||
|
||||
const onRemoveTool = (serverName: string) => {
|
||||
updateUserPlugins.mutate(
|
||||
{
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
action: 'uninstall',
|
||||
auth: {},
|
||||
isEntityTool: true,
|
||||
},
|
||||
{
|
||||
onError: (error: unknown) => handleInstallError(error as TError),
|
||||
onSuccess: () => {
|
||||
const currentTools = getValues('tools') || [];
|
||||
const remainingTools = currentTools.filter(
|
||||
(tool) =>
|
||||
tool !== serverName && !tool.endsWith(`${Constants.mcp_delimiter}${serverName}`),
|
||||
);
|
||||
setValue('tools', remainingTools);
|
||||
},
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const installedToolsSet = useMemo(() => {
|
||||
return new Set(mcpServerNames);
|
||||
}, [mcpServerNames]);
|
||||
|
||||
const mcpServers = useMemo(() => {
|
||||
const servers = Array.from(mcpServersMap.values());
|
||||
return servers.sort((a, b) => a.serverName.localeCompare(b.serverName));
|
||||
}, [mcpServersMap]);
|
||||
|
||||
const filteredServers = useMemo(() => {
|
||||
if (!searchValue) {
|
||||
return mcpServers;
|
||||
}
|
||||
return mcpServers.filter((serverInfo) =>
|
||||
serverInfo.serverName.toLowerCase().includes(searchValue.toLowerCase()),
|
||||
);
|
||||
}, [mcpServers, searchValue]);
|
||||
|
||||
useEffect(() => {
|
||||
setMaxPage(Math.ceil(filteredServers.length / itemsPerPage));
|
||||
if (searchChanged) {
|
||||
setCurrentPage(1);
|
||||
setSearchChanged(false);
|
||||
}
|
||||
}, [
|
||||
setMaxPage,
|
||||
itemsPerPage,
|
||||
searchChanged,
|
||||
setCurrentPage,
|
||||
setSearchChanged,
|
||||
filteredServers.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={isOpen}
|
||||
onClose={() => {
|
||||
setIsOpen(false);
|
||||
setCurrentPage(1);
|
||||
setSearchValue('');
|
||||
setConfiguringServer(null);
|
||||
setIsInitializing(null);
|
||||
}}
|
||||
className="relative z-[102]"
|
||||
>
|
||||
<div className="fixed inset-0 bg-surface-primary opacity-60 transition-opacity dark:opacity-80" />
|
||||
<div className="fixed inset-0 flex items-center justify-center p-4">
|
||||
<DialogPanel
|
||||
className="relative max-h-[90vh] w-full transform overflow-hidden overflow-y-auto rounded-lg bg-surface-secondary text-left shadow-xl transition-all max-sm:h-full sm:mx-7 sm:my-8 sm:max-w-2xl lg:max-w-5xl xl:max-w-7xl"
|
||||
style={{ minHeight: '610px' }}
|
||||
>
|
||||
<div className="flex items-center justify-between border-b-[1px] border-border-medium px-4 pb-4 pt-5 sm:p-6">
|
||||
<div className="flex items-center">
|
||||
<div className="text-center sm:text-left">
|
||||
<DialogTitle className="text-lg font-medium leading-6 text-text-primary">
|
||||
{localize('com_nav_tool_dialog_mcp_server_tools')}
|
||||
</DialogTitle>
|
||||
<Description className="text-sm text-text-secondary">
|
||||
{localize('com_nav_tool_dialog_description')}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<button
|
||||
onClick={() => {
|
||||
setIsOpen(false);
|
||||
setCurrentPage(1);
|
||||
setConfiguringServer(null);
|
||||
setIsInitializing(null);
|
||||
}}
|
||||
className="inline-block rounded-full text-text-secondary transition-colors hover:text-text-primary"
|
||||
aria-label="Close dialog"
|
||||
type="button"
|
||||
>
|
||||
<X aria-hidden="true" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div
|
||||
className="relative m-4 rounded border border-red-400 bg-red-100 px-4 py-3 text-red-700"
|
||||
role="alert"
|
||||
>
|
||||
{localize('com_nav_plugin_auth_error')} {errorMessage}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{configuringServer && (
|
||||
<div className="p-4 sm:p-6 sm:pt-4">
|
||||
<div className="mb-4">
|
||||
<p className="text-sm text-text-secondary">
|
||||
{localize('com_ui_mcp_configure_server_description', { 0: configuringServer })}
|
||||
</p>
|
||||
</div>
|
||||
<CustomUserVarsSection
|
||||
serverName={configuringServer}
|
||||
fields={startupConfig?.mcpServers?.[configuringServer]?.customUserVars || {}}
|
||||
onSave={(authData) => handleSaveCustomVars(configuringServer, authData)}
|
||||
onRevoke={() => handleRevokeCustomVars(configuringServer)}
|
||||
isSubmitting={updateUserPlugins.isLoading}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="p-4 sm:p-6 sm:pt-4">
|
||||
<div className="mt-4 flex flex-col gap-4">
|
||||
<div
|
||||
className="flex items-center justify-center space-x-4"
|
||||
onClick={() => setConfiguringServer(null)}
|
||||
>
|
||||
<Search className="h-6 w-6 text-text-tertiary" />
|
||||
<input
|
||||
type="text"
|
||||
value={searchValue}
|
||||
onChange={handleSearch}
|
||||
placeholder={localize('com_nav_tool_search')}
|
||||
className="w-64 rounded border border-border-medium bg-transparent px-2 py-1 text-text-primary focus:outline-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div
|
||||
ref={gridRef}
|
||||
className="grid grid-cols-1 gap-3 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4"
|
||||
style={{ minHeight: '410px' }}
|
||||
>
|
||||
{filteredServers
|
||||
.slice((currentPage - 1) * itemsPerPage, currentPage * itemsPerPage)
|
||||
.map((serverInfo) => {
|
||||
const isInstalled = installedToolsSet.has(serverInfo.serverName);
|
||||
const isConfiguring = configuringServer === serverInfo.serverName;
|
||||
const isServerInitializing = isInitializing === serverInfo.serverName;
|
||||
|
||||
const tool: AgentToolType = {
|
||||
agent_id: agentId,
|
||||
tool_id: serverInfo.serverName,
|
||||
metadata: {
|
||||
...serverInfo.metadata,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverInfo.serverName}`,
|
||||
},
|
||||
};
|
||||
|
||||
return (
|
||||
<MCPToolItem
|
||||
tool={tool}
|
||||
isInstalled={isInstalled}
|
||||
key={serverInfo.serverName}
|
||||
isConfiguring={isConfiguring}
|
||||
isInitializing={isServerInitializing}
|
||||
onAddTool={() => onAddTool(serverInfo.serverName)}
|
||||
onRemoveTool={() => onRemoveTool(serverInfo.serverName)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-2 flex flex-col items-center gap-2 sm:flex-row sm:justify-between">
|
||||
{maxPage > 0 ? (
|
||||
<PluginPagination
|
||||
currentPage={currentPage}
|
||||
maxPage={maxPage}
|
||||
onChangePage={handleChangePage}
|
||||
/>
|
||||
) : (
|
||||
<div style={{ height: '21px' }}></div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</DialogPanel>
|
||||
</div>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
export default MCPToolSelectDialog;
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useEffect } from 'react';
|
||||
import { Search, X } from 'lucide-react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import { Constants, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import { Dialog, DialogPanel, DialogTitle, Description } from '@headlessui/react';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import type {
|
||||
@@ -15,6 +15,7 @@ import type { AgentForm, TPluginStoreDialogProps } from '~/common';
|
||||
import { PluginPagination, PluginAuthForm } from '~/components/Plugins/Store';
|
||||
import { useAgentPanelContext } from '~/Providers/AgentPanelContext';
|
||||
import { useLocalize, usePluginDialogHelpers } from '~/hooks';
|
||||
import { useAvailableToolsQuery } from '~/data-provider';
|
||||
import ToolItem from './ToolItem';
|
||||
|
||||
function ToolSelectDialog({
|
||||
@@ -25,9 +26,10 @@ function ToolSelectDialog({
|
||||
endpoint: AssistantsEndpoint | EModelEndpoint.agents;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const isAgentTools = isAgentsEndpoint(endpoint);
|
||||
const { getValues, setValue } = useFormContext<AgentForm>();
|
||||
const { groupedTools, pluginTools } = useAgentPanelContext();
|
||||
const { data: tools } = useAvailableToolsQuery(endpoint);
|
||||
const { groupedTools } = useAgentPanelContext();
|
||||
const isAgentTools = isAgentsEndpoint(endpoint);
|
||||
|
||||
const {
|
||||
maxPage,
|
||||
@@ -119,28 +121,38 @@ function ToolSelectDialog({
|
||||
|
||||
const onAddTool = (pluginKey: string) => {
|
||||
setShowPluginAuthForm(false);
|
||||
const availablePluginFromKey = pluginTools?.find((p) => p.pluginKey === pluginKey);
|
||||
setSelectedPlugin(availablePluginFromKey);
|
||||
const getAvailablePluginFromKey = tools?.find((p) => p.pluginKey === pluginKey);
|
||||
setSelectedPlugin(getAvailablePluginFromKey);
|
||||
|
||||
const { authConfig, authenticated = false } = availablePluginFromKey ?? {};
|
||||
if (authConfig && authConfig.length > 0 && !authenticated) {
|
||||
setShowPluginAuthForm(true);
|
||||
const isMCPTool = pluginKey.includes(Constants.mcp_delimiter);
|
||||
|
||||
if (isMCPTool) {
|
||||
// MCP tools have their variables configured elsewhere (e.g., MCPPanel or MCPSelect),
|
||||
// so we directly proceed to install without showing the auth form.
|
||||
handleInstall({ pluginKey, action: 'install', auth: {} });
|
||||
} else {
|
||||
handleInstall({
|
||||
pluginKey,
|
||||
action: 'install',
|
||||
auth: {},
|
||||
});
|
||||
const { authConfig, authenticated = false } = getAvailablePluginFromKey ?? {};
|
||||
if (authConfig && authConfig.length > 0 && !authenticated) {
|
||||
setShowPluginAuthForm(true);
|
||||
} else {
|
||||
handleInstall({
|
||||
pluginKey,
|
||||
action: 'install',
|
||||
auth: {},
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const filteredTools = Object.values(groupedTools || {}).filter(
|
||||
(currentTool: AgentToolType & { tools?: AgentToolType[] }) => {
|
||||
if (currentTool.metadata?.name?.toLowerCase().includes(searchValue.toLowerCase())) {
|
||||
(tool: AgentToolType & { tools?: AgentToolType[] }) => {
|
||||
// Check if the parent tool matches
|
||||
if (tool.metadata?.name?.toLowerCase().includes(searchValue.toLowerCase())) {
|
||||
return true;
|
||||
}
|
||||
if (currentTool.tools) {
|
||||
return currentTool.tools.some((childTool) =>
|
||||
// Check if any child tools match
|
||||
if (tool.tools) {
|
||||
return tool.tools.some((childTool) =>
|
||||
childTool.metadata?.name?.toLowerCase().includes(searchValue.toLowerCase()),
|
||||
);
|
||||
}
|
||||
@@ -157,9 +169,9 @@ function ToolSelectDialog({
|
||||
}
|
||||
}
|
||||
}, [
|
||||
pluginTools,
|
||||
searchValue,
|
||||
tools,
|
||||
itemsPerPage,
|
||||
searchValue,
|
||||
filteredTools,
|
||||
searchChanged,
|
||||
setMaxPage,
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
export { default as MCPToolSelectDialog } from './MCPToolSelectDialog';
|
||||
export { default as ToolSelectDialog } from './ToolSelectDialog';
|
||||
export { default as ToolItem } from './ToolItem';
|
||||
|
||||
@@ -400,27 +400,22 @@ export const usePromptGroupsInfiniteQuery = (
|
||||
params?: t.TPromptGroupsWithFilterRequest,
|
||||
config?: UseInfiniteQueryOptions<t.PromptGroupListResponse, unknown>,
|
||||
) => {
|
||||
const { name, pageSize, category } = params || {};
|
||||
const { name, pageSize, category, ...rest } = params || {};
|
||||
return useInfiniteQuery<t.PromptGroupListResponse, unknown>(
|
||||
[QueryKeys.promptGroups, name, category, pageSize],
|
||||
({ pageParam }) => {
|
||||
const queryParams: t.TPromptGroupsWithFilterRequest = {
|
||||
({ pageParam = '1' }) =>
|
||||
dataService.getPromptGroups({
|
||||
...rest,
|
||||
name,
|
||||
category: category || '',
|
||||
limit: (pageSize || 10).toString(),
|
||||
};
|
||||
|
||||
// Only add cursor if it's a valid string
|
||||
if (pageParam && typeof pageParam === 'string') {
|
||||
queryParams.cursor = pageParam;
|
||||
}
|
||||
|
||||
return dataService.getPromptGroups(queryParams);
|
||||
},
|
||||
pageNumber: pageParam?.toString(),
|
||||
pageSize: (pageSize || 10).toString(),
|
||||
}),
|
||||
{
|
||||
getNextPageParam: (lastPage) => {
|
||||
// Use cursor-based pagination - ensure we return a valid cursor or undefined
|
||||
return lastPage.has_more && lastPage.after ? lastPage.after : undefined;
|
||||
const currentPageNumber = Number(lastPage.pageNumber);
|
||||
const totalPages = Number(lastPage.pages);
|
||||
return currentPageNumber < totalPages ? currentPageNumber + 1 : undefined;
|
||||
},
|
||||
refetchOnWindowFocus: false,
|
||||
refetchOnReconnect: false,
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { v4 } from 'uuid';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import { useToastContext } from '@librechat/client';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import {
|
||||
QueryKeys,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
mergeFileConfig,
|
||||
isAgentsEndpoint,
|
||||
isAssistantsEndpoint,
|
||||
@@ -22,7 +19,6 @@ import useLocalize, { TranslationKeys } from '~/hooks/useLocalize';
|
||||
import { useDelayedUploadToast } from './useDelayedUploadToast';
|
||||
import { processFileForUpload } from '~/utils/heicConverter';
|
||||
import { useChatContext } from '~/Providers/ChatContext';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { logger, validateFiles } from '~/utils';
|
||||
import useClientResize from './useClientResize';
|
||||
import useUpdateFiles from './useUpdateFiles';
|
||||
@@ -43,9 +39,6 @@ const useFileHandling = (params?: UseFileHandling) => {
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const { startUploadTimer, clearUploadTimer } = useDelayedUploadToast();
|
||||
const { files, setFiles, setFilesLoading, conversation } = useChatContext();
|
||||
const setEphemeralAgent = useSetRecoilState(
|
||||
ephemeralAgentByConvoId(conversation?.conversationId ?? Constants.NEW_CONVO),
|
||||
);
|
||||
const setError = (error: string) => setErrors((prevErrors) => [...prevErrors, error]);
|
||||
const { addFile, replaceFile, updateFileById, deleteFileById } = useUpdateFiles(
|
||||
params?.fileSetter ?? setFiles,
|
||||
@@ -140,13 +133,6 @@ const useFileHandling = (params?: UseFileHandling) => {
|
||||
const error = _error as TError | undefined;
|
||||
console.log('upload error', error);
|
||||
const file_id = body.get('file_id');
|
||||
const tool_resource = body.get('tool_resource');
|
||||
if (tool_resource === EToolResources.execute_code) {
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
[EToolResources.execute_code]: false,
|
||||
}));
|
||||
}
|
||||
clearUploadTimer(file_id as string);
|
||||
deleteFileById(file_id as string);
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ import { useGetModelsQuery } from 'librechat-data-provider/react-query';
|
||||
import {
|
||||
Permissions,
|
||||
alternateName,
|
||||
PermissionBits,
|
||||
EModelEndpoint,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
getConfigDefaults,
|
||||
isAssistantsEndpoint,
|
||||
PermissionBits,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TAssistantsMap, TEndpointsConfig } from 'librechat-data-provider';
|
||||
import type { MentionOption } from '~/common';
|
||||
@@ -19,7 +19,6 @@ import {
|
||||
useGetStartupConfig,
|
||||
} from '~/data-provider';
|
||||
import useAssistantListMap from '~/hooks/Assistants/useAssistantListMap';
|
||||
import { useAgentsMapContext } from '~/Providers/AgentsMapContext';
|
||||
import { mapEndpoints, getPresetTitle } from '~/utils';
|
||||
import { EndpointIcon } from '~/components/Endpoints';
|
||||
import useHasAccess from '~/hooks/Roles/useHasAccess';
|
||||
@@ -63,7 +62,6 @@ export default function useMentions({
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const { data: presets } = useGetPresetsQuery();
|
||||
const { data: modelsConfig } = useGetModelsQuery();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
@@ -131,24 +129,7 @@ export default function useMentions({
|
||||
[listMap, assistantMap, endpointsConfig],
|
||||
);
|
||||
|
||||
const modelSpecs = useMemo(() => {
|
||||
const specs = startupConfig?.modelSpecs?.list ?? [];
|
||||
if (!agentsMap) {
|
||||
return specs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter modelSpecs to only include agents the user has access to.
|
||||
* Use agentsMap which already contains permission-filtered agents (consistent with other components).
|
||||
*/
|
||||
return specs.filter((spec) => {
|
||||
if (spec.preset?.endpoint === EModelEndpoint.agents && spec.preset?.agent_id) {
|
||||
return spec.preset.agent_id in agentsMap;
|
||||
}
|
||||
/** Keep non-agent modelSpecs */
|
||||
return true;
|
||||
});
|
||||
}, [startupConfig, agentsMap]);
|
||||
const modelSpecs = useMemo(() => startupConfig?.modelSpecs?.list ?? [], [startupConfig]);
|
||||
|
||||
const options: MentionOption[] = useMemo(() => {
|
||||
let validEndpoints = endpoints;
|
||||
|
||||
@@ -25,7 +25,6 @@ const useSpeechToTextExternal = (
|
||||
|
||||
const [minDecibels] = useRecoilState(store.decibelValue);
|
||||
const [autoSendText] = useRecoilState(store.autoSendText);
|
||||
const [languageSTT] = useRecoilState<string>(store.languageSTT);
|
||||
const [speechToText] = useRecoilState<boolean>(store.speechToText);
|
||||
const [autoTranscribeAudio] = useRecoilState<boolean>(store.autoTranscribeAudio);
|
||||
|
||||
@@ -122,9 +121,6 @@ const useSpeechToTextExternal = (
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('audio', audioBlob, `audio.${fileExtension}`);
|
||||
if (languageSTT) {
|
||||
formData.append('language', languageSTT);
|
||||
}
|
||||
setIsRequestBeingMade(true);
|
||||
cleanup();
|
||||
processAudio(formData);
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
export * from './useGetMCPTools';
|
||||
export * from './useMCPConnectionStatus';
|
||||
export * from './useMCPSelect';
|
||||
export * from './useVisibleTools';
|
||||
export * from './useGetMCPTools';
|
||||
export { useMCPServerManager } from './useMCPServerManager';
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
|
||||
|
||||
export function useMCPConnectionStatus({ enabled }: { enabled?: boolean } = {}) {
|
||||
const { data } = useMCPConnectionStatusQuery({
|
||||
enabled,
|
||||
});
|
||||
|
||||
return {
|
||||
connectionStatus: data?.connectionStatus,
|
||||
};
|
||||
}
|
||||
@@ -9,7 +9,8 @@ import {
|
||||
} from 'librechat-data-provider/react-query';
|
||||
import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider';
|
||||
import type { ConfigFieldDetail } from '~/common';
|
||||
import { useLocalize, useMCPSelect, useGetMCPTools, useMCPConnectionStatus } from '~/hooks';
|
||||
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
|
||||
import { useLocalize, useMCPSelect, useGetMCPTools } from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
|
||||
interface ServerState {
|
||||
@@ -20,7 +21,7 @@ interface ServerState {
|
||||
pollInterval: NodeJS.Timeout | null;
|
||||
}
|
||||
|
||||
export function useMCPServerManager({ conversationId }: { conversationId?: string | null } = {}) {
|
||||
export function useMCPServerManager({ conversationId }: { conversationId?: string | null }) {
|
||||
const localize = useLocalize();
|
||||
const queryClient = useQueryClient();
|
||||
const { showToast } = useToastContext();
|
||||
@@ -82,9 +83,13 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
return initialStates;
|
||||
});
|
||||
|
||||
const { connectionStatus } = useMCPConnectionStatus({
|
||||
const { data: connectionStatusData } = useMCPConnectionStatusQuery({
|
||||
enabled: !!startupConfig?.mcpServers && Object.keys(startupConfig.mcpServers).length > 0,
|
||||
});
|
||||
const connectionStatus = useMemo(
|
||||
() => connectionStatusData?.connectionStatus || {},
|
||||
[connectionStatusData?.connectionStatus],
|
||||
);
|
||||
|
||||
/** Filter disconnected servers when values change, but only after initial load
|
||||
This prevents clearing selections on page refresh when servers haven't connected yet
|
||||
@@ -92,7 +97,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
const hasInitialLoadCompleted = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!connectionStatus || Object.keys(connectionStatus).length === 0) {
|
||||
if (!connectionStatusData || Object.keys(connectionStatus).length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -110,7 +115,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
if (connectedSelected.length !== mcpValues.length) {
|
||||
setMCPValues(connectedSelected);
|
||||
}
|
||||
}, [connectionStatus, mcpValues, setMCPValues]);
|
||||
}, [connectionStatus, connectionStatusData, mcpValues, setMCPValues]);
|
||||
|
||||
const updateServerState = useCallback((serverName: string, updates: Partial<ServerState>) => {
|
||||
setServerStates((prev) => {
|
||||
@@ -224,46 +229,46 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
const initializeServer = useCallback(
|
||||
async (serverName: string, autoOpenOAuth: boolean = true) => {
|
||||
updateServerState(serverName, { isInitializing: true });
|
||||
|
||||
try {
|
||||
const response = await reinitializeMutation.mutateAsync(serverName);
|
||||
if (!response.success) {
|
||||
|
||||
if (response.success) {
|
||||
if (response.oauthRequired && response.oauthUrl) {
|
||||
updateServerState(serverName, {
|
||||
oauthUrl: response.oauthUrl,
|
||||
oauthStartTime: Date.now(),
|
||||
isCancellable: true,
|
||||
isInitializing: true,
|
||||
});
|
||||
|
||||
if (autoOpenOAuth) {
|
||||
window.open(response.oauthUrl, '_blank', 'noopener,noreferrer');
|
||||
}
|
||||
|
||||
startServerPolling(serverName);
|
||||
} else {
|
||||
await queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_initialized_success', { 0: serverName }),
|
||||
status: 'success',
|
||||
});
|
||||
|
||||
const currentValues = mcpValues ?? [];
|
||||
if (!currentValues.includes(serverName)) {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
}
|
||||
|
||||
cleanupServerState(serverName);
|
||||
}
|
||||
} else {
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_failed', { 0: serverName }),
|
||||
status: 'error',
|
||||
});
|
||||
cleanupServerState(serverName);
|
||||
return response;
|
||||
}
|
||||
|
||||
if (response.oauthRequired && response.oauthUrl) {
|
||||
updateServerState(serverName, {
|
||||
oauthUrl: response.oauthUrl,
|
||||
oauthStartTime: Date.now(),
|
||||
isCancellable: true,
|
||||
isInitializing: true,
|
||||
});
|
||||
|
||||
if (autoOpenOAuth) {
|
||||
window.open(response.oauthUrl, '_blank', 'noopener,noreferrer');
|
||||
}
|
||||
|
||||
startServerPolling(serverName);
|
||||
} else {
|
||||
await queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_initialized_success', { 0: serverName }),
|
||||
status: 'success',
|
||||
});
|
||||
|
||||
const currentValues = mcpValues ?? [];
|
||||
if (!currentValues.includes(serverName)) {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
}
|
||||
|
||||
cleanupServerState(serverName);
|
||||
}
|
||||
return response;
|
||||
} catch (error) {
|
||||
console.error(`[MCP Manager] Failed to initialize ${serverName}:`, error);
|
||||
showToast({
|
||||
@@ -346,7 +351,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
return;
|
||||
}
|
||||
|
||||
const serverStatus = connectionStatus?.[serverName];
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
connectedServers.push(serverName);
|
||||
} else {
|
||||
@@ -376,7 +381,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
const filteredValues = currentValues.filter((name) => name !== serverName);
|
||||
setMCPValues(filteredValues);
|
||||
} else {
|
||||
const serverStatus = connectionStatus?.[serverName];
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
} else {
|
||||
@@ -450,7 +455,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
const getServerStatusIconProps = useCallback(
|
||||
(serverName: string) => {
|
||||
const tool = mcpToolDetails?.find((t) => t.name === serverName);
|
||||
const serverStatus = connectionStatus?.[serverName];
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
const serverConfig = startupConfig?.mcpServers?.[serverName];
|
||||
|
||||
const handleConfigClick = (e: React.MouseEvent) => {
|
||||
@@ -527,7 +532,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
|
||||
return {
|
||||
serverName: selectedToolForConfig.name,
|
||||
serverStatus: connectionStatus?.[selectedToolForConfig.name],
|
||||
serverStatus: connectionStatus[selectedToolForConfig.name],
|
||||
isOpen: isConfigModalOpen,
|
||||
onOpenChange: handleDialogOpenChange,
|
||||
fieldsSchema,
|
||||
@@ -548,6 +553,7 @@ export function useMCPServerManager({ conversationId }: { conversationId?: strin
|
||||
|
||||
return {
|
||||
configuredServers,
|
||||
connectionStatus,
|
||||
initializeServer,
|
||||
cancelOAuthFlow,
|
||||
isInitializing,
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import { useMemo } from 'react';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import type { AgentToolType } from 'librechat-data-provider';
|
||||
import type { MCPServerInfo } from '~/common';
|
||||
|
||||
type GroupedToolType = AgentToolType & { tools?: AgentToolType[] };
|
||||
type GroupedToolsRecord = Record<string, GroupedToolType>;
|
||||
|
||||
interface VisibleToolsResult {
|
||||
toolIds: string[];
|
||||
mcpServerNames: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook to calculate visible tool IDs based on selected tools and their parent groups.
|
||||
* If any subtool of a group is selected, the parent group tool is also made visible.
|
||||
*
|
||||
* @param selectedToolIds - Array of selected tool IDs
|
||||
* @param allTools - Record of all available tools
|
||||
* @param mcpServersMap - Map of all MCP servers
|
||||
* @returns Object containing separate arrays of visible tool IDs for regular and MCP tools
|
||||
*/
|
||||
export function useVisibleTools(
|
||||
selectedToolIds: string[] | undefined,
|
||||
allTools: GroupedToolsRecord | undefined,
|
||||
mcpServersMap: Map<string, MCPServerInfo>,
|
||||
): VisibleToolsResult {
|
||||
return useMemo(() => {
|
||||
const mcpServers = new Set<string>();
|
||||
const selectedSet = new Set<string>();
|
||||
const regularToolIds = new Set<string>();
|
||||
|
||||
for (const toolId of selectedToolIds ?? []) {
|
||||
if (!toolId.includes(Constants.mcp_delimiter)) {
|
||||
selectedSet.add(toolId);
|
||||
continue;
|
||||
}
|
||||
const serverName = toolId.split(Constants.mcp_delimiter)[1];
|
||||
if (!serverName) {
|
||||
continue;
|
||||
}
|
||||
mcpServers.add(serverName);
|
||||
}
|
||||
|
||||
if (allTools) {
|
||||
for (const [toolId, toolObj] of Object.entries(allTools)) {
|
||||
if (selectedSet.has(toolId)) {
|
||||
regularToolIds.add(toolId);
|
||||
}
|
||||
|
||||
if (toolObj.tools?.length) {
|
||||
for (const subtool of toolObj.tools) {
|
||||
if (selectedSet.has(subtool.tool_id)) {
|
||||
regularToolIds.add(toolId);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mcpServersMap) {
|
||||
for (const [mcpServerName] of mcpServersMap) {
|
||||
if (mcpServers.has(mcpServerName)) {
|
||||
continue;
|
||||
}
|
||||
/** Legacy check */
|
||||
if (selectedSet.has(mcpServerName)) {
|
||||
mcpServers.add(mcpServerName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
toolIds: Array.from(regularToolIds).sort((a, b) => a.localeCompare(b)),
|
||||
mcpServerNames: Array.from(mcpServers).sort((a, b) => a.localeCompare(b)),
|
||||
};
|
||||
}, [allTools, mcpServersMap, selectedToolIds]);
|
||||
}
|
||||
@@ -98,10 +98,6 @@ export function useToolToggle({
|
||||
if (isAuthenticated !== undefined && !isAuthenticated && setIsDialogOpen) {
|
||||
setIsDialogOpen(true);
|
||||
e?.preventDefault?.();
|
||||
setEphemeralAgent((prev) => ({
|
||||
...(prev || {}),
|
||||
[toolKey]: false,
|
||||
}));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user