Compare commits

..

1 Commits

Author SHA1 Message Date
Danny Avila
7251308244 mcp example, mock i/o for client-to-server communications 2025-06-26 13:33:59 -04:00
311 changed files with 4596 additions and 16875 deletions

View File

@@ -349,11 +349,6 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1 CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1 MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20 NON_BROWSER_VIOLATION_SCORE=20
TTS_VIOLATION_SCORE=0
STT_VIOLATION_SCORE=0
FORK_VIOLATION_SCORE=0
IMPORT_VIOLATION_SCORE=0
FILE_UPLOAD_VIOLATION_SCORE=0
LOGIN_MAX=7 LOGIN_MAX=7
LOGIN_WINDOW=5 LOGIN_WINDOW=5
@@ -458,8 +453,8 @@ OPENID_REUSE_TOKENS=
OPENID_JWKS_URL_CACHE_ENABLED= OPENID_JWKS_URL_CACHE_ENABLED=
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint. #Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED= OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
# Set to true to use the OpenID Connect end session endpoint for logout # Set to true to use the OpenID Connect end session endpoint for logout
OPENID_USE_END_SESSION_ENDPOINT= OPENID_USE_END_SESSION_ENDPOINT=
@@ -580,10 +575,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
# If you have another service in front of your LibreChat doing compression, disable express based compression here # If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true # DISABLE_COMPRESSION=true
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
#===================================================# #===================================================#
# UI # # UI #
#===================================================# #===================================================#
@@ -601,31 +592,11 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS Options # # REDIS Options #
#===============# #===============#
# Enable Redis for caching and session storage # REDIS_URI=10.10.10.10:6379
# USE_REDIS=true # USE_REDIS=true
# Single Redis instance # USE_REDIS_CLUSTER=true
# REDIS_URI=redis://127.0.0.1:6379 # REDIS_CA=/path/to/ca.crt
# Redis cluster (multiple nodes)
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
# Redis with TLS/SSL encryption and CA certificate
# REDIS_URI=rediss://127.0.0.1:6380
# REDIS_CA=/path/to/ca-cert.pem
# Redis authentication (if required)
# REDIS_USERNAME=your_redis_username
# REDIS_PASSWORD=your_redis_password
# Redis key prefix configuration
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
# REDIS_KEY_PREFIX_VAR=K_REVISION
# Or use static prefix directly
# REDIS_KEY_PREFIX=librechat
# Redis connection limits
# REDIS_MAX_LISTENERS=40
#==================================================# #==================================================#
# Others # # Others #

View File

@@ -1,32 +0,0 @@
name: Publish `@librechat/client` to NPM
on:
workflow_dispatch:
inputs:
reason:
description: 'Reason for manual trigger'
required: false
default: 'Manual publish requested'
jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Node.js
uses: actions/setup-node@v4
with:
node-version: '18.x'
- name: Check if client package exists
run: |
if [ -d "packages/client" ]; then
echo "Client package directory found"
else
echo "Client package directory not found - workflow ready for future use"
exit 0
fi
- name: Placeholder for future publishing
run: echo "Client package publishing workflow is ready"

9
.gitignore vendored
View File

@@ -125,12 +125,3 @@ helm/**/.values.yaml
# SAML Idp cert # SAML Idp cert
*.cert *.cert
# AI Assistants
/.claude/
/.cursor/
/.copilot/
/.aider/
/.openai/
/.tabnine/
/.codeium

View File

@@ -7,8 +7,49 @@ All notable changes to this project will be documented in this file.
## [Unreleased]
- no changes ### ✨ New Features
- ✨ feat: implement search parameter updates by **@mawburn** in [#7151](https://github.com/danny-avila/LibreChat/pull/7151)
- 🎏 feat: Add MCP support for Streamable HTTP Transport by **@benverhees** in [#7353](https://github.com/danny-avila/LibreChat/pull/7353)
- 🔒 feat: Add Content Security Policy using Helmet middleware by **@rubentalstra** in [#7377](https://github.com/danny-avila/LibreChat/pull/7377)
- ✨ feat: Add Normalization for MCP Server Names by **@danny-avila** in [#7421](https://github.com/danny-avila/LibreChat/pull/7421)
- 📊 feat: Improve Helm Chart by **@hofq** in [#3638](https://github.com/danny-avila/LibreChat/pull/3638)
- 🦾 feat: Claude-4 Support by **@danny-avila** in [#7509](https://github.com/danny-avila/LibreChat/pull/7509)
- 🪨 feat: Bedrock Support for Claude-4 Reasoning by **@danny-avila** in [#7517](https://github.com/danny-avila/LibreChat/pull/7517)
### 🌍 Internationalization
- 🌍 i18n: Add `Danish` and `Czech` and `Catalan` localization support by **@rubentalstra** in [#7373](https://github.com/danny-avila/LibreChat/pull/7373)
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7375](https://github.com/danny-avila/LibreChat/pull/7375)
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7468](https://github.com/danny-avila/LibreChat/pull/7468)
### 🔧 Fixes
- 💬 fix: update aria-label for accessibility in ConvoLink component by **@berry-13** in [#7320](https://github.com/danny-avila/LibreChat/pull/7320)
- 🔑 fix: use `apiKey` instead of `openAIApiKey` in OpenAI-like Config by **@danny-avila** in [#7337](https://github.com/danny-avila/LibreChat/pull/7337)
- 🔄 fix: update navigation logic in `useFocusChatEffect` to ensure correct search parameters are used by **@mawburn** in [#7340](https://github.com/danny-avila/LibreChat/pull/7340)
- 🔄 fix: Improve MCP Connection Cleanup by **@danny-avila** in [#7400](https://github.com/danny-avila/LibreChat/pull/7400)
- 🛡️ fix: Preset and Validation Logic for URL Query Params by **@danny-avila** in [#7407](https://github.com/danny-avila/LibreChat/pull/7407)
- 🌘 fix: artifact of preview text is illegible in dark mode by **@nhtruong** in [#7405](https://github.com/danny-avila/LibreChat/pull/7405)
- 🛡️ fix: Temporarily Remove CSP until Configurable by **@danny-avila** in [#7419](https://github.com/danny-avila/LibreChat/pull/7419)
- 💽 fix: Exclude index page `/` from static cache settings by **@sbruel** in [#7382](https://github.com/danny-avila/LibreChat/pull/7382)
### ⚙️ Other Changes
- 📜 docs: CHANGELOG for release v0.7.8 by **@github-actions[bot]** in [#7290](https://github.com/danny-avila/LibreChat/pull/7290)
- 📦 chore: Update API Package Dependencies by **@danny-avila** in [#7359](https://github.com/danny-avila/LibreChat/pull/7359)
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7321](https://github.com/danny-avila/LibreChat/pull/7321)
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7434](https://github.com/danny-avila/LibreChat/pull/7434)
- 🛡️ chore: `multer` v2.0.0 for CVE-2025-47935 and CVE-2025-47944 by **@danny-avila** in [#7454](https://github.com/danny-avila/LibreChat/pull/7454)
- 📂 refactor: Improve `FileAttachment` & File Form Deletion by **@danny-avila** in [#7471](https://github.com/danny-avila/LibreChat/pull/7471)
- 📊 chore: Remove Old Helm Chart by **@hofq** in [#7512](https://github.com/danny-avila/LibreChat/pull/7512)
- 🪖 chore: bump helm app version to v0.7.8 by **@austin-barrington** in [#7524](https://github.com/danny-avila/LibreChat/pull/7524)
---
## [v0.7.8] - ## [v0.7.8] -
Changes from v0.7.8-rc1 to v0.7.8. Changes from v0.7.8-rc1 to v0.7.8.
@@ -50,7 +91,6 @@ Changes from v0.7.8-rc1 to v0.7.8.
--- ---
## [v0.7.8-rc1] - ## [v0.7.8-rc1] -
## [v0.7.8-rc1] -
Changes from v0.7.7 to v0.7.8-rc1. Changes from v0.7.7 to v0.7.8-rc1.

View File

@@ -1,4 +1,4 @@
# v0.7.9-rc1 # v0.7.8
# Base node image # Base node image
FROM node:20-alpine AS node FROM node:20-alpine AS node

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi # Dockerfile.multi
# v0.7.9-rc1 # v0.7.8
# Base for all builds # Base for all builds
FROM node:20-alpine AS base-min FROM node:20-alpine AS base-min

View File

@@ -52,7 +52,7 @@
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features - 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
- 🤖 **AI Model Selection**: - 🤖 **AI Model Selection**:
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure) - Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required - [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints): - Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai, - Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
@@ -66,9 +66,10 @@
- 🔦 **Agents & Tools Integration**: - 🔦 **Agents & Tools Integration**:
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**: - **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding - No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more - Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more - Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools - [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
- 🔍 **Web Search**: - 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context - Search the internet and retrieve relevant information to enhance your AI context

View File

@@ -13,6 +13,7 @@ const {
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { checkBalance } = require('~/models/balanceMethods'); const { checkBalance } = require('~/models/balanceMethods');
const { truncateToolCallOutputs } = require('./prompts'); const { truncateToolCallOutputs } = require('./prompts');
const { addSpaceIfNeeded } = require('~/server/utils');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream'); const TextStream = require('./TextStream');
const { logger } = require('~/config'); const { logger } = require('~/config');
@@ -108,15 +109,12 @@ class BaseClient {
/** /**
* Abstract method to record token usage. Subclasses must implement this method. * Abstract method to record token usage. Subclasses must implement this method.
* If a correction to the token usage is needed, the method should return an object with the corrected token counts. * If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* Should only be used if `recordCollectedUsage` was not used instead.
* @param {string} [model]
* @param {number} promptTokens * @param {number} promptTokens
* @param {number} completionTokens * @param {number} completionTokens
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async recordTokenUsage({ model, promptTokens, completionTokens }) { async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', { logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
model,
promptTokens, promptTokens,
completionTokens, completionTokens,
}); });
@@ -200,10 +198,6 @@ class BaseClient {
this.currentMessages[this.currentMessages.length - 1].messageId = head; this.currentMessages[this.currentMessages.length - 1].messageId = head;
} }
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
responseMessageId = crypto.randomUUID();
}
this.responseMessageId = responseMessageId; this.responseMessageId = responseMessageId;
return { return {
@@ -578,7 +572,7 @@ class BaseClient {
}); });
} }
const { editedContent } = opts; const { generation = '' } = opts;
// It's not necessary to push to currentMessages // It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages // depending on subclass implementation of handling messages
@@ -593,21 +587,11 @@ class BaseClient {
isCreatedByUser: false, isCreatedByUser: false,
model: this.modelOptions?.model ?? this.model, model: this.modelOptions?.model ?? this.model,
sender: this.sender, sender: this.sender,
text: generation,
}; };
this.currentMessages.push(userMessage, latestMessage); this.currentMessages.push(userMessage, latestMessage);
} else if (editedContent != null) { } else {
// Handle editedContent for content parts latestMessage.text = generation;
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
const { index, text, type } = editedContent;
if (index >= 0 && index < latestMessage.content.length) {
const contentPart = latestMessage.content[index];
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
contentPart[ContentTypes.THINK] = text;
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
contentPart[ContentTypes.TEXT] = text;
}
}
}
} }
this.continued = true; this.continued = true;
} else { } else {
@@ -688,32 +672,16 @@ class BaseClient {
}; };
if (typeof completion === 'string') { if (typeof completion === 'string') {
responseMessage.text = completion; responseMessage.text = addSpaceIfNeeded(generation) + completion;
} else if ( } else if (
Array.isArray(completion) && Array.isArray(completion) &&
(this.clientName === EModelEndpoint.agents || (this.clientName === EModelEndpoint.agents ||
isParamEndpoint(this.options.endpoint, this.options.endpointType)) isParamEndpoint(this.options.endpoint, this.options.endpointType))
) { ) {
responseMessage.text = ''; responseMessage.text = '';
responseMessage.content = completion;
if (!opts.editedContent || this.currentMessages.length === 0) {
responseMessage.content = completion;
} else {
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
if (!latestMessage?.content) {
responseMessage.content = completion;
} else {
const existingContent = [...latestMessage.content];
const { type: editedType } = opts.editedContent;
responseMessage.content = this.mergeEditedContent(
existingContent,
completion,
editedType,
);
}
}
} else if (Array.isArray(completion)) { } else if (Array.isArray(completion)) {
responseMessage.text = completion.join(''); responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
} }
if ( if (
@@ -744,13 +712,9 @@ class BaseClient {
} else { } else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount; completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens,
model: responseMessage.model,
});
} }
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
} }
if (userMessagePromise) { if (userMessagePromise) {
@@ -1131,50 +1095,6 @@ class BaseClient {
return numTokens; return numTokens;
} }
/**
* Merges completion content with existing content when editing TEXT or THINK types
* @param {Array} existingContent - The existing content array
* @param {Array} newCompletion - The new completion content
* @param {string} editedType - The type of content being edited
* @returns {Array} The merged content array
*/
mergeEditedContent(existingContent, newCompletion, editedType) {
if (!newCompletion.length) {
return existingContent.concat(newCompletion);
}
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
return existingContent.concat(newCompletion);
}
const lastIndex = existingContent.length - 1;
const lastExisting = existingContent[lastIndex];
const firstNew = newCompletion[0];
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
return existingContent.concat(newCompletion);
}
const mergedContent = [...existingContent];
if (editedType === ContentTypes.TEXT) {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.TEXT]:
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
};
} else {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.THINK]:
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
(firstNew[ContentTypes.THINK] || ''),
};
}
// Add remaining completion items
return mergedContent.concat(newCompletion.slice(1));
}
async sendPayload(payload, opts = {}) { async sendPayload(payload, opts = {}) {
if (opts && typeof opts === 'object') { if (opts && typeof opts === 'object') {
this.setOptions(opts); this.setOptions(opts);

View File

@@ -1,7 +1,6 @@
const axios = require('axios'); const axios = require('axios');
const { isEnabled } = require('@librechat/api'); const { isEnabled } = require('~/server/utils');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('~/config');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const footer = `Use the context as your learned knowledge to better answer the user. const footer = `Use the context as your learned knowledge to better answer the user.
@@ -19,7 +18,7 @@ function createContextHandlers(req, userMessageContent) {
const queryPromises = []; const queryPromises = [];
const processedFiles = []; const processedFiles = [];
const processedIds = new Set(); const processedIds = new Set();
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT); const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
const query = async (file) => { const query = async (file) => {

View File

@@ -237,9 +237,41 @@ const formatAgentMessages = (payload) => {
return messages; return messages;
}; };
/**
* Formats an array of messages for LangChain, making sure all content fields are strings
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
*/
const formatContentStrings = (payload) => {
const messages = [];
for (const message of payload) {
if (typeof message.content === 'string') {
continue;
}
if (!Array.isArray(message.content)) {
continue;
}
// Reduce text types to a single string, ignore all other types
const content = message.content.reduce((acc, curr) => {
if (curr.type === ContentTypes.TEXT) {
return `${acc}${curr[ContentTypes.TEXT]}\n`;
}
return acc;
}, '');
message.content = content.trim();
}
return messages;
};
module.exports = { module.exports = {
formatMessage, formatMessage,
formatFromLangChain, formatFromLangChain,
formatAgentMessages, formatAgentMessages,
formatContentStrings,
formatLangChainMessages, formatLangChainMessages,
}; };

View File

@@ -422,46 +422,6 @@ describe('BaseClient', () => {
expect(response).toEqual(expectedResult); expect(response).toEqual(expectedResult);
}); });
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
const mockCrypto = require('crypto');
const newUUID = 'new-uuid-1234';
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe(newUUID);
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
mockCrypto.randomUUID.mockRestore();
});
test('should not replace responseMessageId when isRegenerate is false', async () => {
const opts = {
isRegenerate: false,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id_');
});
test('should not replace responseMessageId when it does not end with underscore', async () => {
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id');
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => { test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation'; const userMessage = 'Second message in the conversation';
const opts = { const opts = {

View File

@@ -1,35 +1,26 @@
const { z } = require('zod'); const { z } = require('zod');
const axios = require('axios'); const axios = require('axios');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Tools, EToolResources } = require('librechat-data-provider'); const { Tools, EToolResources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/** /**
* *
* @param {Object} options * @param {Object} options
* @param {ServerRequest} options.req * @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources * @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @returns {Promise<{ * @returns {Promise<{
* files: Array<{ file_id: string; filename: string }>, * files: Array<{ file_id: string; filename: string }>,
* toolContext: string * toolContext: string
* }>} * }>}
*/ */
const primeFiles = async (options) => { const primeFiles = async (options) => {
const { tool_resources, req, agentId } = options; const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? []; const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids); const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? []; const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
const dbFiles = ( const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`; let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
@@ -68,7 +59,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
if (files.length === 0) { if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.'; return 'No files to search. Instruct the user to add files for the search.';
} }
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
if (!jwtToken) { if (!jwtToken) {
return 'There was an error authenticating the file search request.'; return 'There was an error authenticating the file search request.';
} }

View File

@@ -1,9 +1,14 @@
const { mcpToolPattern } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator'); const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider'); const {
Tools,
EToolResources,
loadWebSearchAuth,
replaceSpecialVars,
} = require('librechat-data-provider');
const { const {
availableTools, availableTools,
manifestToolMap, manifestToolMap,
@@ -240,13 +245,7 @@ const loadTools = async ({
authFields: [EnvVar.CODE_API_KEY], authFields: [EnvVar.CODE_API_KEY],
}); });
const codeApiKey = authValues[EnvVar.CODE_API_KEY]; const codeApiKey = authValues[EnvVar.CODE_API_KEY];
const { files, toolContext } = await primeCodeFiles( const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
{
...options,
agentId: agent?.id,
},
codeApiKey,
);
if (toolContext) { if (toolContext) {
toolContextMap[tool] = toolContext; toolContextMap[tool] = toolContext;
} }
@@ -261,10 +260,7 @@ const loadTools = async ({
continue; continue;
} else if (tool === Tools.file_search) { } else if (tool === Tools.file_search) {
requestedTools[tool] = async () => { requestedTools[tool] = async () => {
const { files, toolContext } = await primeSearchFiles({ const { files, toolContext } = await primeSearchFiles(options);
...options,
agentId: agent?.id,
});
if (toolContext) { if (toolContext) {
toolContextMap[tool] = toolContext; toolContextMap[tool] = toolContext;
} }

View File

@@ -1,33 +0,0 @@
const fs = require('fs');
const { math, isEnabled } = require('@librechat/api');
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
}
const USE_REDIS = isEnabled(process.env.USE_REDIS);
if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
const cacheConfig = {
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
CI: isEnabled(process.env.CI),
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
};
module.exports = { cacheConfig };

View File

@@ -1,108 +0,0 @@
const fs = require('fs');
describe('cacheConfig', () => {
let originalEnv;
let originalReadFileSync;
beforeEach(() => {
originalEnv = { ...process.env };
originalReadFileSync = fs.readFileSync;
// Clear all related env vars first
delete process.env.REDIS_URI;
delete process.env.REDIS_CA;
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
// Clear require cache
jest.resetModules();
});
afterEach(() => {
process.env = originalEnv;
fs.readFileSync = originalReadFileSync;
jest.resetModules();
});
describe('REDIS_KEY_PREFIX validation and resolution', () => {
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
expect(() => {
require('./cacheConfig');
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
});
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.DEPLOYMENT_ID = 'test-deployment-123';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
});
test('should use direct REDIS_KEY_PREFIX value', () => {
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
});
test('should default to empty string when no prefix is configured', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle empty variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
process.env.EMPTY_VAR = '';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle undefined variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
});
describe('USE_REDIS and REDIS_URI validation', () => {
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
process.env.USE_REDIS = 'true';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
expect(() => {
require('./cacheConfig');
}).not.toThrow();
});
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = '';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_CA).toBeNull();
});
});
});

View File

@@ -1,66 +0,0 @@
const KeyvRedis = require('@keyv/redis').default;
const { Keyv } = require('keyv');
const { cacheConfig } = require('./cacheConfig');
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
const { Time } = require('librechat-data-provider');
const { RedisStore: ConnectRedis } = require('connect-redis');
const MemoryStore = require('memorystore')(require('express-session'));
const { violationFile } = require('./keyvFiles');
const { RedisStore } = require('rate-limit-redis');
/**
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
* @param {string} namespace - The cache namespace.
* @param {number} [ttl] - Time to live for cache entries.
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
* @returns {Keyv} Cache instance.
*/
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
if (cacheConfig.USE_REDIS) {
const keyvRedis = new KeyvRedis(keyvRedisClient);
const cache = new Keyv(keyvRedis, { namespace, ttl });
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
return cache;
}
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
return new Keyv({ namespace, ttl });
};
/**
* Creates a cache instance for storing violation data.
* Uses a file-based fallback store if Redis is not enabled.
* @param {string} namespace - The cache namespace for violations.
* @param {number} [ttl] - Time to live for cache entries.
* @returns {Keyv} Cache instance for violations.
*/
const violationCache = (namespace, ttl = undefined) => {
return standardCache(`violations:${namespace}`, ttl, violationFile);
};
/**
* Creates a session cache instance using Redis or in-memory store.
* @param {string} namespace - The session namespace.
* @param {number} [ttl] - Time to live for session entries.
* @returns {MemoryStore | ConnectRedis} Session store instance.
*/
const sessionCache = (namespace, ttl = undefined) => {
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
return new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
};
/**
* Creates a rate limiter cache using Redis.
* @param {string} prefix - The key prefix for rate limiting.
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
*/
const limiterCache = (prefix) => {
if (!prefix) throw new Error('prefix is required');
if (!cacheConfig.USE_REDIS) return undefined;
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
return new RedisStore({ sendCommand, prefix });
};
const sendCommand = (...args) => ioredisClient?.call(...args);
module.exports = { standardCache, sessionCache, violationCache, limiterCache };

View File

@@ -1,270 +0,0 @@
const { Time } = require('librechat-data-provider');
// Mock dependencies first
const mockKeyvRedis = {
namespace: '',
keyPrefixSeparator: '',
};
const mockKeyv = jest.fn().mockReturnValue({ mock: 'keyv' });
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
const mockIoredisClient = {
call: jest.fn(),
};
const mockKeyvRedisClient = {};
const mockViolationFile = {};
// Mock modules before requiring the main module
jest.mock('@keyv/redis', () => ({
default: jest.fn().mockImplementation(() => mockKeyvRedis),
}));
jest.mock('keyv', () => ({
Keyv: mockKeyv,
}));
jest.mock('./cacheConfig', () => ({
cacheConfig: {
USE_REDIS: false,
REDIS_KEY_PREFIX: 'test',
},
}));
jest.mock('./redisClients', () => ({
keyvRedisClient: mockKeyvRedisClient,
ioredisClient: mockIoredisClient,
GLOBAL_PREFIX_SEPARATOR: '::',
}));
jest.mock('./keyvFiles', () => ({
violationFile: mockViolationFile,
}));
jest.mock('connect-redis', () => ({ RedisStore: mockConnectRedis }));
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
jest.mock('rate-limit-redis', () => ({
RedisStore: mockRedisStore,
}));
// Import after mocking
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
const { cacheConfig } = require('./cacheConfig');
describe('cacheFactory', () => {
beforeEach(() => {
jest.clearAllMocks();
// Reset cache config mock
cacheConfig.USE_REDIS = false;
cacheConfig.REDIS_KEY_PREFIX = 'test';
});
describe('redisCache', () => {
it('should create Redis cache when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
});
it('should create Redis cache with undefined ttl when not provided', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
standardCache(namespace);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
});
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
const fallbackStore = { some: 'store' };
standardCache(namespace, ttl, fallbackStore);
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
});
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
});
it('should handle namespace and ttl as undefined', () => {
cacheConfig.USE_REDIS = false;
standardCache();
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
});
});
describe('violationCache', () => {
it('should create violation cache with prefixed namespace', () => {
const namespace = 'test-violations';
const ttl = 7200;
// We can't easily mock the internal redisCache call since it's in the same module
// But we can test that the function executes without throwing
expect(() => violationCache(namespace, ttl)).not.toThrow();
});
it('should create violation cache with undefined ttl', () => {
const namespace = 'test-violations';
violationCache(namespace);
// The function should call redisCache with violations: prefixed namespace
// Since we can't easily mock the internal redisCache call, we test the behavior
expect(() => violationCache(namespace)).not.toThrow();
});
it('should handle undefined namespace', () => {
expect(() => violationCache(undefined)).not.toThrow();
});
});
describe('sessionCache', () => {
it('should return MemoryStore when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
expect(result).toBe(mockMemoryStore());
});
it('should return ConnectRedis when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl,
prefix: `${namespace}:`,
});
expect(result).toBe(mockConnectRedis());
});
it('should add colon to namespace if not present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should not add colon to namespace if already present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions:';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should handle undefined ttl', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockMemoryStore).toHaveBeenCalledWith({
ttl: undefined,
checkPeriod: Time.ONE_DAY,
});
});
});
describe('limiterCache', () => {
it('should return undefined when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const result = limiterCache('prefix');
expect(result).toBeUndefined();
});
it('should return RedisStore when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const result = limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: `rate-limit:`,
});
expect(result).toBe(mockRedisStore());
});
it('should add colon to prefix if not present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should not add colon to prefix if already present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit:');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should pass sendCommand function that calls ioredisClient.call', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
const sendCommandCall = mockRedisStore.mock.calls[0][0];
const sendCommand = sendCommandCall.sendCommand;
// Test that sendCommand properly delegates to ioredisClient.call
const args = ['GET', 'test-key'];
sendCommand(...args);
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
});
it('should handle undefined prefix', () => {
cacheConfig.USE_REDIS = true;
expect(() => limiterCache()).toThrow('prefix is required');
});
});
});

View File

@@ -1,52 +1,113 @@
const { cacheConfig } = require('./cacheConfig');
const { Keyv } = require('keyv'); const { Keyv } = require('keyv');
const { isEnabled, math } = require('@librechat/api');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile } = require('./keyvFiles'); const { logFile, violationFile } = require('./keyvFiles');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo'); const keyvMongo = require('./keyvMongo');
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const isRedisEnabled = isEnabled(USE_REDIS);
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
const createViolationInstance = (namespace) => {
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};
// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
const config = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const roles = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const mcpTools = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MCP_TOOLS });
const audioRuns = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const messages = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
const flows = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
const tokenConfig = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const s3ExpiryInterval = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
const abortKeys = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
const openIdExchangedTokensCache = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.OPENID_EXCHANGED_TOKENS, ttl: Time.TEN_MINUTES });
const namespaces = { const namespaces = {
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }), [CacheKeys.ROLES]: roles,
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS), [CacheKeys.MCP_TOOLS]: mcpTools,
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT), [CacheKeys.CONFIG_STORE]: config,
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER), [CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT), [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS), [CacheKeys.ENCODED_DOMAINS]: new Keyv({
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
[ViolationTypes.BAN]: new Keyv({
store: keyvMongo, store: keyvMongo,
namespace: CacheKeys.BANS, namespace: CacheKeys.ENCODED_DOMAINS,
ttl: cacheConfig.BAN_DURATION, ttl: 0,
}), }),
general: new Keyv({ store: logFile, namespace: 'violations' }),
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION), concurrent: createViolationInstance('concurrent'),
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION), non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES), token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS), registrations: createViolationInstance('registrations'),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE), [ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ), [ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }), [ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES), [ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES), [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES), [ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES), [ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES), ViolationTypes.RESET_PASSWORD_LIMIT,
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
CacheKeys.OPENID_EXCHANGED_TOKENS,
Time.TEN_MINUTES,
), ),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
[CacheKeys.FLOWS]: flows,
[CacheKeys.OPENID_EXCHANGED_TOKENS]: openIdExchangedTokensCache,
}; };
/** /**
@@ -55,10 +116,7 @@ const namespaces = {
*/ */
function getTTLStores() { function getTTLStores() {
return Object.values(namespaces).filter( return Object.values(namespaces).filter(
(store) => (store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
store instanceof Keyv &&
parseInt(store.opts?.ttl ?? '0') > 0 &&
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
); );
} }
@@ -94,18 +152,18 @@ async function clearExpiredFromCache(cache) {
if (data?.expires && data.expires <= expiryTime) { if (data?.expires && data.expires <= expiryTime) {
const deleted = await cache.opts.store.delete(key); const deleted = await cache.opts.store.delete(key);
if (!deleted) { if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`); console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue; continue;
} }
cleared++; cleared++;
} }
} catch (error) { } catch (error) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error); console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
const deleted = await cache.opts.store.delete(key); const deleted = await cache.opts.store.delete(key);
if (!deleted) { if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`); console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue; continue;
} }
@@ -114,7 +172,7 @@ async function clearExpiredFromCache(cache) {
} }
if (cleared > 0) { if (cleared > 0) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.log( console.log(
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`, `[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
); );
@@ -155,7 +213,7 @@ async function clearAllExpiredFromCache() {
} }
} }
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) { if (!isRedisEnabled && !isEnabled(CI)) {
/** @type {Set<NodeJS.Timeout>} */ /** @type {Set<NodeJS.Timeout>} */
const cleanupIntervals = new Set(); const cleanupIntervals = new Set();
@@ -166,7 +224,7 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
cleanupIntervals.add(cleanup); cleanupIntervals.add(cleanup);
if (cacheConfig.DEBUG_MEMORY_CACHE) { if (debugMemoryCache) {
const monitor = setInterval(() => { const monitor = setInterval(() => {
const ttlStores = getTTLStores(); const ttlStores = getTTLStores();
const memory = process.memoryUsage(); const memory = process.memoryUsage();
@@ -187,13 +245,13 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
} }
const dispose = () => { const dispose = () => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...'); debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
cleanupIntervals.forEach((interval) => clearInterval(interval)); cleanupIntervals.forEach((interval) => clearInterval(interval));
cleanupIntervals.clear(); cleanupIntervals.clear();
// One final cleanup before exit // One final cleanup before exit
clearAllExpiredFromCache().then(() => { clearAllExpiredFromCache().then(() => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed'); debugMemoryCache && console.log('[Cache] Final cleanup completed');
process.exit(0); process.exit(0);
}); });
}; };

92
api/cache/ioredisClient.js vendored Normal file
View File

@@ -0,0 +1,92 @@
const fs = require('fs');
const Redis = require('ioredis');
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
let ioredisClient;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
} else {
ioredisClient = new Redis(REDIS_URI, redisOptions);
}
ioredisClient.on('ready', () => {
logger.info('IoRedis connection ready');
});
ioredisClient.on('reconnecting', () => {
logger.info('IoRedis connection reconnecting');
});
ioredisClient.on('end', () => {
logger.info('IoRedis connection ended');
});
ioredisClient.on('close', () => {
logger.info('IoRedis connection closed');
});
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
ioredisClient.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
);
} else {
logger.info('[Optional] IoRedis not initialized for rate limiters.');
}
module.exports = ioredisClient;

109
api/cache/keyvRedis.js vendored Normal file
View File

@@ -0,0 +1,109 @@
const fs = require('fs');
const ioredis = require('ioredis');
const KeyvRedis = require('@keyv/redis').default;
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
process.env;
let keyvRedis;
const redis_prefix = REDIS_KEY_PREFIX || '';
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
/** @type {import('@keyv/redis').KeyvRedisOptions} */
let keyvOpts = {
useRedisSets: false,
keyPrefix: redis_prefix,
};
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
const cluster = new ioredis.Cluster(hosts, { redisOptions });
keyvRedis = new KeyvRedis(cluster, keyvOpts);
} else {
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
}
const pingInterval = setInterval(
() => {
logger.debug('KeyvRedis ping');
keyvRedis.client.ping().catch((err) => logger.error('Redis keep-alive ping failed:', err));
},
5 * 60 * 1000,
);
keyvRedis.on('ready', () => {
logger.info('KeyvRedis connection ready');
});
keyvRedis.on('reconnecting', () => {
logger.info('KeyvRedis connection reconnecting');
});
keyvRedis.on('end', () => {
logger.info('KeyvRedis connection ended');
});
keyvRedis.on('close', () => {
clearInterval(pingInterval);
logger.info('KeyvRedis connection closed');
});
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
);
} else {
logger.info('[Optional] Redis not initialized.');
}
module.exports = keyvRedis;

View File

@@ -1,5 +1,4 @@
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const { ViolationTypes } = require('librechat-data-provider');
const getLogStores = require('./getLogStores'); const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation'); const banViolation = require('./banViolation');
@@ -10,14 +9,14 @@ const banViolation = require('./banViolation');
* @param {Object} res - Express response object. * @param {Object} res - Express response object.
* @param {string} type - The type of violation. * @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log. * @param {Object} errorMessage - The error message to log.
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1 * @param {number} [score=1] - The severity of the violation. Defaults to 1
*/ */
const logViolation = async (req, res, type, errorMessage, score = 1) => { const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id; const userId = req.user?.id ?? req.user?._id;
if (!userId) { if (!userId) {
return; return;
} }
const logs = getLogStores(ViolationTypes.GENERAL); const logs = getLogStores('general');
const violationLogs = getLogStores(type); const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId; const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;

View File

@@ -1,57 +0,0 @@
const IoRedis = require('ioredis');
const { cacheConfig } = require('./cacheConfig');
const { createClient, createCluster } = require('@keyv/redis');
const GLOBAL_PREFIX_SEPARATOR = '::';
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
const ca = cacheConfig.REDIS_CA;
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
let ioredisClient = null;
if (cacheConfig.USE_REDIS) {
const redisOptions = {
username: username,
password: password,
tls: ca ? { ca } : undefined,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
};
ioredisClient =
urls.length === 1
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
: new IoRedis.Cluster(cacheConfig.REDIS_URI, { redisOptions });
// Pinging the Redis server every 5 minutes to keep the connection alive
const pingInterval = setInterval(() => ioredisClient.ping(), 5 * 60 * 1000);
ioredisClient.on('close', () => clearInterval(pingInterval));
ioredisClient.on('end', () => clearInterval(pingInterval));
}
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
let keyvRedisClient = null;
if (cacheConfig.USE_REDIS) {
// ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
// The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
const redisOptions = { username, password, socket: { tls: ca != null, ca } };
keyvRedisClient =
urls.length === 1
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
: createCluster({
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
defaults: redisOptions,
});
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
// Pinging the Redis server every 5 minutes to keep the connection alive
const keyvPingInterval = setInterval(() => keyvRedisClient.ping(), 5 * 60 * 1000);
keyvRedisClient.on('disconnect', () => clearInterval(keyvPingInterval));
keyvRedisClient.on('end', () => clearInterval(keyvPingInterval));
}
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };

View File

@@ -90,7 +90,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
} }
const instructions = req.body.promptPrefix; const instructions = req.body.promptPrefix;
const result = { return {
id: agent_id, id: agent_id,
instructions, instructions,
provider: endpoint, provider: endpoint,
@@ -98,11 +98,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
model, model,
tools, tools,
}; };
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
result.artifacts = ephemeralAgent.artifacts;
}
return result;
}; };
/** /**

View File

@@ -1,6 +1,6 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api'); const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
const { getMessages, deleteMessages } = require('./Message'); const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models'); const { Conversation } = require('~/db/models');

View File

@@ -1,7 +1,5 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { EToolResources, FileContext, Constants } = require('librechat-data-provider'); const { EToolResources, FileContext } = require('librechat-data-provider');
const { getProjectByName } = require('./Project');
const { getAgent } = require('./Agent');
const { File } = require('~/db/models'); const { File } = require('~/db/models');
/** /**
@@ -14,124 +12,17 @@ const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean(); return await File.findOne({ file_id, ...options }).lean();
}; };
/**
* Checks if a user has access to multiple files through a shared agent (batch operation)
* @param {string} userId - The user ID to check access for
* @param {string[]} fileIds - Array of file IDs to check
* @param {string} agentId - The agent ID that might grant access
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false));
try {
const agent = await getAgent({ id: agentId });
if (!agent) {
return accessMap;
}
// Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId) {
fileIds.forEach((fileId) => accessMap.set(fileId, true));
return accessMap;
}
// Check if agent is shared with the user via projects
if (!agent.projectIds || agent.projectIds.length === 0) {
return accessMap;
}
// Check if agent is in global project
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
) {
return accessMap;
}
// Agent is globally shared - check if it's collaborative
if (checkCollaborative && !agent.isCollaborative) {
return accessMap;
}
// Check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap;
} catch (error) {
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
return accessMap;
}
};
/** /**
* Retrieves files matching a given filter, sorted by the most recently updated. * Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply. * @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters. * @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results. * @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field. * Default excludes the 'text' field.
* @param {Object} [options] - Additional options
* @param {string} [options.userId] - User ID for access control
* @param {string} [options.agentId] - Agent ID that might grant access to files
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents. * @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
*/ */
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => { const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions }; const sortOptions = { updatedAt: -1, ..._sortOptions };
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean(); return await File.find(filter).select(selectFields).sort(sortOptions).lean();
// If userId and agentId are provided, filter files based on access
if (options.userId && options.agentId) {
// Collect file IDs that need access check
const filesToCheck = [];
const ownedFiles = [];
for (const file of files) {
if (file.user && file.user.toString() === options.userId) {
ownedFiles.push(file);
} else {
filesToCheck.push(file);
}
}
if (filesToCheck.length === 0) {
return ownedFiles;
}
// Batch check access for all non-owned files
const fileIds = filesToCheck.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
options.userId,
fileIds,
options.agentId,
false,
);
// Filter files based on access
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
return [...ownedFiles, ...accessibleFiles];
}
return files;
}; };
/** /**
@@ -285,5 +176,4 @@ module.exports = {
deleteFiles, deleteFiles,
deleteFileByFilter, deleteFileByFilter,
batchUpdateFiles, batchUpdateFiles,
hasAccessToFilesViaAgent,
}; };

View File

@@ -1,264 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { fileSchema } = require('@librechat/data-schemas');
const { agentSchema } = require('@librechat/data-schemas');
const { projectSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { getFiles, createFile } = require('./File');
const { getProjectByName } = require('./Project');
const { createAgent } = require('./Agent');
let File;
let Agent;
let Project;
describe('File Access Control', () => {
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
File = mongoose.models.File || mongoose.model('File', fileSchema);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await File.deleteMany({});
await Agent.deleteMany({});
await Project.deleteMany({});
});
describe('hasAccessToFilesViaAgent', () => {
it('should efficiently check access for multiple files at once', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
// Create files
for (const fileId of fileIds) {
await createFile({
user: authorId,
file_id: fileId,
filename: `file-${fileId}.txt`,
filepath: `/uploads/${fileId}`,
});
}
// Create agent with only first two files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileIds[0], fileIds[1]],
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for all files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have access only to the first two files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false);
expect(accessMap.get(fileIds[3])).toBe(false);
});
it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create agent
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]], // Only one file attached
},
},
});
// Check access as the author
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should handle non-existent agent gracefully', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const fileIds = [uuidv4(), uuidv4()];
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
// Should have no access to any files
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when agent is not collaborative', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create agent with files but isCollaborative: false
await createAgent({
id: agentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have no access to any files when isCollaborative is false
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
});
describe('getFiles with agent access control', () => {
test('should return files owned by user and files accessible through agent', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
const ownedFileId = `file_${uuidv4()}`;
const sharedFileId = `file_${uuidv4()}`;
const inaccessibleFileId = `file_${uuidv4()}`;
// Create/get global project using getProjectByName which will upsert
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
// Create agent with shared file
await createAgent({
id: agentId,
name: 'Shared Agent',
provider: 'test',
model: 'test-model',
author: authorId,
projectIds: [globalProject._id],
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [sharedFileId],
},
},
});
// Create files
await createFile({
file_id: ownedFileId,
user: userId,
filename: 'owned.txt',
filepath: '/uploads/owned.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: sharedFileId,
user: authorId,
filename: 'shared.txt',
filepath: '/uploads/shared.txt',
type: 'text/plain',
bytes: 200,
embedded: true,
});
await createFile({
file_id: inaccessibleFileId,
user: authorId,
filename: 'inaccessible.txt',
filepath: '/uploads/inaccessible.txt',
type: 'text/plain',
bytes: 300,
});
// Get files with access control
const files = await getFiles(
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
null,
{ text: 0 },
{ userId: userId.toString(), agentId },
);
expect(files).toHaveLength(2);
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
});
test('should return all files when no userId/agentId provided', async () => {
const userId = new mongoose.Types.ObjectId();
const fileId1 = `file_${uuidv4()}`;
const fileId2 = `file_${uuidv4()}`;
await createFile({
file_id: fileId1,
user: userId,
filename: 'file1.txt',
filepath: '/uploads/file1.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: fileId2,
user: new mongoose.Types.ObjectId(),
filename: 'file2.txt',
filepath: '/uploads/file2.txt',
type: 'text/plain',
bytes: 200,
});
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
expect(files).toHaveLength(2);
});
});
});

View File

@@ -1,7 +1,7 @@
const { z } = require('zod'); const { z } = require('zod');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api'); const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig'); const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
const { Message } = require('~/db/models'); const { Message } = require('~/db/models');
const idSchema = z.string().uuid(); const idSchema = z.string().uuid();

View File

@@ -135,11 +135,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 }, 'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 }, 'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 }, 'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.6, completion: 4 }, 'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 }, 'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 }, 'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 }, 'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-4': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 }, 'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 }, 'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 }, 'pixtral-large': { prompt: 2.0, completion: 6.0 },

View File

@@ -636,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
); );
}); });
test('should return correct prompt and completion rates for Grok 4 model', () => {
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => { test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe( expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
tokenValues['grok-3'].prompt, tokenValues['grok-3'].prompt,
@@ -671,15 +662,6 @@ describe('Grok Model Tests - Pricing', () => {
tokenValues['grok-3-mini-fast'].completion, tokenValues['grok-3-mini-fast'].completion,
); );
}); });
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
}); });
}); });

View File

@@ -1,6 +1,6 @@
{ {
"name": "@librechat/backend", "name": "@librechat/backend",
"version": "v0.7.9-rc1", "version": "v0.7.8",
"description": "", "description": "",
"scripts": { "scripts": {
"start": "echo 'please run this from the root directory'", "start": "echo 'please run this from the root directory'",
@@ -44,20 +44,19 @@
"@googleapis/youtube": "^20.0.0", "@googleapis/youtube": "^20.0.0",
"@keyv/redis": "^4.3.3", "@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.47", "@langchain/community": "^0.3.47",
"@langchain/core": "^0.3.62", "@langchain/core": "^0.3.60",
"@langchain/google-genai": "^0.2.13", "@langchain/google-genai": "^0.2.13",
"@langchain/google-vertexai": "^0.2.13", "@langchain/google-vertexai": "^0.2.13",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0", "@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.63", "@librechat/agents": "^2.4.46",
"@librechat/api": "*", "@librechat/api": "*",
"@librechat/data-schemas": "*", "@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0", "@node-saml/passport-saml": "^5.0.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1", "@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2", "axios": "^1.8.2",
"bcryptjs": "^2.4.3", "bcryptjs": "^2.4.3",
"compression": "^1.8.1", "compression": "^1.7.4",
"connect-redis": "^8.1.0", "connect-redis": "^7.1.0",
"cookie": "^0.7.2", "cookie": "^0.7.2",
"cookie-parser": "^1.4.7", "cookie-parser": "^1.4.7",
"cors": "^2.8.5", "cors": "^2.8.5",
@@ -67,7 +66,7 @@
"express": "^4.21.2", "express": "^4.21.2",
"express-mongo-sanitize": "^2.2.0", "express-mongo-sanitize": "^2.2.0",
"express-rate-limit": "^7.4.1", "express-rate-limit": "^7.4.1",
"express-session": "^1.18.2", "express-session": "^1.18.1",
"express-static-gzip": "^2.2.0", "express-static-gzip": "^2.2.0",
"file-type": "^18.7.0", "file-type": "^18.7.0",
"firebase": "^11.0.2", "firebase": "^11.0.2",
@@ -88,7 +87,7 @@
"mime": "^3.0.0", "mime": "^3.0.0",
"module-alias": "^2.2.3", "module-alias": "^2.2.3",
"mongoose": "^8.12.1", "mongoose": "^8.12.1",
"multer": "^2.0.2", "multer": "^2.0.1",
"nanoid": "^3.3.7", "nanoid": "^3.3.7",
"node-fetch": "^2.7.0", "node-fetch": "^2.7.0",
"nodemailer": "^6.9.15", "nodemailer": "^6.9.15",

View File

@@ -24,23 +24,17 @@ const handleValidationError = (err, res) => {
} }
}; };
module.exports = (err, _req, res, _next) => { // eslint-disable-next-line no-unused-vars
module.exports = (err, req, res, next) => {
try { try {
if (err.name === 'ValidationError') { if (err.name === 'ValidationError') {
return handleValidationError(err, res); return (err = handleValidationError(err, res));
} }
if (err.code && err.code == 11000) { if (err.code && err.code == 11000) {
return handleDuplicateKeyError(err, res); return (err = handleDuplicateKeyError(err, res));
} }
// Special handling for errors like SyntaxError
if (err.statusCode && err.body) {
return res.status(err.statusCode).send(err.body);
}
logger.error('ErrorController => error', err);
return res.status(500).send('An unknown error occurred.');
} catch (err) { } catch (err) {
logger.error('ErrorController => processing error', err); logger.error('ErrorController => error', err);
return res.status(500).send('Processing error in ErrorController.'); res.status(500).send('An unknown error occurred.');
} }
}; };

View File

@@ -1,241 +0,0 @@
const errorController = require('./ErrorController');
const { logger } = require('~/config');
// Mock the logger
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
},
}));
describe('ErrorController', () => {
let mockReq, mockRes, mockNext;
beforeEach(() => {
mockReq = {};
mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
mockNext = jest.fn();
logger.error.mockClear();
});
describe('ValidationError handling', () => {
it('should handle ValidationError with single error', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '["Email is required"]',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with multiple errors', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
password: { message: 'Password is required', path: 'password' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '"Email is required Password is required"',
fields: '["email","password"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with empty errors object', () => {
const validationError = {
name: 'ValidationError',
errors: {},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '[]',
fields: '[]',
});
});
});
describe('Duplicate key error handling', () => {
it('should handle duplicate key error (code 11000)', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle duplicate key error with multiple fields', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com', username: 'testuser' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email","username"] already exists.',
fields: '["email","username"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle error with code 11000 as string', () => {
const duplicateKeyError = {
code: '11000',
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
});
});
describe('SyntaxError handling', () => {
it('should handle errors with statusCode and body', () => {
const syntaxError = {
statusCode: 400,
body: 'Invalid JSON syntax',
};
errorController(syntaxError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
});
it('should handle errors with different statusCode and body', () => {
const customError = {
statusCode: 422,
body: { error: 'Unprocessable entity' },
};
errorController(customError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(422);
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
});
it('should handle error with statusCode but no body', () => {
const partialError = {
statusCode: 400,
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
it('should handle error with body but no statusCode', () => {
const partialError = {
body: 'Some error message',
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
});
describe('Unknown error handling', () => {
it('should handle unknown errors', () => {
const unknownError = new Error('Some unknown error');
errorController(unknownError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
});
it('should handle errors with code other than 11000', () => {
const mongoError = {
code: 11100,
message: 'Some MongoDB error',
};
errorController(mongoError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
});
it('should handle null/undefined errors', () => {
errorController(null, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledWith(
'ErrorController => processing error',
expect.any(Error),
);
});
});
describe('Catch block handling', () => {
beforeEach(() => {
// Restore logger mock to normal behavior for these tests
logger.error.mockRestore();
logger.error = jest.fn();
});
it('should handle errors when logger.error throws', () => {
// Create fresh mocks for this test
const freshMockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
// Mock logger to throw on the first call, succeed on the second
logger.error
.mockImplementationOnce(() => {
throw new Error('Logger error');
})
.mockImplementation(() => {});
const testError = new Error('Test error');
errorController(testError, mockReq, freshMockRes, mockNext);
expect(freshMockRes.status).toHaveBeenCalledWith(500);
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledTimes(2);
});
});
});

View File

@@ -1,10 +1,11 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider'); const { CacheKeys, AuthType } = require('librechat-data-provider');
const { getCustomConfig, getCachedTools } = require('~/server/services/Config'); const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
const { getToolkitKey } = require('~/server/services/ToolService'); const { getToolkitKey } = require('~/server/services/ToolService');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { availableTools } = require('~/app/clients/tools'); const { availableTools } = require('~/app/clients/tools');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
const { Constants } = require('librechat-data-provider');
/** /**
* Filters out duplicate plugins from the list of plugins. * Filters out duplicate plugins from the list of plugins.
@@ -139,9 +140,9 @@ function createGetServerTools() {
const getAvailableTools = async (req, res) => { const getAvailableTools = async (req, res) => {
try { try {
const cache = getLogStores(CacheKeys.CONFIG_STORE); const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS); const cachedTools = await cache.get(CacheKeys.TOOLS);
if (cachedToolsArray) { if (cachedTools) {
res.status(200).json(cachedToolsArray); res.status(200).json(cachedTools);
return; return;
} }
@@ -172,7 +173,7 @@ const getAvailableTools = async (req, res) => {
} }
}); });
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {}; const toolDefinitions = await getCachedTools({ includeGlobal: true });
const toolsOutput = []; const toolsOutput = [];
for (const plugin of authenticatedPlugins) { for (const plugin of authenticatedPlugins) {

View File

@@ -1,5 +1,11 @@
const {
Tools,
Constants,
FileSources,
webSearchKeys,
extractWebSearchEnvVars,
} = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { webSearchKeys, extractWebSearchEnvVars } = require('@librechat/api');
const { const {
getFiles, getFiles,
updateUser, updateUser,
@@ -14,7 +20,6 @@ const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/service
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService'); const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { Tools, Constants, FileSources } = require('librechat-data-provider');
const { processDeleteRequest } = require('~/server/services/Files/process'); const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User } = require('~/db/models'); const { Transaction, Balance, User } = require('~/db/models');
const { deleteToolCalls } = require('~/models/ToolCall'); const { deleteToolCalls } = require('~/models/ToolCall');

View File

@@ -1,195 +0,0 @@
const { duplicateAgent } = require('../v1');
const { getAgent, createAgent } = require('~/models/Agent');
const { getActions } = require('~/models/Action');
const { nanoid } = require('nanoid');
jest.mock('~/models/Agent');
jest.mock('~/models/Action');
jest.mock('nanoid');
describe('duplicateAgent', () => {
let req, res;
beforeEach(() => {
req = {
params: { id: 'agent_123' },
user: { id: 'user_456' },
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
jest.clearAllMocks();
});
it('should duplicate an agent successfully', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_789',
versions: [{ name: 'Test Agent', version: 1 }],
__v: 0,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_456',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
id: 'agent_new_123',
author: 'user_456',
name: expect.stringContaining('Test Agent ('),
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
}),
);
expect(createAgent).toHaveBeenCalledWith(
expect.not.objectContaining({
versions: expect.anything(),
__v: expect.anything(),
}),
);
expect(res.status).toHaveBeenCalledWith(201);
expect(res.json).toHaveBeenCalledWith({
agent: mockNewAgent,
actions: [],
});
});
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
versions: [
{
name: 'Test Agent',
versions: [{ name: 'Nested' }],
__v: 1,
},
],
__v: 2,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(mockNewAgent.versions).toHaveLength(1);
const firstVersion = mockNewAgent.versions[0];
expect(firstVersion).not.toHaveProperty('versions');
expect(firstVersion).not.toHaveProperty('__v');
expect(mockNewAgent).not.toHaveProperty('__v');
expect(res.status).toHaveBeenCalledWith(201);
});
it('should return 404 if agent not found', async () => {
getAgent.mockResolvedValue(null);
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Agent not found',
status: 'error',
});
});
it('should handle tool_resources.ocr correctly', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
tool_resources: {
ocr: { enabled: true, config: 'test' },
other: { should: 'not be copied' },
},
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue({ id: 'agent_new_123' });
await duplicateAgent(req, res);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
tool_resources: {
ocr: { enabled: true, config: 'test' },
},
}),
);
});
it('should handle errors gracefully', async () => {
getAgent.mockRejectedValue(new Error('Database error'));
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(500);
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
});
});

View File

@@ -1,14 +1,10 @@
require('events').EventEmitter.defaultMaxListeners = 100; require('events').EventEmitter.defaultMaxListeners = 100;
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { const {
sendEvent, sendEvent,
createRun, createRun,
Tokenizer, Tokenizer,
checkAccess,
memoryInstructions, memoryInstructions,
formatContentStrings,
createMemoryProcessor, createMemoryProcessor,
} = require('@librechat/api'); } = require('@librechat/api');
const { const {
@@ -17,6 +13,7 @@ const {
GraphEvents, GraphEvents,
formatMessage, formatMessage,
formatAgentMessages, formatAgentMessages,
formatContentStrings,
getTokenCountForMessage, getTokenCountForMessage,
createMetadataAggregator, createMetadataAggregator,
} = require('@librechat/agents'); } = require('@librechat/agents');
@@ -26,26 +23,24 @@ const {
VisionModes, VisionModes,
ContentTypes, ContentTypes,
EModelEndpoint, EModelEndpoint,
KnownEndpoints,
PermissionTypes, PermissionTypes,
isAgentsEndpoint, isAgentsEndpoint,
AgentCapabilities, AgentCapabilities,
bedrockInputSchema, bedrockInputSchema,
removeNullishValues, removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { const { DynamicStructuredTool } = require('@langchain/core/tools');
findPluginAuthsByKeys, const { getBufferString, HumanMessage } = require('@langchain/core/messages');
getFormattedMemories, const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config');
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints'); const { getProviderConfig } = require('~/server/services/Endpoints');
const { checkAccess } = require('~/server/middleware/roles/access');
const BaseClient = require('~/app/clients/BaseClient'); const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent'); const { loadAgent } = require('~/models/Agent');
const { getMCPManager } = require('~/config'); const { getMCPManager } = require('~/config');
@@ -58,7 +53,6 @@ const omitTitleOptions = new Set([
'thinkingBudget', 'thinkingBudget',
'includeThoughts', 'includeThoughts',
'maxOutputTokens', 'maxOutputTokens',
'additionalModelRequestFields',
]); ]);
/** /**
@@ -75,6 +69,8 @@ const payloadParser = ({ req, agent, endpoint }) => {
return req.body.endpointOption.model_parameters; return req.body.endpointOption.model_parameters;
}; };
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
function createTokenCounter(encoding) { function createTokenCounter(encoding) {
@@ -405,12 +401,7 @@ class AgentClient extends BaseClient {
if (user.personalization?.memories === false) { if (user.personalization?.memories === false) {
return; return;
} }
const hasAccess = await checkAccess({ const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
user,
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE],
getRoleByName,
});
if (!hasAccess) { if (!hasAccess) {
logger.debug( logger.debug(
@@ -455,12 +446,6 @@ class AgentClient extends BaseClient {
res: this.options.res, res: this.options.res,
agent: prelimAgent, agent: prelimAgent,
allowedProviders, allowedProviders,
endpointOption: {
endpoint:
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
? EModelEndpoint.agents
: memoryConfig.agent?.provider,
},
}); });
if (!agent) { if (!agent) {
@@ -534,10 +519,7 @@ class AgentClient extends BaseClient {
messagesToProcess = [...messages.slice(-messageWindowSize)]; messagesToProcess = [...messages.slice(-messageWindowSize)];
} }
} }
return await this.processMemory(messagesToProcess);
const bufferString = getBufferString(messagesToProcess);
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
return await this.processMemory([bufferMessage]);
} catch (error) { } catch (error) {
logger.error('Memory Agent failed to process memory', error); logger.error('Memory Agent failed to process memory', error);
} }
@@ -709,12 +691,17 @@ class AgentClient extends BaseClient {
version: 'v2', version: 'v2',
}; };
const getUserMCPAuthMap = await createGetMCPAuthMap();
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
payload, payload,
this.indexTokenCountMap, this.indexTokenCountMap,
toolSet, toolSet,
); );
if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) {
initialMessages = formatContentStrings(initialMessages);
}
/** /**
* *
@@ -778,9 +765,6 @@ class AgentClient extends BaseClient {
} }
let messages = _messages; let messages = _messages;
if (agent.useLegacyContent === true) {
messages = formatContentStrings(messages);
}
if ( if (
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
'prompt-caching', 'prompt-caching',
@@ -829,11 +813,10 @@ class AgentClient extends BaseClient {
} }
try { try {
if (await hasCustomUserVars()) { if (getUserMCPAuthMap) {
config.configurable.userMCPAuthMap = await getMCPAuthMap({ config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
tools: agent.tools, tools: agent.tools,
userId: this.options.req.user.id, userId: this.options.req.user.id,
findPluginAuthsByKeys,
}); });
} }
} catch (err) { } catch (err) {
@@ -1051,12 +1034,6 @@ class AgentClient extends BaseClient {
options.llmConfig?.azureOpenAIApiInstanceName == null options.llmConfig?.azureOpenAIApiInstanceName == null
) { ) {
provider = Providers.OPENAI; provider = Providers.OPENAI;
} else if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName != null &&
provider !== Providers.AZURE
) {
provider = Providers.AZURE;
} }
/** @type {import('@librechat/agents').ClientOptions} */ /** @type {import('@librechat/agents').ClientOptions} */
@@ -1135,52 +1112,8 @@ class AgentClient extends BaseClient {
} }
} }
/** /** Silent method, as `recordCollectedUsage` is used instead */
* @param {object} params async recordTokenUsage() {}
* @param {number} params.promptTokens
* @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model]
* @param {string} [params.context='message']
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
try {
await spendTokens(
{
model,
context,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ promptTokens, completionTokens },
);
if (
usage &&
typeof usage === 'object' &&
'reasoning_tokens' in usage &&
typeof usage.reasoning_tokens === 'number'
) {
await spendTokens(
{
model,
context: 'reasoning',
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ completionTokens: usage.reasoning_tokens },
);
}
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #recordTokenUsage] Error recording token usage',
error,
);
}
}
getEncoding() { getEncoding() {
return 'o200k_base'; return 'o200k_base';

View File

@@ -12,14 +12,10 @@ const { saveMessage } = require('~/models');
const AgentController = async (req, res, next, initializeClient, addTitle) => { const AgentController = async (req, res, next, initializeClient, addTitle) => {
let { let {
text, text,
isRegenerate,
endpointOption, endpointOption,
conversationId, conversationId,
isContinued = false,
editedContent = null,
parentMessageId = null, parentMessageId = null,
overrideParentMessageId = null, overrideParentMessageId = null,
responseMessageId: editedResponseMessageId = null,
} = req.body; } = req.body;
let sender; let sender;
@@ -71,7 +67,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
handler(); handler();
} }
} catch (e) { } catch (e) {
logger.error('[AgentController] Error in cleanup handler', e); // Ignore cleanup errors
} }
} }
} }
@@ -159,7 +155,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
try { try {
res.removeListener('close', closeHandler); res.removeListener('close', closeHandler);
} catch (e) { } catch (e) {
logger.error('[AgentController] Error removing close listener', e); // Ignore
} }
}); });
@@ -167,15 +163,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
user: userId, user: userId,
onStart, onStart,
getReqData, getReqData,
isContinued,
isRegenerate,
editedContent,
conversationId, conversationId,
parentMessageId, parentMessageId,
abortController, abortController,
overrideParentMessageId, overrideParentMessageId,
isEdited: !!editedContent,
responseMessageId: editedResponseMessageId,
progressOptions: { progressOptions: {
res, res,
}, },

View File

@@ -1,8 +1,6 @@
const { z } = require('zod');
const fs = require('fs').promises; const fs = require('fs').promises;
const { nanoid } = require('nanoid'); const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const { const {
Tools, Tools,
Constants, Constants,
@@ -10,7 +8,6 @@ const {
SystemRoles, SystemRoles,
EToolResources, EToolResources,
actionDelimiter, actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { const {
getAgent, getAgent,
@@ -33,7 +30,6 @@ const { deleteFileByFilter } = require('~/models/File');
const systemTools = { const systemTools = {
[Tools.execute_code]: true, [Tools.execute_code]: true,
[Tools.file_search]: true, [Tools.file_search]: true,
[Tools.web_search]: true,
}; };
/** /**
@@ -46,13 +42,9 @@ const systemTools = {
*/ */
const createAgentHandler = async (req, res) => { const createAgentHandler = async (req, res) => {
try { try {
const validatedData = agentCreateSchema.parse(req.body); const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
const { tools = [], ...agentData } = removeNullishValues(validatedData);
const { id: userId } = req.user; const { id: userId } = req.user;
agentData.id = `agent_${nanoid()}`;
agentData.author = userId;
agentData.tools = []; agentData.tools = [];
const availableTools = await getCachedTools({ includeGlobal: true }); const availableTools = await getCachedTools({ includeGlobal: true });
@@ -66,13 +58,19 @@ const createAgentHandler = async (req, res) => {
} }
} }
Object.assign(agentData, {
author: userId,
name,
description,
instructions,
provider,
model,
});
agentData.id = `agent_${nanoid()}`;
const agent = await createAgent(agentData); const agent = await createAgent(agentData);
res.status(201).json(agent); res.status(201).json(agent);
} catch (error) { } catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents] Error creating agent', error); logger.error('[/Agents] Error creating agent', error);
res.status(500).json({ error: error.message }); res.status(500).json({ error: error.message });
} }
@@ -156,16 +154,14 @@ const getAgentHandler = async (req, res) => {
const updateAgentHandler = async (req, res) => { const updateAgentHandler = async (req, res) => {
try { try {
const id = req.params.id; const id = req.params.id;
const validatedData = agentUpdateSchema.parse(req.body); const { projectIds, removeProjectIds, ...updateData } = req.body;
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
const isAdmin = req.user.role === SystemRoles.ADMIN; const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id }); const existingAgent = await getAgent({ id });
const isAuthor = existingAgent.author.toString() === req.user.id;
if (!existingAgent) { if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' }); return res.status(404).json({ error: 'Agent not found' });
} }
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor; const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) { if (!hasEditPermission) {
@@ -204,11 +200,6 @@ const updateAgentHandler = async (req, res) => {
return res.json(updatedAgent); return res.json(updatedAgent);
} catch (error) { } catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents/:id] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents/:id] Error updating Agent', error); logger.error('[/Agents/:id] Error updating Agent', error);
if (error.statusCode === 409) { if (error.statusCode === 409) {
@@ -251,8 +242,6 @@ const duplicateAgentHandler = async (req, res) => {
createdAt: _createdAt, createdAt: _createdAt,
updatedAt: _updatedAt, updatedAt: _updatedAt,
tool_resources: _tool_resources = {}, tool_resources: _tool_resources = {},
versions: _versions,
__v: _v,
...cloneData ...cloneData
} = agent; } = agent;
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', { cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
@@ -391,22 +380,6 @@ const uploadAgentAvatarHandler = async (req, res) => {
return res.status(400).json({ message: 'Agent ID is required' }); return res.status(400).json({ message: 'Agent ID is required' });
} }
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id: agent_id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
return res.status(403).json({
error: 'You do not have permission to modify this non-collaborative agent',
});
}
const buffer = await fs.readFile(req.file.path); const buffer = await fs.readFile(req.file.path);
const fileStrategy = req.app.locals.fileStrategy; const fileStrategy = req.app.locals.fileStrategy;
@@ -429,7 +402,14 @@ const uploadAgentAvatarHandler = async (req, res) => {
source: fileStrategy, source: fileStrategy,
}; };
let _avatar = existingAgent.avatar; let _avatar;
try {
const agent = await getAgent({ id: agent_id });
_avatar = agent.avatar;
} catch (error) {
logger.error('[/:agent_id/avatar] Error fetching agent', error);
_avatar = {};
}
if (_avatar && _avatar.source) { if (_avatar && _avatar.source) {
const { deleteFile } = getStrategyFunctions(_avatar.source); const { deleteFile } = getStrategyFunctions(_avatar.source);
@@ -451,7 +431,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
}; };
promises.push( promises.push(
await updateAgent({ id: agent_id }, data, { await updateAgent({ id: agent_id, author: req.user.id }, data, {
updatingUserId: req.user.id, updatingUserId: req.user.id,
}), }),
); );

View File

@@ -1,659 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { agentSchema } = require('@librechat/data-schemas');
// Only mock the dependencies that are not database-related
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn().mockResolvedValue({
web_search: true,
execute_code: true,
file_search: true,
}),
}));
jest.mock('~/models/Project', () => ({
getProjectByName: jest.fn().mockResolvedValue(null),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/services/Files/images/avatar', () => ({
resizeAvatar: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3Url: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
filterFile: jest.fn(),
}));
jest.mock('~/models/Action', () => ({
updateAction: jest.fn(),
getActions: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/models/File', () => ({
deleteFileByFilter: jest.fn(),
}));
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
*/
let Agent;
describe('Agent Controllers - Mass Assignment Protection', () => {
let mongoServer;
let mockReq;
let mockRes;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
// Reset all mocks
jest.clearAllMocks();
// Setup mock request and response objects
mockReq = {
user: {
id: new mongoose.Types.ObjectId().toString(),
role: 'USER',
},
body: {},
params: {},
app: {
locals: {
fileStrategy: 'local',
},
},
};
mockRes = {
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
};
});
describe('createAgentHandler', () => {
test('should create agent with allowed fields only', async () => {
const validData = {
name: 'Test Agent',
description: 'A test agent',
instructions: 'Be helpful',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search'],
model_parameters: { temperature: 0.7 },
tool_resources: {
file_search: { file_ids: ['file1', 'file2'] },
},
};
mockReq.body = validData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
expect(mockRes.json).toHaveBeenCalled();
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.name).toBe('Test Agent');
expect(createdAgent.description).toBe('A test agent');
expect(createdAgent.provider).toBe('openai');
expect(createdAgent.model).toBe('gpt-4');
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
expect(createdAgent.tools).toContain('web_search');
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb).toBeDefined();
expect(agentInDb.name).toBe('Test Agent');
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
});
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
const maliciousData = {
// Required fields
provider: 'openai',
model: 'gpt-4',
name: 'Malicious Agent',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
authorName: 'Hacker', // Should be stripped
isCollaborative: true, // Should be stripped on creation
versions: [], // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
id: 'custom_agent_id', // Should be overridden
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
mockReq.body = maliciousData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not set
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
expect(createdAgent.authorName).toBeUndefined();
expect(createdAgent.isCollaborative).toBeFalsy();
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
// Verify timestamps are recent (not the malicious dates)
const createdTime = new Date(createdAgent.createdAt).getTime();
const now = Date.now();
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
expect(agentInDb.authorName).toBeUndefined();
});
test('should validate required fields', async () => {
const invalidData = {
name: 'Missing Required Fields',
// Missing provider and model
};
mockReq.body = invalidData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
// Verify nothing was created in database
const count = await Agent.countDocuments();
expect(count).toBe(0);
});
test('should handle tool_resources validation', async () => {
const dataWithInvalidToolResources = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Tool Resources',
tool_resources: {
// Valid resources
file_search: {
file_ids: ['file1', 'file2'],
vector_store_ids: ['vs1'],
},
execute_code: {
file_ids: ['file3'],
},
// Invalid resource (should be stripped by schema)
invalid_resource: {
file_ids: ['file4'],
},
},
};
mockReq.body = dataWithInvalidToolResources;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.tool_resources).toBeDefined();
expect(createdAgent.tool_resources.file_search).toBeDefined();
expect(createdAgent.tool_resources.execute_code).toBeDefined();
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
});
test('should handle avatar validation', async () => {
const dataWithAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Avatar',
avatar: {
filepath: 'https://example.com/avatar.png',
source: 's3',
},
};
mockReq.body = dataWithAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.avatar).toEqual({
filepath: 'https://example.com/avatar.png',
source: 's3',
});
});
test('should handle invalid avatar format', async () => {
const dataWithInvalidAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Invalid Avatar',
avatar: 'just-a-string', // Invalid format
};
mockReq.body = dataWithInvalidAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
}),
);
});
});
describe('updateAgentHandler', () => {
let existingAgentId;
let existingAgentAuthorId;
beforeEach(async () => {
// Create an existing agent for update tests
existingAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
author: existingAgentAuthorId,
description: 'Original description',
isCollaborative: false,
versions: [
{
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
description: 'Original description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
existingAgentId = agent.id;
});
test('should update agent with allowed fields only', async () => {
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Agent',
description: 'Updated description',
model: 'gpt-4',
isCollaborative: true, // This IS allowed in updates
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(400);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Updated Agent');
expect(updatedAgent.description).toBe('Updated description');
expect(updatedAgent.model).toBe('gpt-4');
expect(updatedAgent.isCollaborative).toBe(true);
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Updated Agent');
expect(agentInDb.isCollaborative).toBe(true);
});
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Name',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
authorName: 'Hacker', // Should be stripped
id: 'different_agent_id', // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
versions: [], // Should be stripped
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not changed
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
expect(updatedAgent.authorName).toBeUndefined();
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
expect(agentInDb.id).toBe(existingAgentId);
});
test('should reject update from non-author when not collaborative', async () => {
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Unauthorized Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalledWith({
error: 'You do not have permission to modify this non-collaborative agent',
});
// Verify agent was not modified in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Original Agent');
});
test('should allow update from non-author when collaborative', async () => {
// First make the agent collaborative
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Collaborative Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Collaborative Update');
// Author field should be removed for non-author
expect(updatedAgent.author).toBeUndefined();
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Collaborative Update');
});
test('should allow admin to update any agent', async () => {
const adminUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = adminUserId;
mockReq.user.role = 'ADMIN'; // Set as admin
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Admin Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Admin Update');
});
test('should handle projectIds updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
const projectId1 = new mongoose.Types.ObjectId().toString();
const projectId2 = new mongoose.Types.ObjectId().toString();
mockReq.body = {
projectIds: [projectId1, projectId2],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent).toBeDefined();
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
});
test('should validate tool_resources in updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tool_resources: {
ocr: {
file_ids: ['ocr1', 'ocr2'],
},
execute_code: {
file_ids: ['img1'],
},
// Invalid tool resource
invalid_tool: {
file_ids: ['invalid'],
},
},
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tool_resources).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeDefined();
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
});
test('should return 404 for non-existent agent', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
mockReq.body = {
name: 'Update Non-existent',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(404);
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
});
test('should handle validation errors properly', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
model_parameters: 'invalid-not-an-object', // Should be an object
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
});
});
describe('Mass Assignment Attack Scenarios', () => {
test('should prevent setting system fields during creation', async () => {
const systemFields = {
provider: 'openai',
model: 'gpt-4',
name: 'System Fields Test',
// System fields that should never be settable by users
__v: 99,
_id: new mongoose.Types.ObjectId(),
versions: [
{
name: 'Fake Version',
provider: 'fake',
model: 'fake-model',
},
],
};
mockReq.body = systemFields;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify system fields were not affected
expect(createdAgent.__v).not.toBe(99);
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.__v).not.toBe(99);
});
test('should prevent privilege escalation through isCollaborative', async () => {
// Create a non-collaborative agent
const authorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
isCollaborative: false,
versions: [
{
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
// Try to make it collaborative as a different user
const attackerId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = attackerId;
mockReq.params.id = agent.id;
mockReq.body = {
isCollaborative: true, // Trying to escalate privileges
};
await updateAgentHandler(mockReq, mockRes);
// Should be rejected
expect(mockRes.status).toHaveBeenCalledWith(403);
// Verify in database that it's still not collaborative
const agentInDb = await Agent.findOne({ id: agent.id });
expect(agentInDb.isCollaborative).toBe(false);
});
test('should prevent author hijacking', async () => {
const originalAuthorId = new mongoose.Types.ObjectId();
const attackerId = new mongoose.Types.ObjectId();
// Admin creates an agent
mockReq.user.id = originalAuthorId.toString();
mockReq.user.role = 'ADMIN';
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Admin Agent',
author: attackerId.toString(), // Trying to set different author
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Author should be the actual user, not the attempted value
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
});
test('should strip unknown fields to prevent future vulnerabilities', async () => {
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Future Proof Test',
// Unknown fields that might be added in future
superAdminAccess: true,
bypassAllChecks: true,
internalFlag: 'secret',
futureFeature: 'exploit',
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unknown fields were stripped
expect(createdAgent.superAdminAccess).toBeUndefined();
expect(createdAgent.bypassAllChecks).toBeUndefined();
expect(createdAgent.internalFlag).toBeUndefined();
expect(createdAgent.futureFeature).toBeUndefined();
// Also check in database
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
expect(agentInDb.superAdminAccess).toBeUndefined();
expect(agentInDb.bypassAllChecks).toBeUndefined();
expect(agentInDb.internalFlag).toBeUndefined();
expect(agentInDb.futureFeature).toBeUndefined();
});
});
});

View File

@@ -1,21 +1,21 @@
const { nanoid } = require('nanoid'); const { nanoid } = require('nanoid');
const { EnvVar } = require('@librechat/agents'); const { EnvVar } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { checkAccess, loadWebSearchAuth } = require('@librechat/api');
const { const {
Tools, Tools,
AuthType, AuthType,
Permissions, Permissions,
ToolCallTypes, ToolCallTypes,
PermissionTypes, PermissionTypes,
loadWebSearchAuth,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { loadTools } = require('~/app/clients/tools/util'); const { loadTools } = require('~/app/clients/tools/util');
const { getRoleByName } = require('~/models/Role'); const { checkAccess } = require('~/server/middleware');
const { getMessage } = require('~/models/Message'); const { getMessage } = require('~/models/Message');
const { logger } = require('~/config');
const fieldsMap = { const fieldsMap = {
[Tools.execute_code]: [EnvVar.CODE_API_KEY], [Tools.execute_code]: [EnvVar.CODE_API_KEY],
@@ -79,7 +79,6 @@ const verifyToolAuth = async (req, res) => {
throwError: false, throwError: false,
}); });
} catch (error) { } catch (error) {
logger.error('Error loading auth values', error);
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED }); res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
return; return;
} }
@@ -133,12 +132,7 @@ const callTool = async (req, res) => {
logger.debug(`[${toolId}/call] User: ${req.user.id}`); logger.debug(`[${toolId}/call] User: ${req.user.id}`);
let hasAccess = true; let hasAccess = true;
if (toolAccessPermType[toolId]) { if (toolAccessPermType[toolId]) {
hasAccess = await checkAccess({ hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
user: req.user,
permissionType: toolAccessPermType[toolId],
permissions: [Permissions.USE],
getRoleByName,
});
} }
if (!hasAccess) { if (!hasAccess) {
logger.warn( logger.warn(

View File

@@ -55,6 +55,7 @@ const startServer = async () => {
/* Middleware */ /* Middleware */
app.use(noIndex); app.use(noIndex);
app.use(errorController);
app.use(express.json({ limit: '3mb' })); app.use(express.json({ limit: '3mb' }));
app.use(express.urlencoded({ extended: true, limit: '3mb' })); app.use(express.urlencoded({ extended: true, limit: '3mb' }));
app.use(mongoSanitize()); app.use(mongoSanitize());
@@ -120,9 +121,6 @@ const startServer = async () => {
app.use('/api/tags', routes.tags); app.use('/api/tags', routes.tags);
app.use('/api/mcp', routes.mcp); app.use('/api/mcp', routes.mcp);
// Add the error controller one more time after all routes
app.use(errorController);
app.use((req, res) => { app.use((req, res) => {
res.set({ res.set({
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate', 'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',

View File

@@ -1,4 +1,5 @@
const fs = require('fs'); const fs = require('fs');
const path = require('path');
const request = require('supertest'); const request = require('supertest');
const { MongoMemoryServer } = require('mongodb-memory-server'); const { MongoMemoryServer } = require('mongodb-memory-server');
const mongoose = require('mongoose'); const mongoose = require('mongoose');
@@ -58,30 +59,6 @@ describe('Server Configuration', () => {
expect(response.headers['pragma']).toBe('no-cache'); expect(response.headers['pragma']).toBe('no-cache');
expect(response.headers['expires']).toBe('0'); expect(response.headers['expires']).toBe('0');
}); });
it('should return 500 for unknown errors via ErrorController', async () => {
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
// Mock MongoDB operations to fail
const originalFindOne = mongoose.models.User.findOne;
const mockError = new Error('MongoDB operation failed');
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
throw mockError;
});
try {
const response = await request(app).post('/api/auth/login').send({
email: 'test@example.com',
password: 'password123',
});
expect(response.status).toBe(500);
expect(response.text).toBe('An unknown error occurred.');
} finally {
// Restore original function
mongoose.models.User.findOne = originalFindOne;
}
});
}); });
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely // Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely

View File

@@ -1,4 +1,3 @@
const { handleError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { const {
EndpointURLs, EndpointURLs,
@@ -15,6 +14,7 @@ const openAI = require('~/server/services/Endpoints/openAI');
const agents = require('~/server/services/Endpoints/agents'); const agents = require('~/server/services/Endpoints/agents');
const custom = require('~/server/services/Endpoints/custom'); const custom = require('~/server/services/Endpoints/custom');
const google = require('~/server/services/Endpoints/google'); const google = require('~/server/services/Endpoints/google');
const { handleError } = require('~/server/utils');
const buildFunction = { const buildFunction = {
[EModelEndpoint.openAI]: openAI.buildOptions, [EModelEndpoint.openAI]: openAI.buildOptions,

View File

@@ -18,6 +18,7 @@ const message = 'Your account has been temporarily banned due to violations of o
* @function * @function
* @param {Object} req - Express Request object. * @param {Object} req - Express Request object.
* @param {Object} res - Express Response object. * @param {Object} res - Express Response object.
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
* *
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function. * @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
*/ */
@@ -134,7 +135,6 @@ const checkBan = async (req, res, next = () => {}) => {
return await banResponse(req, res); return await banResponse(req, res);
} catch (error) { } catch (error) {
logger.error('Error in checkBan middleware:', error); logger.error('Error in checkBan middleware:', error);
return next(error);
} }
}; };

View File

@@ -1,4 +1,4 @@
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider'); const { Time, CacheKeys } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq'); const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache'); const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
@@ -37,7 +37,7 @@ const concurrentLimiter = async (req, res, next) => {
const userId = req.user?.id ?? req.user?._id ?? ''; const userId = req.user?.id ?? req.user?._id ?? '';
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1); const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = ViolationTypes.CONCURRENT; const type = 'concurrent';
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`; const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
const pendingRequests = +((await cache.get(key)) ?? 0); const pendingRequests = +((await cache.get(key)) ?? 0);

View File

@@ -1,79 +0,0 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
const forkIpMax = FORK_IP_MAX;
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
const forkUserMax = FORK_USER_MAX;
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
return {
forkIpWindowMs,
forkIpMax,
forkIpWindowInMinutes,
forkUserWindowMs,
forkUserMax,
forkUserWindowInMinutes,
forkViolationScore: FORK_VIOLATION_SCORE,
};
};
const createForkHandler = (ip = true) => {
const {
forkIpMax,
forkUserMax,
forkViolationScore,
forkIpWindowInMinutes,
forkUserWindowInMinutes,
} = getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
const errorMessage = {
type,
max: ip ? forkIpMax : forkUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, forkViolationScore);
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
};
};
const createForkLimiters = () => {
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
const ipLimiterOptions = {
windowMs: forkIpWindowMs,
max: forkIpMax,
handler: createForkHandler(),
store: limiterCache('fork_ip_limiter'),
};
const userLimiterOptions = {
windowMs: forkUserWindowMs,
max: forkUserMax,
handler: createForkHandler(false),
keyGenerator: function (req) {
return req.user?.id;
},
store: limiterCache('fork_user_limiter'),
};
const forkIpLimiter = rateLimit(ipLimiterOptions);
const forkUserLimiter = rateLimit(userLimiterOptions);
return { forkIpLimiter, forkUserLimiter };
};
module.exports = { createForkLimiters };

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100; const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15; const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50; const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15; const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000; const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
const importIpMax = IMPORT_IP_MAX; const importIpMax = IMPORT_IP_MAX;
@@ -25,18 +27,12 @@ const getEnvironmentVariables = () => {
importUserWindowMs, importUserWindowMs,
importUserMax, importUserMax,
importUserWindowInMinutes, importUserWindowInMinutes,
importViolationScore: IMPORT_VIOLATION_SCORE,
}; };
}; };
const createImportHandler = (ip = true) => { const createImportHandler = (ip = true) => {
const { const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
importIpMax, getEnvironmentVariables();
importUserMax,
importViolationScore,
importIpWindowInMinutes,
importUserWindowInMinutes,
} = getEnvironmentVariables();
return async (req, res) => { return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT; const type = ViolationTypes.FILE_UPLOAD_LIMIT;
@@ -47,7 +43,7 @@ const createImportHandler = (ip = true) => {
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes, windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
}; };
await logViolation(req, res, type, errorMessage, importViolationScore); await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many conversation import requests. Try again later' }); res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
}; };
}; };
@@ -60,7 +56,6 @@ const createImportLimiters = () => {
windowMs: importIpWindowMs, windowMs: importIpWindowMs,
max: importIpMax, max: importIpMax,
handler: createImportHandler(), handler: createImportHandler(),
store: limiterCache('import_ip_limiter'),
}; };
const userLimiterOptions = { const userLimiterOptions = {
windowMs: importUserWindowMs, windowMs: importUserWindowMs,
@@ -69,9 +64,23 @@ const createImportLimiters = () => {
keyGenerator: function (req) { keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available return req.user?.id; // Use the user ID or NULL if not available
}, },
store: limiterCache('import_user_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for import rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'import_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'import_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const importIpLimiter = rateLimit(ipLimiterOptions); const importIpLimiter = rateLimit(ipLimiterOptions);
const importUserLimiter = rateLimit(userLimiterOptions); const importUserLimiter = rateLimit(userLimiterOptions);
return { importIpLimiter, importUserLimiter }; return { importIpLimiter, importUserLimiter };

View File

@@ -4,7 +4,6 @@ const createSTTLimiters = require('./sttLimiters');
const loginLimiter = require('./loginLimiter'); const loginLimiter = require('./loginLimiter');
const importLimiters = require('./importLimiters'); const importLimiters = require('./importLimiters');
const uploadLimiters = require('./uploadLimiters'); const uploadLimiters = require('./uploadLimiters');
const forkLimiters = require('./forkLimiters');
const registerLimiter = require('./registerLimiter'); const registerLimiter = require('./registerLimiter');
const toolCallLimiter = require('./toolCallLimiter'); const toolCallLimiter = require('./toolCallLimiter');
const messageLimiters = require('./messageLimiters'); const messageLimiters = require('./messageLimiters');
@@ -15,7 +14,6 @@ module.exports = {
...uploadLimiters, ...uploadLimiters,
...importLimiters, ...importLimiters,
...messageLimiters, ...messageLimiters,
...forkLimiters,
loginLimiter, loginLimiter,
registerLimiter, registerLimiter,
toolCallLimiter, toolCallLimiter,

View File

@@ -1,8 +1,9 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider'); const { RedisStore } = require('rate-limit-redis');
const { removePorts } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env; const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
const windowMs = LOGIN_WINDOW * 60 * 1000; const windowMs = LOGIN_WINDOW * 60 * 1000;
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`; const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`;
const handler = async (req, res) => { const handler = async (req, res) => {
const type = ViolationTypes.LOGINS; const type = 'logins';
const errorMessage = { const errorMessage = {
type, type,
max, max,
@@ -27,9 +28,17 @@ const limiterOptions = {
max, max,
handler, handler,
keyGenerator: removePorts, keyGenerator: removePorts,
store: limiterCache('login_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for login rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'login_limiter:',
});
limiterOptions.store = store;
}
const loginLimiter = rateLimit(limiterOptions); const loginLimiter = rateLimit(limiterOptions);
module.exports = loginLimiter; module.exports = loginLimiter;

View File

@@ -1,15 +1,16 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider'); const { RedisStore } = require('rate-limit-redis');
const denyRequest = require('~/server/middleware/denyRequest'); const denyRequest = require('~/server/middleware/denyRequest');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const { isEnabled } = require('~/server/utils');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { const {
MESSAGE_IP_MAX = 40, MESSAGE_IP_MAX = 40,
MESSAGE_IP_WINDOW = 1, MESSAGE_IP_WINDOW = 1,
MESSAGE_USER_MAX = 40, MESSAGE_USER_MAX = 40,
MESSAGE_USER_WINDOW = 1, MESSAGE_USER_WINDOW = 1,
MESSAGE_VIOLATION_SCORE: score,
} = process.env; } = process.env;
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000; const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
@@ -30,7 +31,7 @@ const userWindowInMinutes = userWindowMs / 60000;
*/ */
const createHandler = (ip = true) => { const createHandler = (ip = true) => {
return async (req, res) => { return async (req, res) => {
const type = ViolationTypes.MESSAGE_LIMIT; const type = 'message_limit';
const errorMessage = { const errorMessage = {
type, type,
max: ip ? ipMax : userMax, max: ip ? ipMax : userMax,
@@ -38,7 +39,7 @@ const createHandler = (ip = true) => {
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes, windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
}; };
await logViolation(req, res, type, errorMessage, score); await logViolation(req, res, type, errorMessage);
return await denyRequest(req, res, errorMessage); return await denyRequest(req, res, errorMessage);
}; };
}; };
@@ -50,7 +51,6 @@ const ipLimiterOptions = {
windowMs: ipWindowMs, windowMs: ipWindowMs,
max: ipMax, max: ipMax,
handler: createHandler(), handler: createHandler(),
store: limiterCache('message_ip_limiter'),
}; };
const userLimiterOptions = { const userLimiterOptions = {
@@ -60,9 +60,23 @@ const userLimiterOptions = {
keyGenerator: function (req) { keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available return req.user?.id; // Use the user ID or NULL if not available
}, },
store: limiterCache('message_user_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for message rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'message_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'message_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
/** /**
* Message request rate limiter by IP * Message request rate limiter by IP
*/ */

View File

@@ -1,8 +1,9 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider'); const { RedisStore } = require('rate-limit-redis');
const { removePorts } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env; const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
const windowMs = REGISTER_WINDOW * 60 * 1000; const windowMs = REGISTER_WINDOW * 60 * 1000;
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`; const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`;
const handler = async (req, res) => { const handler = async (req, res) => {
const type = ViolationTypes.REGISTRATIONS; const type = 'registrations';
const errorMessage = { const errorMessage = {
type, type,
max, max,
@@ -27,9 +28,17 @@ const limiterOptions = {
max, max,
handler, handler,
keyGenerator: removePorts, keyGenerator: removePorts,
store: limiterCache('register_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for register rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'register_limiter:',
});
limiterOptions.store = store;
}
const registerLimiter = rateLimit(limiterOptions); const registerLimiter = rateLimit(limiterOptions);
module.exports = registerLimiter; module.exports = registerLimiter;

View File

@@ -1,8 +1,10 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { const {
RESET_PASSWORD_WINDOW = 2, RESET_PASSWORD_WINDOW = 2,
@@ -31,9 +33,17 @@ const limiterOptions = {
max, max,
handler, handler,
keyGenerator: removePorts, keyGenerator: removePorts,
store: limiterCache('reset_password_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for reset password rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'reset_password_limiter:',
});
limiterOptions.store = store;
}
const resetPasswordLimiter = rateLimit(limiterOptions); const resetPasswordLimiter = rateLimit(limiterOptions);
module.exports = resetPasswordLimiter; module.exports = resetPasswordLimiter;

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100; const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1; const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50; const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1; const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000; const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
const sttIpMax = STT_IP_MAX; const sttIpMax = STT_IP_MAX;
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
sttUserWindowMs, sttUserWindowMs,
sttUserMax, sttUserMax,
sttUserWindowInMinutes, sttUserWindowInMinutes,
sttViolationScore: STT_VIOLATION_SCORE,
}; };
}; };
const createSTTHandler = (ip = true) => { const createSTTHandler = (ip = true) => {
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } = const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
getEnvironmentVariables(); getEnvironmentVariables();
return async (req, res) => { return async (req, res) => {
@@ -42,7 +43,7 @@ const createSTTHandler = (ip = true) => {
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes, windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
}; };
await logViolation(req, res, type, errorMessage, sttViolationScore); await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many STT requests. Try again later' }); res.status(429).json({ message: 'Too many STT requests. Try again later' });
}; };
}; };
@@ -54,7 +55,6 @@ const createSTTLimiters = () => {
windowMs: sttIpWindowMs, windowMs: sttIpWindowMs,
max: sttIpMax, max: sttIpMax,
handler: createSTTHandler(), handler: createSTTHandler(),
store: limiterCache('stt_ip_limiter'),
}; };
const userLimiterOptions = { const userLimiterOptions = {
@@ -64,9 +64,23 @@ const createSTTLimiters = () => {
keyGenerator: function (req) { keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available return req.user?.id; // Use the user ID or NULL if not available
}, },
store: limiterCache('stt_user_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for STT rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'stt_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'stt_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const sttIpLimiter = rateLimit(ipLimiterOptions); const sttIpLimiter = rateLimit(ipLimiterOptions);
const sttUserLimiter = rateLimit(userLimiterOptions); const sttUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,9 +1,10 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env; const { logger } = require('~/config');
const handler = async (req, res) => { const handler = async (req, res) => {
const type = ViolationTypes.TOOL_CALL_LIMIT; const type = ViolationTypes.TOOL_CALL_LIMIT;
@@ -14,7 +15,7 @@ const handler = async (req, res) => {
windowInMinutes: 1, windowInMinutes: 1,
}; };
await logViolation(req, res, type, errorMessage, score); await logViolation(req, res, type, errorMessage, 0);
res.status(429).json({ message: 'Too many tool call requests. Try again later' }); res.status(429).json({ message: 'Too many tool call requests. Try again later' });
}; };
@@ -25,9 +26,17 @@ const limiterOptions = {
keyGenerator: function (req) { keyGenerator: function (req) {
return req.user?.id; return req.user?.id;
}, },
store: limiterCache('tool_call_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for tool call rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'tool_call_limiter:',
});
limiterOptions.store = store;
}
const toolCallLimiter = rateLimit(limiterOptions); const toolCallLimiter = rateLimit(limiterOptions);
module.exports = toolCallLimiter; module.exports = toolCallLimiter;

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { limiterCache } = require('~/cache/cacheFactory'); const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100; const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1; const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50; const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1; const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000; const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
const ttsIpMax = TTS_IP_MAX; const ttsIpMax = TTS_IP_MAX;
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
ttsUserWindowMs, ttsUserWindowMs,
ttsUserMax, ttsUserMax,
ttsUserWindowInMinutes, ttsUserWindowInMinutes,
ttsViolationScore: TTS_VIOLATION_SCORE,
}; };
}; };
const createTTSHandler = (ip = true) => { const createTTSHandler = (ip = true) => {
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } = const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
getEnvironmentVariables(); getEnvironmentVariables();
return async (req, res) => { return async (req, res) => {
@@ -42,7 +43,7 @@ const createTTSHandler = (ip = true) => {
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes, windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
}; };
await logViolation(req, res, type, errorMessage, ttsViolationScore); await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many TTS requests. Try again later' }); res.status(429).json({ message: 'Too many TTS requests. Try again later' });
}; };
}; };
@@ -54,19 +55,32 @@ const createTTSLimiters = () => {
windowMs: ttsIpWindowMs, windowMs: ttsIpWindowMs,
max: ttsIpMax, max: ttsIpMax,
handler: createTTSHandler(), handler: createTTSHandler(),
store: limiterCache('tts_ip_limiter'),
}; };
const userLimiterOptions = { const userLimiterOptions = {
windowMs: ttsUserWindowMs, windowMs: ttsUserWindowMs,
max: ttsUserMax, max: ttsUserMax,
handler: createTTSHandler(false), handler: createTTSHandler(false),
store: limiterCache('tts_user_limiter'),
keyGenerator: function (req) { keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available return req.user?.id; // Use the user ID or NULL if not available
}, },
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for TTS rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'tts_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'tts_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const ttsIpLimiter = rateLimit(ipLimiterOptions); const ttsIpLimiter = rateLimit(ipLimiterOptions);
const ttsUserLimiter = rateLimit(userLimiterOptions); const ttsUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation'); const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => { const getEnvironmentVariables = () => {
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100; const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15; const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50; const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15; const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000; const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
const fileUploadIpMax = FILE_UPLOAD_IP_MAX; const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
@@ -25,7 +27,6 @@ const getEnvironmentVariables = () => {
fileUploadUserWindowMs, fileUploadUserWindowMs,
fileUploadUserMax, fileUploadUserMax,
fileUploadUserWindowInMinutes, fileUploadUserWindowInMinutes,
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
}; };
}; };
@@ -35,7 +36,6 @@ const createFileUploadHandler = (ip = true) => {
fileUploadIpWindowInMinutes, fileUploadIpWindowInMinutes,
fileUploadUserMax, fileUploadUserMax,
fileUploadUserWindowInMinutes, fileUploadUserWindowInMinutes,
fileUploadViolationScore,
} = getEnvironmentVariables(); } = getEnvironmentVariables();
return async (req, res) => { return async (req, res) => {
@@ -47,7 +47,7 @@ const createFileUploadHandler = (ip = true) => {
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes, windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
}; };
await logViolation(req, res, type, errorMessage, fileUploadViolationScore); await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many file upload requests. Try again later' }); res.status(429).json({ message: 'Too many file upload requests. Try again later' });
}; };
}; };
@@ -60,7 +60,6 @@ const createFileLimiters = () => {
windowMs: fileUploadIpWindowMs, windowMs: fileUploadIpWindowMs,
max: fileUploadIpMax, max: fileUploadIpMax,
handler: createFileUploadHandler(), handler: createFileUploadHandler(),
store: limiterCache('file_upload_ip_limiter'),
}; };
const userLimiterOptions = { const userLimiterOptions = {
@@ -70,9 +69,23 @@ const createFileLimiters = () => {
keyGenerator: function (req) { keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available return req.user?.id; // Use the user ID or NULL if not available
}, },
store: limiterCache('file_upload_user_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for file upload rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'file_upload_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'file_upload_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const fileUploadIpLimiter = rateLimit(ipLimiterOptions); const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
const fileUploadUserLimiter = rateLimit(userLimiterOptions); const fileUploadUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,8 +1,10 @@
const rateLimit = require('express-rate-limit'); const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils'); const { removePorts, isEnabled } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory'); const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { const {
VERIFY_EMAIL_WINDOW = 2, VERIFY_EMAIL_WINDOW = 2,
@@ -31,9 +33,17 @@ const limiterOptions = {
max, max,
handler, handler,
keyGenerator: removePorts, keyGenerator: removePorts,
store: limiterCache('verify_email_limiter'),
}; };
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for verify email rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'verify_email_limiter:',
});
limiterOptions.store = store;
}
const verifyEmailLimiter = rateLimit(limiterOptions); const verifyEmailLimiter = rateLimit(limiterOptions);
module.exports = verifyEmailLimiter; module.exports = verifyEmailLimiter;

View File

@@ -0,0 +1,78 @@
const { getRoleByName } = require('~/models/Role');
const { logger } = require('~/config');
/**
* Core function to check if a user has one or more required permissions
*
* @param {object} user - The user object
* @param {PermissionTypes} permissionType - The type of permission to check
* @param {Permissions[]} permissions - The list of specific permissions to check
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
* @param {object} [checkObject] - The object to check properties against
* @returns {Promise<boolean>} Whether the user has the required permissions
*/
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
if (!user) {
return false;
}
const role = await getRoleByName(user.role);
if (role && role.permissions && role.permissions[permissionType]) {
const hasAnyPermission = permissions.some((permission) => {
if (role.permissions[permissionType][permission]) {
return true;
}
if (bodyProps[permission] && checkObject) {
return bodyProps[permission].some((prop) =>
Object.prototype.hasOwnProperty.call(checkObject, prop),
);
}
return false;
});
return hasAnyPermission;
}
return false;
};
/**
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
*
* @param {PermissionTypes} permissionType - The type of permission to check.
* @param {Permissions[]} permissions - The list of specific permissions to check.
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
*/
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
return async (req, res, next) => {
try {
const hasAccess = await checkAccess(
req.user,
permissionType,
permissions,
bodyProps,
req.body,
);
if (hasAccess) {
return next();
}
logger.warn(
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
);
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
} catch (error) {
logger.error(error);
return res.status(500).json({ message: `Server error: ${error.message}` });
}
};
};
module.exports = {
checkAccess,
generateCheckAccess,
};

View File

@@ -1,5 +1,8 @@
const checkAdmin = require('./admin'); const checkAdmin = require('./admin');
const { checkAccess, generateCheckAccess } = require('./access');
module.exports = { module.exports = {
checkAdmin, checkAdmin,
checkAccess,
generateCheckAccess,
}; };

View File

@@ -1,6 +1,5 @@
const uap = require('ua-parser-js'); const uap = require('ua-parser-js');
const { ViolationTypes } = require('librechat-data-provider'); const { handleError } = require('../utils');
const { handleError } = require('@librechat/api');
const { logViolation } = require('../../cache'); const { logViolation } = require('../../cache');
/** /**
@@ -22,7 +21,7 @@ async function uaParser(req, res, next) {
const ua = uap(req.headers['user-agent']); const ua = uap(req.headers['user-agent']);
if (!ua.browser.name) { if (!ua.browser.name) {
const type = ViolationTypes.NON_BROWSER; const type = 'non_browser';
await logViolation(req, res, type, { type }, score); await logViolation(req, res, type, { type }, score);
return handleError(res, { message: 'Illegal request' }); return handleError(res, { message: 'Illegal request' });
} }

View File

@@ -1,4 +1,4 @@
const { handleError } = require('@librechat/api'); const { handleError } = require('../utils');
function validateEndpoint(req, res, next) { function validateEndpoint(req, res, next) {
const { endpoint: _endpoint, endpointType } = req.body; const { endpoint: _endpoint, endpointType } = req.body;

View File

@@ -1,6 +1,6 @@
const { handleError } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { getModelsConfig } = require('~/server/controllers/ModelController'); const { getModelsConfig } = require('~/server/controllers/ModelController');
const { handleError } = require('~/server/utils');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
/** /**
* Validates the model of the request. * Validates the model of the request.

View File

@@ -1,162 +0,0 @@
const fs = require('fs');
const path = require('path');
const express = require('express');
const request = require('supertest');
const zlib = require('zlib');
// Create test setup
const mockTestDir = path.join(__dirname, 'test-static-route');
// Mock the paths module to point to our test directory
jest.mock('~/config/paths', () => ({
imageOutput: mockTestDir,
}));
describe('Static Route Integration', () => {
let app;
let staticRoute;
let testDir;
let testImagePath;
beforeAll(() => {
// Create a test directory and files
testDir = mockTestDir;
testImagePath = path.join(testDir, 'test-image.jpg');
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir, { recursive: true });
}
// Create a test image file
fs.writeFileSync(testImagePath, 'fake-image-data');
// Create a gzipped version of the test image (for gzip scanning tests)
fs.writeFileSync(testImagePath + '.gz', zlib.gzipSync('fake-image-data'));
});
afterAll(() => {
// Clean up test files
if (fs.existsSync(testDir)) {
fs.rmSync(testDir, { recursive: true, force: true });
}
});
// Helper function to set up static route with specific config
const setupStaticRoute = (skipGzipScan = false) => {
if (skipGzipScan) {
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
} else {
process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN = 'true';
}
staticRoute = require('../static');
app.use('/images', staticRoute);
};
beforeEach(() => {
// Clear the module cache to get fresh imports
jest.resetModules();
app = express();
// Clear environment variables
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
delete process.env.NODE_ENV;
});
describe('route functionality', () => {
it('should serve static image files', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should return 404 for non-existent files', async () => {
setupStaticRoute();
const response = await request(app).get('/images/nonexistent.jpg');
expect(response.status).toBe(404);
});
});
describe('cache behavior', () => {
it('should set cache headers for images in production', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
});
it('should not set cache headers in development', async () => {
process.env.NODE_ENV = 'development';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
// Our middleware should not set the production cache-control header in development
expect(response.headers['cache-control']).not.toBe('public, max-age=172800, s-maxage=86400');
});
});
describe('gzip compression behavior', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve gzipped files when gzip scanning is enabled', async () => {
setupStaticRoute(false); // Enable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.body.toString()).toBe('fake-image-data');
});
it('should not serve gzipped files when gzip scanning is disabled', async () => {
setupStaticRoute(true); // Disable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.body.toString()).toBe('fake-image-data');
});
});
describe('path configuration', () => {
it('should use the configured imageOutput path', async () => {
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should serve from subdirectories', async () => {
// Create a subdirectory with a file
const subDir = path.join(testDir, 'thumbs');
fs.mkdirSync(subDir, { recursive: true });
const thumbPath = path.join(subDir, 'thumb.jpg');
fs.writeFileSync(thumbPath, 'thumbnail-data');
setupStaticRoute();
const response = await request(app).get('/images/thumbs/thumb.jpg').expect(200);
expect(response.body.toString()).toBe('thumbnail-data');
// Clean up
fs.rmSync(subDir, { recursive: true, force: true });
});
});
});

View File

@@ -1,28 +1,14 @@
const express = require('express'); const express = require('express');
const { nanoid } = require('nanoid'); const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas'); const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
const { generateCheckAccess } = require('@librechat/api');
const {
SystemRoles,
Permissions,
PermissionTypes,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains'); const { isActionDomainAllowed } = require('~/server/services/domains');
const { getAgent, updateAgent } = require('~/models/Agent'); const { getAgent, updateAgent } = require('~/models/Agent');
const { getRoleByName } = require('~/models/Role'); const { logger } = require('~/config');
const router = express.Router(); const router = express.Router();
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
// If the user has ADMIN role // If the user has ADMIN role
// then action edition is possible even if not owner of the assistant // then action edition is possible even if not owner of the assistant
const isAdmin = (req) => { const isAdmin = (req) => {
@@ -55,7 +41,7 @@ router.get('/', async (req, res) => {
* @param {ActionMetadata} req.body.metadata - Metadata for the action. * @param {ActionMetadata} req.body.metadata - Metadata for the action.
* @returns {Object} 200 - success response - application/json * @returns {Object} 200 - success response - application/json
*/ */
router.post('/:agent_id', checkAgentCreate, async (req, res) => { router.post('/:agent_id', async (req, res) => {
try { try {
const { agent_id } = req.params; const { agent_id } = req.params;
@@ -163,7 +149,7 @@ router.post('/:agent_id', checkAgentCreate, async (req, res) => {
* @param {string} req.params.action_id - The ID of the action to delete. * @param {string} req.params.action_id - The ID of the action to delete.
* @returns {Object} 200 - success response - application/json * @returns {Object} 200 - success response - application/json
*/ */
router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => { router.delete('/:agent_id/:action_id', async (req, res) => {
try { try {
const { agent_id, action_id } = req.params; const { agent_id, action_id } = req.params;
const admin = isAdmin(req); const admin = isAdmin(req);

View File

@@ -1,28 +1,22 @@
const express = require('express'); const express = require('express');
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { const {
setHeaders, setHeaders,
moderateText, moderateText,
// validateModel, // validateModel,
generateCheckAccess,
validateConvoAccess, validateConvoAccess,
buildEndpointOption, buildEndpointOption,
} = require('~/server/middleware'); } = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/agents'); const { initializeClient } = require('~/server/services/Endpoints/agents');
const AgentController = require('~/server/controllers/agents/request'); const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title'); const addTitle = require('~/server/services/Endpoints/agents/title');
const { getRoleByName } = require('~/models/Role');
const router = express.Router(); const router = express.Router();
router.use(moderateText); router.use(moderateText);
const checkAgentAccess = generateCheckAccess({ const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
skipCheck: skipAgentCheck,
getRoleByName,
});
router.use(checkAgentAccess); router.use(checkAgentAccess);
router.use(validateConvoAccess); router.use(validateConvoAccess);

View File

@@ -1,4 +1,5 @@
const express = require('express'); const express = require('express');
const { addTool } = require('@librechat/api');
const { callTool, verifyToolAuth, getToolCalls } = require('~/server/controllers/tools'); const { callTool, verifyToolAuth, getToolCalls } = require('~/server/controllers/tools');
const { getAvailableTools } = require('~/server/controllers/PluginController'); const { getAvailableTools } = require('~/server/controllers/PluginController');
const { toolCallLimiter } = require('~/server/middleware/limiters'); const { toolCallLimiter } = require('~/server/middleware/limiters');
@@ -36,4 +37,12 @@ router.get('/:toolId/auth', verifyToolAuth);
*/ */
router.post('/:toolId/call', toolCallLimiter, callTool); router.post('/:toolId/call', toolCallLimiter, callTool);
/**
* Add a new tool to the system
* @route POST /agents/tools/add
* @param {object} req.body - Request body containing tool data
* @returns {object} Created tool object
*/
router.post('/add', addTool);
module.exports = router; module.exports = router;

View File

@@ -1,36 +1,29 @@
const express = require('express'); const express = require('express');
const { generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const v1 = require('~/server/controllers/agents/v1'); const v1 = require('~/server/controllers/agents/v1');
const { getRoleByName } = require('~/models/Role');
const actions = require('./actions'); const actions = require('./actions');
const tools = require('./tools'); const tools = require('./tools');
const router = express.Router(); const router = express.Router();
const avatar = express.Router(); const avatar = express.Router();
const checkAgentAccess = generateCheckAccess({ const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
permissionType: PermissionTypes.AGENTS, const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
permissions: [Permissions.USE], Permissions.USE,
getRoleByName, Permissions.CREATE,
}); ]);
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkGlobalAgentShare = generateCheckAccess({ const checkGlobalAgentShare = generateCheckAccess(
permissionType: PermissionTypes.AGENTS, PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE], [Permissions.USE, Permissions.CREATE],
bodyProps: { {
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
}, },
getRoleByName, );
});
router.use(requireJwtAuth); router.use(requireJwtAuth);
router.use(checkAgentAccess);
/** /**
* Agent actions route. * Agent actions route.

View File

@@ -1,17 +1,16 @@
const multer = require('multer'); const multer = require('multer');
const express = require('express'); const express = require('express');
const { sleep } = require('@librechat/agents');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation'); const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork'); const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
const { storage, importFileFilter } = require('~/server/routes/files/multer'); const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { importConversations } = require('~/server/utils/import'); const { importConversations } = require('~/server/utils/import');
const { createImportLimiters } = require('~/server/middleware');
const { deleteToolCalls } = require('~/models/ToolCall'); const { deleteToolCalls } = require('~/models/ToolCall');
const { isEnabled, sleep } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { logger } = require('~/config');
const assistantClients = { const assistantClients = {
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'), [EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
@@ -44,7 +43,6 @@ router.get('/', async (req, res) => {
}); });
res.status(200).json(result); res.status(200).json(result);
} catch (error) { } catch (error) {
logger.error('Error fetching conversations', error);
res.status(500).json({ error: 'Error fetching conversations' }); res.status(500).json({ error: 'Error fetching conversations' });
} }
}); });
@@ -158,7 +156,6 @@ router.post('/update', async (req, res) => {
}); });
const { importIpLimiter, importUserLimiter } = createImportLimiters(); const { importIpLimiter, importUserLimiter } = createImportLimiters();
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
const upload = multer({ storage: storage, fileFilter: importFileFilter }); const upload = multer({ storage: storage, fileFilter: importFileFilter });
/** /**
@@ -192,7 +189,7 @@ router.post(
* @param {express.Response<TForkConvoResponse>} res - Express response object. * @param {express.Response<TForkConvoResponse>} res - Express response object.
* @returns {Promise<void>} - The response after forking the conversation. * @returns {Promise<void>} - The response after forking the conversation.
*/ */
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => { router.post('/fork', async (req, res) => {
try { try {
/** @type {TForkConvoRequest} */ /** @type {TForkConvoRequest} */
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body; const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;

View File

@@ -1,282 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
// Mock dependencies
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: jest.fn().mockResolvedValue({}),
filterFile: jest.fn(),
processFileUpload: jest.fn(),
processAgentFileUpload: jest.fn(),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({})),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3FileUrls: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({
get: jest.fn(),
set: jest.fn(),
})),
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
},
}));
const { createFile } = require('~/models/File');
const { createAgent } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
// Import the router after mocks
const router = require('./files');
describe('File Routes - Agent Files Endpoint', () => {
let app;
let mongoServer;
let authorId;
let otherUserId;
let agentId;
let fileId1;
let fileId2;
let fileId3;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
// Initialize models
require('~/db/models');
app = express();
app.use(express.json());
// Mock authentication middleware
app.use((req, res, next) => {
req.user = { id: otherUserId || 'default-user' };
req.app = { locals: {} };
next();
});
app.use('/files', router);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
jest.clearAllMocks();
// Clear database
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
authorId = new mongoose.Types.ObjectId().toString();
otherUserId = new mongoose.Types.ObjectId().toString();
agentId = uuidv4();
fileId1 = uuidv4();
fileId2 = uuidv4();
fileId3 = uuidv4();
// Create files
await createFile({
user: authorId,
file_id: fileId1,
filename: 'agent-file1.txt',
filepath: `/uploads/${authorId}/${fileId1}`,
bytes: 1024,
type: 'text/plain',
});
await createFile({
user: authorId,
file_id: fileId2,
filename: 'agent-file2.txt',
filepath: `/uploads/${authorId}/${fileId2}`,
bytes: 2048,
type: 'text/plain',
});
await createFile({
user: otherUserId,
file_id: fileId3,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${fileId3}`,
bytes: 512,
type: 'text/plain',
});
// Create an agent with files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileId1, fileId2],
},
},
});
// Share the agent globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
}
});
describe('GET /files/agent/:agent_id', () => {
it('should return files accessible through the agent for non-author', async () => {
const response = await request(app).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(2); // Only agent files, not user-owned files
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).not.toContain(fileId3); // User's own file not included
});
it('should return 400 when agent_id is not provided', async () => {
const response = await request(app).get('/files/agent/');
expect(response.status).toBe(404); // Express returns 404 for missing route parameter
});
it('should return empty array for non-existent agent', async () => {
const response = await request(app).get('/files/agent/non-existent-agent');
expect(response.status).toBe(200);
expect(response.body).toEqual([]); // Empty array for non-existent agent
});
it('should return empty array when agent is not collaborative', async () => {
// Create a non-collaborative agent
const nonCollabAgentId = uuidv4();
await createAgent({
id: nonCollabAgentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: [fileId1],
},
},
});
// Share it globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: nonCollabAgentId }, { projectIds: [globalProject._id] });
}
const response = await request(app).get(`/files/agent/${nonCollabAgentId}`);
expect(response.status).toBe(200);
expect(response.body).toEqual([]); // Empty array when not collaborative
});
it('should return agent files for agent author', async () => {
// Create a new app instance with author authentication
const authorApp = express();
authorApp.use(express.json());
authorApp.use((req, res, next) => {
req.user = { id: authorId };
req.app = { locals: {} };
next();
});
authorApp.use('/files', router);
const response = await request(authorApp).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(2); // Agent files for author
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).not.toContain(fileId3); // User's own file not included
});
it('should return files uploaded by other users to shared agent for author', async () => {
// Create a file uploaded by another user
const otherUserFileId = uuidv4();
const anotherUserId = new mongoose.Types.ObjectId().toString();
await createFile({
user: anotherUserId,
file_id: otherUserFileId,
filename: 'other-user-file.txt',
filepath: `/uploads/${anotherUserId}/${otherUserFileId}`,
bytes: 4096,
type: 'text/plain',
});
// Update agent to include the file uploaded by another user
const { updateAgent } = require('~/models/Agent');
await updateAgent(
{ id: agentId },
{
tool_resources: {
file_search: {
file_ids: [fileId1, fileId2, otherUserFileId],
},
},
},
);
// Create app instance with author authentication
const authorApp = express();
authorApp.use(express.json());
authorApp.use((req, res, next) => {
req.user = { id: authorId };
req.app = { locals: {} };
next();
});
authorApp.use('/files', router);
const response = await request(authorApp).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(3); // Including file from another user
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).toContain(otherUserFileId); // File uploaded by another user
});
});
});

View File

@@ -5,7 +5,6 @@ const {
Time, Time,
isUUID, isUUID,
CacheKeys, CacheKeys,
Constants,
FileSources, FileSources,
EModelEndpoint, EModelEndpoint,
isAgentsEndpoint, isAgentsEndpoint,
@@ -17,12 +16,11 @@ const {
processDeleteRequest, processDeleteRequest,
processAgentFileUpload, processAgentFileUpload,
} = require('~/server/services/Files/process'); } = require('~/server/services/Files/process');
const { getFiles, batchUpdateFiles, hasAccessToFilesViaAgent } = require('~/models/File');
const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud'); const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
const { getProjectByName } = require('~/models/Project'); const { getFiles, batchUpdateFiles } = require('~/models/File');
const { getAssistant } = require('~/models/Assistant'); const { getAssistant } = require('~/models/Assistant');
const { getAgent } = require('~/models/Agent'); const { getAgent } = require('~/models/Agent');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
@@ -52,68 +50,6 @@ router.get('/', async (req, res) => {
} }
}); });
/**
* Get files specific to an agent
* @route GET /files/agent/:agent_id
* @param {string} agent_id - The agent ID to get files for
* @returns {Promise<TFile[]>} Array of files attached to the agent
*/
router.get('/agent/:agent_id', async (req, res) => {
try {
const { agent_id } = req.params;
const userId = req.user.id;
if (!agent_id) {
return res.status(400).json({ error: 'Agent ID is required' });
}
// Get the agent to check ownership and attached files
const agent = await getAgent({ id: agent_id });
if (!agent) {
// No agent found, return empty array
return res.status(200).json([]);
}
// Check if user has access to the agent
if (agent.author.toString() !== userId) {
// Non-authors need the agent to be globally shared and collaborative
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString()) ||
!agent.isCollaborative
) {
return res.status(200).json([]);
}
}
// Collect all file IDs from agent's tool resources
const agentFileIds = [];
if (agent.tool_resources) {
for (const [, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
agentFileIds.push(...resource.file_ids);
}
}
}
// If no files attached to agent, return empty array
if (agentFileIds.length === 0) {
return res.status(200).json([]);
}
// Get only the files attached to this agent
const files = await getFiles({ file_id: { $in: agentFileIds } }, null, { text: 0 });
res.status(200).json(files);
} catch (error) {
logger.error('[/files/agent/:agent_id] Error fetching agent files:', error);
res.status(500).json({ error: 'Failed to fetch agent files' });
}
});
router.get('/config', async (req, res) => { router.get('/config', async (req, res) => {
try { try {
res.status(200).json(req.app.locals.fileConfig); res.status(200).json(req.app.locals.fileConfig);
@@ -150,62 +86,11 @@ router.delete('/', async (req, res) => {
const fileIds = files.map((file) => file.file_id); const fileIds = files.map((file) => file.file_id);
const dbFiles = await getFiles({ file_id: { $in: fileIds } }); const dbFiles = await getFiles({ file_id: { $in: fileIds } });
const unauthorizedFiles = dbFiles.filter((file) => file.user.toString() !== req.user.id);
const ownedFiles = [];
const nonOwnedFiles = [];
const fileMap = new Map();
for (const file of dbFiles) {
fileMap.set(file.file_id, file);
if (file.user.toString() === req.user.id) {
ownedFiles.push(file);
} else {
nonOwnedFiles.push(file);
}
}
// If all files are owned by the user, no need for further checks
if (nonOwnedFiles.length === 0) {
await processDeleteRequest({ req, files: ownedFiles });
logger.debug(
`[/files] Files deleted successfully: ${ownedFiles
.filter((f) => f.file_id)
.map((f) => f.file_id)
.join(', ')}`,
);
res.status(200).json({ message: 'Files deleted successfully' });
return;
}
// Check access for non-owned files
let authorizedFiles = [...ownedFiles];
let unauthorizedFiles = [];
if (req.body.agent_id && nonOwnedFiles.length > 0) {
// Batch check access for all non-owned files
const nonOwnedFileIds = nonOwnedFiles.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
req.user.id,
nonOwnedFileIds,
req.body.agent_id,
);
// Separate authorized and unauthorized files
for (const file of nonOwnedFiles) {
if (accessMap.get(file.file_id)) {
authorizedFiles.push(file);
} else {
unauthorizedFiles.push(file);
}
}
} else {
// No agent context, all non-owned files are unauthorized
unauthorizedFiles = nonOwnedFiles;
}
if (unauthorizedFiles.length > 0) { if (unauthorizedFiles.length > 0) {
return res.status(403).json({ return res.status(403).json({
message: 'You can only delete files you have access to', message: 'You can only delete your own files',
unauthorizedFiles: unauthorizedFiles.map((f) => f.file_id), unauthorizedFiles: unauthorizedFiles.map((f) => f.file_id),
}); });
} }
@@ -246,10 +131,10 @@ router.delete('/', async (req, res) => {
.json({ message: 'File associations removed successfully from Azure Assistant' }); .json({ message: 'File associations removed successfully from Azure Assistant' });
} }
await processDeleteRequest({ req, files: authorizedFiles }); await processDeleteRequest({ req, files: dbFiles });
logger.debug( logger.debug(
`[/files] Files deleted successfully: ${authorizedFiles `[/files] Files deleted successfully: ${files
.filter((f) => f.file_id) .filter((f) => f.file_id)
.map((f) => f.file_id) .map((f) => f.file_id)
.join(', ')}`, .join(', ')}`,

View File

@@ -1,302 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
// Mock dependencies
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: jest.fn().mockResolvedValue({}),
filterFile: jest.fn(),
processFileUpload: jest.fn(),
processAgentFileUpload: jest.fn(),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({})),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3FileUrls: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({
get: jest.fn(),
set: jest.fn(),
})),
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
},
}));
const { createFile } = require('~/models/File');
const { createAgent } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { processDeleteRequest } = require('~/server/services/Files/process');
// Import the router after mocks
const router = require('./files');
describe('File Routes - Delete with Agent Access', () => {
let app;
let mongoServer;
let authorId;
let otherUserId;
let agentId;
let fileId;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
// Initialize models
require('~/db/models');
app = express();
app.use(express.json());
// Mock authentication middleware
app.use((req, res, next) => {
req.user = { id: otherUserId || 'default-user' };
req.app = { locals: {} };
next();
});
app.use('/files', router);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
jest.clearAllMocks();
// Clear database
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
authorId = new mongoose.Types.ObjectId().toString();
otherUserId = new mongoose.Types.ObjectId().toString();
fileId = uuidv4();
// Create a file owned by the author
await createFile({
user: authorId,
file_id: fileId,
filename: 'test.txt',
filepath: `/uploads/${authorId}/${fileId}`,
bytes: 1024,
type: 'text/plain',
});
// Create an agent with the file attached
const agent = await createAgent({
id: uuidv4(),
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileId],
},
},
});
agentId = agent.id;
// Share the agent globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
}
});
describe('DELETE /files', () => {
it('should allow deleting files owned by the user', async () => {
// Create a file owned by the current user
const userFileId = uuidv4();
await createFile({
user: otherUserId,
file_id: userFileId,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${userFileId}`,
bytes: 1024,
type: 'text/plain',
});
const response = await request(app)
.delete('/files')
.send({
files: [
{
file_id: userFileId,
filepath: `/uploads/${otherUserId}/${userFileId}`,
},
],
});
expect(response.status).toBe(200);
expect(response.body.message).toBe('Files deleted successfully');
expect(processDeleteRequest).toHaveBeenCalled();
});
it('should prevent deleting files not owned by user without agent context', async () => {
const response = await request(app)
.delete('/files')
.send({
files: [
{
file_id: fileId,
filepath: `/uploads/${authorId}/${fileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(fileId);
expect(processDeleteRequest).not.toHaveBeenCalled();
});
it('should allow deleting files accessible through shared agent', async () => {
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: fileId,
filepath: `/uploads/${authorId}/${fileId}`,
},
],
});
expect(response.status).toBe(200);
expect(response.body.message).toBe('Files deleted successfully');
expect(processDeleteRequest).toHaveBeenCalled();
});
it('should prevent deleting files not attached to the specified agent', async () => {
// Create another file not attached to the agent
const unattachedFileId = uuidv4();
await createFile({
user: authorId,
file_id: unattachedFileId,
filename: 'unattached.txt',
filepath: `/uploads/${authorId}/${unattachedFileId}`,
bytes: 1024,
type: 'text/plain',
});
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: unattachedFileId,
filepath: `/uploads/${authorId}/${unattachedFileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(unattachedFileId);
});
it('should handle mixed authorized and unauthorized files', async () => {
// Create a file owned by the current user
const userFileId = uuidv4();
await createFile({
user: otherUserId,
file_id: userFileId,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${userFileId}`,
bytes: 1024,
type: 'text/plain',
});
// Create an unauthorized file
const unauthorizedFileId = uuidv4();
await createFile({
user: authorId,
file_id: unauthorizedFileId,
filename: 'unauthorized.txt',
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
bytes: 1024,
type: 'text/plain',
});
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: fileId, // Authorized through agent
filepath: `/uploads/${authorId}/${fileId}`,
},
{
file_id: userFileId, // Owned by user
filepath: `/uploads/${otherUserId}/${userFileId}`,
},
{
file_id: unauthorizedFileId, // Not authorized
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(unauthorizedFileId);
expect(response.body.unauthorizedFiles).not.toContain(fileId);
expect(response.body.unauthorizedFiles).not.toContain(userFileId);
});
it('should prevent deleting files when agent is not collaborative', async () => {
// Update the agent to be non-collaborative
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { isCollaborative: false });
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: fileId,
filepath: `/uploads/${authorId}/${fileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(fileId);
expect(processDeleteRequest).not.toHaveBeenCalled();
});
});
});

View File

@@ -477,9 +477,7 @@ describe('Multer Configuration', () => {
done(new Error('Expected mkdirSync to throw an error but no error was thrown')); done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
} catch (error) { } catch (error) {
// This is the expected behavior - mkdirSync throws synchronously for invalid paths // This is the expected behavior - mkdirSync throws synchronously for invalid paths
// On Linux, this typically returns EACCES (permission denied) expect(error.code).toBe('EACCES');
// On macOS/Darwin, this returns ENOENT (no such file or directory)
expect(['EACCES', 'ENOENT']).toContain(error.code);
done(); done();
} }
}); });

View File

@@ -1,43 +1,37 @@
const express = require('express'); const express = require('express');
const { Tokenizer, generateCheckAccess } = require('@librechat/api'); const { Tokenizer } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { const {
getAllUserMemories, getAllUserMemories,
toggleUserMemories, toggleUserMemories,
createMemory, createMemory,
deleteMemory,
setMemory, setMemory,
deleteMemory,
} = require('~/models'); } = require('~/models');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role');
const router = express.Router(); const router = express.Router();
const checkMemoryRead = generateCheckAccess({ const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
permissionType: PermissionTypes.MEMORIES, Permissions.USE,
permissions: [Permissions.USE, Permissions.READ], Permissions.READ,
getRoleByName, ]);
}); const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
const checkMemoryCreate = generateCheckAccess({ Permissions.USE,
permissionType: PermissionTypes.MEMORIES, Permissions.CREATE,
permissions: [Permissions.USE, Permissions.CREATE], ]);
getRoleByName, const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
}); Permissions.USE,
const checkMemoryUpdate = generateCheckAccess({ Permissions.UPDATE,
permissionType: PermissionTypes.MEMORIES, ]);
permissions: [Permissions.USE, Permissions.UPDATE], const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
getRoleByName, Permissions.USE,
}); Permissions.UPDATE,
const checkMemoryDelete = generateCheckAccess({ ]);
permissionType: PermissionTypes.MEMORIES, const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
permissions: [Permissions.USE, Permissions.UPDATE], Permissions.USE,
getRoleByName, Permissions.OPT_OUT,
}); ]);
const checkMemoryOptOut = generateCheckAccess({
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE, Permissions.OPT_OUT],
getRoleByName,
});
router.use(requireJwtAuth); router.use(requireJwtAuth);
@@ -172,68 +166,40 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
/** /**
* PATCH /memories/:key * PATCH /memories/:key
* Updates the value of an existing memory entry for the authenticated user. * Updates the value of an existing memory entry for the authenticated user.
* Body: { key?: string, value: string } * Body: { value: string }
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful. * Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
*/ */
router.patch('/:key', checkMemoryUpdate, async (req, res) => { router.patch('/:key', checkMemoryUpdate, async (req, res) => {
const { key: urlKey } = req.params; const { key } = req.params;
const { key: bodyKey, value } = req.body || {}; const { value } = req.body || {};
if (typeof value !== 'string' || value.trim() === '') { if (typeof value !== 'string' || value.trim() === '') {
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
} }
// Use the key from the body if provided, otherwise use the key from the URL
const newKey = bodyKey || urlKey;
try { try {
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
const memories = await getAllUserMemories(req.user.id); const memories = await getAllUserMemories(req.user.id);
const existingMemory = memories.find((m) => m.key === urlKey); const existingMemory = memories.find((m) => m.key === key);
if (!existingMemory) { if (!existingMemory) {
return res.status(404).json({ error: 'Memory not found.' }); return res.status(404).json({ error: 'Memory not found.' });
} }
// If the key is changing, we need to handle it specially const result = await setMemory({
if (newKey !== urlKey) { userId: req.user.id,
const keyExists = memories.find((m) => m.key === newKey); key,
if (keyExists) { value,
return res.status(409).json({ error: 'Memory with this key already exists.' }); tokenCount,
} });
const createResult = await createMemory({ if (!result.ok) {
userId: req.user.id, return res.status(500).json({ error: 'Failed to update memory.' });
key: newKey,
value,
tokenCount,
});
if (!createResult.ok) {
return res.status(500).json({ error: 'Failed to create new memory.' });
}
const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey });
if (!deleteResult.ok) {
return res.status(500).json({ error: 'Failed to delete old memory.' });
}
} else {
// Key is not changing, just update the value
const result = await setMemory({
userId: req.user.id,
key: newKey,
value,
tokenCount,
});
if (!result.ok) {
return res.status(500).json({ error: 'Failed to update memory.' });
}
} }
const updatedMemories = await getAllUserMemories(req.user.id); const updatedMemories = await getAllUserMemories(req.user.id);
const updatedMemory = updatedMemories.find((m) => m.key === newKey); const updatedMemory = updatedMemories.find((m) => m.key === key);
res.json({ updated: true, memory: updatedMemory }); res.json({ updated: true, memory: updatedMemory });
} catch (error) { } catch (error) {

View File

@@ -235,13 +235,12 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) =
return res.status(400).json({ error: 'Content part not found' }); return res.status(400).json({ error: 'Content part not found' });
} }
const currentPartType = updatedContent[index].type; if (updatedContent[index].type !== ContentTypes.TEXT) {
if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) {
return res.status(400).json({ error: 'Cannot update non-text content' }); return res.status(400).json({ error: 'Cannot update non-text content' });
} }
const oldText = updatedContent[index][currentPartType]; const oldText = updatedContent[index].text;
updatedContent[index] = { type: currentPartType, [currentPartType]: text }; updatedContent[index] = { type: ContentTypes.TEXT, text };
let tokenCount = message.tokenCount; let tokenCount = message.tokenCount;
if (tokenCount !== undefined) { if (tokenCount !== undefined) {

View File

@@ -1,7 +1,5 @@
const express = require('express'); const express = require('express');
const { logger } = require('@librechat/data-schemas'); const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
const { generateCheckAccess } = require('@librechat/api');
const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider');
const { const {
getPrompt, getPrompt,
getPrompts, getPrompts,
@@ -16,30 +14,24 @@ const {
// updatePromptLabels, // updatePromptLabels,
makePromptProduction, makePromptProduction,
} = require('~/models/Prompt'); } = require('~/models/Prompt');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role'); const { logger } = require('~/config');
const router = express.Router(); const router = express.Router();
const checkPromptAccess = generateCheckAccess({ const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
permissionType: PermissionTypes.PROMPTS, const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
permissions: [Permissions.USE], Permissions.USE,
getRoleByName, Permissions.CREATE,
}); ]);
const checkPromptCreate = generateCheckAccess({
permissionType: PermissionTypes.PROMPTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkGlobalPromptShare = generateCheckAccess({ const checkGlobalPromptShare = generateCheckAccess(
permissionType: PermissionTypes.PROMPTS, PermissionTypes.PROMPTS,
permissions: [Permissions.USE, Permissions.CREATE], [Permissions.USE, Permissions.CREATE],
bodyProps: { {
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
}, },
getRoleByName, );
});
router.use(requireJwtAuth); router.use(requireJwtAuth);
router.use(checkPromptAccess); router.use(checkPromptAccess);

View File

@@ -1,11 +1,8 @@
const express = require('express'); const express = require('express');
const staticCache = require('../utils/staticCache'); const staticCache = require('../utils/staticCache');
const paths = require('~/config/paths'); const paths = require('~/config/paths');
const { isEnabled } = require('~/server/utils');
const skipGzipScan = !isEnabled(process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN);
const router = express.Router(); const router = express.Router();
router.use(staticCache(paths.imageOutput, { skipGzipScan })); router.use(staticCache(paths.imageOutput));
module.exports = router; module.exports = router;

View File

@@ -1,24 +1,18 @@
const express = require('express'); const express = require('express');
const { logger } = require('@librechat/data-schemas');
const { generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { const {
updateTagsForConversation, getConversationTags,
updateConversationTag, updateConversationTag,
createConversationTag, createConversationTag,
deleteConversationTag, deleteConversationTag,
getConversationTags, updateTagsForConversation,
} = require('~/models/ConversationTag'); } = require('~/models/ConversationTag');
const { requireJwtAuth } = require('~/server/middleware'); const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role'); const { logger } = require('~/config');
const router = express.Router(); const router = express.Router();
const checkBookmarkAccess = generateCheckAccess({ const checkBookmarkAccess = generateCheckAccess(PermissionTypes.BOOKMARKS, [Permissions.USE]);
permissionType: PermissionTypes.BOOKMARKS,
permissions: [Permissions.USE],
getRoleByName,
});
router.use(requireJwtAuth); router.use(requireJwtAuth);
router.use(checkBookmarkAccess); router.use(checkBookmarkAccess);

View File

@@ -1,11 +1,12 @@
const { agentsConfigSetup, loadWebSearchConfig } = require('@librechat/api');
const { const {
FileSources, FileSources,
loadOCRConfig, loadOCRConfig,
EModelEndpoint, EModelEndpoint,
loadMemoryConfig, loadMemoryConfig,
getConfigDefaults, getConfigDefaults,
loadWebSearchConfig,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { agentsConfigSetup } = require('@librechat/api');
const { const {
checkHealth, checkHealth,
checkConfig, checkConfig,

View File

@@ -152,14 +152,12 @@ describe('AppService', () => {
filteredTools: undefined, filteredTools: undefined,
includedTools: undefined, includedTools: undefined,
webSearch: { webSearch: {
safeSearch: 1,
jinaApiKey: '${JINA_API_KEY}',
cohereApiKey: '${COHERE_API_KEY}', cohereApiKey: '${COHERE_API_KEY}',
serperApiKey: '${SERPER_API_KEY}',
searxngApiKey: '${SEARXNG_API_KEY}',
firecrawlApiKey: '${FIRECRAWL_API_KEY}', firecrawlApiKey: '${FIRECRAWL_API_KEY}',
firecrawlApiUrl: '${FIRECRAWL_API_URL}', firecrawlApiUrl: '${FIRECRAWL_API_URL}',
searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}', jinaApiKey: '${JINA_API_KEY}',
safeSearch: 1,
serperApiKey: '${SERPER_API_KEY}',
}, },
memory: undefined, memory: undefined,
agents: { agents: {

View File

@@ -1,5 +1,4 @@
const bcrypt = require('bcryptjs'); const bcrypt = require('bcryptjs');
const jwt = require('jsonwebtoken');
const { webcrypto } = require('node:crypto'); const { webcrypto } = require('node:crypto');
const { isEnabled } = require('@librechat/api'); const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
@@ -500,18 +499,6 @@ const resendVerificationEmail = async (req) => {
}; };
} }
}; };
/**
* Generate a short-lived JWT token
* @param {String} userId - The ID of the user
* @param {String} [expireIn='5m'] - The expiration time for the token (default is 5 minutes)
* @returns {String} - The generated JWT token
*/
const generateShortLivedToken = (userId, expireIn = '5m') => {
return jwt.sign({ id: userId }, process.env.JWT_SECRET, {
expiresIn: expireIn,
algorithm: 'HS256',
});
};
module.exports = { module.exports = {
logoutUser, logoutUser,
@@ -519,8 +506,7 @@ module.exports = {
registerUser, registerUser,
setAuthTokens, setAuthTokens,
resetPassword, resetPassword,
setOpenIDAuthTokens,
requestPasswordReset, requestPasswordReset,
resendVerificationEmail, resendVerificationEmail,
generateShortLivedToken, setOpenIDAuthTokens,
}; };

View File

@@ -1,9 +1,10 @@
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api'); const { getUserMCPAuthMap } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils'); const { normalizeEndpointName, isEnabled } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig'); const loadCustomConfig = require('./loadCustomConfig');
const { getCachedTools } = require('./getCachedTools'); const { getCachedTools } = require('./getCachedTools');
const { findPluginAuthsByKeys } = require('~/models');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
/** /**
@@ -54,48 +55,46 @@ const getCustomEndpointConfig = async (endpoint) => {
); );
}; };
/** async function createGetMCPAuthMap() {
* @param {Object} params
* @param {string} params.userId
* @param {GenericTool[]} [params.tools]
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
try {
if (!tools || tools.length === 0) {
return;
}
const appTools = await getCachedTools({
userId,
});
return await getUserMCPAuthMap({
tools,
userId,
appTools,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
}
/**
* @returns {Promise<boolean>}
*/
async function hasCustomUserVars() {
const customConfig = await getCustomConfig(); const customConfig = await getCustomConfig();
const mcpServers = customConfig?.mcpServers; const mcpServers = customConfig?.mcpServers;
return Object.values(mcpServers ?? {}).some((server) => server.customUserVars); const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
if (!hasCustomUserVars) {
return;
}
/**
* @param {Object} params
* @param {GenericTool[]} [params.tools]
* @param {string} params.userId
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
return async function ({ tools, userId }) {
try {
if (!tools || tools.length === 0) {
return;
}
const appTools = await getCachedTools({
userId,
});
return await getUserMCPAuthMap({
tools,
userId,
appTools,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
};
} }
module.exports = { module.exports = {
getMCPAuthMap,
getCustomConfig, getCustomConfig,
getBalanceConfig, getBalanceConfig,
hasCustomUserVars, createGetMCPAuthMap,
getCustomEndpointConfig, getCustomEndpointConfig,
}; };

View File

@@ -1,10 +1,4 @@
const { const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
CacheKeys,
EModelEndpoint,
isAgentsEndpoint,
orderEndpointsConfig,
defaultAgentCapabilities,
} = require('librechat-data-provider');
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig'); const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
const loadConfigEndpoints = require('./loadConfigEndpoints'); const loadConfigEndpoints = require('./loadConfigEndpoints');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
@@ -86,12 +80,8 @@ async function getEndpointsConfig(req) {
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
const checkCapability = async (req, capability) => { const checkCapability = async (req, capability) => {
const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint);
const endpointsConfig = await getEndpointsConfig(req); const endpointsConfig = await getEndpointsConfig(req);
const capabilities = const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [])
: defaultAgentCapabilities;
return capabilities.includes(capability); return capabilities.includes(capability);
}; };

View File

@@ -1,7 +1,5 @@
const path = require('path');
const { logger } = require('@librechat/data-schemas');
const { loadServiceKey, isUserProvided } = require('@librechat/api');
const { EModelEndpoint } = require('librechat-data-provider'); const { EModelEndpoint } = require('librechat-data-provider');
const { isUserProvided } = require('~/server/utils');
const { config } = require('./EndpointService'); const { config } = require('./EndpointService');
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config; const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config;
@@ -11,41 +9,37 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
* @param {Express.Request} req - The request object * @param {Express.Request} req - The request object
*/ */
async function loadAsyncEndpoints(req) { async function loadAsyncEndpoints(req) {
let i = 0;
let serviceKey, googleUserProvides; let serviceKey, googleUserProvides;
try {
/** Check if GOOGLE_KEY is provided at all(including 'user_provided') */ serviceKey = require('~/data/auth.json');
const isGoogleKeyProvided = googleKey && googleKey.trim() !== ''; } catch (e) {
if (i === 0) {
if (isGoogleKeyProvided) { i++;
/** If GOOGLE_KEY is provided, check if it's user_provided */
googleUserProvides = isUserProvided(googleKey);
} else {
/** Only attempt to load service key if GOOGLE_KEY is not provided */
const serviceKeyPath =
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json');
try {
serviceKey = await loadServiceKey(serviceKeyPath);
} catch (error) {
logger.error('Error loading service key', error);
serviceKey = null;
} }
} }
const google = serviceKey || isGoogleKeyProvided ? { userProvide: googleUserProvides } : false; if (isUserProvided(googleKey)) {
googleUserProvides = true;
if (i <= 1) {
i++;
}
}
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins; const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
const gptPlugins = const gptPlugins =
useAzure || openAIApiKey || azureOpenAIApiKey useAzure || openAIApiKey || azureOpenAIApiKey
? { ? {
availableAgents: ['classic', 'functions'], availableAgents: ['classic', 'functions'],
userProvide: useAzure ? false : userProvidedOpenAI, userProvide: useAzure ? false : userProvidedOpenAI,
userProvideURL: useAzure userProvideURL: useAzure
? false ? false
: config[EModelEndpoint.openAI]?.userProvideURL || : config[EModelEndpoint.openAI]?.userProvideURL ||
config[EModelEndpoint.azureOpenAI]?.userProvideURL, config[EModelEndpoint.azureOpenAI]?.userProvideURL,
azure: useAzurePlugins || useAzure, azure: useAzurePlugins || useAzure,
} }
: false; : false;
return { google, gptPlugins }; return { google, gptPlugins };

View File

@@ -8,12 +8,11 @@ const {
ErrorTypes, ErrorTypes,
EModelEndpoint, EModelEndpoint,
EToolResources, EToolResources,
isAgentsEndpoint,
replaceSpecialVars, replaceSpecialVars,
providerEndpointMap, providerEndpointMap,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getProviderConfig } = require('~/server/services/Endpoints'); const { getProviderConfig } = require('~/server/services/Endpoints');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { processFiles } = require('~/server/services/Files/process'); const { processFiles } = require('~/server/services/Files/process');
const { getFiles, getToolFilesByIds } = require('~/models/File'); const { getFiles, getToolFilesByIds } = require('~/models/File');
const { getConvoFiles } = require('~/models/Conversation'); const { getConvoFiles } = require('~/models/Conversation');
@@ -43,11 +42,7 @@ const initializeAgent = async ({
allowedProviders, allowedProviders,
isInitialAgent = false, isInitialAgent = false,
}) => { }) => {
if ( if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
isAgentsEndpoint(endpointOption?.endpoint) &&
allowedProviders.size > 0 &&
!allowedProviders.has(agent.provider)
) {
throw new Error( throw new Error(
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
); );
@@ -87,11 +82,10 @@ const initializeAgent = async ({
attachments: currentFiles, attachments: currentFiles,
tool_resources: agent.tool_resources, tool_resources: agent.tool_resources,
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)), requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
agentId: agent.id,
}); });
const provider = agent.provider; const provider = agent.provider;
const { tools: structuredTools, toolContextMap } = const { tools, toolContextMap } =
(await loadTools?.({ (await loadTools?.({
req, req,
res, res,
@@ -146,24 +140,6 @@ const initializeAgent = async ({
agent.provider = options.provider; agent.provider = options.provider;
} }
/** @type {import('@librechat/agents').GenericTool[]} */
let tools = options.tools?.length ? options.tools : structuredTools;
if (
(agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) &&
options.tools?.length &&
structuredTools?.length
) {
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
} else if (
(agent.provider === Providers.OPENAI ||
agent.provider === Providers.AZURE ||
agent.provider === Providers.ANTHROPIC) &&
options.tools?.length &&
structuredTools?.length
) {
tools = structuredTools.concat(options.tools);
}
/** @type {import('@librechat/agents').ClientOptions} */ /** @type {import('@librechat/agents').ClientOptions} */
agent.model_parameters = { ...options.llmConfig }; agent.model_parameters = { ...options.llmConfig };
if (options.configOptions) { if (options.configOptions) {
@@ -190,7 +166,6 @@ const initializeAgent = async ({
attachments, attachments,
resendFiles, resendFiles,
toolContextMap, toolContextMap,
useLegacyContent: !!options.useLegacyContent,
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9, maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
}; };
}; };

View File

@@ -78,17 +78,7 @@ function getLLMConfig(apiKey, options = {}) {
requestOptions.anthropicApiUrl = options.reverseProxyUrl; requestOptions.anthropicApiUrl = options.reverseProxyUrl;
} }
const tools = [];
if (mergedOptions.web_search) {
tools.push({
type: 'web_search_20250305',
name: 'web_search',
});
}
return { return {
tools,
/** @type {AnthropicClientOptions} */ /** @type {AnthropicClientOptions} */
llmConfig: removeNullishValues(requestOptions), llmConfig: removeNullishValues(requestOptions),
}; };

View File

@@ -139,9 +139,6 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
); );
clientOptions.modelOptions.user = req.user.id; clientOptions.modelOptions.user = req.user.id;
const options = getOpenAIConfig(apiKey, clientOptions, endpoint); const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
if (options != null) {
options.useLegacyContent = true;
}
if (!customOptions.streamRate) { if (!customOptions.streamRate) {
return options; return options;
} }
@@ -159,7 +156,6 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
} }
return { return {
useLegacyContent: true,
llmConfig: modelOptions, llmConfig: modelOptions,
}; };
} }

View File

@@ -1,6 +1,5 @@
const path = require('path'); const { getGoogleConfig, isEnabled } = require('@librechat/api');
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider'); const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
const { getGoogleConfig, isEnabled, loadServiceKey } = require('@librechat/api');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { GoogleClient } = require('~/app'); const { GoogleClient } = require('~/app');
@@ -16,25 +15,10 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
} }
let serviceKey = {}; let serviceKey = {};
try {
/** Check if GOOGLE_KEY is provided at all (including 'user_provided') */ serviceKey = require('~/data/auth.json');
const isGoogleKeyProvided = } catch (_e) {
(GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null); // Do nothing
if (!isGoogleKeyProvided) {
/** Only attempt to load service key if GOOGLE_KEY is not provided */
try {
const serviceKeyPath =
process.env.GOOGLE_SERVICE_KEY_FILE ||
path.join(__dirname, '../../../..', 'data', 'auth.json');
serviceKey = await loadServiceKey(serviceKeyPath);
if (!serviceKey) {
serviceKey = {};
}
} catch (_e) {
// Service key loading failed, but that's okay if not required
serviceKey = {};
}
} }
const credentials = isUserProvided const credentials = isUserProvided

View File

@@ -7,16 +7,6 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize'); const initGoogle = require('~/server/services/Endpoints/google/initialize');
const { getCustomEndpointConfig } = require('~/server/services/Config'); const { getCustomEndpointConfig } = require('~/server/services/Config');
/** Check if the provider is a known custom provider
* @param {string | undefined} [provider] - The provider string
* @returns {boolean} - True if the provider is a known custom provider, false otherwise
*/
function isKnownCustomProvider(provider) {
return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes(
provider?.toLowerCase() || '',
);
}
const providerConfigMap = { const providerConfigMap = {
[Providers.XAI]: initCustom, [Providers.XAI]: initCustom,
[Providers.OLLAMA]: initCustom, [Providers.OLLAMA]: initCustom,
@@ -56,13 +46,6 @@ async function getProviderConfig(provider) {
overrideProvider = Providers.OPENAI; overrideProvider = Providers.OPENAI;
} }
if (isKnownCustomProvider(overrideProvider || provider) && !customEndpointConfig) {
customEndpointConfig = await getCustomEndpointConfig(provider);
if (!customEndpointConfig) {
throw new Error(`Provider ${provider} not supported`);
}
}
return { return {
getOptions, getOptions,
overrideProvider, overrideProvider,

View File

@@ -65,20 +65,19 @@ const initializeClient = async ({
const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI; const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI;
/** @type {false | TAzureConfig} */ /** @type {false | TAzureConfig} */
const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI]; const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI];
let serverless = false;
if (isAzureOpenAI && azureConfig) { if (isAzureOpenAI && azureConfig) {
const { modelGroupMap, groupMap } = azureConfig; const { modelGroupMap, groupMap } = azureConfig;
const { const {
azureOptions, azureOptions,
baseURL, baseURL,
headers = {}, headers = {},
serverless: _serverless, serverless,
} = mapModelToAzureConfig({ } = mapModelToAzureConfig({
modelName, modelName,
modelGroupMap, modelGroupMap,
groupMap, groupMap,
}); });
serverless = _serverless;
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
clientOptions.headers = resolveHeaders( clientOptions.headers = resolveHeaders(
@@ -144,9 +143,6 @@ const initializeClient = async ({
clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions.modelOptions.user = req.user.id; clientOptions.modelOptions.user = req.user.id;
const options = getOpenAIConfig(apiKey, clientOptions); const options = getOpenAIConfig(apiKey, clientOptions);
if (options != null && serverless === true) {
options.useLegacyContent = true;
}
const streamRate = clientOptions.streamRate; const streamRate = clientOptions.streamRate;
if (!streamRate) { if (!streamRate) {
return options; return options;

View File

@@ -152,7 +152,6 @@ async function getSessionInfo(fileIdentifier, apiKey) {
* @param {Object} options * @param {Object} options
* @param {ServerRequest} options.req * @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources * @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @param {string} apiKey * @param {string} apiKey
* @returns {Promise<{ * @returns {Promise<{
* files: Array<{ id: string; session_id: string; name: string }>, * files: Array<{ id: string; session_id: string; name: string }>,
@@ -160,18 +159,11 @@ async function getSessionInfo(fileIdentifier, apiKey) {
* }>} * }>}
*/ */
const primeFiles = async (options, apiKey) => { const primeFiles = async (options, apiKey) => {
const { tool_resources, req, agentId } = options; const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.execute_code]?.file_ids ?? []; const file_ids = tool_resources?.[EToolResources.execute_code]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids); const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.execute_code]?.files ?? []; const resourceFiles = tool_resources?.[EToolResources.execute_code]?.files ?? [];
const dbFiles = ( const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
const files = []; const files = [];
const sessions = new Map(); const sessions = new Map();

View File

@@ -1,11 +1,10 @@
const fs = require('fs'); const fs = require('fs');
const path = require('path'); const path = require('path');
const axios = require('axios'); const axios = require('axios');
const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint } = require('librechat-data-provider'); const { EModelEndpoint } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getBufferMetadata } = require('~/server/utils'); const { getBufferMetadata } = require('~/server/utils');
const paths = require('~/config/paths'); const paths = require('~/config/paths');
const { logger } = require('~/config');
/** /**
* Saves a file to a specified output path with a new filename. * Saves a file to a specified output path with a new filename.
@@ -207,7 +206,7 @@ const deleteLocalFile = async (req, file) => {
const cleanFilepath = file.filepath.split('?')[0]; const cleanFilepath = file.filepath.split('?')[0];
if (file.embedded && process.env.RAG_API_URL) { if (file.embedded && process.env.RAG_API_URL) {
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
axios.delete(`${process.env.RAG_API_URL}/documents`, { axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: { headers: {
Authorization: `Bearer ${jwtToken}`, Authorization: `Bearer ${jwtToken}`,

View File

@@ -4,7 +4,6 @@ const FormData = require('form-data');
const { logAxiosError } = require('@librechat/api'); const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { FileSources } = require('librechat-data-provider'); const { FileSources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
/** /**
* Deletes a file from the vector database. This function takes a file object, constructs the full path, and * Deletes a file from the vector database. This function takes a file object, constructs the full path, and
@@ -24,8 +23,7 @@ const deleteVectors = async (req, file) => {
return; return;
} }
try { try {
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
return await axios.delete(`${process.env.RAG_API_URL}/documents`, { return await axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: { headers: {
Authorization: `Bearer ${jwtToken}`, Authorization: `Bearer ${jwtToken}`,
@@ -72,7 +70,7 @@ async function uploadVectors({ req, file, file_id, entity_id }) {
} }
try { try {
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
const formData = new FormData(); const formData = new FormData();
formData.append('file_id', file_id); formData.append('file_id', file_id);
formData.append('file', fs.createReadStream(file.path)); formData.append('file', fs.createReadStream(file.path));

View File

@@ -55,9 +55,7 @@ const processFiles = async (files, fileIds) => {
} }
if (!fileIds) { if (!fileIds) {
const results = await Promise.all(promises); return await Promise.all(promises);
// Filter out null results from failed updateFileUsage calls
return results.filter((result) => result != null);
} }
for (let file_id of fileIds) { for (let file_id of fileIds) {
@@ -69,9 +67,7 @@ const processFiles = async (files, fileIds) => {
} }
// TODO: calculate token cost when image is first uploaded // TODO: calculate token cost when image is first uploaded
const results = await Promise.all(promises); return await Promise.all(promises);
// Filter out null results from failed updateFileUsage calls
return results.filter((result) => result != null);
}; };
/** /**

View File

@@ -1,208 +0,0 @@
// Mock the updateFileUsage function before importing the actual processFiles
jest.mock('~/models/File', () => ({
updateFileUsage: jest.fn(),
}));
// Mock winston and logger configuration to avoid dependency issues
jest.mock('~/config', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
}));
// Mock all other dependencies that might cause issues
jest.mock('librechat-data-provider', () => ({
isUUID: { parse: jest.fn() },
megabyte: 1024 * 1024,
FileContext: { message_attachment: 'message_attachment' },
FileSources: { local: 'local' },
EModelEndpoint: { assistants: 'assistants' },
EToolResources: { file_search: 'file_search' },
mergeFileConfig: jest.fn(),
removeNullishValues: jest.fn((obj) => obj),
isAssistantsEndpoint: jest.fn(),
}));
jest.mock('~/server/services/Files/images', () => ({
convertImage: jest.fn(),
resizeAndConvert: jest.fn(),
resizeImageBuffer: jest.fn(),
}));
jest.mock('~/server/controllers/assistants/v2', () => ({
addResourceFileId: jest.fn(),
deleteResourceFileId: jest.fn(),
}));
jest.mock('~/models/Agent', () => ({
addAgentResourceFile: jest.fn(),
removeAgentResourceFiles: jest.fn(),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
checkCapability: jest.fn(),
}));
jest.mock('~/server/utils/queue', () => ({
LB_QueueAsyncCall: jest.fn(),
}));
jest.mock('./strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/utils', () => ({
determineFileType: jest.fn(),
}));
// Import the actual processFiles function after all mocks are set up
const { processFiles } = require('./process');
const { updateFileUsage } = require('~/models/File');
describe('processFiles', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('null filtering functionality', () => {
it('should filter out null results from updateFileUsage when files do not exist', async () => {
const mockFiles = [
{ file_id: 'existing-file-1' },
{ file_id: 'non-existent-file' },
{ file_id: 'existing-file-2' },
];
// Mock updateFileUsage to return null for non-existent files
updateFileUsage.mockImplementation(({ file_id }) => {
if (file_id === 'non-existent-file') {
return Promise.resolve(null); // Simulate file not found in the database
}
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
expect(updateFileUsage).toHaveBeenCalledTimes(3);
expect(result).toEqual([
{ file_id: 'existing-file-1', usage: 1 },
{ file_id: 'existing-file-2', usage: 1 },
]);
// Critical test - ensure no null values in result
expect(result).not.toContain(null);
expect(result).not.toContain(undefined);
expect(result.length).toBe(2); // Only valid files should be returned
});
it('should return empty array when all updateFileUsage calls return null', async () => {
const mockFiles = [{ file_id: 'non-existent-1' }, { file_id: 'non-existent-2' }];
// All updateFileUsage calls return null
updateFileUsage.mockResolvedValue(null);
const result = await processFiles(mockFiles);
expect(updateFileUsage).toHaveBeenCalledTimes(2);
expect(result).toEqual([]);
expect(result).not.toContain(null);
expect(result.length).toBe(0);
});
it('should work correctly when all files exist', async () => {
const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }];
updateFileUsage.mockImplementation(({ file_id }) => {
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
expect(result).toEqual([
{ file_id: 'file-1', usage: 1 },
{ file_id: 'file-2', usage: 1 },
]);
expect(result).not.toContain(null);
expect(result.length).toBe(2);
});
it('should handle fileIds parameter and filter nulls correctly', async () => {
const mockFiles = [{ file_id: 'file-1' }];
const mockFileIds = ['file-2', 'non-existent-file'];
updateFileUsage.mockImplementation(({ file_id }) => {
if (file_id === 'non-existent-file') {
return Promise.resolve(null);
}
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles, mockFileIds);
expect(result).toEqual([
{ file_id: 'file-1', usage: 1 },
{ file_id: 'file-2', usage: 1 },
]);
expect(result).not.toContain(null);
expect(result).not.toContain(undefined);
expect(result.length).toBe(2);
});
it('should handle duplicate file_ids correctly', async () => {
const mockFiles = [
{ file_id: 'duplicate-file' },
{ file_id: 'duplicate-file' }, // Duplicate should be ignored
{ file_id: 'unique-file' },
];
updateFileUsage.mockImplementation(({ file_id }) => {
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
// Should only call updateFileUsage twice (duplicate ignored)
expect(updateFileUsage).toHaveBeenCalledTimes(2);
expect(result).toEqual([
{ file_id: 'duplicate-file', usage: 1 },
{ file_id: 'unique-file', usage: 1 },
]);
expect(result.length).toBe(2);
});
});
describe('edge cases', () => {
it('should handle empty files array', async () => {
const result = await processFiles([]);
expect(result).toEqual([]);
expect(updateFileUsage).not.toHaveBeenCalled();
});
it('should handle mixed null and undefined returns from updateFileUsage', async () => {
const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }, { file_id: 'file-3' }];
updateFileUsage.mockImplementation(({ file_id }) => {
if (file_id === 'file-1') return Promise.resolve(null);
if (file_id === 'file-2') return Promise.resolve(undefined);
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
expect(result).toEqual([{ file_id: 'file-3', usage: 1 }]);
expect(result).not.toContain(null);
expect(result).not.toContain(undefined);
expect(result.length).toBe(1);
});
});
});

View File

@@ -1,9 +1,5 @@
const { FileSources } = require('librechat-data-provider'); const { FileSources } = require('librechat-data-provider');
const { const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api');
uploadMistralOCR,
uploadAzureMistralOCR,
uploadGoogleVertexMistralOCR,
} = require('@librechat/api');
const { const {
getFirebaseURL, getFirebaseURL,
prepareImageURL, prepareImageURL,
@@ -226,26 +222,6 @@ const azureMistralOCRStrategy = () => ({
handleFileUpload: uploadAzureMistralOCR, handleFileUpload: uploadAzureMistralOCR,
}); });
const vertexMistralOCRStrategy = () => ({
/** @type {typeof saveFileFromURL | null} */
saveURL: null,
/** @type {typeof getLocalFileURL | null} */
getFileURL: null,
/** @type {typeof saveLocalBuffer | null} */
saveBuffer: null,
/** @type {typeof processLocalAvatar | null} */
processAvatar: null,
/** @type {typeof uploadLocalImage | null} */
handleImageUpload: null,
/** @type {typeof prepareImagesLocal | null} */
prepareImagePayload: null,
/** @type {typeof deleteLocalFile | null} */
deleteFile: null,
/** @type {typeof getLocalFileStream | null} */
getDownloadStream: null,
handleFileUpload: uploadGoogleVertexMistralOCR,
});
// Strategy Selector // Strategy Selector
const getStrategyFunctions = (fileSource) => { const getStrategyFunctions = (fileSource) => {
if (fileSource === FileSources.firebase) { if (fileSource === FileSources.firebase) {
@@ -268,8 +244,6 @@ const getStrategyFunctions = (fileSource) => {
return mistralOCRStrategy(); return mistralOCRStrategy();
} else if (fileSource === FileSources.azure_mistral_ocr) { } else if (fileSource === FileSources.azure_mistral_ocr) {
return azureMistralOCRStrategy(); return azureMistralOCRStrategy();
} else if (fileSource === FileSources.vertexai_mistral_ocr) {
return vertexMistralOCRStrategy();
} else { } else {
throw new Error('Invalid file source'); throw new Error('Invalid file source');
} }

View File

@@ -2,16 +2,16 @@ const { z } = require('zod');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('@librechat/data-schemas');
const { Time, CacheKeys, StepTypes } = require('librechat-data-provider'); const { Time, CacheKeys, StepTypes } = require('librechat-data-provider');
const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api');
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents'); const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
const { Constants, ContentTypes, isAssistantsEndpoint } = require('librechat-data-provider');
const { const {
sendEvent, Constants,
MCPOAuthHandler, ContentTypes,
normalizeServerName, isAssistantsEndpoint,
convertWithResolvedRefs, convertJsonSchemaToZod,
} = require('@librechat/api'); } = require('librechat-data-provider');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config'); const { getMCPManager, getFlowStateManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getCachedTools } = require('./Config'); const { getCachedTools } = require('./Config');
const { getLogStores } = require('~/cache'); const { getLogStores } = require('~/cache');
@@ -113,7 +113,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
/** @type {LCTool} */ /** @type {LCTool} */
const { description, parameters } = toolDefinition; const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
let schema = convertWithResolvedRefs(parameters, { let schema = convertJsonSchemaToZod(parameters, {
allowEmptyObject: !isGoogle, allowEmptyObject: !isGoogle,
transformOneOfAnyOf: true, transformOneOfAnyOf: true,
}); });

View File

@@ -44,9 +44,6 @@ async function initializeMCP(app) {
await mcpManager.mapAvailableTools(toolsCopy, flowManager); await mcpManager.mapAvailableTools(toolsCopy, flowManager);
await setCachedTools(toolsCopy, { isGlobal: true }); await setCachedTools(toolsCopy, { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug('Cleared tools array cache after MCP initialization');
logger.info('MCP servers initialized successfully'); logger.info('MCP servers initialized successfully');
} catch (error) { } catch (error) {
logger.error('Failed to initialize MCP servers:', error); logger.error('Failed to initialize MCP servers:', error);

View File

@@ -1,6 +1,6 @@
const { webSearchKeys } = require('@librechat/api');
const { const {
Constants, Constants,
webSearchKeys,
deprecatedAzureVariables, deprecatedAzureVariables,
conflictingAzureVariables, conflictingAzureVariables,
extractVariableName, extractVariableName,

View File

@@ -1,5 +1,8 @@
const { Keyv } = require('keyv');
const passport = require('passport'); const passport = require('passport');
const session = require('express-session'); const session = require('express-session');
const MemoryStore = require('memorystore')(session);
const RedisStore = require('connect-redis').default;
const { const {
setupOpenId, setupOpenId,
googleLogin, googleLogin,
@@ -11,9 +14,8 @@ const {
openIdJwtLogin, openIdJwtLogin,
} = require('~/strategies'); } = require('~/strategies');
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config'); const { logger } = require('~/config');
const { getLogStores } = require('~/cache');
const { CacheKeys } = require('librechat-data-provider');
/** /**
* *
@@ -49,8 +51,17 @@ const configureSocialLogins = async (app) => {
secret: process.env.OPENID_SESSION_SECRET, secret: process.env.OPENID_SESSION_SECRET,
resave: false, resave: false,
saveUninitialized: false, saveUninitialized: false,
store: getLogStores(CacheKeys.OPENID_SESSION),
}; };
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for session storage in OpenID...');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.client;
sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' });
} else {
sessionOptions.store = new MemoryStore({
checkPeriod: 86400000, // prune expired entries every 24h
});
}
app.use(session(sessionOptions)); app.use(session(sessionOptions));
app.use(passport.session()); app.use(passport.session());
const config = await setupOpenId(); const config = await setupOpenId();
@@ -71,8 +82,17 @@ const configureSocialLogins = async (app) => {
secret: process.env.SAML_SESSION_SECRET, secret: process.env.SAML_SESSION_SECRET,
resave: false, resave: false,
saveUninitialized: false, saveUninitialized: false,
store: getLogStores(CacheKeys.SAML_SESSION),
}; };
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for session storage in SAML...');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.client;
sessionOptions.store = new RedisStore({ client, prefix: 'saml_session' });
} else {
sessionOptions.store = new MemoryStore({
checkPeriod: 86400000, // prune expired entries every 24h
});
}
app.use(session(sessionOptions)); app.use(session(sessionOptions));
app.use(passport.session()); app.use(passport.session());
setupSaml(); setupSaml();

View File

@@ -1,407 +0,0 @@
const fs = require('fs');
const path = require('path');
const express = require('express');
const request = require('supertest');
const zlib = require('zlib');
const staticCache = require('../staticCache');
describe('staticCache', () => {
let app;
let testDir;
let testFile;
let indexFile;
let manifestFile;
let swFile;
beforeAll(() => {
// Create a test directory and files
testDir = path.join(__dirname, 'test-static');
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir, { recursive: true });
}
// Create test files
testFile = path.join(testDir, 'test.js');
indexFile = path.join(testDir, 'index.html');
manifestFile = path.join(testDir, 'manifest.json');
swFile = path.join(testDir, 'sw.js');
const jsContent = 'console.log("test");';
const htmlContent = '<html><body>Test</body></html>';
const jsonContent = '{"name": "test"}';
const swContent = 'self.addEventListener("install", () => {});';
fs.writeFileSync(testFile, jsContent);
fs.writeFileSync(indexFile, htmlContent);
fs.writeFileSync(manifestFile, jsonContent);
fs.writeFileSync(swFile, swContent);
// Create gzipped versions of some files
fs.writeFileSync(testFile + '.gz', zlib.gzipSync(jsContent));
fs.writeFileSync(path.join(testDir, 'test.css'), 'body { color: red; }');
fs.writeFileSync(path.join(testDir, 'test.css.gz'), zlib.gzipSync('body { color: red; }'));
// Create a file that only exists in gzipped form
fs.writeFileSync(
path.join(testDir, 'only-gzipped.js.gz'),
zlib.gzipSync('console.log("only gzipped");'),
);
// Create a subdirectory for dist/images testing
const distImagesDir = path.join(testDir, 'dist', 'images');
fs.mkdirSync(distImagesDir, { recursive: true });
fs.writeFileSync(path.join(distImagesDir, 'logo.png'), 'fake-png-data');
});
afterAll(() => {
// Clean up test files
if (fs.existsSync(testDir)) {
fs.rmSync(testDir, { recursive: true, force: true });
}
});
beforeEach(() => {
app = express();
// Clear environment variables
delete process.env.NODE_ENV;
delete process.env.STATIC_CACHE_S_MAX_AGE;
delete process.env.STATIC_CACHE_MAX_AGE;
});
describe('cache headers in production', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should set standard cache headers for regular files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
});
it('should set no-cache headers for index.html', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/index.html').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
it('should set no-cache headers for manifest.json', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/manifest.json').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
it('should set no-cache headers for sw.js', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/sw.js').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
it('should not set cache headers for /dist/images/ files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/dist/images/logo.png').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=0');
});
it('should set no-cache headers when noCache option is true', async () => {
app.use(staticCache(testDir, { noCache: true }));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
});
describe('cache headers in non-production', () => {
beforeEach(() => {
process.env.NODE_ENV = 'development';
});
it('should not set cache headers in development', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
// Our middleware should not set cache-control in non-production
// Express static might set its own default headers
const cacheControl = response.headers['cache-control'];
expect(cacheControl).toBe('public, max-age=0');
});
});
describe('environment variable configuration', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should use custom s-maxage from environment', async () => {
process.env.STATIC_CACHE_S_MAX_AGE = '3600';
// Need to re-require to pick up new env vars
jest.resetModules();
const freshStaticCache = require('../staticCache');
app.use(freshStaticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=3600');
});
it('should use custom max-age from environment', async () => {
process.env.STATIC_CACHE_MAX_AGE = '7200';
// Need to re-require to pick up new env vars
jest.resetModules();
const freshStaticCache = require('../staticCache');
app.use(freshStaticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=7200, s-maxage=86400');
});
it('should use both custom values from environment', async () => {
process.env.STATIC_CACHE_S_MAX_AGE = '1800';
process.env.STATIC_CACHE_MAX_AGE = '3600';
// Need to re-require to pick up new env vars
jest.resetModules();
const freshStaticCache = require('../staticCache');
app.use(freshStaticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=3600, s-maxage=1800');
});
});
describe('express-static-gzip behavior', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve gzipped files when client accepts gzip encoding', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip, deflate')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['content-type']).toMatch(/javascript/);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
// Content should be decompressed by supertest
expect(response.text).toBe('console.log("test");');
});
it('should fall back to uncompressed files when client does not accept gzip', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'identity')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.headers['content-type']).toMatch(/javascript/);
expect(response.text).toBe('console.log("test");');
});
it('should serve gzipped CSS files with correct content-type', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.css')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['content-type']).toMatch(/css/);
expect(response.text).toBe('body { color: red; }');
});
it('should serve uncompressed files when no gzipped version exists', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/manifest.json')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.headers['content-type']).toMatch(/json/);
expect(response.text).toBe('{"name": "test"}');
});
it('should handle files that only exist in gzipped form', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/only-gzipped.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['content-type']).toMatch(/javascript/);
expect(response.text).toBe('console.log("only gzipped");');
});
it('should return 404 for gzip-only files when client does not accept gzip', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/only-gzipped.js')
.set('Accept-Encoding', 'identity');
expect(response.status).toBe(404);
});
it('should handle cache headers correctly for gzipped content', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.headers['content-type']).toMatch(/javascript/);
});
it('should preserve original MIME types for gzipped files', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const jsResponse = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
const cssResponse = await request(app)
.get('/test.css')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(jsResponse.headers['content-type']).toMatch(/javascript/);
expect(cssResponse.headers['content-type']).toMatch(/css/);
expect(jsResponse.headers['content-encoding']).toBe('gzip');
expect(cssResponse.headers['content-encoding']).toBe('gzip');
});
});
describe('skipGzipScan option comparison', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should use express.static (no gzip) when skipGzipScan is true', async () => {
app.use(staticCache(testDir, { skipGzipScan: true }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
// Should NOT serve gzipped version even though client accepts it
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.text).toBe('console.log("test");');
});
it('should use expressStaticGzip when skipGzipScan is false', async () => {
app.use(staticCache(testDir));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
// Should serve gzipped version when client accepts it
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.text).toBe('console.log("test");');
});
});
describe('file serving', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve files correctly', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.text).toBe('console.log("test");');
expect(response.headers['content-type']).toMatch(/javascript|text/);
});
it('should return 404 for non-existent files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/nonexistent.js');
expect(response.status).toBe(404);
});
it('should serve HTML files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/index.html').expect(200);
expect(response.text).toBe('<html><body>Test</body></html>');
expect(response.headers['content-type']).toMatch(/html/);
});
});
describe('edge cases', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should handle webmanifest files', async () => {
// Create a webmanifest file
const webmanifestFile = path.join(testDir, 'site.webmanifest');
fs.writeFileSync(webmanifestFile, '{"name": "test app"}');
app.use(staticCache(testDir));
const response = await request(app).get('/site.webmanifest').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
// Clean up
fs.unlinkSync(webmanifestFile);
});
it('should handle files in subdirectories', async () => {
const subDir = path.join(testDir, 'subdir');
fs.mkdirSync(subDir, { recursive: true });
const subFile = path.join(subDir, 'nested.js');
fs.writeFileSync(subFile, 'console.log("nested");');
app.use(staticCache(testDir));
const response = await request(app).get('/subdir/nested.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.text).toBe('console.log("nested");');
// Clean up
fs.rmSync(subDir, { recursive: true, force: true });
});
});
});

Some files were not shown because too many files have changed in this diff Show More