Compare commits

..

1 Commits

Author SHA1 Message Date
Danny Avila
6900fa73c6 fix: Add useBackToNewChat Hook for Back Navigation Handling
- Introduced a new hook `useBackToNewChat` to manage state when navigating back to the `/c/new` route.
- The hook listens for browser back/forward navigation events and resets the conversation state appropriately.
- Integrated the hook into `ChatRoute.tsx` to ensure proper functionality during navigation.
- Added documentation detailing the problem, solution, and implementation of the new hook.
2025-08-27 19:09:00 -04:00
307 changed files with 3844 additions and 19689 deletions

View File

@@ -690,8 +690,8 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS_PING_INTERVAL=300
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES)
# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES
# Comma-separated list of CacheKeys (e.g., STATIC_CONFIG,ROLES,MESSAGES)
# FORCED_IN_MEMORY_CACHE_NAMESPACES=STATIC_CONFIG,ROLES
#==================================================#
# Others #

View File

@@ -1,4 +1,4 @@
# v0.8.0-rc4
# v0.8.0-rc3
# Base node image
FROM node:20-alpine AS node
@@ -30,7 +30,7 @@ RUN \
# Allow mounting of these files, which have no default
touch .env ; \
# Create directories for the volumes to inherit the correct permissions
mkdir -p /app/client/public/images /app/api/logs /app/uploads ; \
mkdir -p /app/client/public/images /app/api/logs ; \
npm config set fetch-retry-maxtimeout 600000 ; \
npm config set fetch-retries 5 ; \
npm config set fetch-retry-mintimeout 15000 ; \
@@ -44,6 +44,8 @@ RUN \
npm prune --production; \
npm cache clean --force
RUN mkdir -p /app/client/public/images /app/api/logs
# Node API setup
EXPOSE 3080
ENV HOST=0.0.0.0

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi
# v0.8.0-rc4
# v0.8.0-rc3
# Base for all builds
FROM node:20-alpine AS base-min

View File

@@ -75,7 +75,6 @@
- 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context
- Combines search providers, content scrapers, and result rerankers for optimal results
- **Customizable Jina Reranking**: Configure custom Jina API URLs for reranking services
- **[Learn More →](https://www.librechat.ai/docs/features/web_search)**
- 🪄 **Generative UI with Code Artifacts**:

View File

@@ -10,17 +10,7 @@ const {
validateVisionModel,
} = require('librechat-data-provider');
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
const {
Tokenizer,
createFetch,
matchModelName,
getClaudeHeaders,
getModelMaxTokens,
configureReasoning,
checkPromptCacheSupport,
getModelMaxOutputTokens,
createStreamEventHandlers,
} = require('@librechat/api');
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
const {
truncateText,
formatMessage,
@@ -29,6 +19,12 @@ const {
parseParamFromPrompt,
createContextHandlers,
} = require('./prompts');
const {
getClaudeHeaders,
configureReasoning,
checkPromptCacheSupport,
} = require('~/server/services/Endpoints/anthropic/helpers');
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { sleep } = require('~/server/utils');

View File

@@ -233,7 +233,6 @@ class BaseClient {
sender: 'User',
text,
isCreatedByUser: true,
targetModel: this.modelOptions?.model ?? this.model,
};
}

View File

@@ -1,5 +1,4 @@
const { google } = require('googleapis');
const { getModelMaxTokens } = require('@librechat/api');
const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
@@ -22,6 +21,7 @@ const {
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const {

View File

@@ -7,9 +7,7 @@ const {
createFetch,
resolveHeaders,
constructAzureURL,
getModelMaxTokens,
genAzureChatCompletion,
getModelMaxOutputTokens,
createStreamEventHandlers,
} = require('@librechat/api');
const {
@@ -33,13 +31,13 @@ const {
titleInstruction,
createContextHandlers,
} = require('./prompts');
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { extractBaseURL } = require('~/utils');
const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient');
const { createLLM } = require('./llm');

View File

@@ -1,5 +1,5 @@
const { getModelMaxTokens } = require('@librechat/api');
const BaseClient = require('../BaseClient');
const { getModelMaxTokens } = require('../../../utils');
class FakeClient extends BaseClient {
constructor(apiKey, options = {}) {

View File

@@ -71,10 +71,9 @@ const primeFiles = async (options) => {
* @param {ServerRequest} options.req
* @param {Array<{ file_id: string; filename: string }>} options.files
* @param {string} [options.entity_id]
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
* @returns
*/
const createFileSearchTool = async ({ req, files, entity_id, fileCitations = false }) => {
const createFileSearchTool = async ({ req, files, entity_id }) => {
return tool(
async ({ query }) => {
if (files.length === 0) {
@@ -143,9 +142,9 @@ const createFileSearchTool = async ({ req, files, entity_id, fileCitations = fal
const formattedString = formattedResults
.map(
(result, index) =>
`File: ${result.filename}${
fileCitations ? `\nAnchor: \\ue202turn0file${index} (${result.filename})` : ''
}\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${result.content}\n`,
`File: ${result.filename}\nAnchor: \\ue202turn0file${index} (${result.filename})\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${
result.content
}\n`,
)
.join('\n---\n');
@@ -159,14 +158,12 @@ const createFileSearchTool = async ({ req, files, entity_id, fileCitations = fal
pageRelevance: result.page ? { [result.page]: 1.0 - result.distance } : {},
}));
return [formattedString, { [Tools.file_search]: { sources, fileCitations } }];
return [formattedString, { [Tools.file_search]: { sources } }];
},
{
name: Tools.file_search,
responseFormat: 'content_and_artifact',
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.${
fileCitations
? `
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.
**CITE FILE SEARCH RESULTS:**
Use anchor markers immediately after statements derived from file content. Reference the filename in your text:
@@ -174,9 +171,7 @@ Use anchor markers immediately after statements derived from file content. Refer
- Page reference: "According to report.docx... \\ue202turn0file1"
- Multi-file: "Multiple sources confirm... \\ue200\\ue202turn0file0\\ue202turn0file1\\ue201"
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`
: ''
}`,
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`,
schema: z.object({
query: z
.string()

View File

@@ -1,16 +1,9 @@
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const {
Tools,
Constants,
Permissions,
EToolResources,
PermissionTypes,
replaceSpecialVars,
} = require('librechat-data-provider');
const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
const {
availableTools,
manifestToolMap,
@@ -34,7 +27,6 @@ const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getCachedTools } = require('~/server/services/Config');
const { getRoleByName } = require('~/models/Role');
/**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
@@ -289,29 +281,7 @@ const loadTools = async ({
if (toolContext) {
toolContextMap[tool] = toolContext;
}
/** @type {boolean | undefined} Check if user has FILE_CITATIONS permission */
let fileCitations;
if (fileCitations == null && options.req?.user != null) {
try {
fileCitations = await checkAccess({
user: options.req.user,
permissionType: PermissionTypes.FILE_CITATIONS,
permissions: [Permissions.USE],
getRoleByName,
});
} catch (error) {
logger.error('[handleTools] FILE_CITATIONS permission check failed:', error);
fileCitations = false;
}
}
return createFileSearchTool({
req: options.req,
files,
entity_id: agent?.id,
fileCitations,
});
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
};
continue;
} else if (tool === Tools.web_search) {
@@ -342,16 +312,6 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
continue;
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
if (toolName === Constants.mcp_server) {
/** Placeholder used for UI purposes */
continue;
}
if (serverName && options.req?.config?.mcpConfig?.[serverName] == null) {
logger.warn(
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
);
continue;
}
if (toolName === Constants.mcp_all) {
const currentMCPGenerator = async (index) =>
createMCPTools({

View File

@@ -157,11 +157,12 @@ describe('cacheConfig', () => {
describe('FORCED_IN_MEMORY_CACHE_NAMESPACES validation', () => {
test('should parse comma-separated cache keys correctly', () => {
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, MESSAGES ';
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, STATIC_CONFIG ,MESSAGES ';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([
'ROLES',
'STATIC_CONFIG',
'MESSAGES',
]);
});

View File

@@ -31,8 +31,8 @@ const namespaces = {
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
[CacheKeys.APP_CONFIG]: standardCache(CacheKeys.APP_CONFIG),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),

View File

@@ -112,17 +112,8 @@ module.exports = {
update.expiredAt = null;
}
/** @type {{ $set: Partial<TConversation>; $addToSet?: Record<string, any>; $unset?: Record<keyof TConversation, number> }} */
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
const updateOperation = { $set: update };
if (convo.model && convo.endpoint) {
updateOperation.$addToSet = {
modelHistory: {
model: convo.model,
endpoint: convo.endpoint,
},
};
}
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
updateOperation.$unset = metadata.unsetFields;
}

View File

@@ -211,7 +211,7 @@ describe('File Access Control', () => {
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when user only has VIEW permission and needs access for deletion', async () => {
it('should deny access when user only has VIEW permission', async () => {
const userId = new mongoose.Types.ObjectId();
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
@@ -263,71 +263,12 @@ describe('File Access Control', () => {
role: SystemRoles.USER,
fileIds,
agentId,
isDelete: true,
});
// Should have no access to any files when only VIEW permission
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should grant access when user has VIEW permission', async () => {
const userId = new mongoose.Types.ObjectId();
const authorId = new mongoose.Types.ObjectId();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create users
await User.create({
_id: userId,
email: 'user@example.com',
emailVerified: true,
provider: 'local',
});
await User.create({
_id: authorId,
email: 'author@example.com',
emailVerified: true,
provider: 'local',
});
// Create agent with files
const agent = await createAgent({
id: agentId,
name: 'View-Only Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Grant only VIEW permission to user on the agent
await grantPermission({
principalType: PrincipalType.USER,
principalId: userId,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
accessRoleId: AccessRoleIds.AGENT_VIEWER,
grantedBy: authorId,
});
// Check access for files
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
const accessMap = await hasAccessToFilesViaAgent({
userId: userId,
role: SystemRoles.USER,
fileIds,
agentId,
});
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
});
});
describe('getFiles with agent access control', () => {

View File

@@ -269,7 +269,7 @@ async function getListPromptGroupsByAccess({
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
// Add cursor condition
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
if (after) {
try {
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
const { updatedAt, _id } = cursor;

View File

@@ -189,15 +189,11 @@ async function createAutoRefillTransaction(txData) {
* @param {txData} _txData - Transaction data.
*/
async function createTransaction(_txData) {
const { balance, transactions, ...txData } = _txData;
const { balance, ...txData } = _txData;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
if (transactions?.enabled === false) {
return;
}
const transaction = new Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;
calculateTokenValue(transaction);
@@ -226,11 +222,7 @@ async function createTransaction(_txData) {
* @param {txData} _txData - Transaction data.
*/
async function createStructuredTransaction(_txData) {
const { balance, transactions, ...txData } = _txData;
if (transactions?.enabled === false) {
return;
}
const { balance, ...txData } = _txData;
const transaction = new Transaction({
...txData,
endpointTokenConfig: txData.endpointTokenConfig,

View File

@@ -1,9 +1,10 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { createTransaction, createStructuredTransaction } = require('./Transaction');
const { Balance, Transaction } = require('~/db/models');
const { createTransaction } = require('./Transaction');
const { Balance } = require('~/db/models');
let mongoServer;
beforeAll(async () => {
@@ -379,188 +380,3 @@ describe('NaN Handling Tests', () => {
expect(balance.tokenCredits).toBe(initialBalance);
});
});
describe('Transactions Config Tests', () => {
test('createTransaction should not save when transactions.enabled is false', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
transactions: { enabled: false },
};
// Act
const result = await createTransaction(txData);
// Assert: No transaction should be created
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(0);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
test('createTransaction should save when transactions.enabled is true', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
transactions: { enabled: true },
balance: { enabled: true },
};
// Act
const result = await createTransaction(txData);
// Assert: Transaction should be created
expect(result).toBeDefined();
expect(result.balance).toBeLessThan(initialBalance);
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
expect(transactions[0].rawAmount).toBe(-100);
});
test('createTransaction should save when balance.enabled is true even if transactions config is missing', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
balance: { enabled: true },
// No transactions config provided
};
// Act
const result = await createTransaction(txData);
// Assert: Transaction should be created (backward compatibility)
expect(result).toBeDefined();
expect(result.balance).toBeLessThan(initialBalance);
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
});
test('createTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
rawAmount: -100,
tokenType: 'prompt',
transactions: { enabled: true },
balance: { enabled: false },
};
// Act
const result = await createTransaction(txData);
// Assert: Transaction should be created but balance unchanged
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
expect(transactions[0].rawAmount).toBe(-100);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
test('createStructuredTransaction should not save when transactions.enabled is false', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'claude-3-5-sonnet';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'message',
tokenType: 'prompt',
inputTokens: -10,
writeTokens: -100,
readTokens: -5,
transactions: { enabled: false },
};
// Act
const result = await createStructuredTransaction(txData);
// Assert: No transaction should be created
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(0);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
test('createStructuredTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'claude-3-5-sonnet';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'message',
tokenType: 'prompt',
inputTokens: -10,
writeTokens: -100,
readTokens: -5,
transactions: { enabled: true },
balance: { enabled: false },
};
// Act
const result = await createStructuredTransaction(txData);
// Assert: Transaction should be created but balance unchanged
expect(result).toBeUndefined();
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(1);
expect(transactions[0].inputTokens).toBe(-10);
expect(transactions[0].writeTokens).toBe(-100);
expect(transactions[0].readTokens).toBe(-5);
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});
});

View File

@@ -1,4 +1,4 @@
const { matchModelName } = require('@librechat/api');
const { matchModelName } = require('../utils/tokens');
const defaultRate = 6;
/**

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "v0.8.0-rc4",
"version": "v0.8.0-rc3",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -49,14 +49,14 @@
"@langchain/google-vertexai": "^0.2.13",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.79",
"@librechat/agents": "^2.4.76",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@microsoft/microsoft-graph-client": "^3.0.7",
"@modelcontextprotocol/sdk": "^1.17.1",
"@node-saml/passport-saml": "^5.1.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.12.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
"connect-redis": "^8.1.0",

View File

@@ -75,7 +75,7 @@ const refreshController = async (req, res) => {
if (!user) {
return res.status(401).redirect('/login');
}
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
const token = setOpenIDAuthTokens(tokenset, res);
return res.status(200).send({ token, user });
} catch (error) {
logger.error('[refreshController] OpenID token refresh error', error);

View File

@@ -74,23 +74,14 @@ const getAvailableTools = async (req, res) => {
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const cachedUserTools = await getCachedTools({ userId });
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
const mcpManager = getMCPManager();
const userPlugins =
cachedUserTools != null
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
: undefined;
/** @type {TPlugin[]} */
let mcpPlugins;
if (appConfig?.mcpConfig) {
const mcpManager = getMCPManager();
mcpPlugins =
cachedUserTools != null
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
: undefined;
}
if (
cachedToolsArray != null &&
(appConfig?.mcpConfig != null ? mcpPlugins != null && mcpPlugins.length > 0 : true)
) {
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...cachedToolsArray]);
if (cachedToolsArray != null && userPlugins != null) {
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
res.status(200).json(dedupedTools);
return;
}
@@ -102,9 +93,9 @@ const getAvailableTools = async (req, res) => {
/** @type {import('@librechat/api').LCManifestTool[]} */
let pluginManifest = availableTools;
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
if (appConfig?.mcpConfig != null) {
try {
const mcpManager = getMCPManager();
const mcpTools = await mcpManager.getAllToolFunctions(userId);
prelimCachedTools = prelimCachedTools ?? {};
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
@@ -184,7 +175,7 @@ const getAvailableTools = async (req, res) => {
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...finalTools]);
const dedupedTools = filterUniquePlugins([...(userPlugins ?? []), ...finalTools]);
res.status(200).json(dedupedTools);
} catch (error) {
logger.error('[getAvailableTools]', error);

View File

@@ -174,19 +174,10 @@ describe('PluginController', () => {
mockCache.get.mockResolvedValue(null);
getCachedTools.mockResolvedValueOnce(mockUserTools);
mockReq.config = {
mcpConfig: {
server1: {},
},
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Mock MCP manager to return empty tools initially (since getAllToolFunctions is called)
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// Mock second call to return tool definitions (includeGlobal: true)
getCachedTools.mockResolvedValueOnce(mockUserTools);
@@ -514,7 +505,7 @@ describe('PluginController', () => {
expect(mockRes.json).toHaveBeenCalledWith([]);
});
it('should handle `cachedToolsArray` and `mcpPlugins` both being defined', async () => {
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
const userTools = {
@@ -531,19 +522,10 @@ describe('PluginController', () => {
mockCache.get.mockResolvedValue(cachedTools);
getCachedTools.mockResolvedValueOnce(userTools);
mockReq.config = {
mcpConfig: {
server1: {},
},
mcpConfig: null,
paths: { structuredTools: '/mock/path' },
};
// Mock MCP manager to return empty tools initially
const mockMCPManager = {
getAllToolFunctions: jest.fn().mockResolvedValue({}),
getRawConfig: jest.fn().mockReturnValue({}),
};
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
// The controller expects a second call to getCachedTools
getCachedTools.mockResolvedValueOnce({
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },

View File

@@ -1,10 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const {
webSearchKeys,
extractWebSearchEnvVars,
normalizeHttpError,
MCPTokenStorage,
} = require('@librechat/api');
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
const {
getFiles,
updateUser,
@@ -21,17 +16,11 @@ const { verifyEmail, resendVerificationEmail } = require('~/server/services/Auth
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { Tools, Constants, FileSources } = require('librechat-data-provider');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User, Token } = require('~/db/models');
const { Transaction, Balance, User } = require('~/db/models');
const { getAppConfig } = require('~/server/services/Config');
const { deleteToolCalls } = require('~/models/ToolCall');
const { deleteAllSharedLinks } = require('~/models');
const { getMCPManager } = require('~/config');
const { MCPOAuthHandler } = require('@librechat/api');
const { getFlowStateManager } = require('~/config');
const { CacheKeys } = require('librechat-data-provider');
const { getLogStores } = require('~/cache');
const { clearMCPServerTools } = require('~/server/services/Config/mcpToolsCache');
const { findToken } = require('~/models');
const getUserController = async (req, res) => {
const appConfig = await getAppConfig({ role: req.user?.role });
@@ -173,15 +162,6 @@ const updateUserPluginsController = async (req, res) => {
);
({ status, message } = normalizeHttpError(authService));
}
try {
// if the MCP server uses OAuth, perform a full cleanup and token revocation
await maybeUninstallOAuthMCP(user.id, pluginKey, appConfig);
} catch (error) {
logger.error(
`[updateUserPluginsController] Error uninstalling OAuth MCP for ${pluginKey}:`,
error,
);
}
} else {
// This handles:
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
@@ -207,7 +187,7 @@ const updateUserPluginsController = async (req, res) => {
// Extract server name from pluginKey (format: "mcp_<serverName>")
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
logger.info(
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
);
await mcpManager.disconnectUserConnection(user.id, serverName);
}
@@ -289,97 +269,6 @@ const resendVerificationController = async (req, res) => {
}
};
/**
* OAuth MCP specific uninstall logic
*/
const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
if (!pluginKey.startsWith(Constants.mcp_prefix)) {
// this is not an MCP server, so nothing to do here
return;
}
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
const mcpManager = getMCPManager(userId);
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
if (!mcpManager.getOAuthServers().has(serverName)) {
// this server does not use OAuth, so nothing to do here as well
return;
}
// 1. get client info used for revocation (client id, secret)
const clientTokenData = await MCPTokenStorage.getClientInfoAndMetadata({
userId,
serverName,
findToken,
});
if (clientTokenData == null) {
return;
}
const { clientInfo, clientMetadata } = clientTokenData;
// 2. get decrypted tokens before deletion
const tokens = await MCPTokenStorage.getTokens({
userId,
serverName,
findToken,
});
// 3. revoke OAuth tokens at the provider
const revocationEndpoint =
serverConfig.oauth?.revocation_endpoint ?? clientMetadata.revocation_endpoint;
const revocationEndpointAuthMethodsSupported =
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
clientMetadata.revocation_endpoint_auth_methods_supported;
if (tokens?.access_token) {
try {
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
});
} catch (error) {
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
}
}
if (tokens?.refresh_token) {
try {
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
serverUrl: serverConfig.url,
clientId: clientInfo.client_id,
clientSecret: clientInfo.client_secret ?? '',
revocationEndpoint,
revocationEndpointAuthMethodsSupported,
});
} catch (error) {
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
}
}
// 4. delete tokens from the DB after revocation attempts
await MCPTokenStorage.deleteUserTokens({
userId,
serverName,
deleteToken: async (filter) => {
await Token.deleteOne(filter);
},
});
// 5. clear the flow state for the OAuth tokens
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
await flowManager.deleteFlow(flowId, 'mcp_oauth');
// 6. clear the tools cache for the server
await clearMCPServerTools({ userId, serverName });
};
module.exports = {
getUserController,
getTermsStatusController,

View File

@@ -1,342 +0,0 @@
const { Tools } = require('librechat-data-provider');
// Mock all dependencies before requiring the module
jest.mock('nanoid', () => ({
nanoid: jest.fn(() => 'mock-id'),
}));
jest.mock('@librechat/api', () => ({
sendEvent: jest.fn(),
}));
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
},
}));
jest.mock('@librechat/agents', () => ({
EnvVar: { CODE_API_KEY: 'CODE_API_KEY' },
Providers: { GOOGLE: 'google' },
GraphEvents: {},
getMessageId: jest.fn(),
ToolEndHandler: jest.fn(),
handleToolCalls: jest.fn(),
ChatModelStreamHandler: jest.fn(),
}));
jest.mock('~/server/services/Files/Citations', () => ({
processFileCitations: jest.fn(),
}));
jest.mock('~/server/services/Files/Code/process', () => ({
processCodeOutput: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
saveBase64Image: jest.fn(),
}));
describe('createToolEndCallback', () => {
let req, res, artifactPromises, createToolEndCallback;
let logger;
beforeEach(() => {
jest.clearAllMocks();
// Get the mocked logger
logger = require('@librechat/data-schemas').logger;
// Now require the module after all mocks are set up
const callbacks = require('../callbacks');
createToolEndCallback = callbacks.createToolEndCallback;
req = {
user: { id: 'user123' },
};
res = {
headersSent: false,
write: jest.fn(),
};
artifactPromises = [];
});
describe('ui_resources artifact handling', () => {
it('should process ui_resources artifact and return attachment when headers not sent', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'button', label: 'Click me' },
1: { type: 'input', placeholder: 'Enter text' },
},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
// Wait for all promises to resolve
const results = await Promise.all(artifactPromises);
// When headers are not sent, it returns attachment without writing
expect(res.write).not.toHaveBeenCalled();
const attachment = results[0];
expect(attachment).toEqual({
type: Tools.ui_resources,
messageId: 'run456',
toolCallId: 'tool123',
conversationId: 'thread789',
[Tools.ui_resources]: {
0: { type: 'button', label: 'Click me' },
1: { type: 'input', placeholder: 'Enter text' },
},
});
});
it('should write to response when headers are already sent', async () => {
res.headersSent = true;
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'carousel', items: [] },
},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(res.write).toHaveBeenCalled();
expect(results[0]).toEqual({
type: Tools.ui_resources,
messageId: 'run456',
toolCallId: 'tool123',
conversationId: 'thread789',
[Tools.ui_resources]: {
0: { type: 'carousel', items: [] },
},
});
});
it('should handle errors when processing ui_resources', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
// Mock res.write to throw an error
res.headersSent = true;
res.write.mockImplementation(() => {
throw new Error('Write failed');
});
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'test' },
},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(logger.error).toHaveBeenCalledWith(
'Error processing artifact content:',
expect.any(Error),
);
expect(results[0]).toBeNull();
});
it('should handle multiple artifacts including ui_resources', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {
0: { type: 'chart', data: [] },
},
},
[Tools.web_search]: {
results: ['result1', 'result2'],
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
// Both ui_resources and web_search should be processed
expect(artifactPromises).toHaveLength(2);
expect(results).toHaveLength(2);
// Check ui_resources attachment
const uiResourceAttachment = results.find((r) => r?.type === Tools.ui_resources);
expect(uiResourceAttachment).toBeTruthy();
expect(uiResourceAttachment[Tools.ui_resources]).toEqual({
0: { type: 'chart', data: [] },
});
// Check web_search attachment
const webSearchAttachment = results.find((r) => r?.type === Tools.web_search);
expect(webSearchAttachment).toBeTruthy();
expect(webSearchAttachment[Tools.web_search]).toEqual({
results: ['result1', 'result2'],
});
});
it('should not process artifacts when output has no artifacts', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
content: 'Some regular content',
// No artifact property
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
expect(artifactPromises).toHaveLength(0);
expect(res.write).not.toHaveBeenCalled();
});
});
describe('edge cases', () => {
it('should handle empty ui_resources data object', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: {},
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(results[0]).toEqual({
type: Tools.ui_resources,
messageId: 'run456',
toolCallId: 'tool123',
conversationId: 'thread789',
[Tools.ui_resources]: {},
});
});
it('should handle ui_resources with complex nested data', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const complexData = {
0: {
type: 'form',
fields: [
{ name: 'field1', type: 'text', required: true },
{ name: 'field2', type: 'select', options: ['a', 'b', 'c'] },
],
nested: {
deep: {
value: 123,
array: [1, 2, 3],
},
},
},
};
const output = {
tool_call_id: 'tool123',
artifact: {
[Tools.ui_resources]: {
data: complexData,
},
},
};
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output }, metadata);
const results = await Promise.all(artifactPromises);
expect(results[0][Tools.ui_resources]).toEqual(complexData);
});
it('should handle when output is undefined', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback({ output: undefined }, metadata);
expect(artifactPromises).toHaveLength(0);
expect(res.write).not.toHaveBeenCalled();
});
it('should handle when data parameter is undefined', async () => {
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
const metadata = {
run_id: 'run456',
thread_id: 'thread789',
};
await toolEndCallback(undefined, metadata);
expect(artifactPromises).toHaveLength(0);
expect(res.write).not.toHaveBeenCalled();
});
});
});

View File

@@ -265,30 +265,6 @@ function createToolEndCallback({ req, res, artifactPromises }) {
);
}
// TODO: a lot of duplicated code in createToolEndCallback
// we should refactor this to use a helper function in a follow-up PR
if (output.artifact[Tools.ui_resources]) {
artifactPromises.push(
(async () => {
const attachment = {
type: Tools.ui_resources,
messageId: metadata.run_id,
toolCallId: output.tool_call_id,
conversationId: metadata.thread_id,
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
};
if (!res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
return attachment;
})().catch((error) => {
logger.error('Error processing artifact content:', error);
return null;
}),
);
}
if (output.artifact[Tools.web_search]) {
artifactPromises.push(
(async () => {

View File

@@ -7,12 +7,9 @@ const {
createRun,
Tokenizer,
checkAccess,
logAxiosError,
resolveHeaders,
getBalanceConfig,
memoryInstructions,
formatContentStrings,
getTransactionsConfig,
createMemoryProcessor,
} = require('@librechat/api');
const {
@@ -89,10 +86,11 @@ function createTokenCounter(encoding) {
}
function logToolError(graph, error, toolId) {
logAxiosError({
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
error,
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
});
toolId,
);
}
class AgentClient extends BaseClient {
@@ -624,13 +622,11 @@ class AgentClient extends BaseClient {
* @param {string} [params.model]
* @param {string} [params.context='message']
* @param {AppConfig['balance']} [params.balance]
* @param {AppConfig['transactions']} [params.transactions]
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
*/
async recordCollectedUsage({
model,
balance,
transactions,
context = 'message',
collectedUsage = this.collectedUsage,
}) {
@@ -656,7 +652,6 @@ class AgentClient extends BaseClient {
const txMetadata = {
context,
balance,
transactions,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
@@ -872,10 +867,11 @@ class AgentClient extends BaseClient {
if (agent.useLegacyContent === true) {
messages = formatContentStrings(messages);
}
const defaultHeaders =
agent.model_parameters?.clientOptions?.defaultHeaders ??
agent.model_parameters?.configuration?.defaultHeaders;
if (defaultHeaders?.['anthropic-beta']?.includes('prompt-caching')) {
if (
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
'prompt-caching',
)
) {
messages = addCacheControl(messages);
}
@@ -883,16 +879,6 @@ class AgentClient extends BaseClient {
memoryPromise = this.runMemory(messages);
}
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
* non-custom endpoints, needs consideration of varying provider header configs.
*/
if (agent.model_parameters?.configuration?.defaultHeaders != null) {
agent.model_parameters.configuration.defaultHeaders = resolveHeaders({
headers: agent.model_parameters.configuration.defaultHeaders,
body: config.configurable.requestBody,
});
}
run = await createRun({
agent,
req: this.options.req,
@@ -1054,12 +1040,7 @@ class AgentClient extends BaseClient {
}
const balanceConfig = getBalanceConfig(appConfig);
const transactionsConfig = getTransactionsConfig(appConfig);
await this.recordCollectedUsage({
context: 'message',
balance: balanceConfig,
transactions: transactionsConfig,
});
await this.recordCollectedUsage({ context: 'message', balance: balanceConfig });
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
@@ -1200,20 +1181,6 @@ class AgentClient extends BaseClient {
clientOptions.json = true;
}
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
* non-custom endpoints, needs consideration of varying provider header configs.
*/
if (clientOptions?.configuration?.defaultHeaders != null) {
clientOptions.configuration.defaultHeaders = resolveHeaders({
headers: clientOptions.configuration.defaultHeaders,
body: {
messageId: this.responseMessageId,
conversationId: this.conversationId,
parentMessageId: this.parentMessageId,
},
});
}
try {
const titleResult = await this.run.generateTitle({
provider,
@@ -1253,13 +1220,11 @@ class AgentClient extends BaseClient {
});
const balanceConfig = getBalanceConfig(appConfig);
const transactionsConfig = getTransactionsConfig(appConfig);
await this.recordCollectedUsage({
collectedUsage,
context: 'title',
model: clientOptions.model,
balance: balanceConfig,
transactions: transactionsConfig,
}).catch((err) => {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',

View File

@@ -237,9 +237,6 @@ describe('AgentClient - titleConvo', () => {
balance: {
enabled: false,
},
transactions: {
enabled: true,
},
});
});

View File

@@ -5,7 +5,6 @@ const { logger } = require('@librechat/data-schemas');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const {
Tools,
Constants,
SystemRoles,
FileSources,
ResourceType,
@@ -70,9 +69,9 @@ const createAgentHandler = async (req, res) => {
for (const tool of tools) {
if (availableTools[tool]) {
agentData.tools.push(tool);
} else if (systemTools[tool]) {
agentData.tools.push(tool);
} else if (tool.includes(Constants.mcp_delimiter)) {
}
if (systemTools[tool]) {
agentData.tools.push(tool);
}
}

View File

@@ -1,7 +1,7 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
const { sendEvent, getBalanceConfig } = require('@librechat/api');
const {
Time,
Constants,
@@ -34,6 +34,7 @@ const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
/**

View File

@@ -1,7 +1,7 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
const { sendEvent, getBalanceConfig } = require('@librechat/api');
const {
Time,
Constants,
@@ -31,6 +31,7 @@ const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
/**

View File

@@ -12,7 +12,7 @@ const { logger } = require('@librechat/data-schemas');
const mongoSanitize = require('express-mongo-sanitize');
const { isEnabled, ErrorController } = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const createValidateImageRequest = require('./middleware/validateImageRequest');
const validateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const { updateInterfacePermissions } = require('~/models/interface');
const { checkMigrations } = require('./services/start/migration');
@@ -126,7 +126,7 @@ const startServer = async () => {
app.use('/api/config', routes.config);
app.use('/api/assistants', routes.assistants);
app.use('/api/files', await routes.files.initialize());
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
app.use('/images/', validateImageRequest, routes.staticRoute);
app.use('/api/share', routes.share);
app.use('/api/roles', routes.roles);
app.use('/api/agents', routes.agents);

View File

@@ -11,25 +11,18 @@ const { getAppConfig } = require('~/server/services/Config');
* @param {Object} res - Express response object.
* @param {Function} next - Next middleware function.
*
* @returns {Promise<void>} - Calls next middleware if the domain's email is allowed, otherwise redirects to login
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the domain's email is allowed
*/
const checkDomainAllowed = async (req, res, next) => {
try {
const email = req?.user?.email;
const appConfig = await getAppConfig({
role: req?.user?.role,
});
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
res.redirect('/login');
return;
}
next();
} catch (error) {
logger.error('[checkDomainAllowed] Error checking domain:', error);
res.redirect('/login');
const checkDomainAllowed = async (req, res, next = () => {}) => {
const email = req?.user?.email;
const appConfig = await getAppConfig({
role: req?.user?.role,
});
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
return res.redirect('/login');
} else {
return next();
}
};

View File

@@ -1,5 +1,6 @@
const validatePasswordReset = require('./validatePasswordReset');
const validateRegistration = require('./validateRegistration');
const validateImageRequest = require('./validateImageRequest');
const buildEndpointOption = require('./buildEndpointOption');
const validateMessageReq = require('./validateMessageReq');
const checkDomainAllowed = require('./checkDomainAllowed');
@@ -49,5 +50,6 @@ module.exports = {
validateMessageReq,
buildEndpointOption,
validateRegistration,
validateImageRequest,
validatePasswordReset,
};

View File

@@ -1,14 +1,14 @@
const jwt = require('jsonwebtoken');
const { isEnabled } = require('@librechat/api');
const createValidateImageRequest = require('~/server/middleware/validateImageRequest');
const validateImageRequest = require('~/server/middleware/validateImageRequest');
jest.mock('@librechat/api', () => ({
isEnabled: jest.fn(),
jest.mock('~/server/services/Config/app', () => ({
getAppConfig: jest.fn(),
}));
describe('validateImageRequest middleware', () => {
let req, res, next, validateImageRequest;
let req, res, next;
const validObjectId = '65cfb246f7ecadb8b1e8036b';
const { getAppConfig } = require('~/server/services/Config/app');
beforeEach(() => {
jest.clearAllMocks();
@@ -22,278 +22,116 @@ describe('validateImageRequest middleware', () => {
};
next = jest.fn();
process.env.JWT_REFRESH_SECRET = 'test-secret';
process.env.OPENID_REUSE_TOKENS = 'false';
// Default: OpenID token reuse disabled
isEnabled.mockReturnValue(false);
// Mock getAppConfig to return secureImageLinks: true by default
getAppConfig.mockResolvedValue({
secureImageLinks: true,
});
});
afterEach(() => {
jest.clearAllMocks();
});
describe('Factory function', () => {
test('should return a pass-through middleware if secureImageLinks is false', async () => {
const middleware = createValidateImageRequest(false);
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('should return validation middleware if secureImageLinks is true', async () => {
validateImageRequest = createValidateImageRequest(true);
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
test('should call next() if secureImageLinks is false', async () => {
getAppConfig.mockResolvedValue({
secureImageLinks: false,
});
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
describe('Standard LibreChat token flow', () => {
beforeEach(() => {
validateImageRequest = createValidateImageRequest(true);
});
test('should return 401 if refresh token is not provided', async () => {
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
});
test('should return 403 if refresh token is invalid', async () => {
req.headers.cookie = 'refreshToken=invalid-token';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should return 403 if refresh token is expired', async () => {
const expiredToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${expiredToken}`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should call next() for valid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should allow agent avatar pattern for any valid ObjectId', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should prevent file traversal attempts', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
const traversalAttempts = [
`/images/${validObjectId}/../../../etc/passwd`,
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
];
for (const attempt of traversalAttempts) {
req.originalUrl = attempt;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
jest.clearAllMocks();
// Reset mocks for next iteration
res.status = jest.fn().mockReturnThis();
res.send = jest.fn();
}
});
test('should handle URL encoded characters in valid paths', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 401 if refresh token is not provided', async () => {
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
});
describe('OpenID token flow', () => {
beforeEach(() => {
validateImageRequest = createValidateImageRequest(true);
// Enable OpenID token reuse
isEnabled.mockReturnValue(true);
process.env.OPENID_REUSE_TOKENS = 'true';
});
test('should return 403 if no OpenID user ID cookie when token_provider is openid', async () => {
req.headers.cookie = 'refreshToken=dummy-token; token_provider=openid';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should validate JWT-signed user ID for OpenID flow', async () => {
const signedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
req.originalUrl = `/images/${validObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid JWT-signed user ID', async () => {
req.headers.cookie =
'refreshToken=dummy-token; token_provider=openid; openid_user_id=invalid-jwt';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should return 403 for expired JWT-signed user ID', async () => {
const expiredSignedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${expiredSignedUserId}`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should validate image path against JWT-signed user ID', async () => {
const signedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
const differentObjectId = '65cfb246f7ecadb8b1e8036c';
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
req.originalUrl = `/images/${differentObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should allow agent avatars in OpenID flow', async () => {
const signedUserId = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 if refresh token is invalid', async () => {
req.headers.cookie = 'refreshToken=invalid-token';
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
describe('Security edge cases', () => {
let validToken;
test('should return 403 if refresh token is expired', async () => {
const expiredToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${expiredToken}`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
beforeEach(() => {
validateImageRequest = createValidateImageRequest(true);
validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
});
test('should call next() for valid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/example.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should handle very long image filenames', async () => {
const longFilename = 'a'.repeat(1000) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid image path', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should handle URLs with maximum practical length', async () => {
// Most browsers support URLs up to ~2000 characters
const longFilename = 'x'.repeat(1900) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid ObjectId format', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/123/example.jpg'; // Invalid ObjectId
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should accept URLs just under the 2048 limit', async () => {
// Create a URL exactly 2047 characters long
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
const filenameLength = 2047 - baseLength;
const filename = 'a'.repeat(filenameLength) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${filename}`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
// File traversal tests
test('should prevent file traversal attempts', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
test('should handle malformed URL encoding gracefully', async () => {
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/test%ZZinvalid.jpg`;
const traversalAttempts = [
`/images/${validObjectId}/../../../etc/passwd`,
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
];
for (const attempt of traversalAttempts) {
req.originalUrl = attempt;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
jest.clearAllMocks();
}
});
test('should reject URLs with null bytes', async () => {
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/test\x00.jpg`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should handle URLs with repeated slashes', async () => {
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}//test.jpg`;
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should reject extremely long URLs as potential DoS', async () => {
// Create a URL longer than 2048 characters
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
const filenameLength = 2049 - baseLength; // Ensure total length exceeds 2048
const extremelyLongFilename = 'x'.repeat(filenameLength) + '.jpg';
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/${extremelyLongFilename}`;
// Verify our test URL is actually too long
expect(req.originalUrl.length).toBeGreaterThan(2048);
await validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should handle URL encoded characters in valid paths', async () => {
const validToken = jwt.sign(
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
await validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
});

View File

@@ -1,7 +1,7 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { getAppConfig } = require('~/server/services/Config/app');
const OBJECT_ID_LENGTH = 24;
const OBJECT_ID_PATTERN = /^[0-9a-f]{24}$/i;
@@ -22,129 +22,50 @@ function isValidObjectId(id) {
}
/**
* Validates a LibreChat refresh token
* @param {string} refreshToken - The refresh token to validate
* @returns {{valid: boolean, userId?: string, error?: string}} - Validation result
* Middleware to validate image request.
* Must be set by `secureImageLinks` via custom config file.
*/
function validateToken(refreshToken) {
async function validateImageRequest(req, res, next) {
const appConfig = await getAppConfig({ role: req.user?.role });
if (!appConfig.secureImageLinks) {
return next();
}
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
if (!refreshToken) {
logger.warn('[validateImageRequest] Refresh token not provided');
return res.status(401).send('Unauthorized');
}
let payload;
try {
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
if (!isValidObjectId(payload.id)) {
return { valid: false, error: 'Invalid User ID' };
}
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
if (payload.exp < currentTimeInSeconds) {
return { valid: false, error: 'Refresh token expired' };
}
return { valid: true, userId: payload.id };
payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
} catch (err) {
logger.warn('[validateToken]', err);
return { valid: false, error: 'Invalid token' };
logger.warn('[validateImageRequest]', err);
return res.status(403).send('Access Denied');
}
if (!isValidObjectId(payload.id)) {
logger.warn('[validateImageRequest] Invalid User ID');
return res.status(403).send('Access Denied');
}
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
if (payload.exp < currentTimeInSeconds) {
logger.warn('[validateImageRequest] Refresh token expired');
return res.status(403).send('Access Denied');
}
const fullPath = decodeURIComponent(req.originalUrl);
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
if (pathPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
next();
} else {
logger.warn('[validateImageRequest] Invalid image path');
res.status(403).send('Access Denied');
}
}
/**
* Factory to create the `validateImageRequest` middleware with configured secureImageLinks
* @param {boolean} [secureImageLinks] - Whether secure image links are enabled
*/
function createValidateImageRequest(secureImageLinks) {
if (!secureImageLinks) {
return (_req, _res, next) => next();
}
/**
* Middleware to validate image request.
* Supports both LibreChat refresh tokens and OpenID JWT tokens.
* Must be set by `secureImageLinks` via custom config file.
*/
return async function validateImageRequest(req, res, next) {
try {
const cookieHeader = req.headers.cookie;
if (!cookieHeader) {
logger.warn('[validateImageRequest] No cookies provided');
return res.status(401).send('Unauthorized');
}
const parsedCookies = cookies.parse(cookieHeader);
const refreshToken = parsedCookies.refreshToken;
if (!refreshToken) {
logger.warn('[validateImageRequest] Token not provided');
return res.status(401).send('Unauthorized');
}
const tokenProvider = parsedCookies.token_provider;
let userIdForPath;
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
const openidUserId = parsedCookies.openid_user_id;
if (!openidUserId) {
logger.warn('[validateImageRequest] No OpenID user ID cookie found');
return res.status(403).send('Access Denied');
}
const validationResult = validateToken(openidUserId);
if (!validationResult.valid) {
logger.warn(`[validateImageRequest] ${validationResult.error}`);
return res.status(403).send('Access Denied');
}
userIdForPath = validationResult.userId;
} else {
const validationResult = validateToken(refreshToken);
if (!validationResult.valid) {
logger.warn(`[validateImageRequest] ${validationResult.error}`);
return res.status(403).send('Access Denied');
}
userIdForPath = validationResult.userId;
}
if (!userIdForPath) {
logger.warn('[validateImageRequest] No user ID available for path validation');
return res.status(403).send('Access Denied');
}
const MAX_URL_LENGTH = 2048;
if (req.originalUrl.length > MAX_URL_LENGTH) {
logger.warn('[validateImageRequest] URL too long');
return res.status(403).send('Access Denied');
}
if (req.originalUrl.includes('\x00')) {
logger.warn('[validateImageRequest] URL contains null byte');
return res.status(403).send('Access Denied');
}
let fullPath;
try {
fullPath = decodeURIComponent(req.originalUrl);
} catch {
logger.warn('[validateImageRequest] Invalid URL encoding');
return res.status(403).send('Access Denied');
}
const agentAvatarPattern = /^\/images\/[a-f0-9]{24}\/agent-[^/]*$/;
if (agentAvatarPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
return next();
}
const escapedUserId = userIdForPath.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const pathPattern = new RegExp(`^/images/${escapedUserId}/[^/]+$`);
if (pathPattern.test(fullPath)) {
logger.debug('[validateImageRequest] Image request validated');
next();
} else {
logger.warn('[validateImageRequest] Invalid image path');
res.status(403).send('Access Denied');
}
} catch (error) {
logger.error('[validateImageRequest] Error:', error);
res.status(500).send('Internal Server Error');
}
};
}
module.exports = createValidateImageRequest;
module.exports = validateImageRequest;

View File

@@ -1,680 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
createMethods: jest.fn(() => ({})),
createModels: jest.fn(() => ({})),
}));
jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
validateMessageReq: (req, res, next) => next(),
}));
jest.mock('~/models', () => ({
getConvo: jest.fn(),
saveConvo: jest.fn(),
saveMessage: jest.fn(),
getMessage: jest.fn(),
getMessages: jest.fn(),
updateMessage: jest.fn(),
deleteMessages: jest.fn(),
}));
jest.mock('~/db/models', () => {
let User, Message, Transaction, Conversation;
return {
get User() {
return User;
},
get Message() {
return Message;
},
get Transaction() {
return Transaction;
},
get Conversation() {
return Conversation;
},
setUser: (model) => {
User = model;
},
setMessage: (model) => {
Message = model;
},
setTransaction: (model) => {
Transaction = model;
},
setConversation: (model) => {
Conversation = model;
},
};
});
describe('Costs Endpoint', () => {
let app;
let mongoServer;
let messagesRouter;
let User, Message, Transaction, Conversation;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
const userSchema = new mongoose.Schema({
_id: String,
name: String,
email: String,
});
const conversationSchema = new mongoose.Schema({
conversationId: String,
user: String,
title: String,
createdAt: Date,
});
const messageSchema = new mongoose.Schema({
messageId: String,
conversationId: String,
user: String,
isCreatedByUser: Boolean,
tokenCount: Number,
createdAt: Date,
});
const transactionSchema = new mongoose.Schema({
conversationId: String,
user: String,
tokenType: String,
tokenValue: Number,
createdAt: Date,
});
User = mongoose.model('User', userSchema);
Conversation = mongoose.model('Conversation', conversationSchema);
Message = mongoose.model('Message', messageSchema);
Transaction = mongoose.model('Transaction', transactionSchema);
const dbModels = require('~/db/models');
dbModels.setUser(User);
dbModels.setMessage(Message);
dbModels.setTransaction(Transaction);
dbModels.setConversation(Conversation);
require('~/db/models');
try {
messagesRouter = require('../messages');
} catch (error) {
console.error('Error loading messages router:', error);
throw error;
}
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: 'test-user-id' };
next();
});
app.use('/api/messages', messagesRouter);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await User.deleteMany({});
await Conversation.deleteMany({});
await Message.deleteMany({});
await Transaction.deleteMany({});
});
describe('GET /:conversationId/costs', () => {
const conversationId = 'test-conversation-123';
const userId = 'test-user-id';
it('should return cost data for valid conversation', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const aiMessage = new Message({
messageId: 'ai-msg-1',
conversationId,
user: userId,
isCreatedByUser: false,
tokenCount: 150,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage.save(), aiMessage.save()]);
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const completionTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'completion',
tokenValue: 750000,
createdAt: new Date('2024-01-01T10:01:30Z'),
});
await Promise.all([promptTransaction.save(), completionTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body).toMatchObject({
conversationId,
totals: {
prompt: { usd: 0.5, tokenCount: 100 },
completion: { usd: 0.75, tokenCount: 150 },
total: { usd: 1.25, tokenCount: 250 },
},
perMessage: [
{ messageId: 'user-msg-1', tokenType: 'prompt', tokenCount: 100, usd: 0.5 },
{ messageId: 'ai-msg-1', tokenType: 'completion', tokenCount: 150, usd: 0.75 },
],
});
});
it('should return empty data for conversation with no messages', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body).toMatchObject({
conversationId,
totals: {
prompt: { usd: 0, tokenCount: 0 },
completion: { usd: 0, tokenCount: 0 },
total: { usd: 0, tokenCount: 0 },
},
perMessage: [],
});
});
it('should handle messages without transactions', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const aiMessage = new Message({
messageId: 'ai-msg-1',
conversationId,
user: userId,
isCreatedByUser: false,
tokenCount: 150,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage.save(), aiMessage.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0);
expect(response.body.totals.completion.usd).toBe(0);
expect(response.body.totals.total.usd).toBe(0);
});
it('should aggregate multiple transactions correctly', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction1 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 300000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const promptTransaction2 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 200000,
createdAt: new Date('2024-01-01T10:00:45Z'),
});
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.perMessage[0].usd).toBe(0.5);
});
it('should handle null tokenCount values', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: null,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.tokenCount).toBe(0);
});
it('should handle null tokenValue in transactions', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: null,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await promptTransaction.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0);
});
it('should handle negative tokenValue using Math.abs', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: -500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await promptTransaction.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
});
it('should filter by user correctly', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const otherUserId = 'other-user-id';
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const otherUserMessage = new Message({
messageId: 'other-user-msg-1',
conversationId,
user: otherUserId,
isCreatedByUser: true,
tokenCount: 200,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await Promise.all([userMessage.save(), otherUserMessage.save()]);
const userTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const otherUserTransaction = new Transaction({
conversationId,
user: otherUserId,
tokenType: 'prompt',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await Promise.all([userTransaction.save(), otherUserTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.perMessage).toHaveLength(1);
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
});
it('should filter transactions by tokenType', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const otherTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'other',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await Promise.all([promptTransaction.save(), otherTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.totals.completion.usd).toBe(0);
expect(response.body.totals.total.usd).toBe(0.5);
});
it('should map transactions to messages chronologically', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage1 = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const userMessage2 = new Message({
messageId: 'user-msg-2',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 200,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage1.save(), userMessage2.save()]);
const promptTransaction1 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const promptTransaction2 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:01:30Z'),
});
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.perMessage).toHaveLength(2);
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
expect(response.body.perMessage[0].usd).toBe(0.5);
expect(response.body.perMessage[1].messageId).toBe('user-msg-2');
expect(response.body.perMessage[1].usd).toBe(1.0);
});
it('should handle database errors', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
await mongoose.connection.close();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(500);
expect(response.body).toHaveProperty('error');
});
});
});

View File

@@ -122,11 +122,9 @@ router.get('/', async function (req, res) {
payload.minPasswordLength = minPasswordLength;
}
payload.mcpServers = {};
const getMCPServers = () => {
try {
if (appConfig?.mcpConfig == null) {
return;
}
const mcpManager = getMCPManager();
if (!mcpManager) {
return;
@@ -135,9 +133,6 @@ router.get('/', async function (req, res) {
if (!mcpServers) return;
const oauthServers = mcpManager.getOAuthServers();
for (const serverName in mcpServers) {
if (!payload.mcpServers) {
payload.mcpServers = {};
}
const serverConfig = mcpServers[serverName];
payload.mcpServers[serverName] = removeNullishValues({
startup: serverConfig?.startup,

View File

@@ -4,13 +4,9 @@ const { sleep } = require('@librechat/agents');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const {
createImportLimiters,
createForkLimiters,
configMiddleware,
} = require('~/server/middleware');
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { importConversations } = require('~/server/utils/import');
@@ -175,7 +171,6 @@ router.post(
'/import',
importIpLimiter,
importUserLimiter,
configMiddleware,
upload.single('file'),
async (req, res) => {
try {

View File

@@ -31,7 +31,6 @@ const { getAssistant } = require('~/models/Assistant');
const { getAgent } = require('~/models/Agent');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const { Readable } = require('stream');
const router = express.Router();
@@ -185,7 +184,6 @@ router.delete('/', async (req, res) => {
role: req.user.role,
fileIds: nonOwnedFileIds,
agentId: req.body.agent_id,
isDelete: true,
});
for (const file of nonOwnedFiles) {
@@ -327,6 +325,11 @@ router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
res.setHeader('X-File-Metadata', JSON.stringify(file));
};
/** @type {{ body: import('stream').PassThrough } | undefined} */
let passThrough;
/** @type {ReadableStream | undefined} */
let fileStream;
if (checkOpenAIStorage(file.source)) {
req.body = { model: file.model };
const endpointMap = {
@@ -339,19 +342,12 @@ router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
overrideEndpoint: endpointMap[file.source],
});
logger.debug(`Downloading file ${file_id} from OpenAI`);
const passThrough = await getDownloadStream(file_id, openai);
passThrough = await getDownloadStream(file_id, openai);
setHeaders();
logger.debug(`File ${file_id} downloaded from OpenAI`);
// Handle both Node.js and Web streams
const stream =
passThrough.body && typeof passThrough.body.getReader === 'function'
? Readable.fromWeb(passThrough.body)
: passThrough.body;
stream.pipe(res);
passThrough.body.pipe(res);
} else {
const fileStream = await getDownloadStream(req, file.filepath);
fileStream = await getDownloadStream(req, file.filepath);
fileStream.on('error', (streamError) => {
logger.error('[DOWNLOAD ROUTE] Stream error:', streamError);

View File

@@ -11,7 +11,6 @@ const {
} = require('~/models');
const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { tokenValues, getValueKey, defaultRate } = require('~/models/tx');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const { getConvosQueried } = require('~/models/Conversation');
const { countTokens } = require('~/server/utils');
@@ -161,41 +160,6 @@ router.post('/artifact/:messageId', async (req, res) => {
}
});
/**
* POST /costs
* Get cost information for models in modelHistory array
*/
router.post('/costs', async (req, res) => {
try {
const { modelHistory } = req.body;
if (!Array.isArray(modelHistory)) {
return res.status(400).json({ error: 'modelHistory must be an array' });
}
const modelCostTable = {};
modelHistory.forEach((modelEntry) => {
if (modelEntry && typeof modelEntry === 'object' && modelEntry.model && modelEntry.endpoint) {
const { model, endpoint } = modelEntry;
const valueKey = getValueKey(model, endpoint);
const pricing = tokenValues[valueKey];
modelCostTable[model] = {
prompt: pricing?.prompt ?? defaultRate,
completion: pricing?.completion ?? defaultRate,
};
}
});
res.status(200).json({ modelCostTable });
} catch (error) {
logger.error('Error fetching model costs:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
router.get('/:conversationId', validateMessageReq, async (req, res) => {
try {

View File

@@ -26,12 +26,9 @@ const domains = {
router.use(logHeaders);
router.use(loginLimiter);
const oauthHandler = async (req, res, next) => {
const oauthHandler = async (req, res) => {
try {
if (res.headersSent) {
return;
}
await checkDomainAllowed(req, res);
await checkBan(req, res);
if (req.banned) {
return;
@@ -42,14 +39,13 @@ const oauthHandler = async (req, res, next) => {
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
setOpenIDAuthTokens(req.user.tokenset, res);
} else {
await setAuthTokens(req.user._id, res);
}
res.redirect(domains.client);
} catch (err) {
logger.error('Error in setting authentication tokens:', err);
next(err);
}
};
@@ -83,7 +79,6 @@ router.get(
scope: ['openid', 'profile', 'email'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -109,7 +104,6 @@ router.get(
profileFields: ['id', 'email', 'name'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -131,7 +125,6 @@ router.get(
session: false,
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -155,7 +148,6 @@ router.get(
scope: ['user:email', 'read:user'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -179,7 +171,6 @@ router.get(
scope: ['identify', 'email'],
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);
@@ -201,7 +192,6 @@ router.post(
session: false,
}),
setBalanceConfig,
checkDomainAllowed,
oauthHandler,
);

View File

@@ -156,7 +156,7 @@ router.get('/all', async (req, res) => {
router.get('/groups', async (req, res) => {
try {
const userId = req.user.id;
const { pageSize, limit, cursor, name, category, ...otherFilters } = req.query;
const { pageSize, pageNumber, limit, cursor, name, category, ...otherFilters } = req.query;
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
name,
@@ -171,13 +171,6 @@ router.get('/groups', async (req, res) => {
actualLimit = parseInt(pageSize, 10);
}
if (
actualCursor &&
(actualCursor === 'undefined' || actualCursor === 'null' || actualCursor.length === 0)
) {
actualCursor = null;
}
let accessibleIds = await findAccessibleResources({
userId,
role: req.user.role,
@@ -197,7 +190,6 @@ router.get('/groups', async (req, res) => {
publicPromptGroupIds: publiclyAccessibleIds,
});
// Cursor-based pagination only
const result = await getListPromptGroupsByAccess({
accessibleIds: filteredAccessibleIds,
otherParams: filter,
@@ -206,21 +198,19 @@ router.get('/groups', async (req, res) => {
});
if (!result) {
const emptyResponse = createEmptyPromptGroupsResponse({
pageNumber: '1',
pageSize: actualLimit,
actualLimit,
});
const emptyResponse = createEmptyPromptGroupsResponse({ pageNumber, pageSize, actualLimit });
return res.status(200).send(emptyResponse);
}
const { data: promptGroups = [], has_more = false, after = null } = result;
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
const response = formatPromptGroupsResponse({
promptGroups: groupsWithPublicFlag,
pageNumber: '1', // Always 1 for cursor-based pagination
pageSize: actualLimit.toString(),
pageNumber,
pageSize,
actualLimit,
hasMore: has_more,
after,
});

View File

@@ -33,11 +33,22 @@ let promptRoutes;
let Prompt, PromptGroup, AclEntry, AccessRole, User;
let testUsers, testRoles;
let grantPermission;
let currentTestUser; // Track current user for middleware
// Helper function to set user in middleware
function setTestUser(app, user) {
currentTestUser = user;
app.use((req, res, next) => {
req.user = {
...(user.toObject ? user.toObject() : user),
id: user.id || user._id.toString(),
_id: user._id,
name: user.name,
role: user.role,
};
if (user.role === SystemRoles.ADMIN) {
console.log('Setting admin user with role:', req.user.role);
}
next();
});
}
beforeAll(async () => {
@@ -64,35 +75,14 @@ beforeAll(async () => {
app = express();
app.use(express.json());
// Add user middleware before routes
app.use((req, res, next) => {
if (currentTestUser) {
req.user = {
...(currentTestUser.toObject ? currentTestUser.toObject() : currentTestUser),
id: currentTestUser._id.toString(),
_id: currentTestUser._id,
name: currentTestUser.name,
role: currentTestUser.role,
};
}
next();
});
// Mock authentication middleware - default to owner
setTestUser(app, testUsers.owner);
// Set default user
currentTestUser = testUsers.owner;
// Import routes after middleware is set up
// Import routes after mocks are set up
promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
});
afterEach(() => {
// Always reset to owner user after each test for isolation
if (currentTestUser !== testUsers.owner) {
currentTestUser = testUsers.owner;
}
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
@@ -126,26 +116,36 @@ async function setupTestData() {
// Create test users
testUsers = {
owner: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Prompt Owner',
email: 'owner@example.com',
role: SystemRoles.USER,
}),
viewer: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Prompt Viewer',
email: 'viewer@example.com',
role: SystemRoles.USER,
}),
editor: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Prompt Editor',
email: 'editor@example.com',
role: SystemRoles.USER,
}),
noAccess: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'No Access',
email: 'noaccess@example.com',
role: SystemRoles.USER,
}),
admin: await User.create({
id: new ObjectId().toString(),
_id: new ObjectId(),
name: 'Admin',
email: 'admin@example.com',
role: SystemRoles.ADMIN,
@@ -181,7 +181,8 @@ describe('Prompt Routes - ACL Permissions', () => {
it('should have routes loaded', async () => {
// This should at least not crash
const response = await request(app).get('/api/prompts/test-404');
console.log('Test 404 response status:', response.status);
console.log('Test 404 response body:', response.body);
// We expect a 401 or 404, not 500
expect(response.status).not.toBe(500);
});
@@ -206,6 +207,12 @@ describe('Prompt Routes - ACL Permissions', () => {
const response = await request(app).post('/api/prompts').send(promptData);
if (response.status !== 200) {
console.log('POST /api/prompts error status:', response.status);
console.log('POST /api/prompts error body:', response.body);
console.log('Console errors:', consoleErrorSpy.mock.calls);
}
expect(response.status).toBe(200);
expect(response.body.prompt).toBeDefined();
expect(response.body.prompt.prompt).toBe(promptData.prompt.prompt);
@@ -311,8 +318,29 @@ describe('Prompt Routes - ACL Permissions', () => {
});
it('should allow admin access without explicit permissions', async () => {
// Set admin user
setTestUser(app, testUsers.admin);
// First, reset the app to remove previous middleware
app = express();
app.use(express.json());
// Set admin user BEFORE adding routes
app.use((req, res, next) => {
req.user = {
...testUsers.admin.toObject(),
id: testUsers.admin._id.toString(),
_id: testUsers.admin._id,
name: testUsers.admin.name,
role: testUsers.admin.role,
};
next();
});
// Now add the routes
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
console.log('Admin user:', testUsers.admin);
console.log('Admin role:', testUsers.admin.role);
console.log('SystemRoles.ADMIN:', SystemRoles.ADMIN);
const response = await request(app).get(`/api/prompts/${testPrompt._id}`).expect(200);
@@ -404,8 +432,21 @@ describe('Prompt Routes - ACL Permissions', () => {
grantedBy: testUsers.editor._id,
});
// Set viewer user
setTestUser(app, testUsers.viewer);
// Recreate app with viewer user
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = {
...testUsers.viewer.toObject(),
id: testUsers.viewer._id.toString(),
_id: testUsers.viewer._id,
name: testUsers.viewer.name,
role: testUsers.viewer.role,
};
next();
});
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
await request(app)
.delete(`/api/prompts/${authorPrompt._id}`)
@@ -458,8 +499,21 @@ describe('Prompt Routes - ACL Permissions', () => {
grantedBy: testUsers.owner._id,
});
// Ensure owner user
setTestUser(app, testUsers.owner);
// Recreate app to ensure fresh middleware
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = {
...testUsers.owner.toObject(),
id: testUsers.owner._id.toString(),
_id: testUsers.owner._id,
name: testUsers.owner.name,
role: testUsers.owner.role,
};
next();
});
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
const response = await request(app)
.patch(`/api/prompts/${testPrompt._id}/tags/production`)
@@ -483,8 +537,21 @@ describe('Prompt Routes - ACL Permissions', () => {
grantedBy: testUsers.owner._id,
});
// Set viewer user
setTestUser(app, testUsers.viewer);
// Recreate app with viewer user
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = {
...testUsers.viewer.toObject(),
id: testUsers.viewer._id.toString(),
_id: testUsers.viewer._id,
name: testUsers.viewer.name,
role: testUsers.viewer.role,
};
next();
});
const promptRoutes = require('./prompts');
app.use('/api/prompts', promptRoutes);
await request(app).patch(`/api/prompts/${testPrompt._id}/tags/production`).expect(403);
@@ -543,305 +610,4 @@ describe('Prompt Routes - ACL Permissions', () => {
expect(response.body._id).toBe(publicPrompt._id.toString());
});
});
describe('Pagination', () => {
beforeEach(async () => {
// Create multiple prompt groups for pagination testing
const groups = [];
for (let i = 0; i < 15; i++) {
const group = await PromptGroup.create({
name: `Test Group ${i + 1}`,
category: 'pagination-test',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000), // Stagger updatedAt for consistent ordering
});
groups.push(group);
// Grant owner permissions on each group
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
});
afterEach(async () => {
await PromptGroup.deleteMany({});
await AclEntry.deleteMany({});
});
it('should correctly indicate hasMore when there are more pages', async () => {
const response = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10' })
.expect(200);
expect(response.body.promptGroups).toHaveLength(10);
expect(response.body.has_more).toBe(true);
expect(response.body.after).toBeTruthy();
// Since has_more is true, pages should be a high number (9999 in our fix)
expect(parseInt(response.body.pages)).toBeGreaterThan(1);
});
it('should correctly indicate no more pages on the last page', async () => {
// First get the cursor for page 2
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10' })
.expect(200);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
// Now fetch the second page using the cursor
const response = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10', cursor: firstPage.body.after })
.expect(200);
expect(response.body.promptGroups).toHaveLength(5); // 15 total, 10 on page 1, 5 on page 2
expect(response.body.has_more).toBe(false);
});
it('should support cursor-based pagination', async () => {
// First page
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5' })
.expect(200);
expect(firstPage.body.promptGroups).toHaveLength(5);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
// Second page using cursor
const secondPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', cursor: firstPage.body.after })
.expect(200);
expect(secondPage.body.promptGroups).toHaveLength(5);
expect(secondPage.body.has_more).toBe(true);
expect(secondPage.body.after).toBeTruthy();
// Verify different groups
const firstPageIds = firstPage.body.promptGroups.map((g) => g._id);
const secondPageIds = secondPage.body.promptGroups.map((g) => g._id);
expect(firstPageIds).not.toEqual(secondPageIds);
});
it('should paginate correctly with category filtering', async () => {
// Create groups with different categories
await PromptGroup.deleteMany({}); // Clear existing groups
await AclEntry.deleteMany({});
// Create 8 groups with category 'test-cat-1'
for (let i = 0; i < 8; i++) {
const group = await PromptGroup.create({
name: `Category 1 Group ${i + 1}`,
category: 'test-cat-1',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Create 7 groups with category 'test-cat-2'
for (let i = 0; i < 7; i++) {
const group = await PromptGroup.create({
name: `Category 2 Group ${i + 1}`,
category: 'test-cat-2',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - (i + 8) * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Test pagination with category filter
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', category: 'test-cat-1' })
.expect(200);
expect(firstPage.body.promptGroups).toHaveLength(5);
expect(firstPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
const secondPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', cursor: firstPage.body.after, category: 'test-cat-1' })
.expect(200);
expect(secondPage.body.promptGroups).toHaveLength(3); // 8 total, 5 on page 1, 3 on page 2
expect(secondPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
expect(secondPage.body.has_more).toBe(false);
});
it('should paginate correctly with name/keyword filtering', async () => {
// Create groups with specific names
await PromptGroup.deleteMany({}); // Clear existing groups
await AclEntry.deleteMany({});
// Create 12 groups with 'Search' in the name
for (let i = 0; i < 12; i++) {
const group = await PromptGroup.create({
name: `Search Test Group ${i + 1}`,
category: 'search-test',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Create 5 groups without 'Search' in the name
for (let i = 0; i < 5; i++) {
const group = await PromptGroup.create({
name: `Other Group ${i + 1}`,
category: 'other-test',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - (i + 12) * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Test pagination with name filter
const firstPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10', name: 'Search' })
.expect(200);
expect(firstPage.body.promptGroups).toHaveLength(10);
expect(firstPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
expect(firstPage.body.has_more).toBe(true);
expect(firstPage.body.after).toBeTruthy();
const secondPage = await request(app)
.get('/api/prompts/groups')
.query({ limit: '10', cursor: firstPage.body.after, name: 'Search' })
.expect(200);
expect(secondPage.body.promptGroups).toHaveLength(2); // 12 total, 10 on page 1, 2 on page 2
expect(secondPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
expect(secondPage.body.has_more).toBe(false);
});
it('should paginate correctly with combined filters', async () => {
// Create groups with various combinations
await PromptGroup.deleteMany({}); // Clear existing groups
await AclEntry.deleteMany({});
// Create 6 groups matching both category and name filters
for (let i = 0; i < 6; i++) {
const group = await PromptGroup.create({
name: `API Test Group ${i + 1}`,
category: 'api-category',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - i * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Create groups that only match one filter
for (let i = 0; i < 4; i++) {
const group = await PromptGroup.create({
name: `API Other Group ${i + 1}`,
category: 'other-category',
author: testUsers.owner._id,
authorName: testUsers.owner.name,
productionId: new ObjectId(),
updatedAt: new Date(Date.now() - (i + 6) * 1000),
});
await grantPermission({
principalType: PrincipalType.USER,
principalId: testUsers.owner._id,
resourceType: ResourceType.PROMPTGROUP,
resourceId: group._id,
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
grantedBy: testUsers.owner._id,
});
}
// Test pagination with both filters
const response = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', name: 'API', category: 'api-category' })
.expect(200);
expect(response.body.promptGroups).toHaveLength(5);
expect(
response.body.promptGroups.every(
(g) => g.name.includes('API') && g.category === 'api-category',
),
).toBe(true);
expect(response.body.has_more).toBe(true);
expect(response.body.after).toBeTruthy();
// Page 2
const page2 = await request(app)
.get('/api/prompts/groups')
.query({ limit: '5', cursor: response.body.after, name: 'API', category: 'api-category' })
.expect(200);
expect(page2.body.promptGroups).toHaveLength(1); // 6 total, 5 on page 1, 1 on page 2
expect(page2.body.has_more).toBe(false);
});
});
});

View File

@@ -49,7 +49,6 @@ const AppService = async () => {
enabled: isEnabled(process.env.CHECK_BALANCE),
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
};
const transactions = config.transactions ?? configDefaults.transactions;
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
process.env.CDN_PROVIDER = fileStrategy;
@@ -85,7 +84,6 @@ const AppService = async () => {
memory,
speech,
balance,
transactions,
mcpConfig,
webSearch,
fileStrategy,

View File

@@ -152,7 +152,6 @@ describe('AppService', () => {
webSearch: expect.objectContaining({
safeSearch: 1,
jinaApiKey: '${JINA_API_KEY}',
jinaApiUrl: '${JINA_API_URL}',
cohereApiKey: '${COHERE_API_KEY}',
serperApiKey: '${SERPER_API_KEY}',
searxngApiKey: '${SEARXNG_API_KEY}',

View File

@@ -3,12 +3,12 @@ const jwt = require('jsonwebtoken');
const { webcrypto } = require('node:crypto');
const { logger } = require('@librechat/data-schemas');
const { isEnabled, checkEmailConfig } = require('@librechat/api');
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
const { SystemRoles, errorsToString } = require('librechat-data-provider');
const {
findUser,
findToken,
createUser,
updateUser,
findToken,
countUsers,
getUserById,
findSession,
@@ -181,14 +181,6 @@ const registerUser = async (user, additionalData = {}) => {
let newUserId;
try {
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
return { status: 403, message: errorMessage };
}
const existingUser = await findUser({ email }, 'email _id');
if (existingUser) {
@@ -203,6 +195,14 @@ const registerUser = async (user, additionalData = {}) => {
return { status: 200, message: genericVerificationMessage };
}
const appConfig = await getAppConfig({ role: user.role });
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const errorMessage =
'The email address provided cannot be used. Please use a different email address.';
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
return { status: 403, message: errorMessage };
}
//determine if this is the first registered user (not counting anonymous_user)
const isFirstRegisteredUser = (await countUsers()) === 0;
@@ -252,13 +252,6 @@ const registerUser = async (user, additionalData = {}) => {
*/
const requestPasswordReset = async (req) => {
const { email } = req.body;
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
const error = new Error(ErrorTypes.AUTH_FAILED);
error.code = ErrorTypes.AUTH_FAILED;
error.message = 'Email domain not allowed';
return error;
}
const user = await findUser({ email }, 'email _id');
const emailEnabled = checkEmailConfig();
@@ -409,10 +402,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
* @param {import('openid-client').TokenEndpointResponse & import('openid-client').TokenEndpointResponseHelpers} tokenset
* - The tokenset object containing access and refresh tokens
* @param {Object} res - response object
* @param {string} [userId] - Optional MongoDB user ID for image path validation
* @returns {String} - access token
*/
const setOpenIDAuthTokens = (tokenset, res, userId) => {
const setOpenIDAuthTokens = (tokenset, res) => {
try {
if (!tokenset) {
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
@@ -443,18 +435,6 @@ const setOpenIDAuthTokens = (tokenset, res, userId) => {
secure: isProduction,
sameSite: 'strict',
});
if (userId && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
/** JWT-signed user ID cookie for image path validation when OPENID_REUSE_TOKENS is enabled */
const signedUserId = jwt.sign({ id: userId }, process.env.JWT_REFRESH_SECRET, {
expiresIn: expiryInMilliseconds / 1000,
});
res.cookie('openid_user_id', signedUserId, {
expires: expirationDate,
httpOnly: true,
secure: isProduction,
sameSite: 'strict',
});
}
return tokenset.access_token;
} catch (error) {
logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error);
@@ -472,7 +452,7 @@ const setOpenIDAuthTokens = (tokenset, res, userId) => {
const resendVerificationEmail = async (req) => {
try {
const { email } = req.body;
await deleteTokens({ email });
await deleteTokens(email);
const user = await findUser({ email }, 'email _id name');
if (!user) {

View File

@@ -4,8 +4,6 @@ const AppService = require('~/server/services/AppService');
const { setCachedTools } = require('./getCachedTools');
const getLogStores = require('~/cache/getLogStores');
const BASE_CONFIG_KEY = '_BASE_';
/**
* Get the app configuration based on user context
* @param {Object} [options]
@@ -16,8 +14,8 @@ const BASE_CONFIG_KEY = '_BASE_';
async function getAppConfig(options = {}) {
const { role, refresh } = options;
const cache = getLogStores(CacheKeys.APP_CONFIG);
const cacheKey = role ? role : BASE_CONFIG_KEY;
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cacheKey = role ? `${CacheKeys.APP_CONFIG}:${role}` : CacheKeys.APP_CONFIG;
if (!refresh) {
const cached = await cache.get(cacheKey);
@@ -26,7 +24,7 @@ async function getAppConfig(options = {}) {
}
}
let baseConfig = await cache.get(BASE_CONFIG_KEY);
let baseConfig = await cache.get(CacheKeys.APP_CONFIG);
if (!baseConfig) {
logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...');
baseConfig = await AppService();
@@ -39,7 +37,7 @@ async function getAppConfig(options = {}) {
await setCachedTools(baseConfig.availableTools, { isGlobal: true });
}
await cache.set(BASE_CONFIG_KEY, baseConfig);
await cache.set(CacheKeys.APP_CONFIG, baseConfig);
}
// For now, return the base config

View File

@@ -119,6 +119,10 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
.filter((endpoint) => endpoint.customParams)
.forEach((endpoint) => parseCustomParams(endpoint.name, endpoint.customParams));
if (customConfig.cache) {
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
await cache.set(CacheKeys.LIBRECHAT_YAML_CONFIG, customConfig);
}
if (result.data.modelSpecs) {
customConfig.modelSpecs = result.data.modelSpecs;

View File

@@ -48,11 +48,16 @@ const axios = require('axios');
const { loadYaml } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
describe('loadCustomConfig', () => {
const mockSet = jest.fn();
const mockCache = { set: mockSet };
beforeEach(() => {
jest.resetAllMocks();
delete process.env.CONFIG_PATH;
getLogStores.mockReturnValue(mockCache);
});
it('should return null and log error if remote config fetch fails', async () => {
@@ -89,6 +94,7 @@ describe('loadCustomConfig', () => {
const result = await loadCustomConfig();
expect(result).toEqual(mockConfig);
expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig);
});
it('should return null and log if config schema validation fails', async () => {
@@ -128,6 +134,7 @@ describe('loadCustomConfig', () => {
axios.get.mockResolvedValue({ data: mockConfig });
const result = await loadCustomConfig();
expect(result).toEqual(mockConfig);
expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig);
});
it('should return null if the remote config file is not found', async () => {
@@ -161,6 +168,7 @@ describe('loadCustomConfig', () => {
process.env.CONFIG_PATH = 'validConfig.yaml';
loadYaml.mockReturnValueOnce(mockConfig);
await loadCustomConfig();
expect(mockSet).not.toHaveBeenCalled();
});
it('should log the loaded custom config', async () => {

View File

@@ -1,7 +1,6 @@
const { Providers } = require('@librechat/agents');
const {
primeResources,
getModelMaxTokens,
extractLibreChatParams,
optionalChainWithEmptyCheck,
} = require('@librechat/api');
@@ -18,6 +17,7 @@ const { getProviderConfig } = require('~/server/services/Endpoints');
const { processFiles } = require('~/server/services/Files/process');
const { getFiles, getToolFilesByIds } = require('~/models/File');
const { getConvoFiles } = require('~/models/Conversation');
const { getModelMaxTokens } = require('~/utils');
/**
* @param {object} params

View File

@@ -1,14 +1,13 @@
import { logger } from '@librechat/data-schemas';
import { AnthropicClientOptions } from '@librechat/agents';
import { EModelEndpoint, anthropicSettings } from 'librechat-data-provider';
import { matchModelName } from '~/utils/tokens';
const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider');
const { matchModelName } = require('~/utils');
const { logger } = require('~/config');
/**
* @param {string} modelName
* @returns {boolean}
*/
function checkPromptCacheSupport(modelName: string): boolean {
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic) ?? '';
function checkPromptCacheSupport(modelName) {
const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic);
if (
modelMatch.includes('claude-3-5-sonnet-latest') ||
modelMatch.includes('claude-3.5-sonnet-latest')
@@ -32,10 +31,7 @@ function checkPromptCacheSupport(modelName: string): boolean {
* @param {boolean} supportsCacheControl Whether the model supports cache control
* @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable
*/
function getClaudeHeaders(
model: string,
supportsCacheControl: boolean,
): Record<string, string> | undefined {
function getClaudeHeaders(model, supportsCacheControl) {
if (!supportsCacheControl) {
return undefined;
}
@@ -76,13 +72,9 @@ function getClaudeHeaders(
* @param {number|null} extendedOptions.thinkingBudget The token budget for thinking
* @returns {Object} Updated request options
*/
function configureReasoning(
anthropicInput: AnthropicClientOptions & { max_tokens?: number },
extendedOptions: { thinking?: boolean; thinkingBudget?: number | null } = {},
): AnthropicClientOptions & { max_tokens?: number } {
function configureReasoning(anthropicInput, extendedOptions = {}) {
const updatedOptions = { ...anthropicInput };
const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens;
if (
extendedOptions.thinking &&
updatedOptions?.model &&
@@ -90,16 +82,11 @@ function configureReasoning(
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(updatedOptions.model))
) {
updatedOptions.thinking = {
...updatedOptions.thinking,
type: 'enabled',
} as { type: 'enabled'; budget_tokens: number };
};
}
if (
updatedOptions.thinking != null &&
extendedOptions.thinkingBudget != null &&
updatedOptions.thinking.type === 'enabled'
) {
if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) {
updatedOptions.thinking = {
...updatedOptions.thinking,
budget_tokens: extendedOptions.thinkingBudget,
@@ -108,10 +95,9 @@ function configureReasoning(
if (
updatedOptions.thinking != null &&
updatedOptions.thinking.type === 'enabled' &&
(currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens)
) {
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model ?? '');
const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model);
updatedOptions.max_tokens = currentMaxTokens ?? maxTokens;
logger.warn(
@@ -129,4 +115,4 @@ function configureReasoning(
return updatedOptions;
}
export { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };
module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning };

View File

@@ -1,6 +1,6 @@
const { getLLMConfig } = require('@librechat/api');
const { EModelEndpoint } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
const AnthropicClient = require('~/app/clients/AnthropicClient');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
@@ -40,6 +40,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
clientOptions = Object.assign(
{
proxy: PROXY ?? null,
userId: req.user.id,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
modelOptions: endpointOption?.model_parameters ?? {},
},
@@ -48,7 +49,6 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
if (overrideModel) {
clientOptions.modelOptions.model = overrideModel;
}
clientOptions.modelOptions.user = req.user.id;
return getLLMConfig(anthropicApiKey, clientOptions);
}

View File

@@ -0,0 +1,103 @@
const { ProxyAgent } = require('undici');
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
/**
* Generates configuration options for creating an Anthropic language model (LLM) instance.
*
* @param {string} apiKey - The API key for authentication with Anthropic.
* @param {Object} [options={}] - Additional options for configuring the LLM.
* @param {Object} [options.modelOptions] - Model-specific options.
* @param {string} [options.modelOptions.model] - The name of the model to use.
* @param {number} [options.modelOptions.maxOutputTokens] - The maximum number of tokens to generate.
* @param {number} [options.modelOptions.temperature] - Controls randomness in output generation.
* @param {number} [options.modelOptions.topP] - Controls diversity of output generation.
* @param {number} [options.modelOptions.topK] - Controls the number of top tokens to consider.
* @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens.
* @param {boolean} [options.modelOptions.stream] - Whether to stream the response.
* @param {string} options.userId - The user ID for tracking and personalization.
* @param {string} [options.proxy] - Proxy server URL.
* @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used.
*
* @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed.
*/
function getLLMConfig(apiKey, options = {}) {
const systemOptions = {
thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default,
promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default,
thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default,
};
for (let key in systemOptions) {
delete options.modelOptions[key];
}
const defaultOptions = {
model: anthropicSettings.model.default,
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
stream: true,
};
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
/** @type {AnthropicClientOptions} */
let requestOptions = {
apiKey,
model: mergedOptions.model,
stream: mergedOptions.stream,
temperature: mergedOptions.temperature,
stopSequences: mergedOptions.stop,
maxTokens:
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
clientOptions: {},
invocationKwargs: {
metadata: {
user_id: options.userId,
},
},
};
requestOptions = configureReasoning(requestOptions, systemOptions);
if (!/claude-3[-.]7/.test(mergedOptions.model)) {
requestOptions.topP = mergedOptions.topP;
requestOptions.topK = mergedOptions.topK;
} else if (requestOptions.thinking == null) {
requestOptions.topP = mergedOptions.topP;
requestOptions.topK = mergedOptions.topK;
}
const supportsCacheControl =
systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model);
const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl);
if (headers) {
requestOptions.clientOptions.defaultHeaders = headers;
}
if (options.proxy) {
const proxyAgent = new ProxyAgent(options.proxy);
requestOptions.clientOptions.fetchOptions = {
dispatcher: proxyAgent,
};
}
if (options.reverseProxyUrl) {
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
}
const tools = [];
if (mergedOptions.web_search) {
tools.push({
type: 'web_search_20250305',
name: 'web_search',
});
}
return {
tools,
/** @type {AnthropicClientOptions} */
llmConfig: removeNullishValues(requestOptions),
};
}
module.exports = { getLLMConfig };

View File

@@ -0,0 +1,341 @@
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
jest.mock('https-proxy-agent', () => ({
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
}));
describe('getLLMConfig', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should create a basic configuration with default values', () => {
const result = getLLMConfig('test-api-key', { modelOptions: {} });
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
expect(result.llmConfig).toHaveProperty('model', 'claude-3-5-sonnet-latest');
expect(result.llmConfig).toHaveProperty('stream', true);
expect(result.llmConfig).toHaveProperty('maxTokens');
});
it('should include proxy settings when provided', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
proxy: 'http://proxy:8080',
});
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
'ProxyAgent',
);
});
it('should include reverse proxy URL when provided', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
reverseProxyUrl: 'http://reverse-proxy',
});
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy');
});
it('should include topK and topP for non-Claude-3.7 models', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
topK: 10,
topP: 0.9,
},
});
expect(result.llmConfig).toHaveProperty('topK', 10);
expect(result.llmConfig).toHaveProperty('topP', 0.9);
});
it('should include topK and topP for Claude-3.5 models', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
topK: 10,
topP: 0.9,
},
});
expect(result.llmConfig).toHaveProperty('topK', 10);
expect(result.llmConfig).toHaveProperty('topP', 0.9);
});
it('should NOT include topK and topP for Claude-3-7 models with thinking enabled (hyphen notation)', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
topK: 10,
topP: 0.9,
thinking: true,
},
});
expect(result.llmConfig).not.toHaveProperty('topK');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('thinking');
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
// When thinking is enabled, it uses the default thinkingBudget of 2000
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
});
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model', () => {
const modelOptions = {
model: 'claude-sonnet-4-20250514',
promptCache: true,
};
const result = getLLMConfig('test-key', { modelOptions });
const clientOptions = result.llmConfig.clientOptions;
expect(clientOptions.defaultHeaders).toBeDefined();
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31,context-1m-2025-08-07',
);
});
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model formats', () => {
const modelVariations = [
'claude-sonnet-4-20250514',
'claude-sonnet-4-latest',
'anthropic/claude-sonnet-4-20250514',
];
modelVariations.forEach((model) => {
const modelOptions = { model, promptCache: true };
const result = getLLMConfig('test-key', { modelOptions });
const clientOptions = result.llmConfig.clientOptions;
expect(clientOptions.defaultHeaders).toBeDefined();
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31,context-1m-2025-08-07',
);
});
});
it('should NOT include topK and topP for Claude-3.7 models with thinking enabled (decimal notation)', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3.7-sonnet',
topK: 10,
topP: 0.9,
thinking: true,
},
});
expect(result.llmConfig).not.toHaveProperty('topK');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('thinking');
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
// When thinking is enabled, it uses the default thinkingBudget of 2000
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
});
it('should handle custom maxOutputTokens', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
maxOutputTokens: 2048,
},
});
expect(result.llmConfig).toHaveProperty('maxTokens', 2048);
});
it('should handle promptCache setting', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
promptCache: true,
},
});
// We're not checking specific header values since that depends on the actual helper function
// Just verifying that the promptCache setting is processed
expect(result.llmConfig).toBeDefined();
});
it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => {
// Test with thinking explicitly set to null/undefined
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
topK: 10,
topP: 0.9,
thinking: false,
},
});
expect(result.llmConfig).toHaveProperty('topK', 10);
expect(result.llmConfig).toHaveProperty('topP', 0.9);
// Test with thinking explicitly set to false
const result2 = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
topK: 10,
topP: 0.9,
thinking: false,
},
});
expect(result2.llmConfig).toHaveProperty('topK', 10);
expect(result2.llmConfig).toHaveProperty('topP', 0.9);
// Test with decimal notation as well
const result3 = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3.7-sonnet',
topK: 10,
topP: 0.9,
thinking: false,
},
});
expect(result3.llmConfig).toHaveProperty('topK', 10);
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
});
describe('Edge cases', () => {
it('should handle missing apiKey', () => {
const result = getLLMConfig(undefined, { modelOptions: {} });
expect(result.llmConfig).not.toHaveProperty('apiKey');
});
it('should handle empty modelOptions', () => {
expect(() => {
getLLMConfig('test-api-key', {});
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
});
it('should handle no options parameter', () => {
expect(() => {
getLLMConfig('test-api-key');
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
});
it('should handle temperature, stop sequences, and stream settings', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
temperature: 0.7,
stop: ['\n\n', 'END'],
stream: false,
},
});
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']);
expect(result.llmConfig).toHaveProperty('stream', false);
});
it('should handle maxOutputTokens when explicitly set to falsy value', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
maxOutputTokens: null,
},
});
// The actual anthropicSettings.maxOutputTokens.reset('claude-3-opus') returns 4096
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
});
it('should handle both proxy and reverseProxyUrl', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
proxy: 'http://proxy:8080',
reverseProxyUrl: 'https://reverse-proxy.com',
});
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
'ProxyAgent',
);
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com');
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com');
});
it('should handle prompt cache with supported model', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
promptCache: true,
},
});
// claude-3-5-sonnet supports prompt caching and should get the appropriate headers
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31',
});
});
it('should handle thinking and thinkingBudget options', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
thinking: true,
thinkingBudget: 10000, // This exceeds the default max_tokens of 8192
},
});
// The function should add thinking configuration for claude-3-7 models
expect(result.llmConfig).toHaveProperty('thinking');
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
// With claude-3-7-sonnet, the max_tokens default is 8192
// Budget tokens gets adjusted to 90% of max_tokens (8192 * 0.9 = 7372) when it exceeds max_tokens
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 7372);
// Test with budget_tokens within max_tokens limit
const result2 = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
thinking: true,
thinkingBudget: 2000,
},
});
expect(result2.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
});
it('should remove system options from modelOptions', () => {
const modelOptions = {
model: 'claude-3-opus',
thinking: true,
promptCache: true,
thinkingBudget: 1000,
temperature: 0.5,
};
getLLMConfig('test-api-key', { modelOptions });
expect(modelOptions).not.toHaveProperty('thinking');
expect(modelOptions).not.toHaveProperty('promptCache');
expect(modelOptions).not.toHaveProperty('thinkingBudget');
expect(modelOptions).toHaveProperty('temperature', 0.5);
});
it('should handle all nullish values removal', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
temperature: null,
topP: undefined,
topK: 0,
stop: [],
},
});
expect(result.llmConfig).not.toHaveProperty('temperature');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('topK', 0);
expect(result.llmConfig).toHaveProperty('stopSequences', []);
});
});
});

View File

@@ -1,4 +1,3 @@
const { getModelMaxTokens } = require('@librechat/api');
const { createContentAggregator } = require('@librechat/agents');
const {
EModelEndpoint,
@@ -8,6 +7,7 @@ const {
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
const getOptions = require('~/server/services/Endpoints/bedrock/options');
const AgentClient = require('~/server/controllers/agents/client');
const { getModelMaxTokens } = require('~/utils');
const initializeClient = async ({ req, res, endpointOption }) => {
if (!endpointOption) {

View File

@@ -36,12 +36,10 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
/** Intentionally excludes passing `body`, i.e. `req.body`, as
* values may not be accurate until `AgentClient` is initialized
*/
let resolvedHeaders = resolveHeaders({
headers: endpointConfig.headers,
user: req.user,
body: req.body,
});
if (CUSTOM_API_KEY.match(envVarRegex)) {

View File

@@ -76,10 +76,7 @@ describe('custom/initializeClient', () => {
expect(resolveHeaders).toHaveBeenCalledWith({
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
user: { id: 'user-123', email: 'test@example.com', role: 'user' },
/**
* Note: Request-based Header Resolution is deferred until right before LLM request is made
body: { endpoint: 'test-endpoint' }, // body - supports {{LIBRECHAT_BODY_*}} placeholders
*/
});
});

View File

@@ -159,11 +159,9 @@ class STTService {
* Prepares the request for the OpenAI STT provider.
* @param {Object} sttSchema - The STT schema for OpenAI.
* @param {Stream} audioReadStream - The audio data to be transcribed.
* @param {Object} audioFile - The audio file object (unused in OpenAI provider).
* @param {string} language - The language code for the transcription.
* @returns {Array} An array containing the URL, data, and headers for the request.
*/
openAIProvider(sttSchema, audioReadStream, audioFile, language) {
openAIProvider(sttSchema, audioReadStream) {
const url = sttSchema?.url || 'https://api.openai.com/v1/audio/transcriptions';
const apiKey = extractEnvVariable(sttSchema.apiKey) || '';
@@ -172,12 +170,6 @@ class STTService {
model: sttSchema.model,
};
if (language) {
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
const isoLanguage = language.split('-')[0];
data.language = isoLanguage;
}
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { Authorization: `Bearer ${apiKey}` }),
@@ -192,11 +184,10 @@ class STTService {
* @param {Object} sttSchema - The STT schema for Azure OpenAI.
* @param {Buffer} audioBuffer - The audio data to be transcribed.
* @param {Object} audioFile - The audio file object containing originalname, mimetype, and size.
* @param {string} language - The language code for the transcription.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the audio file size exceeds 25MB or the audio file format is not accepted.
*/
azureOpenAIProvider(sttSchema, audioBuffer, audioFile, language) {
azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
const url = `${genAzureEndpoint({
azureOpenAIApiInstanceName: extractEnvVariable(sttSchema?.instanceName),
azureOpenAIApiDeploymentName: extractEnvVariable(sttSchema?.deploymentName),
@@ -220,12 +211,6 @@ class STTService {
contentType: audioFile.mimetype,
});
if (language) {
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
const isoLanguage = language.split('-')[0];
formData.append('language', isoLanguage);
}
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { 'api-key': apiKey }),
@@ -244,11 +229,10 @@ class STTService {
* @param {Object} requestData - The data required for the STT request.
* @param {Buffer} requestData.audioBuffer - The audio data to be transcribed.
* @param {Object} requestData.audioFile - The audio file object containing originalname, mimetype, and size.
* @param {string} requestData.language - The language code for the transcription.
* @returns {Promise<string>} A promise that resolves to the transcribed text.
* @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing.
*/
async sttRequest(provider, sttSchema, { audioBuffer, audioFile, language }) {
async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) {
const strategy = this.providerStrategies[provider];
if (!strategy) {
throw new Error('Invalid provider');
@@ -259,13 +243,7 @@ class STTService {
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = `audio.${fileExtension}`;
const [url, data, headers] = strategy.call(
this,
sttSchema,
audioReadStream,
audioFile,
language,
);
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
try {
const response = await axios.post(url, data, { headers });
@@ -306,8 +284,7 @@ class STTService {
try {
const [provider, sttSchema] = await this.getProviderSchema(req);
const language = req.body?.language || '';
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile, language });
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile });
res.json({ text });
} catch (error) {
logger.error('An error occurred while processing the audio:', error);

View File

@@ -14,18 +14,16 @@ const { getProvider } = require('./TTSService');
*/
async function getVoices(req, res) {
try {
const appConfig =
req.config ??
(await getAppConfig({
role: req.user?.role,
}));
const appConfig = await getAppConfig({
role: req.user?.role,
});
const ttsSchema = appConfig?.speech?.tts;
if (!ttsSchema) {
if (!appConfig || !appConfig?.speech?.tts) {
throw new Error('Configuration or TTS schema is missing');
}
const provider = await getProvider(appConfig);
const ttsSchema = appConfig?.speech?.tts;
const provider = await getProvider(ttsSchema);
let voices;
switch (provider) {

View File

@@ -17,7 +17,7 @@ const { Files } = require('~/models');
* @param {IUser} options.user - The user object
* @param {AppConfig} options.appConfig - The app configuration object
* @param {GraphRunnableConfig['configurable']} options.metadata - The metadata
* @param {{ [Tools.file_search]: { sources: Object[]; fileCitations: boolean } }} options.toolArtifact - The tool artifact containing structured data
* @param {any} options.toolArtifact - The tool artifact containing structured data
* @param {string} options.toolCallId - The tool call ID
* @returns {Promise<Object|null>} The file search attachment or null
*/
@@ -29,14 +29,12 @@ async function processFileCitations({ user, appConfig, toolArtifact, toolCallId,
if (user) {
try {
const hasFileCitationsAccess =
toolArtifact?.[Tools.file_search]?.fileCitations ??
(await checkAccess({
user,
permissionType: PermissionTypes.FILE_CITATIONS,
permissions: [Permissions.USE],
getRoleByName,
}));
const hasFileCitationsAccess = await checkAccess({
user,
permissionType: PermissionTypes.FILE_CITATIONS,
permissions: [Permissions.USE],
getRoleByName,
});
if (!hasFileCitationsAccess) {
logger.debug(

View File

@@ -10,10 +10,9 @@ const { getAgent } = require('~/models/Agent');
* @param {string} [params.role] - Optional user role to avoid DB query
* @param {string[]} params.fileIds - Array of file IDs to check
* @param {string} params.agentId - The agent ID that might grant access
* @param {boolean} [params.isDelete] - Whether the operation is a delete operation
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId }) => {
const accessMap = new Map();
// Initialize all files as no access
@@ -45,23 +44,22 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDele
return accessMap;
}
if (isDelete) {
// Check if user has EDIT permission (which would indicate collaborative access)
const hasEditPermission = await checkPermission({
userId,
role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
// Check if user has EDIT permission (which would indicate collaborative access)
const hasEditPermission = await checkPermission({
userId,
role,
resourceType: ResourceType.AGENT,
resourceId: agent._id,
requiredPermission: PermissionBits.EDIT,
});
// If user only has VIEW permission, they can't access files
// Only users with EDIT permission or higher can access agent files
if (!hasEditPermission) {
return accessMap;
}
// If user only has VIEW permission, they can't access files
// Only users with EDIT permission or higher can access agent files
if (!hasEditPermission) {
return accessMap;
}
// User has edit permissions - check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {

View File

@@ -616,7 +616,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
if (shouldUseSTT) {
const sttService = await STTService.getInstance();
const { text, bytes } = await processAudioFile({ req, file, sttService });
const { text, bytes } = await processAudioFile({ file, sttService });
return await createTextFile({ text, bytes });
}
@@ -646,8 +646,8 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
req,
file,
file_id,
basePath,
entity_id,
basePath,
});
// SECOND: Upload to Vector DB
@@ -670,18 +670,17 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
req,
file,
file_id,
basePath,
entity_id,
basePath,
});
}
let { bytes, filename, filepath: _filepath, height, width } = storageResult;
const { bytes, filename, filepath: _filepath, height, width } = storageResult;
// For RAG files, use embedding result; for others, use storage result
let embedded = storageResult.embedded;
if (tool_resource === EToolResources.file_search) {
embedded = embeddingResult?.embedded;
filename = embeddingResult?.filename || filename;
}
const embedded =
tool_resource === EToolResources.file_search
? embeddingResult?.embedded
: storageResult.embedded;
let filepath = _filepath;
@@ -930,7 +929,6 @@ async function saveBase64Image(
url,
{ req, file_id: _file_id, filename: _filename, endpoint, context, resolution },
) {
const appConfig = req.config;
const effectiveResolution = resolution ?? appConfig.fileConfig?.imageGeneration ?? 'high';
const file_id = _file_id ?? v4();
let filename = `${file_id}-${_filename}`;
@@ -945,6 +943,7 @@ async function saveBase64Image(
}
const image = await resizeImageBuffer(inputBuffer, effectiveResolution, endpoint);
const appConfig = req.config;
const source = getFileStrategy(appConfig, { isImage: true });
const { saveBuffer } = getStrategyFunctions(source);
const filepath = await saveBuffer({

View File

@@ -271,7 +271,6 @@ async function createMCPTool({
availableTools: tools,
}) {
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
const availableTools =
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
/** @type {LCTool | undefined} */

View File

@@ -1,13 +1,13 @@
const axios = require('axios');
const { Providers } = require('@librechat/agents');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { logAxiosError, inputSchema, processModelData } = require('@librechat/api');
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
const { inputSchema, extractBaseURL, processModelData } = require('~/utils');
const { OllamaClient } = require('~/app/clients/OllamaClient');
const { isUserProvided } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { extractBaseURL } = require('~/utils');
/**
* Splits a string by commas and trims each resulting value.

View File

@@ -11,8 +11,8 @@ const {
getAnthropicModels,
} = require('./ModelService');
jest.mock('@librechat/api', () => {
const originalUtils = jest.requireActual('@librechat/api');
jest.mock('~/utils', () => {
const originalUtils = jest.requireActual('~/utils');
return {
...originalUtils,
processModelData: jest.fn((...args) => {
@@ -108,7 +108,7 @@ describe('fetchModels with createTokenConfig true', () => {
beforeEach(() => {
// Clears the mock's history before each test
const _utils = require('@librechat/api');
const _utils = require('~/utils');
axios.get.mockResolvedValue({ data });
});
@@ -120,7 +120,7 @@ describe('fetchModels with createTokenConfig true', () => {
createTokenConfig: true,
});
const { processModelData } = require('@librechat/api');
const { processModelData } = require('~/utils');
expect(processModelData).toHaveBeenCalled();
expect(processModelData).toHaveBeenCalledWith(data);
});

View File

@@ -88,7 +88,6 @@ async function saveUserMessage(req, params) {
parentMessageId: params.parentMessageId ?? Constants.NO_PARENT,
/* For messages, use the assistant_id instead of model */
model: params.assistant_id,
targetModel: params.model,
thread_id: params.thread_id,
sender: 'User',
text: params.text,

View File

@@ -229,7 +229,7 @@
>
<!--[if mso]><style>.v-button {background: transparent !important;}</style><![endif]-->
<div align='left'>
<!--[if mso]><v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="{{verificationLink}}" style="height:37px; v-text-anchor:middle; width:114px;" arcsize="11%" stroke="f" fillcolor="#10a37f"><w:anchorlock/><center style="color:#FFFFFF;"><![endif]-->
<!--[if mso]><v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="href=&quot;{{verificationLink}}&quot;" style="height:37px; v-text-anchor:middle; width:114px;" arcsize="11%" stroke="f" fillcolor="#10a37f"><w:anchorlock/><center style="color:#FFFFFF;"><![endif]-->
<a
href='{{verificationLink}}'
target='_blank'

View File

@@ -1,6 +1,6 @@
const fs = require('fs').promises;
const { logger } = require('@librechat/data-schemas');
const { getImporter } = require('./importers');
const { logger } = require('~/config');
/**
* Job definition for importing a conversation.

View File

@@ -1,9 +1,9 @@
const { v4: uuidv4 } = require('uuid');
const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
const { createImportBatchBuilder } = require('./importBatchBuilder');
const { cloneMessagesWithTimestamps } = require('./fork');
const getLogStores = require('~/cache/getLogStores');
const logger = require('~/config/winston');
/**
* Returns the appropriate importer function based on the provided JSON data.
@@ -212,29 +212,6 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
}
}
/**
* Helper function to find the nearest non-system parent
* @param {string} parentId - The ID of the parent message.
* @returns {string} The ID of the nearest non-system parent message.
*/
const findNonSystemParent = (parentId) => {
if (!parentId || !messageMap.has(parentId)) {
return Constants.NO_PARENT;
}
const parentMapping = conv.mapping[parentId];
if (!parentMapping?.message) {
return Constants.NO_PARENT;
}
/* If parent is a system message, traverse up to find the nearest non-system parent */
if (parentMapping.message.author?.role === 'system') {
return findNonSystemParent(parentMapping.parent);
}
return messageMap.get(parentId);
};
// Create and save messages using the mapped IDs
const messages = [];
for (const [id, mapping] of Object.entries(conv.mapping)) {
@@ -243,27 +220,23 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
messageMap.delete(id);
continue;
} else if (role === 'system') {
// Skip system messages but keep their ID in messageMap for parent references
messageMap.delete(id);
continue;
}
const newMessageId = messageMap.get(id);
const parentMessageId = findNonSystemParent(mapping.parent);
const parentMessageId =
mapping.parent && messageMap.has(mapping.parent)
? messageMap.get(mapping.parent)
: Constants.NO_PARENT;
const messageText = formatMessageText(mapping.message);
const isCreatedByUser = role === 'user';
let sender = isCreatedByUser ? 'user' : 'assistant';
let sender = isCreatedByUser ? 'user' : 'GPT-3.5';
const model = mapping.message.metadata.model_slug || openAISettings.model.default;
if (!isCreatedByUser) {
/** Extracted model name from model slug */
const gptMatch = model.match(/gpt-(.+)/i);
if (gptMatch) {
sender = `GPT-${gptMatch[1]}`;
} else {
sender = model || 'assistant';
}
if (model.includes('gpt-4')) {
sender = 'GPT-4';
}
messages.push({

View File

@@ -99,404 +99,6 @@ describe('importChatGptConvo', () => {
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
});
it('should handle system messages without breaking parent-child relationships', async () => {
/**
* Test data that reproduces message graph "breaking" when it encounters a system message
*/
const testData = [
{
title: 'System Message Parent Test',
create_time: 1714585031.148505,
update_time: 1714585060.879308,
mapping: {
'root-node': {
id: 'root-node',
message: null,
parent: null,
children: ['user-msg-1'],
},
'user-msg-1': {
id: 'user-msg-1',
message: {
id: 'user-msg-1',
author: { role: 'user' },
create_time: 1714585031.150442,
content: { content_type: 'text', parts: ['First user message'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'root-node',
children: ['assistant-msg-1'],
},
'assistant-msg-1': {
id: 'assistant-msg-1',
message: {
id: 'assistant-msg-1',
author: { role: 'assistant' },
create_time: 1714585032.150442,
content: { content_type: 'text', parts: ['First assistant response'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'user-msg-1',
children: ['system-msg'],
},
'system-msg': {
id: 'system-msg',
message: {
id: 'system-msg',
author: { role: 'system' },
create_time: 1714585033.150442,
content: { content_type: 'text', parts: ['System message in middle'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'assistant-msg-1',
children: ['user-msg-2'],
},
'user-msg-2': {
id: 'user-msg-2',
message: {
id: 'user-msg-2',
author: { role: 'user' },
create_time: 1714585034.150442,
content: { content_type: 'text', parts: ['Second user message'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'system-msg',
children: ['assistant-msg-2'],
},
'assistant-msg-2': {
id: 'assistant-msg-2',
message: {
id: 'assistant-msg-2',
author: { role: 'assistant' },
create_time: 1714585035.150442,
content: { content_type: 'text', parts: ['Second assistant response'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'user-msg-2',
children: [],
},
},
},
];
const requestUserId = 'user-123';
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
jest.spyOn(importBatchBuilder, 'saveMessage');
const importer = getImporter(testData);
await importer(testData, requestUserId, () => importBatchBuilder);
/** 2 user messages + 2 assistant messages (system message should be skipped) */
const expectedMessages = 4;
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedMessages);
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
const messageMap = new Map();
savedMessages.forEach((msg) => {
messageMap.set(msg.text, msg);
});
const firstUser = messageMap.get('First user message');
const firstAssistant = messageMap.get('First assistant response');
const secondUser = messageMap.get('Second user message');
const secondAssistant = messageMap.get('Second assistant response');
expect(firstUser).toBeDefined();
expect(firstAssistant).toBeDefined();
expect(secondUser).toBeDefined();
expect(secondAssistant).toBeDefined();
expect(firstUser.parentMessageId).toBe(Constants.NO_PARENT);
expect(firstAssistant.parentMessageId).toBe(firstUser.messageId);
// This is the key test: second user message should have first assistant as parent
// (not NO_PARENT which would indicate the system message broke the chain)
expect(secondUser.parentMessageId).toBe(firstAssistant.messageId);
expect(secondAssistant.parentMessageId).toBe(secondUser.messageId);
});
it('should maintain correct sender for user messages regardless of GPT-4 model', async () => {
/**
* Test data with GPT-4 model to ensure user messages keep 'user' sender
*/
const testData = [
{
title: 'GPT-4 Sender Test',
create_time: 1714585031.148505,
update_time: 1714585060.879308,
mapping: {
'root-node': {
id: 'root-node',
message: null,
parent: null,
children: ['user-msg-1'],
},
'user-msg-1': {
id: 'user-msg-1',
message: {
id: 'user-msg-1',
author: { role: 'user' },
create_time: 1714585031.150442,
content: { content_type: 'text', parts: ['User message with GPT-4'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'root-node',
children: ['assistant-msg-1'],
},
'assistant-msg-1': {
id: 'assistant-msg-1',
message: {
id: 'assistant-msg-1',
author: { role: 'assistant' },
create_time: 1714585032.150442,
content: { content_type: 'text', parts: ['Assistant response with GPT-4'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'user-msg-1',
children: ['user-msg-2'],
},
'user-msg-2': {
id: 'user-msg-2',
message: {
id: 'user-msg-2',
author: { role: 'user' },
create_time: 1714585033.150442,
content: { content_type: 'text', parts: ['Another user message with GPT-4o-mini'] },
metadata: { model_slug: 'gpt-4o-mini' },
},
parent: 'assistant-msg-1',
children: ['assistant-msg-2'],
},
'assistant-msg-2': {
id: 'assistant-msg-2',
message: {
id: 'assistant-msg-2',
author: { role: 'assistant' },
create_time: 1714585034.150442,
content: { content_type: 'text', parts: ['Assistant response with GPT-3.5'] },
metadata: { model_slug: 'gpt-3.5-turbo' },
},
parent: 'user-msg-2',
children: [],
},
},
},
];
const requestUserId = 'user-123';
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
jest.spyOn(importBatchBuilder, 'saveMessage');
const importer = getImporter(testData);
await importer(testData, requestUserId, () => importBatchBuilder);
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
const userMsg1 = savedMessages.find((msg) => msg.text === 'User message with GPT-4');
const assistantMsg1 = savedMessages.find((msg) => msg.text === 'Assistant response with GPT-4');
const userMsg2 = savedMessages.find(
(msg) => msg.text === 'Another user message with GPT-4o-mini',
);
const assistantMsg2 = savedMessages.find(
(msg) => msg.text === 'Assistant response with GPT-3.5',
);
expect(userMsg1.sender).toBe('user');
expect(userMsg1.isCreatedByUser).toBe(true);
expect(userMsg1.model).toBe('gpt-4');
expect(userMsg2.sender).toBe('user');
expect(userMsg2.isCreatedByUser).toBe(true);
expect(userMsg2.model).toBe('gpt-4o-mini');
expect(assistantMsg1.sender).toBe('GPT-4');
expect(assistantMsg1.isCreatedByUser).toBe(false);
expect(assistantMsg1.model).toBe('gpt-4');
expect(assistantMsg2.sender).toBe('GPT-3.5-turbo');
expect(assistantMsg2.isCreatedByUser).toBe(false);
expect(assistantMsg2.model).toBe('gpt-3.5-turbo');
});
it('should correctly extract and format model names from various model slugs', async () => {
/**
* Test data with various model slugs to test dynamic model identifier extraction
*/
const testData = [
{
title: 'Dynamic Model Identifier Test',
create_time: 1714585031.148505,
update_time: 1714585060.879308,
mapping: {
'root-node': {
id: 'root-node',
message: null,
parent: null,
children: ['msg-1'],
},
'msg-1': {
id: 'msg-1',
message: {
id: 'msg-1',
author: { role: 'user' },
create_time: 1714585031.150442,
content: { content_type: 'text', parts: ['Test message'] },
metadata: {},
},
parent: 'root-node',
children: ['msg-2', 'msg-3', 'msg-4', 'msg-5', 'msg-6', 'msg-7', 'msg-8', 'msg-9'],
},
'msg-2': {
id: 'msg-2',
message: {
id: 'msg-2',
author: { role: 'assistant' },
create_time: 1714585032.150442,
content: { content_type: 'text', parts: ['GPT-4 response'] },
metadata: { model_slug: 'gpt-4' },
},
parent: 'msg-1',
children: [],
},
'msg-3': {
id: 'msg-3',
message: {
id: 'msg-3',
author: { role: 'assistant' },
create_time: 1714585033.150442,
content: { content_type: 'text', parts: ['GPT-4o response'] },
metadata: { model_slug: 'gpt-4o' },
},
parent: 'msg-1',
children: [],
},
'msg-4': {
id: 'msg-4',
message: {
id: 'msg-4',
author: { role: 'assistant' },
create_time: 1714585034.150442,
content: { content_type: 'text', parts: ['GPT-4o-mini response'] },
metadata: { model_slug: 'gpt-4o-mini' },
},
parent: 'msg-1',
children: [],
},
'msg-5': {
id: 'msg-5',
message: {
id: 'msg-5',
author: { role: 'assistant' },
create_time: 1714585035.150442,
content: { content_type: 'text', parts: ['GPT-3.5-turbo response'] },
metadata: { model_slug: 'gpt-3.5-turbo' },
},
parent: 'msg-1',
children: [],
},
'msg-6': {
id: 'msg-6',
message: {
id: 'msg-6',
author: { role: 'assistant' },
create_time: 1714585036.150442,
content: { content_type: 'text', parts: ['GPT-4-turbo response'] },
metadata: { model_slug: 'gpt-4-turbo' },
},
parent: 'msg-1',
children: [],
},
'msg-7': {
id: 'msg-7',
message: {
id: 'msg-7',
author: { role: 'assistant' },
create_time: 1714585037.150442,
content: { content_type: 'text', parts: ['GPT-4-1106-preview response'] },
metadata: { model_slug: 'gpt-4-1106-preview' },
},
parent: 'msg-1',
children: [],
},
'msg-8': {
id: 'msg-8',
message: {
id: 'msg-8',
author: { role: 'assistant' },
create_time: 1714585038.150442,
content: { content_type: 'text', parts: ['Claude response'] },
metadata: { model_slug: 'claude-3-opus' },
},
parent: 'msg-1',
children: [],
},
'msg-9': {
id: 'msg-9',
message: {
id: 'msg-9',
author: { role: 'assistant' },
create_time: 1714585039.150442,
content: { content_type: 'text', parts: ['No model slug response'] },
metadata: {},
},
parent: 'msg-1',
children: [],
},
},
},
];
const requestUserId = 'user-123';
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
jest.spyOn(importBatchBuilder, 'saveMessage');
const importer = getImporter(testData);
await importer(testData, requestUserId, () => importBatchBuilder);
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
// Test various GPT model slug formats
const gpt4 = savedMessages.find((msg) => msg.text === 'GPT-4 response');
expect(gpt4.sender).toBe('GPT-4');
expect(gpt4.model).toBe('gpt-4');
const gpt4o = savedMessages.find((msg) => msg.text === 'GPT-4o response');
expect(gpt4o.sender).toBe('GPT-4o');
expect(gpt4o.model).toBe('gpt-4o');
const gpt4oMini = savedMessages.find((msg) => msg.text === 'GPT-4o-mini response');
expect(gpt4oMini.sender).toBe('GPT-4o-mini');
expect(gpt4oMini.model).toBe('gpt-4o-mini');
const gpt35Turbo = savedMessages.find((msg) => msg.text === 'GPT-3.5-turbo response');
expect(gpt35Turbo.sender).toBe('GPT-3.5-turbo');
expect(gpt35Turbo.model).toBe('gpt-3.5-turbo');
const gpt4Turbo = savedMessages.find((msg) => msg.text === 'GPT-4-turbo response');
expect(gpt4Turbo.sender).toBe('GPT-4-turbo');
expect(gpt4Turbo.model).toBe('gpt-4-turbo');
const gpt4Preview = savedMessages.find((msg) => msg.text === 'GPT-4-1106-preview response');
expect(gpt4Preview.sender).toBe('GPT-4-1106-preview');
expect(gpt4Preview.model).toBe('gpt-4-1106-preview');
// Test non-GPT model (should use the model slug as sender)
const claude = savedMessages.find((msg) => msg.text === 'Claude response');
expect(claude.sender).toBe('claude-3-opus');
expect(claude.model).toBe('claude-3-opus');
// Test missing model slug (should default to openAISettings.model.default)
const noModel = savedMessages.find((msg) => msg.text === 'No model slug response');
// When no model slug is provided, it defaults to gpt-4o-mini which gets formatted to GPT-4o-mini
expect(noModel.sender).toBe('GPT-4o-mini');
expect(noModel.model).toBe(openAISettings.model.default);
// Verify user message is unaffected
const userMsg = savedMessages.find((msg) => msg.text === 'Test message');
expect(userMsg.sender).toBe('user');
expect(userMsg.isCreatedByUser).toBe(true);
});
});
describe('importLibreChatConvo', () => {

View File

@@ -1,10 +1,10 @@
const jwt = require('jsonwebtoken');
const mongoose = require('mongoose');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Strategy: AppleStrategy } = require('passport-apple');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { createSocialUser, handleExistingUser } = require('./process');
const { isEnabled } = require('~/server/utils');
const socialLogin = require('./socialLogin');
const { findUser } = require('~/models');
const { User } = require('~/db/models');
@@ -17,8 +17,6 @@ jest.mock('@librechat/data-schemas', () => {
logger: {
error: jest.fn(),
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
},
};
});
@@ -26,19 +24,12 @@ jest.mock('./process', () => ({
createSocialUser: jest.fn(),
handleExistingUser: jest.fn(),
}));
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
jest.mock('~/server/utils', () => ({
isEnabled: jest.fn(),
}));
jest.mock('~/models', () => ({
findUser: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
getAppConfig: jest.fn().mockResolvedValue({
fileStrategy: 'local',
balance: { enabled: false },
}),
}));
describe('Apple Login Strategy', () => {
let mongoServer;
@@ -297,14 +288,7 @@ describe('Apple Login Strategy', () => {
expect(mockVerifyCallback).toHaveBeenCalledWith(null, existingUser);
expect(existingUser.avatarUrl).toBeNull(); // As per getProfileDetails
expect(handleExistingUser).toHaveBeenCalledWith(
existingUser,
null,
expect.objectContaining({
fileStrategy: 'local',
balance: { enabled: false },
}),
);
expect(handleExistingUser).toHaveBeenCalledWith(existingUser, null);
});
it('should handle missing idToken gracefully', async () => {

View File

@@ -4,7 +4,6 @@ const { logger } = require('@librechat/data-schemas');
const { isEnabled, getBalanceConfig } = require('@librechat/api');
const { SystemRoles, ErrorTypes } = require('librechat-data-provider');
const { createUser, findUser, updateUser, countUsers } = require('~/models');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { getAppConfig } = require('~/server/services/Config');
const {
@@ -122,18 +121,9 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
);
}
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) {
logger.error(
`[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
if (!user) {
const isFirstRegisteredUser = (await countUsers()) === 0;
const role = isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER;
user = {
provider: 'ldap',
ldapId,
@@ -143,6 +133,7 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
name: fullName,
role,
};
const appConfig = await getAppConfig({ role: user?.role });
const balanceConfig = getBalanceConfig(appConfig);
const userId = await createUser(user, balanceConfig);
user._id = userId;

View File

@@ -15,7 +15,6 @@ const {
getBalanceConfig,
} = require('@librechat/api');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { findUser, createUser, updateUser } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
const getLogStores = require('~/cache/getLogStores');
@@ -184,7 +183,7 @@ const getUserInfo = async (config, accessToken, sub) => {
const exchangedAccessToken = await exchangeAccessTokenIfNeeded(config, accessToken, sub);
return await client.fetchUserInfo(config, exchangedAccessToken, sub);
} catch (error) {
logger.error('[openidStrategy] getUserInfo: Error fetching user info:', error);
logger.warn(`[openidStrategy] getUserInfo: Error fetching user info: ${error}`);
return null;
}
};
@@ -340,19 +339,6 @@ async function setupOpenId() {
async (tokenset, done) => {
try {
const claims = tokenset.claims();
const userinfo = {
...claims,
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
};
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
const result = await findOpenIDUser({
openidId: claims.sub,
email: claims.email,
@@ -367,7 +353,10 @@ async function setupOpenId() {
message: ErrorTypes.AUTH_FAILED,
});
}
const userinfo = {
...claims,
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
};
const fullName = getFullName(userinfo);
if (requiredRole) {
@@ -420,6 +409,7 @@ async function setupOpenId() {
idOnTheSource: userinfo.oid,
};
const appConfig = await getAppConfig();
const balanceConfig = getBalanceConfig(appConfig);
user = await createUser(user, balanceConfig, true, true);
} else {
@@ -448,9 +438,7 @@ async function setupOpenId() {
userinfo.sub,
);
if (imageBuffer) {
const { saveBuffer } = getStrategyFunctions(
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
);
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),

View File

@@ -3,6 +3,7 @@ const { FileSources } = require('librechat-data-provider');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
const { updateUser, createUser, getUserById } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
/**
* Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter
@@ -11,15 +12,14 @@ const { updateUser, createUser, getUserById } = require('~/models');
*
* @param {IUser} oldUser - The existing user object that needs to be updated.
* @param {string} avatarUrl - The new avatar URL to be set for the user.
* @param {AppConfig} appConfig - The application configuration object.
*
* @returns {Promise<void>}
* The function updates the user's avatar and saves the user object. It does not return any value.
*
* @throws {Error} Throws an error if there's an issue saving the updated user object.
*/
const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
const handleExistingUser = async (oldUser, avatarUrl) => {
const fileStrategy = process.env.CDN_PROVIDER;
const isLocal = fileStrategy === FileSources.local;
let updatedAvatar = false;
@@ -56,7 +56,6 @@ const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
* @param {string} params.providerId - The provider-specific ID of the user.
* @param {string} params.username - The username of the new user.
* @param {string} params.name - The name of the new user.
* @param {AppConfig} appConfig - The application configuration object.
* @param {boolean} [params.emailVerified=false] - Optional. Indicates whether the user's email is verified. Defaults to false.
*
* @returns {Promise<User>}
@@ -72,7 +71,6 @@ const createSocialUser = async ({
providerId,
username,
name,
appConfig,
emailVerified,
}) => {
const update = {
@@ -85,9 +83,10 @@ const createSocialUser = async ({
emailVerified,
};
const appConfig = await getAppConfig();
const balanceConfig = getBalanceConfig(appConfig);
const newUserId = await createUser(update, balanceConfig);
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
const fileStrategy = process.env.CDN_PROVIDER;
const isLocal = fileStrategy === FileSources.local;
if (!isLocal) {

View File

@@ -7,7 +7,6 @@ const { ErrorTypes } = require('librechat-data-provider');
const { hashToken, logger } = require('@librechat/data-schemas');
const { Strategy: SamlStrategy } = require('@node-saml/passport-saml');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { findUser, createUser, updateUser } = require('~/models');
const { getAppConfig } = require('~/server/services/Config');
const paths = require('~/config/paths');
@@ -193,25 +192,16 @@ async function setupSaml() {
logger.info(`[samlStrategy] SAML authentication received for NameID: ${profile.nameID}`);
logger.debug('[samlStrategy] SAML profile:', profile);
const userEmail = getEmail(profile) || '';
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) {
logger.error(
`[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
let user = await findUser({ samlId: profile.nameID });
logger.info(
`[samlStrategy] User ${user ? 'found' : 'not found'} with SAML ID: ${profile.nameID}`,
);
if (!user) {
user = await findUser({ email: userEmail });
const email = getEmail(profile) || '';
user = await findUser({ email });
logger.info(
`[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${userEmail}`,
`[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${profile.email}`,
);
}
@@ -235,10 +225,11 @@ async function setupSaml() {
provider: 'saml',
samlId: profile.nameID,
username,
email: userEmail,
email: getEmail(profile) || '',
emailVerified: true,
name: fullName,
};
const appConfig = await getAppConfig();
const balanceConfig = getBalanceConfig(appConfig);
user = await createUser(user, balanceConfig, true, true);
} else {
@@ -259,9 +250,7 @@ async function setupSaml() {
fileName = profile.nameID + '.png';
}
const { saveBuffer } = getStrategyFunctions(
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
);
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),

View File

@@ -2,8 +2,6 @@ const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { ErrorTypes } = require('librechat-data-provider');
const { createSocialUser, handleExistingUser } = require('./process');
const { isEmailDomainAllowed } = require('~/server/services/domains');
const { getAppConfig } = require('~/server/services/Config');
const { findUser } = require('~/models');
const socialLogin =
@@ -14,22 +12,11 @@ const socialLogin =
profile,
});
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`,
);
const error = new Error(ErrorTypes.AUTH_FAILED);
error.code = ErrorTypes.AUTH_FAILED;
error.message = 'Email domain not allowed';
return cb(error);
}
const existingUser = await findUser({ email: email.trim() });
const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION);
if (existingUser?.provider === provider) {
await handleExistingUser(existingUser, avatarUrl, appConfig);
await handleExistingUser(existingUser, avatarUrl);
return cb(null, existingUser);
} else if (existingUser) {
logger.info(
@@ -41,29 +28,19 @@ const socialLogin =
return cb(error);
}
const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION);
if (!ALLOW_SOCIAL_REGISTRATION) {
logger.error(
`[${provider}Login] Registration blocked - social registration is disabled [Email: ${email}]`,
);
const error = new Error(ErrorTypes.AUTH_FAILED);
error.code = ErrorTypes.AUTH_FAILED;
error.message = 'Social registration is disabled';
return cb(error);
if (ALLOW_SOCIAL_REGISTRATION) {
const newUser = await createSocialUser({
email,
avatarUrl,
provider,
providerKey: `${provider}Id`,
providerId: id,
username,
name,
emailVerified,
});
return cb(null, newUser);
}
const newUser = await createSocialUser({
email,
avatarUrl,
provider,
providerKey: `${provider}Id`,
providerId: id,
username,
name,
emailVerified,
appConfig,
});
return cb(null, newUser);
} catch (err) {
logger.error(`[${provider}Login]`, err);
return cb(err);

View File

@@ -1,7 +1,7 @@
const axios = require('axios');
const deriveBaseURL = require('./deriveBaseURL');
jest.mock('@librechat/api', () => {
const originalUtils = jest.requireActual('@librechat/api');
jest.mock('~/utils', () => {
const originalUtils = jest.requireActual('~/utils');
return {
...originalUtils,
processModelData: jest.fn((...args) => {

View File

@@ -1,3 +1,4 @@
const tokenHelpers = require('./tokens');
const deriveBaseURL = require('./deriveBaseURL');
const extractBaseURL = require('./extractBaseURL');
const findMessageContent = require('./findMessageContent');
@@ -5,5 +6,6 @@ const findMessageContent = require('./findMessageContent');
module.exports = {
deriveBaseURL,
extractBaseURL,
...tokenHelpers,
findMessageContent,
};

View File

@@ -1,23 +1,5 @@
import z from 'zod';
import { EModelEndpoint } from 'librechat-data-provider';
/** Configuration object mapping model keys to their respective prompt, completion rates, and context limit
*
* Note: the [key: string]: unknown is not in the original JSDoc typedef in /api/typedefs.js, but I've included it since
* getModelMaxOutputTokens calls getModelTokenValue with a key of 'output', which was not in the original JSDoc typedef,
* but would be referenced in a TokenConfig in the if(matchedPattern) portion of getModelTokenValue.
* So in order to preserve functionality for that case and any others which might reference an additional key I'm unaware of,
* I've included it here until the interface can be typed more tightly.
*/
export interface TokenConfig {
prompt: number;
completion: number;
context: number;
[key: string]: unknown;
}
/** An endpoint's config object mapping model keys to their respective prompt, completion rates, and context limit */
export type EndpointTokenConfig = Record<string, TokenConfig>;
const z = require('zod');
const { EModelEndpoint } = require('librechat-data-provider');
const openAIModels = {
'o4-mini': 200000,
@@ -260,7 +242,7 @@ const aggregateModels = {
'gpt-oss-120b': 131000,
};
export const maxTokensMap = {
const maxTokensMap = {
[EModelEndpoint.azureOpenAI]: openAIModels,
[EModelEndpoint.openAI]: aggregateModels,
[EModelEndpoint.agents]: aggregateModels,
@@ -270,7 +252,7 @@ export const maxTokensMap = {
[EModelEndpoint.bedrock]: bedrockModels,
};
export const modelMaxOutputs = {
const modelMaxOutputs = {
o1: 32268, // -500 from max: 32,768
'o1-mini': 65136, // -500 from max: 65,536
'o1-preview': 32268, // -500 from max: 32,768
@@ -279,7 +261,7 @@ export const modelMaxOutputs = {
'gpt-5-nano': 128000,
'gpt-oss-20b': 131000,
'gpt-oss-120b': 131000,
system_default: 32000,
system_default: 1024,
};
/** Outputs from https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-names */
@@ -295,7 +277,7 @@ const anthropicMaxOutputs = {
'claude-3-7-sonnet': 128000,
};
export const maxOutputTokensMap = {
const maxOutputTokensMap = {
[EModelEndpoint.anthropic]: anthropicMaxOutputs,
[EModelEndpoint.azureOpenAI]: modelMaxOutputs,
[EModelEndpoint.openAI]: modelMaxOutputs,
@@ -305,13 +287,10 @@ export const maxOutputTokensMap = {
/**
* Finds the first matching pattern in the tokens map.
* @param {string} modelName
* @param {Record<string, number> | EndpointTokenConfig} tokensMap
* @param {Record<string, number>} tokensMap
* @returns {string|null}
*/
export function findMatchingPattern(
modelName: string,
tokensMap: Record<string, number> | EndpointTokenConfig,
): string | null {
function findMatchingPattern(modelName, tokensMap) {
const keys = Object.keys(tokensMap);
for (let i = keys.length - 1; i >= 0; i--) {
const modelKey = keys[i];
@@ -326,79 +305,57 @@ export function findMatchingPattern(
/**
* Retrieves a token value for a given model name from a tokens map.
*
* @param modelName - The name of the model to look up.
* @param tokensMap - The map of model names to token values.
* @param [key='context'] - The key to look up in the tokens map.
* @returns The token value for the given model or undefined if no match is found.
* @param {string} modelName - The name of the model to look up.
* @param {EndpointTokenConfig | Record<string, number>} tokensMap - The map of model names to token values.
* @param {string} [key='context'] - The key to look up in the tokens map.
* @returns {number|undefined} The token value for the given model or undefined if no match is found.
*/
export function getModelTokenValue(
modelName: string,
tokensMap?: EndpointTokenConfig | Record<string, number>,
key = 'context' as keyof TokenConfig,
): number | undefined {
function getModelTokenValue(modelName, tokensMap, key = 'context') {
if (typeof modelName !== 'string' || !tokensMap) {
return undefined;
}
const value = tokensMap[modelName];
if (typeof value === 'number') {
return value;
if (tokensMap[modelName]?.context) {
return tokensMap[modelName].context;
}
if (value?.context) {
return value.context;
if (tokensMap[modelName]) {
return tokensMap[modelName];
}
const matchedPattern = findMatchingPattern(modelName, tokensMap);
if (matchedPattern) {
const result = tokensMap[matchedPattern];
if (typeof result === 'number') {
return result;
}
const tokenValue = result?.[key];
if (typeof tokenValue === 'number') {
return tokenValue;
}
return tokensMap.system_default as number | undefined;
return result?.[key] ?? result ?? tokensMap.system_default;
}
return tokensMap.system_default as number | undefined;
return tokensMap.system_default;
}
/**
* Retrieves the maximum tokens for a given model name.
*
* @param modelName - The name of the model to look up.
* @param endpoint - The endpoint (default is 'openAI').
* @param [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns The maximum tokens for the given model or undefined if no match is found.
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns {number|undefined} The maximum tokens for the given model or undefined if no match is found.
*/
export function getModelMaxTokens(
modelName: string,
endpoint = EModelEndpoint.openAI,
endpointTokenConfig?: EndpointTokenConfig,
): number | undefined {
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint as keyof typeof maxTokensMap];
function getModelMaxTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
const tokensMap = endpointTokenConfig ?? maxTokensMap[endpoint];
return getModelTokenValue(modelName, tokensMap);
}
/**
* Retrieves the maximum output tokens for a given model name.
*
* @param modelName - The name of the model to look up.
* @param endpoint - The endpoint (default is 'openAI').
* @param [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns The maximum output tokens for the given model or undefined if no match is found.
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @param {EndpointTokenConfig} [endpointTokenConfig] - Token Config for current endpoint to use for max tokens lookup
* @returns {number|undefined} The maximum output tokens for the given model or undefined if no match is found.
*/
export function getModelMaxOutputTokens(
modelName: string,
endpoint = EModelEndpoint.openAI,
endpointTokenConfig?: EndpointTokenConfig,
): number | undefined {
const tokensMap =
endpointTokenConfig ?? maxOutputTokensMap[endpoint as keyof typeof maxOutputTokensMap];
function getModelMaxOutputTokens(modelName, endpoint = EModelEndpoint.openAI, endpointTokenConfig) {
const tokensMap = endpointTokenConfig ?? maxOutputTokensMap[endpoint];
return getModelTokenValue(modelName, tokensMap, 'output');
}
@@ -406,24 +363,21 @@ export function getModelMaxOutputTokens(
* Retrieves the model name key for a given model name input. If the exact model name isn't found,
* it searches for partial matches within the model name, checking keys in reverse order.
*
* @param modelName - The name of the model to look up.
* @param endpoint - The endpoint (default is 'openAI').
* @returns The model name key for the given model; returns input if no match is found and is string.
* @param {string} modelName - The name of the model to look up.
* @param {string} endpoint - The endpoint (default is 'openAI').
* @returns {string|undefined} The model name key for the given model; returns input if no match is found and is string.
*
* @example
* matchModelName('gpt-4-32k-0613'); // Returns 'gpt-4-32k-0613'
* matchModelName('gpt-4-32k-unknown'); // Returns 'gpt-4-32k'
* matchModelName('unknown-model'); // Returns undefined
*/
export function matchModelName(
modelName: string,
endpoint = EModelEndpoint.openAI,
): string | undefined {
function matchModelName(modelName, endpoint = EModelEndpoint.openAI) {
if (typeof modelName !== 'string') {
return undefined;
}
const tokensMap: Record<string, number> = maxTokensMap[endpoint as keyof typeof maxTokensMap];
const tokensMap = maxTokensMap[endpoint];
if (!tokensMap) {
return modelName;
}
@@ -436,7 +390,7 @@ export function matchModelName(
return matchedPattern || modelName;
}
export const modelSchema = z.object({
const modelSchema = z.object({
id: z.string(),
pricing: z.object({
prompt: z.string(),
@@ -445,7 +399,7 @@ export const modelSchema = z.object({
context_length: z.number(),
});
export const inputSchema = z.object({
const inputSchema = z.object({
data: z.array(modelSchema),
});
@@ -454,7 +408,7 @@ export const inputSchema = z.object({
* @param {{ data: Array<z.infer<typeof modelSchema>> }} input The input object containing base URL and data fetched from the API.
* @returns {EndpointTokenConfig} The processed model data.
*/
export function processModelData(input: z.infer<typeof inputSchema>): EndpointTokenConfig {
function processModelData(input) {
const validationResult = inputSchema.safeParse(input);
if (!validationResult.success) {
throw new Error('Invalid input data');
@@ -462,7 +416,7 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
const { data } = validationResult.data;
/** @type {EndpointTokenConfig} */
const tokenConfig: EndpointTokenConfig = {};
const tokenConfig = {};
for (const model of data) {
const modelKey = model.id;
@@ -485,7 +439,7 @@ export function processModelData(input: z.infer<typeof inputSchema>): EndpointTo
return tokenConfig;
}
export const tiktokenModels = new Set([
const tiktokenModels = new Set([
'text-davinci-003',
'text-davinci-002',
'text-davinci-001',
@@ -523,3 +477,17 @@ export const tiktokenModels = new Set([
'gpt-3.5-turbo',
'gpt-3.5-turbo-0301',
]);
module.exports = {
inputSchema,
modelSchema,
maxTokensMap,
tiktokenModels,
maxOutputTokensMap,
matchModelName,
processModelData,
getModelMaxTokens,
getModelTokenValue,
findMatchingPattern,
getModelMaxOutputTokens,
};

View File

@@ -1,12 +1,12 @@
const { EModelEndpoint } = require('librechat-data-provider');
const {
maxTokensMap,
matchModelName,
processModelData,
getModelMaxTokens,
maxOutputTokensMap,
findMatchingPattern,
} = require('@librechat/api');
getModelMaxTokens,
processModelData,
matchModelName,
maxTokensMap,
} = require('./tokens');
describe('getModelMaxTokens', () => {
test('should return correct tokens for exact match', () => {
@@ -394,7 +394,7 @@ describe('getModelMaxTokens', () => {
});
test('should return correct max output tokens for GPT-5 models', () => {
const { getModelMaxOutputTokens } = require('@librechat/api');
const { getModelMaxOutputTokens } = require('./tokens');
['gpt-5', 'gpt-5-mini', 'gpt-5-nano'].forEach((model) => {
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
@@ -407,7 +407,7 @@ describe('getModelMaxTokens', () => {
});
test('should return correct max output tokens for GPT-OSS models', () => {
const { getModelMaxOutputTokens } = require('@librechat/api');
const { getModelMaxOutputTokens } = require('./tokens');
['gpt-oss-20b', 'gpt-oss-120b'].forEach((model) => {
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(

BIN
bun.lockb

Binary file not shown.

View File

@@ -1,4 +1,3 @@
/** v0.8.0-rc4 */
module.exports = {
roots: ['<rootDir>/src'],
testEnvironment: 'jsdom',
@@ -29,8 +28,7 @@ module.exports = {
'jest-file-loader',
'^test/(.*)$': '<rootDir>/test/$1',
'^~/(.*)$': '<rootDir>/src/$1',
'^librechat-data-provider/react-query$':
'<rootDir>/../node_modules/librechat-data-provider/src/react-query',
'^librechat-data-provider/react-query$': '<rootDir>/../node_modules/librechat-data-provider/src/react-query',
},
restoreMocks: true,
testResultsProcessor: 'jest-junit',

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/frontend",
"version": "v0.8.0-rc4",
"version": "v0.8.0-rc3",
"description": "",
"type": "module",
"scripts": {
@@ -37,7 +37,6 @@
"@headlessui/react": "^2.1.2",
"@librechat/client": "*",
"@marsidev/react-turnstile": "^1.1.0",
"@mcp-ui/client": "^5.7.0",
"@radix-ui/react-accordion": "^1.1.2",
"@radix-ui/react-alert-dialog": "^1.0.2",
"@radix-ui/react-checkbox": "^1.0.3",
@@ -148,8 +147,8 @@
"tailwindcss": "^3.4.1",
"ts-jest": "^29.2.5",
"typescript": "^5.3.3",
"vite": "^6.3.6",
"vite-plugin-compression2": "^2.2.1",
"vite": "^6.3.4",
"vite-plugin-compression2": "^1.3.3",
"vite-plugin-node-polyfills": "^0.23.0",
"vite-plugin-pwa": "^0.21.2"
}

View File

@@ -1,14 +1,11 @@
import React, { createContext, useContext, useState, useMemo } from 'react';
import React, { createContext, useContext, useState } from 'react';
import { Constants, EModelEndpoint } from 'librechat-data-provider';
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
import type { AgentPanelContextType, MCPServerInfo } from '~/common';
import { useAvailableToolsQuery, useGetActionsQuery, useGetStartupConfig } from '~/data-provider';
import { useLocalize, useGetAgentsConfig, useMCPConnectionStatus } from '~/hooks';
import type { AgentPanelContextType } from '~/common';
import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider';
import { useLocalize, useGetAgentsConfig } from '~/hooks';
import { Panel } from '~/common';
type GroupedToolType = AgentToolType & { tools?: AgentToolType[] };
type GroupedToolsRecord = Record<string, GroupedToolType>;
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
export function useAgentPanelContext() {
@@ -36,117 +33,67 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
enabled: !!agent_id,
});
const { data: startupConfig } = useGetStartupConfig();
const mcpServerNames = useMemo(
() => Object.keys(startupConfig?.mcpServers ?? {}),
[startupConfig],
);
const { connectionStatus } = useMCPConnectionStatus({
enabled: !!agent_id && mcpServerNames.length > 0,
});
const processedData = useMemo(() => {
if (!pluginTools) {
return {
tools: [],
groupedTools: {},
mcpServersMap: new Map<string, MCPServerInfo>(),
};
}
const tools: AgentToolType[] = [];
const groupedTools: GroupedToolsRecord = {};
const configuredServers = new Set(mcpServerNames);
const mcpServersMap = new Map<string, MCPServerInfo>();
for (const pluginTool of pluginTools) {
const tool: AgentToolType = {
tool_id: pluginTool.pluginKey,
metadata: pluginTool as TPlugin,
};
tools.push(tool);
const tools =
pluginTools?.map((tool) => ({
tool_id: tool.pluginKey,
metadata: tool as TPlugin,
agent_id: agent_id || '',
})) || [];
const groupedTools = tools?.reduce(
(acc, tool) => {
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
if (!mcpServersMap.has(serverName)) {
const metadata = {
name: serverName,
pluginKey: serverName,
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
icon: pluginTool.icon || '',
} as TPlugin;
mcpServersMap.set(serverName, {
serverName,
const groupKey = `${serverName.toLowerCase()}`;
if (!acc[groupKey]) {
acc[groupKey] = {
tool_id: groupKey,
metadata: {
name: `${serverName}`,
pluginKey: groupKey,
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
icon: tool.metadata.icon || '',
} as TPlugin,
agent_id: agent_id || '',
tools: [],
isConfigured: configuredServers.has(serverName),
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
metadata,
});
};
}
mcpServersMap.get(serverName)!.tools.push(tool);
} else {
// Non-MCP tool
groupedTools[tool.tool_id] = {
acc[groupKey].tools?.push({
tool_id: tool.tool_id,
metadata: tool.metadata,
agent_id: agent_id || '',
});
} else {
acc[tool.tool_id] = {
tool_id: tool.tool_id,
metadata: tool.metadata,
agent_id: agent_id || '',
};
}
}
for (const mcpServerName of mcpServerNames) {
if (mcpServersMap.has(mcpServerName)) {
continue;
}
const metadata = {
icon: '',
name: mcpServerName,
pluginKey: mcpServerName,
description: `${localize('com_ui_tool_collection_prefix')} ${mcpServerName}`,
} as TPlugin;
mcpServersMap.set(mcpServerName, {
tools: [],
metadata,
isConfigured: true,
serverName: mcpServerName,
isConnected: connectionStatus?.[mcpServerName]?.connectionState === 'connected',
});
}
return {
tools,
groupedTools,
mcpServersMap,
};
}, [pluginTools, localize, mcpServerNames, connectionStatus]);
return acc;
},
{} as Record<string, AgentToolType & { tools?: AgentToolType[] }>,
);
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
const value: AgentPanelContextType = {
mcp,
mcps,
/** Query data for actions and tools */
tools,
action,
setMcp,
actions,
setMcps,
agent_id,
setAction,
pluginTools,
activePanel,
groupedTools,
agentsConfig,
startupConfig,
setActivePanel,
endpointsConfig,
setCurrentAgentId,
tools: processedData.tools,
groupedTools: processedData.groupedTools,
mcpServersMap: processedData.mcpServersMap,
};
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;

View File

@@ -2,14 +2,7 @@ import React, { createContext, useContext, useEffect, useRef } from 'react';
import { useSetRecoilState } from 'recoil';
import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider';
import type { TAgentsEndpoint } from 'librechat-data-provider';
import {
useMCPServerManager,
useSearchApiKeyForm,
useGetAgentsConfig,
useCodeApiKeyForm,
useToolToggle,
} from '~/hooks';
import { getTimestampedValue, setTimestamp } from '~/utils/timestamps';
import { useSearchApiKeyForm, useGetAgentsConfig, useCodeApiKeyForm, useToolToggle } from '~/hooks';
import { ephemeralAgentByConvoId } from '~/store';
interface BadgeRowContextType {
@@ -21,7 +14,6 @@ interface BadgeRowContextType {
codeInterpreter: ReturnType<typeof useToolToggle>;
codeApiKeyForm: ReturnType<typeof useCodeApiKeyForm>;
searchApiKeyForm: ReturnType<typeof useSearchApiKeyForm>;
mcpServerManager: ReturnType<typeof useMCPServerManager>;
}
const BadgeRowContext = createContext<BadgeRowContextType | undefined>(undefined);
@@ -45,11 +37,10 @@ export default function BadgeRowProvider({
isSubmitting,
conversationId,
}: BadgeRowProviderProps) {
const lastKeyRef = useRef<string>('');
const hasInitializedRef = useRef(false);
const lastKeyRef = useRef<string>('');
const { agentsConfig } = useGetAgentsConfig();
const key = conversationId ?? Constants.NEW_CONVO;
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key));
/** Initialize ephemeralAgent from localStorage on mount and when conversation changes */
@@ -62,15 +53,16 @@ export default function BadgeRowProvider({
hasInitializedRef.current = true;
lastKeyRef.current = key;
// Load all localStorage values
const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`;
const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`;
const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`;
const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`;
const codeToggleValue = getTimestampedValue(codeToggleKey);
const webSearchToggleValue = getTimestampedValue(webSearchToggleKey);
const fileSearchToggleValue = getTimestampedValue(fileSearchToggleKey);
const artifactsToggleValue = getTimestampedValue(artifactsToggleKey);
const codeToggleValue = localStorage.getItem(codeToggleKey);
const webSearchToggleValue = localStorage.getItem(webSearchToggleKey);
const fileSearchToggleValue = localStorage.getItem(fileSearchToggleKey);
const artifactsToggleValue = localStorage.getItem(artifactsToggleKey);
const initialValues: Record<string, any> = {};
@@ -106,37 +98,15 @@ export default function BadgeRowProvider({
}
}
/**
* Always set values for all tools (use defaults if not in `localStorage`)
* If `ephemeralAgent` is `null`, create a new object with just our tool values
*/
const finalValues = {
// Always set values for all tools (use defaults if not in localStorage)
// If ephemeralAgent is null, create a new object with just our tool values
setEphemeralAgent((prev) => ({
...(prev || {}),
[Tools.execute_code]: initialValues[Tools.execute_code] ?? false,
[Tools.web_search]: initialValues[Tools.web_search] ?? false,
[Tools.file_search]: initialValues[Tools.file_search] ?? false,
[AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false,
};
setEphemeralAgent((prev) => ({
...(prev || {}),
...finalValues,
}));
Object.entries(finalValues).forEach(([toolKey, value]) => {
if (value !== false) {
let storageKey = artifactsToggleKey;
if (toolKey === Tools.execute_code) {
storageKey = codeToggleKey;
} else if (toolKey === Tools.web_search) {
storageKey = webSearchToggleKey;
} else if (toolKey === Tools.file_search) {
storageKey = fileSearchToggleKey;
}
// Store the value and set timestamp for existing values
localStorage.setItem(storageKey, JSON.stringify(value));
setTimestamp(storageKey);
}
});
}
}, [key, isSubmitting, setEphemeralAgent]);
@@ -186,8 +156,6 @@ export default function BadgeRowProvider({
isAuthenticated: true,
});
const mcpServerManager = useMCPServerManager({ conversationId });
const value: BadgeRowContextType = {
webSearch,
artifacts,
@@ -197,7 +165,6 @@ export default function BadgeRowProvider({
codeApiKeyForm,
codeInterpreter,
searchApiKeyForm,
mcpServerManager,
};
return <BadgeRowContext.Provider value={value}>{children}</BadgeRowContext.Provider>;

View File

@@ -1,31 +0,0 @@
import React, { createContext, useContext, useMemo } from 'react';
import { Constants } from 'librechat-data-provider';
import { useChatContext } from './ChatContext';
interface MCPPanelContextValue {
conversationId: string;
}
const MCPPanelContext = createContext<MCPPanelContextValue | undefined>(undefined);
export function MCPPanelProvider({ children }: { children: React.ReactNode }) {
const { conversation } = useChatContext();
/** Context value only created when conversationId changes */
const contextValue = useMemo<MCPPanelContextValue>(
() => ({
conversationId: conversation?.conversationId ?? Constants.NEW_CONVO,
}),
[conversation?.conversationId],
);
return <MCPPanelContext.Provider value={contextValue}>{children}</MCPPanelContext.Provider>;
}
export function useMCPPanelContext() {
const context = useContext(MCPPanelContext);
if (!context) {
throw new Error('useMCPPanelContext must be used within MCPPanelProvider');
}
return context;
}

View File

@@ -23,7 +23,6 @@ export * from './SetConvoContext';
export * from './SearchContext';
export * from './BadgeRowContext';
export * from './SidePanelContext';
export * from './MCPPanelContext';
export * from './ArtifactsContext';
export * from './PromptGroupsContext';
export { default as BadgeRowProvider } from './BadgeRowContext';

View File

@@ -8,11 +8,6 @@ import type * as t from 'librechat-data-provider';
import type { LucideIcon } from 'lucide-react';
import type { TranslationKeys } from '~/hooks';
export interface ConfigFieldDetail {
title: string;
description: string;
}
export type CodeBarProps = {
lang: string;
error?: boolean;
@@ -216,14 +211,6 @@ export type AgentPanelProps = {
agentsConfig?: t.TAgentsEndpoint | null;
};
export interface MCPServerInfo {
serverName: string;
tools: t.AgentToolType[];
isConfigured: boolean;
isConnected: boolean;
metadata: t.TPlugin;
}
export type AgentPanelContextType = {
action?: t.Action;
actions?: t.Action[];
@@ -233,17 +220,13 @@ export type AgentPanelContextType = {
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
setMcps: React.Dispatch<React.SetStateAction<t.MCP[] | undefined>>;
groupedTools: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
activePanel?: string;
tools: t.AgentToolType[];
pluginTools?: t.TPlugin[];
activePanel?: string;
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
agent_id?: string;
startupConfig?: t.TStartupConfig | null;
agentsConfig?: t.TAgentsEndpoint | null;
endpointsConfig?: t.TEndpointsConfig | null;
/** Pre-computed MCP server information indexed by server key */
mcpServersMap: Map<string, MCPServerInfo>;
};
export type AgentModelPanelProps = {

View File

@@ -13,7 +13,7 @@ interface AgentGridProps {
category: string; // Currently selected category
searchQuery: string; // Current search query
onSelectAgent: (agent: t.Agent) => void; // Callback when agent is selected
scrollElementRef?: React.RefObject<HTMLElement>; // Parent scroll container ref for infinite scroll
scrollElement?: HTMLElement | null; // Parent scroll container for infinite scroll
}
/**
@@ -23,7 +23,7 @@ const AgentGrid: React.FC<AgentGridProps> = ({
category,
searchQuery,
onSelectAgent,
scrollElementRef,
scrollElement,
}) => {
const localize = useLocalize();
@@ -87,7 +87,7 @@ const AgentGrid: React.FC<AgentGridProps> = ({
// Set up infinite scroll
const { setScrollElement } = useInfiniteScroll({
hasNextPage,
isLoading: isFetching || isFetchingNextPage,
isFetchingNextPage,
fetchNextPage: () => {
if (hasNextPage && !isFetching) {
fetchNextPage();
@@ -99,11 +99,10 @@ const AgentGrid: React.FC<AgentGridProps> = ({
// Connect the scroll element when it's provided
useEffect(() => {
const scrollElement = scrollElementRef?.current;
if (scrollElement) {
setScrollElement(scrollElement);
}
}, [scrollElementRef, setScrollElement]);
}, [scrollElement, setScrollElement]);
/**
* Get category display name from API data or use fallback

View File

@@ -11,6 +11,7 @@ import { useDocumentTitle, useHasAccess, useLocalize, TranslationKeys } from '~/
import { useGetEndpointsQuery, useGetAgentCategoriesQuery } from '~/data-provider';
import MarketplaceAdminSettings from './MarketplaceAdminSettings';
import { SidePanelProvider, useChatContext } from '~/Providers';
import { MarketplaceProvider } from './MarketplaceContext';
import { SidePanelGroup } from '~/components/SidePanel';
import { OpenSidebar } from '~/components/Chat/Menus';
import CategoryTabs from './CategoryTabs';
@@ -197,21 +198,21 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
*/
const handleSearch = (query: string) => {
const newParams = new URLSearchParams(searchParams);
const currentCategory = displayCategory;
if (query.trim()) {
newParams.set('q', query.trim());
// Switch to "all" category when starting a new search
navigate(`/agents/all?${newParams.toString()}`);
} else {
newParams.delete('q');
}
// Always preserve current category when searching or clearing search
if (currentCategory === 'promoted') {
navigate(`/agents${newParams.toString() ? `?${newParams.toString()}` : ''}`);
} else {
navigate(
`/agents/${currentCategory}${newParams.toString() ? `?${newParams.toString()}` : ''}`,
);
// Preserve current category when clearing search
const currentCategory = displayCategory;
if (currentCategory === 'promoted') {
navigate(`/agents${newParams.toString() ? `?${newParams.toString()}` : ''}`);
} else {
navigate(
`/agents/${currentCategory}${newParams.toString() ? `?${newParams.toString()}` : ''}`,
);
}
}
};
@@ -271,176 +272,100 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
}
return (
<div className={`relative flex w-full grow overflow-hidden bg-presentation ${className}`}>
<SidePanelProvider>
<SidePanelGroup
defaultLayout={defaultLayout}
fullPanelCollapse={fullCollapse}
defaultCollapsed={defaultCollapsed}
>
<main className="flex h-full flex-col overflow-hidden" role="main">
{/* Scrollable container */}
<div
ref={scrollContainerRef}
className="scrollbar-gutter-stable relative flex h-full flex-col overflow-y-auto overflow-x-hidden"
>
{/* Simplified header for agents marketplace - only show nav controls when needed */}
{!isSmallScreen && (
<div className="sticky top-0 z-20 flex items-center justify-between bg-surface-secondary p-2 font-semibold text-text-primary md:h-14">
<div className="mx-1 flex items-center gap-2">
{!navVisible ? (
<>
<OpenSidebar setNavVisible={setNavVisible} />
<TooltipAnchor
description={localize('com_ui_new_chat')}
render={
<Button
size="icon"
variant="outline"
data-testid="agents-new-chat-button"
aria-label={localize('com_ui_new_chat')}
className="rounded-xl border border-border-light bg-surface-secondary p-2 hover:bg-surface-hover max-md:hidden"
onClick={handleNewChat}
>
<NewChatIcon />
</Button>
}
/>
</>
) : (
// Invisible placeholder to maintain height
<div className="h-10 w-10" />
)}
</div>
</div>
)}
{/* Hero Section - scrolls away */}
{!isSmallScreen && (
<div className="container mx-auto max-w-4xl">
<div className={cn('mb-8 text-center', 'mt-12')}>
<h1 className="mb-3 text-3xl font-bold tracking-tight text-text-primary md:text-5xl">
{localize('com_agents_marketplace')}
</h1>
<p className="mx-auto mb-6 max-w-2xl text-lg text-text-secondary">
{localize('com_agents_marketplace_subtitle')}
</p>
</div>
</div>
)}
{/* Sticky wrapper for search bar and categories */}
<MarketplaceProvider>
<SidePanelProvider>
<SidePanelGroup
defaultLayout={defaultLayout}
fullPanelCollapse={fullCollapse}
defaultCollapsed={defaultCollapsed}
>
<main className="flex h-full flex-col overflow-hidden" role="main">
{/* Scrollable container */}
<div
className={cn(
'sticky z-10 bg-presentation pb-4',
isSmallScreen ? 'top-0' : 'top-14',
)}
ref={scrollContainerRef}
className="scrollbar-gutter-stable relative flex h-full flex-col overflow-y-auto overflow-x-hidden"
>
<div className="container mx-auto max-w-4xl px-4">
{/* Search bar */}
<div className="mx-auto flex max-w-2xl gap-2 pb-6">
<SearchBar value={searchQuery} onSearch={handleSearch} />
{/* TODO: Remove this once we have a better way to handle admin settings */}
{/* Admin Settings */}
<MarketplaceAdminSettings />
{/* Simplified header for agents marketplace - only show nav controls when needed */}
{!isSmallScreen && (
<div className="sticky top-0 z-20 flex items-center justify-between bg-surface-secondary p-2 font-semibold text-text-primary md:h-14">
<div className="mx-1 flex items-center gap-2">
{!navVisible ? (
<>
<OpenSidebar setNavVisible={setNavVisible} />
<TooltipAnchor
description={localize('com_ui_new_chat')}
render={
<Button
size="icon"
variant="outline"
data-testid="agents-new-chat-button"
aria-label={localize('com_ui_new_chat')}
className="rounded-xl border border-border-light bg-surface-secondary p-2 hover:bg-surface-hover max-md:hidden"
onClick={handleNewChat}
>
<NewChatIcon />
</Button>
}
/>
</>
) : (
// Invisible placeholder to maintain height
<div className="h-10 w-10" />
)}
</div>
</div>
)}
{/* Hero Section - scrolls away */}
{!isSmallScreen && (
<div className="container mx-auto max-w-4xl">
<div className={cn('mb-8 text-center', 'mt-12')}>
<h1 className="mb-3 text-3xl font-bold tracking-tight text-text-primary md:text-5xl">
{localize('com_agents_marketplace')}
</h1>
<p className="mx-auto mb-6 max-w-2xl text-lg text-text-secondary">
{localize('com_agents_marketplace_subtitle')}
</p>
</div>
</div>
)}
{/* Sticky wrapper for search bar and categories */}
<div
className={cn(
'sticky z-10 bg-presentation pb-4',
isSmallScreen ? 'top-0' : 'top-14',
)}
>
<div className="container mx-auto max-w-4xl px-4">
{/* Search bar */}
<div className="mx-auto flex max-w-2xl gap-2 pb-6">
<SearchBar value={searchQuery} onSearch={handleSearch} />
{/* TODO: Remove this once we have a better way to handle admin settings */}
{/* Admin Settings */}
<MarketplaceAdminSettings />
</div>
{/* Category tabs */}
<CategoryTabs
categories={categoriesQuery.data || []}
activeTab={displayCategory}
isLoading={categoriesQuery.isLoading}
onChange={handleTabChange}
/>
</div>
</div>
{/* Scrollable content area */}
<div className="container mx-auto max-w-4xl px-4 pb-8">
{/* Two-pane animated container wrapping category header + grid */}
<div className="relative overflow-hidden">
{/* Current content pane */}
<div
className={cn(
isTransitioning &&
(animationDirection === 'right'
? 'motion-safe:animate-slide-out-left'
: 'motion-safe:animate-slide-out-right'),
)}
key={`pane-current-${displayCategory}`}
>
{/* Category header - only show when not searching */}
{!searchQuery && (
<div className="mb-6 mt-6">
{(() => {
// Get category data for display
const getCategoryData = () => {
if (displayCategory === 'promoted') {
return {
name: localize('com_agents_top_picks'),
description: localize('com_agents_recommended'),
};
}
if (displayCategory === 'all') {
return {
name: localize('com_agents_all'),
description: localize('com_agents_all_description'),
};
}
// Find the category in the API data
const categoryData = categoriesQuery.data?.find(
(cat) => cat.value === displayCategory,
);
if (categoryData) {
return {
name: categoryData.label?.startsWith('com_')
? localize(categoryData.label as TranslationKeys)
: categoryData.label,
description: categoryData.description?.startsWith('com_')
? localize(categoryData.description as TranslationKeys)
: categoryData.description || '',
};
}
// Fallback for unknown categories
return {
name:
displayCategory.charAt(0).toUpperCase() + displayCategory.slice(1),
description: '',
};
};
const { name, description } = getCategoryData();
return (
<div className="text-left">
<h2 className="text-2xl font-bold text-text-primary">{name}</h2>
{description && (
<p className="mt-2 text-text-secondary">{description}</p>
)}
</div>
);
})()}
</div>
)}
{/* Agent grid */}
<AgentGrid
key={`grid-${displayCategory}`}
category={displayCategory}
searchQuery={searchQuery}
onSelectAgent={handleAgentSelect}
scrollElementRef={scrollContainerRef}
{/* Category tabs */}
<CategoryTabs
categories={categoriesQuery.data || []}
activeTab={displayCategory}
isLoading={categoriesQuery.isLoading}
onChange={handleTabChange}
/>
</div>
{/* Next content pane, only during transition */}
{isTransitioning && nextCategory && (
</div>
{/* Scrollable content area */}
<div className="container mx-auto max-w-4xl px-4 pb-8">
{/* Two-pane animated container wrapping category header + grid */}
<div className="relative overflow-hidden">
{/* Current content pane */}
<div
className={cn(
'absolute inset-0',
animationDirection === 'right'
? 'motion-safe:animate-slide-in-right'
: 'motion-safe:animate-slide-in-left',
isTransitioning &&
(animationDirection === 'right'
? 'motion-safe:animate-slide-out-left'
: 'motion-safe:animate-slide-out-right'),
)}
key={`pane-next-${nextCategory}-${animationDirection}`}
key={`pane-current-${displayCategory}`}
>
{/* Category header - only show when not searching */}
{!searchQuery && (
@@ -448,13 +373,13 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
{(() => {
// Get category data for display
const getCategoryData = () => {
if (nextCategory === 'promoted') {
if (displayCategory === 'promoted') {
return {
name: localize('com_agents_top_picks'),
description: localize('com_agents_recommended'),
};
}
if (nextCategory === 'all') {
if (displayCategory === 'all') {
return {
name: localize('com_agents_all'),
description: localize('com_agents_all_description'),
@@ -463,7 +388,7 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
// Find the category in the API data
const categoryData = categoriesQuery.data?.find(
(cat) => cat.value === nextCategory,
(cat) => cat.value === displayCategory,
);
if (categoryData) {
return {
@@ -471,9 +396,7 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
? localize(categoryData.label as TranslationKeys)
: categoryData.label,
description: categoryData.description?.startsWith('com_')
? localize(
categoryData.description as Parameters<typeof localize>[0],
)
? localize(categoryData.description as TranslationKeys)
: categoryData.description || '',
};
}
@@ -481,8 +404,8 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
// Fallback for unknown categories
return {
name:
(nextCategory || '').charAt(0).toUpperCase() +
(nextCategory || '').slice(1),
displayCategory.charAt(0).toUpperCase() +
displayCategory.slice(1),
description: '',
};
};
@@ -503,30 +426,113 @@ const AgentMarketplace: React.FC<AgentMarketplaceProps> = ({ className = '' }) =
{/* Agent grid */}
<AgentGrid
key={`grid-${nextCategory}`}
category={nextCategory}
key={`grid-${displayCategory}`}
category={displayCategory}
searchQuery={searchQuery}
onSelectAgent={handleAgentSelect}
scrollElementRef={scrollContainerRef}
scrollElement={scrollContainerRef.current}
/>
</div>
)}
{/* Note: Using Tailwind keyframes for slide in/out animations */}
{/* Next content pane, only during transition */}
{isTransitioning && nextCategory && (
<div
className={cn(
'absolute inset-0',
animationDirection === 'right'
? 'motion-safe:animate-slide-in-right'
: 'motion-safe:animate-slide-in-left',
)}
key={`pane-next-${nextCategory}-${animationDirection}`}
>
{/* Category header - only show when not searching */}
{!searchQuery && (
<div className="mb-6 mt-6">
{(() => {
// Get category data for display
const getCategoryData = () => {
if (nextCategory === 'promoted') {
return {
name: localize('com_agents_top_picks'),
description: localize('com_agents_recommended'),
};
}
if (nextCategory === 'all') {
return {
name: localize('com_agents_all'),
description: localize('com_agents_all_description'),
};
}
// Find the category in the API data
const categoryData = categoriesQuery.data?.find(
(cat) => cat.value === nextCategory,
);
if (categoryData) {
return {
name: categoryData.label?.startsWith('com_')
? localize(categoryData.label as TranslationKeys)
: categoryData.label,
description: categoryData.description?.startsWith('com_')
? localize(
categoryData.description as Parameters<
typeof localize
>[0],
)
: categoryData.description || '',
};
}
// Fallback for unknown categories
return {
name:
(nextCategory || '').charAt(0).toUpperCase() +
(nextCategory || '').slice(1),
description: '',
};
};
const { name, description } = getCategoryData();
return (
<div className="text-left">
<h2 className="text-2xl font-bold text-text-primary">{name}</h2>
{description && (
<p className="mt-2 text-text-secondary">{description}</p>
)}
</div>
);
})()}
</div>
)}
{/* Agent grid */}
<AgentGrid
key={`grid-${nextCategory}`}
category={nextCategory}
searchQuery={searchQuery}
onSelectAgent={handleAgentSelect}
scrollElement={scrollContainerRef.current}
/>
</div>
)}
{/* Note: Using Tailwind keyframes for slide in/out animations */}
</div>
</div>
{/* Agent detail dialog */}
{isDetailOpen && selectedAgent && (
<AgentDetail
agent={selectedAgent}
isOpen={isDetailOpen}
onClose={handleDetailClose}
/>
)}
</div>
{/* Agent detail dialog */}
{isDetailOpen && selectedAgent && (
<AgentDetail
agent={selectedAgent}
isOpen={isDetailOpen}
onClose={handleDetailClose}
/>
)}
</div>
</main>
</SidePanelGroup>
</SidePanelProvider>
</main>
</SidePanelGroup>
</SidePanelProvider>
</MarketplaceProvider>
</div>
);
};

View File

@@ -1,6 +1,5 @@
import React from 'react';
import { render, screen, fireEvent, waitFor, act } from '@testing-library/react';
import { render, screen, fireEvent } from '@testing-library/react';
import '@testing-library/jest-dom';
import AgentGrid from '../AgentGrid';
import type t from 'librechat-data-provider';
@@ -82,115 +81,6 @@ import { useMarketplaceAgentsInfiniteQuery } from '~/data-provider/Agents';
const mockUseMarketplaceAgentsInfiniteQuery = jest.mocked(useMarketplaceAgentsInfiniteQuery);
// Helper to create mock API response
const createMockResponse = (
agentIds: string[],
hasMore: boolean,
afterCursor?: string,
): t.AgentListResponse => ({
object: 'list',
data: agentIds.map(
(id) =>
({
id,
name: `Agent ${id}`,
description: `Description for ${id}`,
created_at: Date.now(),
model: 'gpt-4',
tools: [],
instructions: '',
avatar: null,
provider: 'openai',
model_parameters: {
temperature: 0.7,
top_p: 1,
frequency_penalty: 0,
presence_penalty: 0,
maxContextTokens: 2000,
max_context_tokens: 2000,
max_output_tokens: 2000,
},
}) as t.Agent,
),
first_id: agentIds[0] || '',
last_id: agentIds[agentIds.length - 1] || '',
has_more: hasMore,
after: afterCursor,
});
// Helper to setup mock viewport
const setupViewport = (scrollHeight: number, clientHeight: number) => {
const listeners: { [key: string]: EventListener[] } = {};
return {
scrollHeight,
clientHeight,
scrollTop: 0,
addEventListener: jest.fn((event: string, listener: EventListener) => {
if (!listeners[event]) {
listeners[event] = [];
}
listeners[event].push(listener);
}),
removeEventListener: jest.fn((event: string, listener: EventListener) => {
if (listeners[event]) {
listeners[event] = listeners[event].filter((l) => l !== listener);
}
}),
dispatchEvent: jest.fn((event: Event) => {
const eventListeners = listeners[event.type];
if (eventListeners) {
eventListeners.forEach((listener) => listener(event));
}
return true;
}),
} as unknown as HTMLElement;
};
// Helper to create mock infinite query return value
const createMockInfiniteQuery = (
pages: t.AgentListResponse[],
options?: {
isLoading?: boolean;
hasNextPage?: boolean;
fetchNextPage?: jest.Mock;
isFetchingNextPage?: boolean;
},
) =>
({
data: {
pages,
pageParams: pages.map((_, i) => (i === 0 ? undefined : `cursor-${i * 6}`)),
},
isLoading: options?.isLoading ?? false,
error: null,
isFetching: false,
hasNextPage: options?.hasNextPage ?? pages[pages.length - 1]?.has_more ?? false,
isFetchingNextPage: options?.isFetchingNextPage ?? false,
fetchNextPage: options?.fetchNextPage ?? jest.fn(),
refetch: jest.fn(),
// Add missing required properties for UseInfiniteQueryResult
isError: false,
isLoadingError: false,
isRefetchError: false,
isSuccess: true,
status: 'success' as const,
dataUpdatedAt: Date.now(),
errorUpdateCount: 0,
errorUpdatedAt: 0,
failureCount: 0,
failureReason: null,
fetchStatus: 'idle' as const,
isFetched: true,
isFetchedAfterMount: true,
isInitialLoading: false,
isPaused: false,
isPlaceholderData: false,
isPending: false,
isRefetching: false,
isStale: false,
remove: jest.fn(),
}) as any;
describe('AgentGrid Integration with useGetMarketplaceAgentsQuery', () => {
const mockOnSelectAgent = jest.fn();
@@ -453,15 +343,6 @@ describe('AgentGrid Integration with useGetMarketplaceAgentsQuery', () => {
});
describe('Infinite Scroll Functionality', () => {
beforeEach(() => {
// Silence console.log in tests
jest.spyOn(console, 'log').mockImplementation(() => {});
});
afterEach(() => {
jest.restoreAllMocks();
});
it('should show loading indicator when fetching next page', () => {
mockUseMarketplaceAgentsInfiniteQuery.mockReturnValue({
...defaultMockQueryResult,
@@ -515,358 +396,5 @@ describe('AgentGrid Integration with useGetMarketplaceAgentsQuery', () => {
expect(screen.queryByText("You've reached the end of the results")).not.toBeInTheDocument();
});
describe('Auto-fetch to fill viewport', () => {
it('should NOT auto-fetch when viewport is filled (5 agents, has_more=false)', async () => {
const mockResponse = createMockResponse(['1', '2', '3', '4', '5'], false);
const fetchNextPage = jest.fn();
mockUseMarketplaceAgentsInfiniteQuery.mockReturnValue(
createMockInfiniteQuery([mockResponse], { fetchNextPage }),
);
const scrollElement = setupViewport(500, 1000); // Content smaller than viewport
const scrollElementRef = { current: scrollElement };
const Wrapper = createWrapper();
render(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Wait for initial render
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(5);
});
// Wait to ensure no auto-fetch happens
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 200));
});
// fetchNextPage should NOT be called since has_more is false
expect(fetchNextPage).not.toHaveBeenCalled();
});
it('should auto-fetch when viewport not filled (7 agents, big viewport)', async () => {
const firstPage = createMockResponse(['1', '2', '3', '4', '5', '6'], true, 'cursor-6');
const secondPage = createMockResponse(['7'], false);
let currentPages = [firstPage];
const fetchNextPage = jest.fn();
// Mock that updates pages when fetchNextPage is called
mockUseMarketplaceAgentsInfiniteQuery.mockImplementation(() =>
createMockInfiniteQuery(currentPages, {
fetchNextPage: jest.fn().mockImplementation(() => {
fetchNextPage();
currentPages = [firstPage, secondPage];
return Promise.resolve();
}),
hasNextPage: true,
}),
);
const scrollElement = setupViewport(400, 1200); // Large viewport (content < viewport)
const scrollElementRef = { current: scrollElement };
const Wrapper = createWrapper();
const { rerender } = render(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Wait for initial 6 agents
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(6);
});
// Wait for ResizeObserver and auto-fetch to trigger
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 150));
});
// Auto-fetch should have been triggered (multiple times due to reliability checks)
expect(fetchNextPage).toHaveBeenCalled();
expect(fetchNextPage.mock.calls.length).toBeGreaterThanOrEqual(1);
// Update mock data and re-render
currentPages = [firstPage, secondPage];
rerender(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Should now show all 7 agents
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(7);
});
});
it('should NOT auto-fetch when viewport is filled (7 agents, small viewport)', async () => {
const firstPage = createMockResponse(['1', '2', '3', '4', '5', '6'], true, 'cursor-6');
const fetchNextPage = jest.fn();
mockUseMarketplaceAgentsInfiniteQuery.mockReturnValue(
createMockInfiniteQuery([firstPage], { fetchNextPage, hasNextPage: true }),
);
const scrollElement = setupViewport(1200, 600); // Small viewport, content fills it
const scrollElementRef = { current: scrollElement };
const Wrapper = createWrapper();
render(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Wait for initial 6 agents
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(6);
});
// Wait to ensure no auto-fetch happens
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 200));
});
// Should NOT auto-fetch since viewport is filled
expect(fetchNextPage).not.toHaveBeenCalled();
});
it('should auto-fetch once to fill viewport then stop (20 agents)', async () => {
const allPages = [
createMockResponse(['1', '2', '3', '4', '5', '6'], true, 'cursor-6'),
createMockResponse(['7', '8', '9', '10', '11', '12'], true, 'cursor-12'),
createMockResponse(['13', '14', '15', '16', '17', '18'], true, 'cursor-18'),
createMockResponse(['19', '20'], false),
];
let currentPages = [allPages[0]];
let fetchCount = 0;
const fetchNextPage = jest.fn();
mockUseMarketplaceAgentsInfiniteQuery.mockImplementation(() =>
createMockInfiniteQuery(currentPages, {
fetchNextPage: jest.fn().mockImplementation(() => {
fetchCount++;
fetchNextPage();
if (currentPages.length < 2) {
currentPages = allPages.slice(0, 2);
}
return Promise.resolve();
}),
hasNextPage: currentPages.length < 2,
}),
);
const scrollElement = setupViewport(600, 1000); // Viewport fits ~12 agents
const scrollElementRef = { current: scrollElement };
const Wrapper = createWrapper();
const { rerender } = render(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Wait for initial 6 agents
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(6);
});
// Should auto-fetch to fill viewport
await waitFor(
() => {
expect(fetchNextPage).toHaveBeenCalledTimes(1);
},
{ timeout: 500 },
);
// Simulate viewport being filled after 12 agents
Object.defineProperty(scrollElement, 'scrollHeight', {
value: 1200,
writable: true,
configurable: true,
});
currentPages = allPages.slice(0, 2);
rerender(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Should show 12 agents
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(12);
});
// Wait to ensure no additional auto-fetch
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 200));
});
// Should only have fetched once (to fill viewport)
expect(fetchCount).toBe(1);
expect(fetchNextPage).toHaveBeenCalledTimes(1);
});
it('should auto-fetch when viewport resizes to be taller (window resize)', async () => {
const firstPage = createMockResponse(['1', '2', '3', '4', '5', '6'], true, 'cursor-6');
const secondPage = createMockResponse(['7', '8', '9', '10', '11', '12'], true, 'cursor-12');
let currentPages = [firstPage];
const fetchNextPage = jest.fn();
let resizeObserverCallback: ResizeObserverCallback | null = null;
// Mock that updates pages when fetchNextPage is called
mockUseMarketplaceAgentsInfiniteQuery.mockImplementation(() =>
createMockInfiniteQuery(currentPages, {
fetchNextPage: jest.fn().mockImplementation(() => {
fetchNextPage();
if (currentPages.length === 1) {
currentPages = [firstPage, secondPage];
}
return Promise.resolve();
}),
hasNextPage: currentPages.length === 1,
}),
);
// Mock ResizeObserver to capture the callback
const ResizeObserverMock = jest.fn().mockImplementation((callback) => {
resizeObserverCallback = callback;
return {
observe: jest.fn(),
disconnect: jest.fn(),
unobserve: jest.fn(),
};
});
global.ResizeObserver = ResizeObserverMock as any;
// Start with a small viewport that fits the content
const scrollElement = setupViewport(800, 600);
const scrollElementRef = { current: scrollElement };
const Wrapper = createWrapper();
const { rerender } = render(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Wait for initial 6 agents
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(6);
});
// Verify ResizeObserver was set up
expect(ResizeObserverMock).toHaveBeenCalled();
expect(resizeObserverCallback).not.toBeNull();
// Initially no fetch should happen as viewport is filled
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 100));
});
expect(fetchNextPage).not.toHaveBeenCalled();
// Simulate window resize - make viewport taller
Object.defineProperty(scrollElement, 'clientHeight', {
value: 1200, // Now taller than content
writable: true,
configurable: true,
});
// Trigger ResizeObserver callback to simulate resize detection
act(() => {
if (resizeObserverCallback) {
resizeObserverCallback(
[
{
target: scrollElement,
contentRect: {
x: 0,
y: 0,
width: 800,
height: 1200,
top: 0,
right: 800,
bottom: 1200,
left: 0,
} as DOMRectReadOnly,
borderBoxSize: [],
contentBoxSize: [],
devicePixelContentBoxSize: [],
} as ResizeObserverEntry,
],
{} as ResizeObserver,
);
}
});
// Should trigger auto-fetch due to viewport now being larger than content
await waitFor(
() => {
expect(fetchNextPage).toHaveBeenCalledTimes(1);
},
{ timeout: 500 },
);
// Update the component with new data
rerender(
<Wrapper>
<AgentGrid
category="all"
searchQuery=""
onSelectAgent={mockOnSelectAgent}
scrollElementRef={scrollElementRef}
/>
</Wrapper>,
);
// Should now show 12 agents after fetching
await waitFor(() => {
expect(screen.getAllByRole('gridcell')).toHaveLength(12);
});
});
});
});
});

View File

@@ -194,7 +194,7 @@ describe('Virtual Scrolling Performance', () => {
// Performance check: rendering should be fast
const renderTime = endTime - startTime;
expect(renderTime).toBeLessThan(720);
expect(renderTime).toBeLessThan(650);
console.log(`Rendered 1000 agents in ${renderTime.toFixed(2)}ms`);
console.log(`Only ${renderedCards.length} DOM nodes created for 1000 agents`);

View File

@@ -1,7 +1,7 @@
import React, { useState, useEffect, useContext } from 'react';
import { useForm } from 'react-hook-form';
import { Turnstile } from '@marsidev/react-turnstile';
import { ThemeContext, Spinner, Button, isDark } from '@librechat/client';
import { ThemeContext, Spinner, Button } from '@librechat/client';
import type { TLoginUser, TStartupConfig } from 'librechat-data-provider';
import type { TAuthContext } from '~/common';
import { useResendVerificationEmail, useGetStartupConfig } from '~/data-provider';
@@ -28,7 +28,7 @@ const LoginForm: React.FC<TLoginFormProps> = ({ onSubmit, startupConfig, error,
const { data: config } = useGetStartupConfig();
const useUsernameLogin = config?.ldap?.username;
const validTheme = isDark(theme) ? 'dark' : 'light';
const validTheme = theme === 'dark' ? 'dark' : 'light';
const requireCaptcha = Boolean(startupConfig.turnstile?.siteKey);
useEffect(() => {

View File

@@ -1,7 +1,7 @@
import { useForm } from 'react-hook-form';
import React, { useContext, useState } from 'react';
import { Turnstile } from '@marsidev/react-turnstile';
import { ThemeContext, Spinner, Button, isDark } from '@librechat/client';
import { ThemeContext, Spinner, Button } from '@librechat/client';
import { useNavigate, useOutletContext, useLocation } from 'react-router-dom';
import { useRegisterUserMutation } from 'librechat-data-provider/react-query';
import type { TRegisterUser, TError } from 'librechat-data-provider';
@@ -31,7 +31,7 @@ const Registration: React.FC = () => {
const location = useLocation();
const queryParams = new URLSearchParams(location.search);
const token = queryParams.get('token');
const validTheme = isDark(theme) ? 'dark' : 'light';
const validTheme = theme === 'dark' ? 'dark' : 'light';
// only require captcha if we have a siteKey
const requireCaptcha = Boolean(startupConfig?.turnstile?.siteKey);

View File

@@ -1,4 +1,4 @@
import { memo, useCallback, useState, useEffect, useRef } from 'react';
import { memo, useCallback } from 'react';
import { useRecoilValue } from 'recoil';
import { useForm } from 'react-hook-form';
import { Spinner } from '@librechat/client';
@@ -13,7 +13,6 @@ import { useGetMessagesByConvoId } from '~/data-provider';
import MessagesView from './Messages/MessagesView';
import Presentation from './Presentation';
import ChatForm from './Input/ChatForm';
import CostBar from './CostBar';
import Landing from './Landing';
import Header from './Header';
import Footer from './Footer';
@@ -30,13 +29,7 @@ function LoadingSpinner() {
);
}
function ChatView({
index = 0,
modelCosts,
}: {
index?: number;
modelCosts?: { modelCostTable: Record<string, { prompt: number; completion: number }> };
}) {
function ChatView({ index = 0 }: { index?: number }) {
const { conversationId } = useParams();
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
@@ -44,9 +37,6 @@ function ChatView({
const fileMap = useFileMapContext();
const [showCostBar, setShowCostBar] = useState(false);
const lastScrollY = useRef(0);
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
select: useCallback(
(data: TMessage[]) => {
@@ -64,58 +54,6 @@ function ChatView({
useSSE(rootSubmission, chatHelpers, false);
useSSE(addedSubmission, addedChatHelpers, true);
const checkIfAtBottom = useCallback(
(container: HTMLElement) => {
const currentScrollY = container.scrollTop;
const scrollHeight = container.scrollHeight;
const clientHeight = container.clientHeight;
const distanceFromBottom = scrollHeight - currentScrollY - clientHeight;
const isAtBottom = distanceFromBottom < 10;
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
setShowCostBar(isAtBottom && !isStreaming);
lastScrollY.current = currentScrollY;
},
[chatHelpers.isSubmitting, addedChatHelpers.isSubmitting],
);
useEffect(() => {
const handleScroll = (event: Event) => {
const target = event.target as HTMLElement;
checkIfAtBottom(target);
};
const findAndAttachScrollListener = () => {
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
if (messagesContainer) {
checkIfAtBottom(messagesContainer as HTMLElement);
messagesContainer.addEventListener('scroll', handleScroll, { passive: true });
return () => {
messagesContainer.removeEventListener('scroll', handleScroll);
};
}
setTimeout(findAndAttachScrollListener, 100);
};
const cleanup = findAndAttachScrollListener();
return cleanup;
}, [messagesTree, checkIfAtBottom]);
useEffect(() => {
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
if (isStreaming) {
setShowCostBar(false);
} else {
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
if (messagesContainer) {
checkIfAtBottom(messagesContainer as HTMLElement);
}
}
}, [chatHelpers.isSubmitting, addedChatHelpers.isSubmitting, checkIfAtBottom]);
const methods = useForm<ChatFormValues>({
defaultValues: { text: '' },
});
@@ -131,22 +69,7 @@ function ChatView({
} else if ((isLoading || isNavigating) && !isLandingPage) {
content = <LoadingSpinner />;
} else if (!isLandingPage) {
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
content = (
<MessagesView
messagesTree={messagesTree}
costBar={
!isLandingPage &&
modelCosts && (
<CostBar
messagesTree={messagesTree}
modelCosts={modelCosts}
showCostBar={showCostBar && !isStreaming}
/>
)
}
/>
);
content = <MessagesView messagesTree={messagesTree} />;
} else {
content = <Landing centerFormOnLanding={centerFormOnLanding} />;
}

View File

@@ -1,112 +0,0 @@
import { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import { TModelCosts, TMessage } from 'librechat-data-provider';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import store from '~/store';
interface CostBarProps {
messagesTree: TMessage[];
modelCosts: TModelCosts;
showCostBar: boolean;
}
export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) {
const localize = useLocalize();
const showCostTracking = useRecoilValue(store.showCostTracking);
const conversationCosts = useMemo(() => {
if (!modelCosts?.modelCostTable || !messagesTree) {
return null;
}
let totalPromptTokens = 0;
let totalCompletionTokens = 0;
let totalPromptUSD = 0;
let totalCompletionUSD = 0;
const flattenMessages = (messages: TMessage[]) => {
const flattened: TMessage[] = [];
messages.forEach((message: TMessage) => {
flattened.push(message);
if (message.children && message.children.length > 0) {
flattened.push(...flattenMessages(message.children));
}
});
return flattened;
};
const allMessages = flattenMessages(messagesTree);
allMessages.forEach((message) => {
if (!message.tokenCount) {
return null;
}
const modelToUse = message.isCreatedByUser ? message.targetModel : message.model;
const modelPricing = modelCosts.modelCostTable[modelToUse];
if (message.isCreatedByUser) {
totalPromptTokens += message.tokenCount;
totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt;
} else {
totalCompletionTokens += message.tokenCount;
totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion;
}
});
const totalTokens = totalPromptTokens + totalCompletionTokens;
const totalUSD = totalPromptUSD + totalCompletionUSD;
return {
totals: {
prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD },
completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD },
total: { tokenCount: totalTokens, usd: totalUSD },
},
};
}, [modelCosts, messagesTree]);
if (!showCostTracking || !conversationCosts || !conversationCosts.totals) {
return null;
}
return (
<div
className={cn(
'mx-auto w-full max-w-md px-4 text-xs text-muted-foreground transition-all duration-300 ease-in-out',
showCostBar ? 'opacity-100' : 'opacity-0',
)}
>
<div className="grid grid-cols-3 gap-2 text-center">
<div>
<div>
<ArrowIcon direction="up" />
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.prompt.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.prompt.usd).toFixed(6)}</div>
</div>
<div>
<div>
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.total.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.total.usd).toFixed(6)}</div>
</div>
<div>
<div>
<ArrowIcon direction="down" />
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.completion.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.completion.usd).toFixed(6)}</div>
</div>
</div>
</div>
);
}

View File

@@ -253,7 +253,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
handleSaveBadges={handleSaveBadges}
setBadges={setBadges}
/>
<FileFormChat conversation={conversation} />
<FileFormChat disableInputs={disableInputs} />
{endpoint && (
<div className={cn('flex', isRTL ? 'flex-row-reverse' : 'flex-row')}>
<TextareaAutosize
@@ -301,7 +301,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
)}
>
<div className={`${isRTL ? 'mr-2' : 'ml-2'}`}>
<AttachFileChat conversation={conversation} disableInputs={disableInputs} />
<AttachFileChat disableInputs={disableInputs} />
</div>
<BadgeRow
showEphemeralBadges={!isAgentsEndpoint(endpoint) && !isAssistantsEndpoint(endpoint)}

View File

@@ -7,18 +7,14 @@ import {
isAssistantsEndpoint,
fileConfig as defaultFileConfig,
} from 'librechat-data-provider';
import type { EndpointFileConfig, TConversation } from 'librechat-data-provider';
import type { EndpointFileConfig } from 'librechat-data-provider';
import { useGetFileConfig } from '~/data-provider';
import AttachFileMenu from './AttachFileMenu';
import { useChatContext } from '~/Providers';
import AttachFile from './AttachFile';
function AttachFileChat({
disableInputs,
conversation,
}: {
disableInputs: boolean;
conversation: TConversation | null;
}) {
function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
const { conversation } = useChatContext();
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
const { endpoint, endpointType } = conversation ?? { endpoint: null };
const isAgents = useMemo(() => isAgentsEndpoint(endpoint), [endpoint]);
@@ -39,7 +35,6 @@ function AttachFileChat({
<AttachFileMenu
disabled={disableInputs}
conversationId={conversationId}
agentId={conversation?.agent_id}
endpointFileConfig={endpointFileConfig}
/>
);

View File

@@ -11,13 +11,7 @@ import {
SharePointIcon,
} from '@librechat/client';
import type { EndpointFileConfig } from 'librechat-data-provider';
import {
useAgentToolPermissions,
useAgentCapabilities,
useGetAgentsConfig,
useFileHandling,
useLocalize,
} from '~/hooks';
import { useLocalize, useGetAgentsConfig, useFileHandling, useAgentCapabilities } from '~/hooks';
import useSharePointFileHandling from '~/hooks/Files/useSharePointFileHandling';
import { SharePointPickerDialog } from '~/components/SharePoint';
import { useGetStartupConfig } from '~/data-provider';
@@ -27,17 +21,11 @@ import { cn } from '~/utils';
interface AttachFileMenuProps {
conversationId: string;
agentId?: string | null;
disabled?: boolean | null;
endpointFileConfig?: EndpointFileConfig;
}
const AttachFileMenu = ({
agentId,
disabled,
conversationId,
endpointFileConfig,
}: AttachFileMenuProps) => {
const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: AttachFileMenuProps) => {
const localize = useLocalize();
const isUploadDisabled = disabled ?? false;
const inputRef = useRef<HTMLInputElement>(null);
@@ -64,8 +52,6 @@ const AttachFileMenu = ({
* */
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
const { fileSearchAllowedByAgent, codeAllowedByAgent } = useAgentToolPermissions(agentId);
const handleUploadClick = (isImage?: boolean) => {
if (!inputRef.current) {
return;
@@ -100,22 +86,18 @@ const AttachFileMenu = ({
});
}
if (capabilities.fileSearchEnabled && fileSearchAllowedByAgent) {
if (capabilities.fileSearchEnabled) {
items.push({
label: localize('com_ui_upload_file_search'),
onClick: () => {
setToolResource(EToolResources.file_search);
setEphemeralAgent((prev) => ({
...prev,
[EToolResources.file_search]: true,
}));
onAction();
},
icon: <FileSearch className="icon-md" />,
});
}
if (capabilities.codeEnabled && codeAllowedByAgent) {
if (capabilities.codeEnabled) {
items.push({
label: localize('com_ui_upload_code_files'),
onClick: () => {
@@ -156,8 +138,6 @@ const AttachFileMenu = ({
setToolResource,
setEphemeralAgent,
sharePointEnabled,
codeAllowedByAgent,
fileSearchAllowedByAgent,
setIsSharePointDialogOpen,
]);

Some files were not shown because too many files have changed in this diff Show More