Compare commits

..

5 Commits

Author SHA1 Message Date
Rakshit Tiwari
3fb97b7c4a Fixing the naming to clientresize from clientsideresize 2025-06-24 19:33:24 +05:30
Rakshit Tiwari
0417a38a6e Merge remote-tracking branch 'origin/main' into feature/client-side-image-resize 2025-06-24 19:29:31 +05:30
Rakshit Tiwari
b6de2a8557 Addressing eslint errors 2025-06-16 16:42:11 +05:30
Rakshit Tiwari
a77ad276f1 Addressing comments from author 2025-06-15 17:21:56 +05:30
Rakshit Tiwari
8703b9c2d8 feat: Add optional client-side image resizing to prevent upload errors 2025-06-15 12:40:56 +05:30
410 changed files with 8717 additions and 20176 deletions

View File

@@ -58,7 +58,7 @@ DEBUG_CONSOLE=false
# Endpoints #
#===================================================#
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic
# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
PROXY=
@@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# Vertex AI
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
@@ -349,11 +349,6 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20
TTS_VIOLATION_SCORE=0
STT_VIOLATION_SCORE=0
FORK_VIOLATION_SCORE=0
IMPORT_VIOLATION_SCORE=0
FILE_UPLOAD_VIOLATION_SCORE=0
LOGIN_MAX=7
LOGIN_WINDOW=5
@@ -458,8 +453,8 @@ OPENID_REUSE_TOKENS=
OPENID_JWKS_URL_CACHE_ENABLED=
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
# Set to true to use the OpenID Connect end session endpoint for logout
OPENID_USE_END_SESSION_ENDPOINT=
@@ -580,10 +575,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
# If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
#===================================================#
# UI #
#===================================================#
@@ -601,31 +592,11 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS Options #
#===============#
# Enable Redis for caching and session storage
# REDIS_URI=10.10.10.10:6379
# USE_REDIS=true
# Single Redis instance
# REDIS_URI=redis://127.0.0.1:6379
# Redis cluster (multiple nodes)
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
# Redis with TLS/SSL encryption and CA certificate
# REDIS_URI=rediss://127.0.0.1:6380
# REDIS_CA=/path/to/ca-cert.pem
# Redis authentication (if required)
# REDIS_USERNAME=your_redis_username
# REDIS_PASSWORD=your_redis_password
# Redis key prefix configuration
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
# REDIS_KEY_PREFIX_VAR=K_REVISION
# Or use static prefix directly
# REDIS_KEY_PREFIX=librechat
# Redis connection limits
# REDIS_MAX_LISTENERS=40
# USE_REDIS_CLUSTER=true
# REDIS_CA=/path/to/ca.crt
#==================================================#
# Others #
@@ -686,4 +657,4 @@ OPENWEATHER_API_KEY=
# Reranker (Required)
# JINA_API_KEY=your_jina_api_key
# or
# COHERE_API_KEY=your_cohere_api_key
# COHERE_API_KEY=your_cohere_api_key

View File

@@ -1,32 +0,0 @@
name: Publish `@librechat/client` to NPM
on:
workflow_dispatch:
inputs:
reason:
description: 'Reason for manual trigger'
required: false
default: 'Manual publish requested'
jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Node.js
uses: actions/setup-node@v4
with:
node-version: '18.x'
- name: Check if client package exists
run: |
if [ -d "packages/client" ]; then
echo "Client package directory found"
else
echo "Client package directory not found - workflow ready for future use"
exit 0
fi
- name: Placeholder for future publishing
run: echo "Client package publishing workflow is ready"

9
.gitignore vendored
View File

@@ -125,12 +125,3 @@ helm/**/.values.yaml
# SAML Idp cert
*.cert
# AI Assistants
/.claude/
/.cursor/
/.copilot/
/.aider/
/.openai/
/.tabnine/
/.codeium

View File

@@ -7,8 +7,49 @@ All notable changes to this project will be documented in this file.
## [Unreleased]
- no changes
### ✨ New Features
- ✨ feat: implement search parameter updates by **@mawburn** in [#7151](https://github.com/danny-avila/LibreChat/pull/7151)
- 🎏 feat: Add MCP support for Streamable HTTP Transport by **@benverhees** in [#7353](https://github.com/danny-avila/LibreChat/pull/7353)
- 🔒 feat: Add Content Security Policy using Helmet middleware by **@rubentalstra** in [#7377](https://github.com/danny-avila/LibreChat/pull/7377)
- ✨ feat: Add Normalization for MCP Server Names by **@danny-avila** in [#7421](https://github.com/danny-avila/LibreChat/pull/7421)
- 📊 feat: Improve Helm Chart by **@hofq** in [#3638](https://github.com/danny-avila/LibreChat/pull/3638)
- 🦾 feat: Claude-4 Support by **@danny-avila** in [#7509](https://github.com/danny-avila/LibreChat/pull/7509)
- 🪨 feat: Bedrock Support for Claude-4 Reasoning by **@danny-avila** in [#7517](https://github.com/danny-avila/LibreChat/pull/7517)
### 🌍 Internationalization
- 🌍 i18n: Add `Danish` and `Czech` and `Catalan` localization support by **@rubentalstra** in [#7373](https://github.com/danny-avila/LibreChat/pull/7373)
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7375](https://github.com/danny-avila/LibreChat/pull/7375)
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7468](https://github.com/danny-avila/LibreChat/pull/7468)
### 🔧 Fixes
- 💬 fix: update aria-label for accessibility in ConvoLink component by **@berry-13** in [#7320](https://github.com/danny-avila/LibreChat/pull/7320)
- 🔑 fix: use `apiKey` instead of `openAIApiKey` in OpenAI-like Config by **@danny-avila** in [#7337](https://github.com/danny-avila/LibreChat/pull/7337)
- 🔄 fix: update navigation logic in `useFocusChatEffect` to ensure correct search parameters are used by **@mawburn** in [#7340](https://github.com/danny-avila/LibreChat/pull/7340)
- 🔄 fix: Improve MCP Connection Cleanup by **@danny-avila** in [#7400](https://github.com/danny-avila/LibreChat/pull/7400)
- 🛡️ fix: Preset and Validation Logic for URL Query Params by **@danny-avila** in [#7407](https://github.com/danny-avila/LibreChat/pull/7407)
- 🌘 fix: artifact of preview text is illegible in dark mode by **@nhtruong** in [#7405](https://github.com/danny-avila/LibreChat/pull/7405)
- 🛡️ fix: Temporarily Remove CSP until Configurable by **@danny-avila** in [#7419](https://github.com/danny-avila/LibreChat/pull/7419)
- 💽 fix: Exclude index page `/` from static cache settings by **@sbruel** in [#7382](https://github.com/danny-avila/LibreChat/pull/7382)
### ⚙️ Other Changes
- 📜 docs: CHANGELOG for release v0.7.8 by **@github-actions[bot]** in [#7290](https://github.com/danny-avila/LibreChat/pull/7290)
- 📦 chore: Update API Package Dependencies by **@danny-avila** in [#7359](https://github.com/danny-avila/LibreChat/pull/7359)
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7321](https://github.com/danny-avila/LibreChat/pull/7321)
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7434](https://github.com/danny-avila/LibreChat/pull/7434)
- 🛡️ chore: `multer` v2.0.0 for CVE-2025-47935 and CVE-2025-47944 by **@danny-avila** in [#7454](https://github.com/danny-avila/LibreChat/pull/7454)
- 📂 refactor: Improve `FileAttachment` & File Form Deletion by **@danny-avila** in [#7471](https://github.com/danny-avila/LibreChat/pull/7471)
- 📊 chore: Remove Old Helm Chart by **@hofq** in [#7512](https://github.com/danny-avila/LibreChat/pull/7512)
- 🪖 chore: bump helm app version to v0.7.8 by **@austin-barrington** in [#7524](https://github.com/danny-avila/LibreChat/pull/7524)
---
## [v0.7.8] -
Changes from v0.7.8-rc1 to v0.7.8.
@@ -50,7 +91,6 @@ Changes from v0.7.8-rc1 to v0.7.8.
---
## [v0.7.8-rc1] -
## [v0.7.8-rc1] -
Changes from v0.7.7 to v0.7.8-rc1.

View File

@@ -1,4 +1,4 @@
# v0.7.9-rc1
# v0.7.8
# Base node image
FROM node:20-alpine AS node

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi
# v0.7.9-rc1
# v0.7.8
# Base for all builds
FROM node:20-alpine AS base-min

View File

@@ -52,7 +52,7 @@
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
- 🤖 **AI Model Selection**:
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
@@ -66,9 +66,10 @@
- 🔦 **Agents & Tools Integration**:
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
- 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context

View File

@@ -13,6 +13,7 @@ const {
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { checkBalance } = require('~/models/balanceMethods');
const { truncateToolCallOutputs } = require('./prompts');
const { addSpaceIfNeeded } = require('~/server/utils');
const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream');
const { logger } = require('~/config');
@@ -108,15 +109,12 @@ class BaseClient {
/**
* Abstract method to record token usage. Subclasses must implement this method.
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* Should only be used if `recordCollectedUsage` was not used instead.
* @param {string} [model]
* @param {number} promptTokens
* @param {number} completionTokens
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens }) {
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
model,
promptTokens,
completionTokens,
});
@@ -200,10 +198,6 @@ class BaseClient {
this.currentMessages[this.currentMessages.length - 1].messageId = head;
}
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
responseMessageId = crypto.randomUUID();
}
this.responseMessageId = responseMessageId;
return {
@@ -578,7 +572,7 @@ class BaseClient {
});
}
const { editedContent } = opts;
const { generation = '' } = opts;
// It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages
@@ -593,21 +587,11 @@ class BaseClient {
isCreatedByUser: false,
model: this.modelOptions?.model ?? this.model,
sender: this.sender,
text: generation,
};
this.currentMessages.push(userMessage, latestMessage);
} else if (editedContent != null) {
// Handle editedContent for content parts
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
const { index, text, type } = editedContent;
if (index >= 0 && index < latestMessage.content.length) {
const contentPart = latestMessage.content[index];
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
contentPart[ContentTypes.THINK] = text;
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
contentPart[ContentTypes.TEXT] = text;
}
}
}
} else {
latestMessage.text = generation;
}
this.continued = true;
} else {
@@ -688,32 +672,16 @@ class BaseClient {
};
if (typeof completion === 'string') {
responseMessage.text = completion;
responseMessage.text = addSpaceIfNeeded(generation) + completion;
} else if (
Array.isArray(completion) &&
(this.clientName === EModelEndpoint.agents ||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
) {
responseMessage.text = '';
if (!opts.editedContent || this.currentMessages.length === 0) {
responseMessage.content = completion;
} else {
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
if (!latestMessage?.content) {
responseMessage.content = completion;
} else {
const existingContent = [...latestMessage.content];
const { type: editedType } = opts.editedContent;
responseMessage.content = this.mergeEditedContent(
existingContent,
completion,
editedType,
);
}
}
responseMessage.content = completion;
} else if (Array.isArray(completion)) {
responseMessage.text = completion.join('');
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
}
if (
@@ -744,13 +712,9 @@ class BaseClient {
} else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens,
model: responseMessage.model,
});
}
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
}
if (userMessagePromise) {
@@ -828,8 +792,7 @@ class BaseClient {
userMessage.tokenCount = userMessageTokenCount;
/*
Note: `AgentController` saves the user message if not saved here
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
*/
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
@@ -838,8 +801,7 @@ class BaseClient {
}
/*
Note: we update the user message to be sure it gets the calculated token count;
though `AgentController` saves the user message if not saved here
(noted by `savedMessageIds`), EditController does not
though `AskController` saves the user message, EditController does not
*/
await userMessagePromise;
await this.updateMessageInDatabase({
@@ -1131,50 +1093,6 @@ class BaseClient {
return numTokens;
}
/**
* Merges completion content with existing content when editing TEXT or THINK types
* @param {Array} existingContent - The existing content array
* @param {Array} newCompletion - The new completion content
* @param {string} editedType - The type of content being edited
* @returns {Array} The merged content array
*/
mergeEditedContent(existingContent, newCompletion, editedType) {
if (!newCompletion.length) {
return existingContent.concat(newCompletion);
}
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
return existingContent.concat(newCompletion);
}
const lastIndex = existingContent.length - 1;
const lastExisting = existingContent[lastIndex];
const firstNew = newCompletion[0];
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
return existingContent.concat(newCompletion);
}
const mergedContent = [...existingContent];
if (editedType === ContentTypes.TEXT) {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.TEXT]:
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
};
} else {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.THINK]:
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
(firstNew[ContentTypes.THINK] || ''),
};
}
// Add remaining completion items
return mergedContent.concat(newCompletion.slice(1));
}
async sendPayload(payload, opts = {}) {
if (opts && typeof opts === 'object') {
this.setOptions(opts);

View File

@@ -0,0 +1,804 @@
const { Keyv } = require('keyv');
const crypto = require('crypto');
const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
ImageDetail,
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { createContextHandlers } = require('./prompts');
const { createCoherePayload } = require('./llm');
const { extractBaseURL } = require('~/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {};
class ChatGPTClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
this.conversationsCache = new Keyv(cacheOptions);
this.setOptions(options);
}
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = {
...this.options.modelOptions,
...options.modelOptions,
};
delete options.modelOptions;
// now we can merge options
this.options = {
...this.options,
...options,
};
} else {
this.options = options;
}
if (this.options.openaiApiKey) {
this.apiKey = this.options.openaiApiKey;
}
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || CHATGPT_MODEL,
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop,
};
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
const { isChatGptModel } = this;
this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
const { isUnofficialChatGptModel } = this;
// Davinci models have a max context length of 4097 tokens.
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
// I decided to reserve 1024 tokens for the response.
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
}
this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
if (isChatGptModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isUnofficialChatGptModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
});
} else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead.
this.startToken = '||>';
this.endToken = '';
try {
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch {
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
}
}
if (!this.modelOptions.stop) {
const stopTokens = [this.startToken];
if (this.endToken && this.endToken !== this.startToken) {
stopTokens.push(this.endToken);
}
stopTokens.push(`\n${this.userLabel}:`);
stopTokens.push('<|diff_marker|>');
// I chose not to do one for `chatGptLabel` because I've never seen it happen
this.modelOptions.stop = stopTokens;
}
if (this.options.reverseProxyUrl) {
this.completionsUrl = this.options.reverseProxyUrl;
} else if (isChatGptModel) {
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
} else {
this.completionsUrl = 'https://api.openai.com/v1/completions';
}
return this;
}
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
}
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
let modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = input;
} else {
modelOptions.prompt = input;
}
if (this.useOpenRouter && modelOptions.prompt) {
delete modelOptions.stop;
}
const { debug } = this.options;
let baseURL = this.completionsUrl;
if (debug) {
console.debug();
console.debug(baseURL);
console.debug(modelOptions);
console.debug();
}
const opts = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
};
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}
/** @type {TAzureConfig | undefined} */
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
const isAzure = this.azure || this.options.azure;
if (
(isAzure && this.isVisionModel && azureConfig) ||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName: modelOptions.model,
modelGroupMap,
groupMap,
});
opts.headers = resolveHeaders(headers);
this.langchainProxy = extractBaseURL(baseURL);
this.apiKey = azureOptions.azureOpenAIApiKey;
const groupName = modelGroupMap[modelOptions.model].group;
this.options.addParams = azureConfig.groupMap[groupName].addParams;
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
// Note: `forcePrompt` not re-assigned as only chat models are vision models
this.azure = !serverless && azureOptions;
this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
}
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
}
if (this.options.headers) {
opts.headers = { ...opts.headers, ...this.options.headers };
}
if (isAzure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
baseURL = this.langchainProxy
? constructAzureURL({
baseURL: this.langchainProxy,
azureOptions: this.azure,
})
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
if (this.options.forcePrompt) {
baseURL += '/completions';
} else {
baseURL += '/chat/completions';
}
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
} else if (this.apiKey) {
opts.headers.Authorization = `Bearer ${this.apiKey}`;
}
if (process.env.OPENAI_ORGANIZATION) {
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
}
if (this.useOpenRouter) {
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
opts.headers['X-Title'] = 'LibreChat';
}
/* hacky fixes for Mistral AI API:
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
- If there is only one message and it's a system message, change the role to user
*/
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
const { messages } = modelOptions;
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
if (systemMessageIndex > 0) {
const [systemMessage] = messages.splice(systemMessageIndex, 1);
messages.unshift(systemMessage);
}
modelOptions.messages = messages;
if (messages.length === 1 && messages[0].role === 'system') {
modelOptions.messages[0].role = 'user';
}
}
if (this.options.addParams && typeof this.options.addParams === 'object') {
modelOptions = {
...modelOptions,
...this.options.addParams,
};
logger.debug('[ChatGPTClient] chatCompletion: added params', {
addParams: this.options.addParams,
modelOptions,
});
}
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => {
delete modelOptions[param];
});
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
dropParams: this.options.dropParams,
modelOptions,
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
baseURL.includes('v1') &&
!baseURL.includes('/chat/completions') &&
this.isChatCompletion
) {
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
}
const BASE_URL = new URL(baseURL);
if (opts.defaultQuery) {
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
BASE_URL.searchParams.append(key, value);
});
delete opts.defaultQuery;
}
const completionsURL = BASE_URL.toString();
opts.body = JSON.stringify(modelOptions);
if (modelOptions.stream) {
return new Promise(async (resolve, reject) => {
try {
let done = false;
await fetchEventSource(completionsURL, {
...opts,
signal: abortController.signal,
async onopen(response) {
if (response.status === 200) {
return;
}
if (debug) {
console.debug(response);
}
let error;
try {
const body = await response.text();
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
error.json = JSON.parse(body);
} catch {
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
}
throw error;
},
onclose() {
if (debug) {
console.debug('Server closed the connection unexpectedly, returning...');
}
// workaround for private API not sending [DONE] event
if (!done) {
onProgress('[DONE]');
resolve();
}
},
onerror(err) {
if (debug) {
console.debug(err);
}
// rethrow to stop the operation
throw err;
},
onmessage(message) {
if (debug) {
console.debug(message);
}
if (!message.data || message.event === 'ping') {
return;
}
if (message.data === '[DONE]') {
onProgress('[DONE]');
resolve();
done = true;
return;
}
onProgress(JSON.parse(message.data));
},
});
} catch (err) {
reject(err);
}
});
}
const response = await fetch(completionsURL, {
...opts,
signal: abortController.signal,
});
if (response.status !== 200) {
const body = await response.text();
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
try {
error.json = JSON.parse(body);
} catch {
error.body = body;
}
throw error;
}
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
reply += message.text;
}
/*
Cohere API Chinese Unicode character replacement hotfix.
Should be un-commented when the following issue is resolved:
https://github.com/cohere-ai/cohere-typescript/issues/151
else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
*/
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
||>Message:
${userMessage.message}
||>Response:
${botMessage.message}
||>Title:`,
};
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
titleGenClientOptions.modelOptions = {
model: 'gpt-3.5-turbo',
temperature: 0,
presence_penalty: 0,
frequency_penalty: 0,
};
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
const result = await titleGenClient.getCompletion([instructionsPayload], null);
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
return result.choices[0].message.content
.replace(/[^a-zA-Z0-9' ]/g, '')
.replace(/\s+/g, ' ')
.trim();
}
async sendMessage(message, opts = {}) {
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
this.setOptions(opts.clientOptions);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
let conversation =
typeof opts.conversation === 'object'
? opts.conversation
: await this.conversationsCache.get(conversationId);
let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
isNewConversation = true;
}
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const userMessage = {
id: crypto.randomUUID(),
parentMessageId,
role: 'User',
message,
};
conversation.messages.push(userMessage);
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
const { prompt: payload, context } = await this.buildPrompt(
conversation.messages,
userMessage.id,
{
isChatGptModel: this.isChatGptModel,
promptPrefix: opts.promptPrefix,
},
);
if (this.options.keepNecessaryMessagesOnly) {
conversation.messages = context;
}
let reply = '';
let result = null;
if (typeof opts.onProgress === 'function') {
await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
return;
}
const token = this.isChatGptModel
? progressMessage.choices[0].delta.content
: progressMessage.choices[0].text;
// first event's delta content is always undefined
if (!token) {
return;
}
if (this.options.debug) {
console.debug(token);
}
if (token === this.endToken) {
return;
}
opts.onProgress(token);
reply += token;
},
opts.abortController || new AbortController(),
);
} else {
result = await this.getCompletion(
payload,
null,
opts.abortController || new AbortController(),
);
if (this.options.debug) {
console.debug(JSON.stringify(result));
}
if (this.isChatGptModel) {
reply = result.choices[0].message.content;
} else {
reply = result.choices[0].text.replace(this.endToken, '');
}
}
// avoids some rendering issues when using the CLI app
if (this.options.debug) {
console.debug();
}
reply = reply.trim();
const replyMessage = {
id: crypto.randomUUID(),
parentMessageId: userMessage.id,
role: 'ChatGPT',
message: reply,
};
conversation.messages.push(replyMessage);
const returnData = {
response: replyMessage.message,
conversationId,
parentMessageId: replyMessage.parentMessageId,
messageId: replyMessage.id,
details: result || {},
};
if (shouldGenerateTitle) {
conversation.title = await this.generateTitle(userMessage, replyMessage);
returnData.title = conversation.title;
}
await this.conversationsCache.set(conversationId, conversation);
if (this.options.returnConversation) {
returnData.conversation = conversation;
}
return returnData;
}
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
// Handle attachments and create augmentedPrompt
if (this.options.attachments) {
const attachments = await this.options.attachments;
const lastMessage = messages[messages.length - 1];
if (this.message_file_map) {
this.message_file_map[lastMessage.messageId] = attachments;
} else {
this.message_file_map = {
[lastMessage.messageId]: attachments,
};
}
const files = await this.addImageURLs(lastMessage, attachments);
this.options.attachments = files;
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
messages[messages.length - 1].text,
);
}
// Calculate image token cost and process embedded files
messages.forEach((message, i) => {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
messages[i].tokenCount =
(messages[i].tokenCount || 0) +
this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
}
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
let currentTokenCount;
if (isChatGptModel) {
currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
} else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
}
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
const context = [];
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && messages.length > 0) {
const message = messages.pop();
const roleLabel =
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
? this.userLabel
: this.chatGptLabel;
const messageString = `${this.startToken}${roleLabel}:\n${
message?.text ?? message?.message
}${this.endToken}\n`;
let newPromptBody;
if (promptBody || isChatGptModel) {
newPromptBody = `${messageString}${promptBody}`;
} else {
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
// like "what's the last thing I wrote?".
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}
context.unshift(message);
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = `${promptBody}${promptSuffix}`;
if (isChatGptModel) {
messagePayload.content = prompt;
// Add 3 tokens for Assistant Label priming after all messages have been counted.
currentTokenCount += 3;
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (isChatGptModel) {
return { prompt: [instructionsPayload, messagePayload], context };
}
return { prompt, context, promptTokens: currentTokenCount };
}
getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length;
}
/**
* Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
*
* @param {Object} message
*/
getTokenCountForMessage(message) {
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let tokensPerMessage = 3;
let tokensPerName = 1;
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
tokensPerMessage = 4;
tokensPerName = -1;
}
let numTokens = tokensPerMessage;
for (let [key, value] of Object.entries(message)) {
numTokens += this.getTokenCount(value);
if (key === 'name') {
numTokens += tokensPerName;
}
}
return numTokens;
}
}
module.exports = ChatGPTClient;

View File

@@ -1,7 +1,7 @@
const { google } = require('googleapis');
const { Tokenizer } = require('@librechat/api');
const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
@@ -12,13 +12,13 @@ const {
endpointSettings,
parseTextParts,
EModelEndpoint,
googleSettings,
ContentTypes,
VisionModes,
ErrorTypes,
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
@@ -166,16 +166,6 @@ class GoogleClient extends BaseClient {
);
}
// Add thinking configuration
this.modelOptions.thinkingConfig = {
thinkingBudget:
(this.modelOptions.thinking ?? googleSettings.thinking.default)
? this.modelOptions.thinkingBudget
: 0,
};
delete this.modelOptions.thinking;
delete this.modelOptions.thinkingBudget;
this.sender =
this.options.sender ??
getResponseSender({

View File

@@ -5,7 +5,6 @@ const {
isEnabled,
Tokenizer,
createFetch,
resolveHeaders,
constructAzureURL,
genAzureChatCompletion,
createStreamEventHandlers,
@@ -16,6 +15,7 @@ const {
ContentTypes,
parseTextParts,
EModelEndpoint,
resolveHeaders,
KnownEndpoints,
openAISettings,
ImageDetailCost,
@@ -37,6 +37,7 @@ const { addSpaceIfNeeded, sleep } = require('~/server/utils');
const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
@@ -46,6 +47,12 @@ const { logger } = require('~/config');
class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
@@ -372,12 +379,23 @@ class OpenAIClient extends BaseClient {
return files;
}
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
async buildMessages(
messages,
parentMessageId,
{ isChatCompletion = false, promptPrefix = null },
opts,
) {
let orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
summary: this.shouldSummarize,
});
if (!isChatCompletion) {
return await this.buildPrompt(orderedMessages, {
isChatGptModel: isChatCompletion,
promptPrefix,
});
}
let payload;
let instructions;

View File

@@ -0,0 +1,542 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { processFileURL } = require('~/server/services/Files/process');
const { EModelEndpoint } = require('librechat-data-provider');
const { checkBalance } = require('~/models/balanceMethods');
const { formatLangChainMessages } = require('./prompts');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { logger } = require('~/config');
class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.sender = options.sender ?? 'Assistant';
this.tools = [];
this.actions = [];
this.setOptions(options);
this.openAIApiKey = this.apiKey;
this.executor = null;
}
setOptions(options) {
this.agentOptions = { ...options.agentOptions };
this.functionsAgent = this.agentOptions?.agent === 'functions';
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
super.setOptions(options);
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) {
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
}
}
getSaveOptions() {
return {
artifacts: this.options.artifacts,
chatGptLabel: this.options.chatGptLabel,
modelLabel: this.options.modelLabel,
promptPrefix: this.options.promptPrefix,
tools: this.options.tools,
...this.modelOptions,
agentOptions: this.agentOptions,
iconURL: this.options.iconURL,
greeting: this.options.greeting,
spec: this.options.spec,
};
}
saveLatestAction(action) {
this.actions.push(action);
}
getFunctionModelName(input) {
if (/-(?!0314)\d{4}/.test(input)) {
return input;
} else if (input.includes('gpt-3.5-turbo')) {
return 'gpt-3.5-turbo';
} else if (input.includes('gpt-4')) {
return 'gpt-4';
} else {
return 'gpt-3.5-turbo';
}
}
getBuildMessagesOptions(opts) {
return {
isChatCompletion: true,
promptPrefix: opts.promptPrefix,
abortController: opts.abortController,
};
}
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = {
modelName: this.agentOptions.model,
temperature: this.agentOptions.temperature,
};
const model = this.initializeLLM({
...modelOptions,
context: 'plugins',
initialMessageCount: this.currentMessages.length + 1,
});
logger.debug(
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
);
// Map Messages to Langchain format
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
userName: this.options?.name,
});
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
const memory = new BufferMemory({
llm: model,
chatHistory: new ChatMessageHistory(pastMessages),
});
const { loadedTools } = await loadTools({
user,
model,
tools: this.options.tools,
functions: this.functionsAgent,
options: {
memory,
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
fileStrategy: this.options.req.app.locals.fileStrategy,
processFileURL,
message,
},
useSpecs: true,
});
if (loadedTools.length === 0) {
return;
}
this.tools = loadedTools;
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
logger.debug(
'[PluginsClient] Loaded Tools',
this.tools.map((tool) => tool.name),
);
const handleAction = (action, runId, callback = null) => {
this.saveLatestAction(action);
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
if (typeof callback === 'function') {
callback(action, runId);
}
};
// initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
let customInstructions = (this.options.promptPrefix ?? '').trim();
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
}
this.executor = await initializer({
model,
signal,
pastMessages,
tools: this.tools,
customInstructions,
verbose: this.options.debug,
returnIntermediateSteps: true,
customName: this.options.chatGptLabel,
currentDateString: this.currentDateString,
callbackManager: CallbackManager.fromHandlers({
async handleAgentAction(action, runId) {
handleAction(action, runId, onAgentAction);
},
async handleChainEnd(action) {
if (typeof onChainEnd === 'function') {
onChainEnd(action);
}
},
}),
});
logger.debug('[PluginsClient] Loaded agent.');
}
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
let errorMessage = '';
const maxAttempts = 1;
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
const errorInput = buildErrorInput({
message,
errorMessage,
actions: this.actions,
functionsAgent: this.functionsAgent,
});
const input = attempts > 1 ? errorInput : message;
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
if (errorMessage.length > 0) {
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
}
try {
this.result = await this.executor.call({ input, signal }, [
{
async handleToolStart(...args) {
await onToolStart(...args);
},
async handleToolEnd(...args) {
await onToolEnd(...args);
},
async handleLLMEnd(output) {
const { generations } = output;
const { text } = generations[0][0];
if (text && typeof stream === 'function') {
await stream(text);
}
},
},
]);
break; // Exit the loop if the function call is successful
} catch (err) {
logger.error('[PluginsClient] executorCall error:', err);
if (attempts === maxAttempts) {
const { run } = this.runManager.getRunByConversationId(this.conversationId);
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
this.result.output = run && run.error ? run.error : defaultOutput;
this.result.errorMessage = run && run.error ? run.error : err.message;
this.result.intermediateSteps = this.actions;
break;
}
}
}
}
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
output,
errorMessage,
...result,
});
const { error } = responseMessage;
if (!error) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
}
// Record usage only when completion is skipped as it is already recorded in the agent phase.
if (!this.agentOptions.skipCompletion && !error) {
await this.recordTokenUsage(responseMessage);
}
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result, databasePromise };
}
async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);
}
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
} = await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
this.currentMessages.push(userMessage);
let {
prompt: payload,
tokenCountMap,
promptTokens,
} = await this.buildMessages(
this.currentMessages,
userMessage.messageId,
this.getBuildMessagesOptions({
promptPrefix: null,
abortController: this.abortController,
}),
);
if (tokenCountMap) {
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
}
this.handleTokenCountMap(tokenCountMap);
}
this.result = {};
if (payload) {
this.currentMessages = payload;
}
if (!this.skipSaveUserMessage) {
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise,
});
}
}
const balance = this.options.req?.app?.locals?.balance;
if (balance?.enabled) {
await checkBalance({
req: this.options.req,
res: this.options.res,
txData: {
user: this.user,
tokenType: 'prompt',
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: EModelEndpoint.openAI,
},
});
}
const responseMessage = {
endpoint: EModelEndpoint.gptPlugins,
iconURL: this.options.iconURL,
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
isCreatedByUser: false,
model: this.modelOptions.model,
sender: this.sender,
promptTokens,
};
await this.initialize({
user,
message,
onAgentAction,
onChainEnd,
signal: this.abortController.signal,
onProgress: opts.onProgress,
});
// const stream = async (text) => {
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
// };
await this.executorCall(message, {
signal: this.abortController.signal,
// stream,
onToolStart,
onToolEnd,
});
// If message was aborted mid-generation
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
responseMessage.text = 'Cancelled.';
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
// If error occurred during generation (likely token_balance)
if (this.result?.errorMessage?.length > 0) {
responseMessage.error = true;
responseMessage.text = this.result.output;
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
const partialText = opts.getPartialText();
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
responseMessage.text =
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output) {
responseMessage.text = this.result.output;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
const promptPrefix = buildPromptPrefix({
result: this.result,
message,
functionsAgent: this.functionsAgent,
});
logger.debug('[PluginsClient]', { promptPrefix });
payload = await this.buildCompletionPrompt({
messages: this.currentMessages,
promptPrefix,
});
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
responseMessage.text = await this.sendCompletion(payload, opts);
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
const orderedMessages = messages;
let promptPrefix = _promptPrefix.trim();
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
if (this.isGpt3) {
instructionsPayload.role = 'user';
messagePayload.role = 'user';
instructionsPayload.content += `\n${promptSuffix}`;
}
// testing if this works with browser endpoint
if (!this.isGpt3 && this.options.reverseProxyUrl) {
instructionsPayload.role = 'user';
}
let currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop();
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${
message.text ?? message.content ?? ''
}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setTimeout(resolve, 0));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = promptBody;
messagePayload.content = prompt;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
if (this.isGpt3 && messagePayload.content.length > 0) {
const context = 'Chat History:\n';
messagePayload.content = `${context}${prompt}`;
currentTokenCount += this.getTokenCount(context);
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (this.isGpt3) {
messagePayload.content += promptSuffix;
return [instructionsPayload, messagePayload];
}
const result = [messagePayload, instructionsPayload];
if (this.functionsAgent && !this.isGpt3) {
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
}
return result.filter((message) => message.content.length > 0);
}
}
module.exports = PluginsClient;

View File

@@ -1,11 +1,15 @@
const ChatGPTClient = require('./ChatGPTClient');
const OpenAIClient = require('./OpenAIClient');
const PluginsClient = require('./PluginsClient');
const GoogleClient = require('./GoogleClient');
const TextStream = require('./TextStream');
const AnthropicClient = require('./AnthropicClient');
const toolUtils = require('./tools/util');
module.exports = {
ChatGPTClient,
OpenAIClient,
PluginsClient,
GoogleClient,
TextStream,
AnthropicClient,

View File

@@ -1,7 +1,6 @@
const axios = require('axios');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const footer = `Use the context as your learned knowledge to better answer the user.
@@ -19,7 +18,7 @@ function createContextHandlers(req, userMessageContent) {
const queryPromises = [];
const processedFiles = [];
const processedIds = new Set();
const jwtToken = generateShortLivedToken(req.user.id);
const jwtToken = req.headers.authorization.split(' ')[1];
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
const query = async (file) => {
@@ -97,35 +96,35 @@ function createContextHandlers(req, userMessageContent) {
resolvedQueries.length === 0
? '\n\tThe semantic search did not return any results.'
: resolvedQueries
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
const generateContext = (currentContext) =>
`
const generateContext = (currentContext) =>
`
<file>
<filename>${file.filename}</filename>
<context>${currentContext}
</context>
</file>`;
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
<contextItem>
<![CDATA[${pageContent?.trim()}]]>
</contextItem>`;
})
.join('');
})
.join('');
return generateContext(contextItems);
})
.join('');
return generateContext(contextItems);
})
.join('');
if (useFullContext) {
const prompt = `${header}

View File

@@ -237,9 +237,41 @@ const formatAgentMessages = (payload) => {
return messages;
};
/**
* Formats an array of messages for LangChain, making sure all content fields are strings
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
*/
const formatContentStrings = (payload) => {
const messages = [];
for (const message of payload) {
if (typeof message.content === 'string') {
continue;
}
if (!Array.isArray(message.content)) {
continue;
}
// Reduce text types to a single string, ignore all other types
const content = message.content.reduce((acc, curr) => {
if (curr.type === ContentTypes.TEXT) {
return `${acc}${curr[ContentTypes.TEXT]}\n`;
}
return acc;
}, '');
message.content = content.trim();
}
return messages;
};
module.exports = {
formatMessage,
formatFromLangChain,
formatAgentMessages,
formatContentStrings,
formatLangChainMessages,
};

View File

@@ -422,46 +422,6 @@ describe('BaseClient', () => {
expect(response).toEqual(expectedResult);
});
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
const mockCrypto = require('crypto');
const newUUID = 'new-uuid-1234';
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe(newUUID);
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
mockCrypto.randomUUID.mockRestore();
});
test('should not replace responseMessageId when isRegenerate is false', async () => {
const opts = {
isRegenerate: false,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id_');
});
test('should not replace responseMessageId when it does not end with underscore', async () => {
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id');
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {

View File

@@ -531,6 +531,44 @@ describe('OpenAIClient', () => {
});
});
describe('sendMessage/getCompletion/chatCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
const model = 'text-davinci-003';
const onProgress = jest.fn().mockImplementation(() => ({}));
const testClient = new OpenAIClient('test-api-key', {
...defaultOptions,
modelOptions: { model },
});
const getCompletion = jest.spyOn(testClient, 'getCompletion');
await testClient.sendMessage('Hi mom!', { onProgress });
expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);
// Check if the first argument (url) is correct
const firstCallArgs = fetchEventSource.mock.calls[0];
const expectedURL = 'https://api.openai.com/v1/completions';
expect(firstCallArgs[0]).toBe(expectedURL);
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).toHaveProperty('model');
expect(requestBody.model).toBe(model);
});
});
describe('checkVisionRequest functionality', () => {
let client;
const attachments = [{ type: 'image/png' }];

View File

@@ -0,0 +1,314 @@
const crypto = require('crypto');
const { Constants } = require('librechat-data-provider');
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
const PluginsClient = require('../PluginsClient');
jest.mock('~/db/connect');
jest.mock('~/models/Conversation', () => {
return function () {
return {
save: jest.fn(),
deleteConvos: jest.fn(),
};
};
});
const defaultAzureOptions = {
azureOpenAIApiInstanceName: 'your-instance-name',
azureOpenAIApiDeploymentName: 'your-deployment-name',
azureOpenAIApiVersion: '2020-07-01-preview',
};
describe('PluginsClient', () => {
let TestAgent;
let options = {
tools: [],
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
};
let parentMessageId;
let conversationId;
const fakeMessages = [];
const userMessage = 'Hello, ChatGPT!';
const apiKey = 'fake-api-key';
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, options);
TestAgent.loadHistory = jest
.fn()
.mockImplementation((conversationId, parentMessageId = null) => {
if (!conversationId) {
TestAgent.currentMessages = [];
return Promise.resolve([]);
}
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
messages: fakeMessages,
parentMessageId,
});
const chatMessages = orderedMessages.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanMessage(msg.text)
: new AIMessage(msg.text),
);
TestAgent.currentMessages = orderedMessages;
return Promise.resolve(chatMessages);
});
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
if (opts && typeof opts === 'object') {
TestAgent.setOptions(opts);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
this.pastMessages = await TestAgent.loadHistory(
conversationId,
TestAgent.options?.parentMessageId,
);
const userMessage = {
text: message,
sender: 'ChatGPT',
isCreatedByUser: true,
messageId: userMessageId,
parentMessageId,
conversationId,
};
const response = {
sender: 'ChatGPT',
text: 'Hello, User!',
isCreatedByUser: false,
messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId,
conversationId,
};
fakeMessages.push(userMessage);
fakeMessages.push(response);
return response;
});
});
test('initializes PluginsClient without crashing', () => {
expect(TestAgent).toBeInstanceOf(PluginsClient);
});
test('check setOptions function', () => {
expect(TestAgent.agentIsGpt3).toBe(true);
});
describe('sendMessage', () => {
test('sendMessage should return a response message', async () => {
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: expect.any(String),
});
const response = await TestAgent.sendMessage(userMessage);
parentMessageId = response.messageId;
conversationId = response.conversationId;
expect(response).toEqual(expectedResult);
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {
conversationId,
parentMessageId,
};
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: opts.conversationId,
});
const response = await TestAgent.sendMessage(userMessage, opts);
parentMessageId = response.messageId;
expect(response.conversationId).toEqual(conversationId);
expect(response).toEqual(expectedResult);
});
test('should return chat history', async () => {
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
expect(TestAgent.currentMessages).toHaveLength(4);
expect(chatMessages[0].text).toEqual(userMessage);
});
});
describe('getFunctionModelName', () => {
let client;
beforeEach(() => {
client = new PluginsClient('dummy_api_key');
});
test('should return the input when it includes a dash followed by four digits', () => {
expect(client.getFunctionModelName('-1234')).toBe('-1234');
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
});
test('should return the input for all function-capable models (`0613` models and above)', () => {
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
});
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-4" when the input includes "gpt-4"', () => {
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
});
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
});
});
describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins
// let client;
// beforeEach(() => {
// client = new PluginsClient('dummy_api_key');
// });
test('should not call getFunctionModelName when azure options are set', () => {
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
const model = 'gpt-4-turbo';
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
const testClient = new PluginsClient('dummy_api_key', {
agentOptions: {
model,
agent: 'functions',
},
azure: defaultAzureOptions,
});
expect(spy).not.toHaveBeenCalled();
expect(testClient.agentOptions.model).toBe(model);
spy.mockRestore();
});
});
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
});

View File

@@ -107,12 +107,6 @@ const getImageEditPromptDescription = () => {
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
};
function createAbortHandler() {
return function () {
logger.debug('[ImageGenOAI] Image generation aborted');
};
}
/**
* Creates OpenAI Image tools (generation and editing)
* @param {Object} fields - Configuration fields
@@ -207,18 +201,10 @@ function createOpenAIImageTools(fields = {}) {
}
let resp;
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
resp = await openai.images.generate(
{
model: 'gpt-image-1',
@@ -242,10 +228,6 @@ function createOpenAIImageTools(fields = {}) {
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
Error Message: ${error.message}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
if (!resp) {
@@ -427,17 +409,10 @@ Error Message: ${error.message}`);
headers['Authorization'] = `Bearer ${apiKey}`;
}
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
/** @type {import('axios').AxiosRequestConfig} */
const axiosConfig = {
@@ -492,10 +467,6 @@ Error Message: ${error.message}`);
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
Error Message: ${error.message || 'Unknown error'}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
},
{

View File

@@ -1,35 +1,26 @@
const { z } = require('zod');
const axios = require('axios');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Tools, EToolResources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/**
*
* @param {Object} options
* @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @returns {Promise<{
* files: Array<{ file_id: string; filename: string }>,
* toolContext: string
* }>}
*/
const primeFiles = async (options) => {
const { tool_resources, req, agentId } = options;
const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
const dbFiles = (
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
@@ -68,7 +59,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.';
}
const jwtToken = generateShortLivedToken(req.user.id);
const jwtToken = req.headers.authorization.split(' ')[1];
if (!jwtToken) {
return 'There was an error authenticating the file search request.';
}

View File

@@ -1,9 +1,14 @@
const { mcpToolPattern } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
const {
Tools,
EToolResources,
loadWebSearchAuth,
replaceSpecialVars,
} = require('librechat-data-provider');
const {
availableTools,
manifestToolMap,
@@ -240,13 +245,7 @@ const loadTools = async ({
authFields: [EnvVar.CODE_API_KEY],
});
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
const { files, toolContext } = await primeCodeFiles(
{
...options,
agentId: agent?.id,
},
codeApiKey,
);
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -261,10 +260,7 @@ const loadTools = async ({
continue;
} else if (tool === Tools.file_search) {
requestedTools[tool] = async () => {
const { files, toolContext } = await primeSearchFiles({
...options,
agentId: agent?.id,
});
const { files, toolContext } = await primeSearchFiles(options);
if (toolContext) {
toolContextMap[tool] = toolContext;
}

View File

@@ -1,8 +1,7 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, math } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, math, removePorts } = require('~/server/utils');
const { deleteAllUserSessions } = require('~/models');
const { removePorts } = require('~/server/utils');
const getLogStores = require('./getLogStores');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};

View File

@@ -1,33 +0,0 @@
const fs = require('fs');
const { math, isEnabled } = require('@librechat/api');
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
}
const USE_REDIS = isEnabled(process.env.USE_REDIS);
if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
const cacheConfig = {
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
CI: isEnabled(process.env.CI),
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
};
module.exports = { cacheConfig };

View File

@@ -1,108 +0,0 @@
const fs = require('fs');
describe('cacheConfig', () => {
let originalEnv;
let originalReadFileSync;
beforeEach(() => {
originalEnv = { ...process.env };
originalReadFileSync = fs.readFileSync;
// Clear all related env vars first
delete process.env.REDIS_URI;
delete process.env.REDIS_CA;
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
// Clear require cache
jest.resetModules();
});
afterEach(() => {
process.env = originalEnv;
fs.readFileSync = originalReadFileSync;
jest.resetModules();
});
describe('REDIS_KEY_PREFIX validation and resolution', () => {
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
expect(() => {
require('./cacheConfig');
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
});
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.DEPLOYMENT_ID = 'test-deployment-123';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
});
test('should use direct REDIS_KEY_PREFIX value', () => {
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
});
test('should default to empty string when no prefix is configured', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle empty variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
process.env.EMPTY_VAR = '';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle undefined variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
});
describe('USE_REDIS and REDIS_URI validation', () => {
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
process.env.USE_REDIS = 'true';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
expect(() => {
require('./cacheConfig');
}).not.toThrow();
});
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = '';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_CA).toBeNull();
});
});
});

View File

@@ -1,66 +0,0 @@
const KeyvRedis = require('@keyv/redis').default;
const { Keyv } = require('keyv');
const { cacheConfig } = require('./cacheConfig');
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
const { Time } = require('librechat-data-provider');
const { RedisStore: ConnectRedis } = require('connect-redis');
const MemoryStore = require('memorystore')(require('express-session'));
const { violationFile } = require('./keyvFiles');
const { RedisStore } = require('rate-limit-redis');
/**
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
* @param {string} namespace - The cache namespace.
* @param {number} [ttl] - Time to live for cache entries.
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
* @returns {Keyv} Cache instance.
*/
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
if (cacheConfig.USE_REDIS) {
const keyvRedis = new KeyvRedis(keyvRedisClient);
const cache = new Keyv(keyvRedis, { namespace, ttl });
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
return cache;
}
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
return new Keyv({ namespace, ttl });
};
/**
* Creates a cache instance for storing violation data.
* Uses a file-based fallback store if Redis is not enabled.
* @param {string} namespace - The cache namespace for violations.
* @param {number} [ttl] - Time to live for cache entries.
* @returns {Keyv} Cache instance for violations.
*/
const violationCache = (namespace, ttl = undefined) => {
return standardCache(`violations:${namespace}`, ttl, violationFile);
};
/**
* Creates a session cache instance using Redis or in-memory store.
* @param {string} namespace - The session namespace.
* @param {number} [ttl] - Time to live for session entries.
* @returns {MemoryStore | ConnectRedis} Session store instance.
*/
const sessionCache = (namespace, ttl = undefined) => {
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
return new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
};
/**
* Creates a rate limiter cache using Redis.
* @param {string} prefix - The key prefix for rate limiting.
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
*/
const limiterCache = (prefix) => {
if (!prefix) throw new Error('prefix is required');
if (!cacheConfig.USE_REDIS) return undefined;
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
return new RedisStore({ sendCommand, prefix });
};
const sendCommand = (...args) => ioredisClient?.call(...args);
module.exports = { standardCache, sessionCache, violationCache, limiterCache };

View File

@@ -1,270 +0,0 @@
const { Time } = require('librechat-data-provider');
// Mock dependencies first
const mockKeyvRedis = {
namespace: '',
keyPrefixSeparator: '',
};
const mockKeyv = jest.fn().mockReturnValue({ mock: 'keyv' });
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
const mockIoredisClient = {
call: jest.fn(),
};
const mockKeyvRedisClient = {};
const mockViolationFile = {};
// Mock modules before requiring the main module
jest.mock('@keyv/redis', () => ({
default: jest.fn().mockImplementation(() => mockKeyvRedis),
}));
jest.mock('keyv', () => ({
Keyv: mockKeyv,
}));
jest.mock('./cacheConfig', () => ({
cacheConfig: {
USE_REDIS: false,
REDIS_KEY_PREFIX: 'test',
},
}));
jest.mock('./redisClients', () => ({
keyvRedisClient: mockKeyvRedisClient,
ioredisClient: mockIoredisClient,
GLOBAL_PREFIX_SEPARATOR: '::',
}));
jest.mock('./keyvFiles', () => ({
violationFile: mockViolationFile,
}));
jest.mock('connect-redis', () => ({ RedisStore: mockConnectRedis }));
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
jest.mock('rate-limit-redis', () => ({
RedisStore: mockRedisStore,
}));
// Import after mocking
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
const { cacheConfig } = require('./cacheConfig');
describe('cacheFactory', () => {
beforeEach(() => {
jest.clearAllMocks();
// Reset cache config mock
cacheConfig.USE_REDIS = false;
cacheConfig.REDIS_KEY_PREFIX = 'test';
});
describe('redisCache', () => {
it('should create Redis cache when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
});
it('should create Redis cache with undefined ttl when not provided', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
standardCache(namespace);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
});
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
const fallbackStore = { some: 'store' };
standardCache(namespace, ttl, fallbackStore);
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
});
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
});
it('should handle namespace and ttl as undefined', () => {
cacheConfig.USE_REDIS = false;
standardCache();
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
});
});
describe('violationCache', () => {
it('should create violation cache with prefixed namespace', () => {
const namespace = 'test-violations';
const ttl = 7200;
// We can't easily mock the internal redisCache call since it's in the same module
// But we can test that the function executes without throwing
expect(() => violationCache(namespace, ttl)).not.toThrow();
});
it('should create violation cache with undefined ttl', () => {
const namespace = 'test-violations';
violationCache(namespace);
// The function should call redisCache with violations: prefixed namespace
// Since we can't easily mock the internal redisCache call, we test the behavior
expect(() => violationCache(namespace)).not.toThrow();
});
it('should handle undefined namespace', () => {
expect(() => violationCache(undefined)).not.toThrow();
});
});
describe('sessionCache', () => {
it('should return MemoryStore when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
expect(result).toBe(mockMemoryStore());
});
it('should return ConnectRedis when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl,
prefix: `${namespace}:`,
});
expect(result).toBe(mockConnectRedis());
});
it('should add colon to namespace if not present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should not add colon to namespace if already present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions:';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should handle undefined ttl', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockMemoryStore).toHaveBeenCalledWith({
ttl: undefined,
checkPeriod: Time.ONE_DAY,
});
});
});
describe('limiterCache', () => {
it('should return undefined when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const result = limiterCache('prefix');
expect(result).toBeUndefined();
});
it('should return RedisStore when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const result = limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: `rate-limit:`,
});
expect(result).toBe(mockRedisStore());
});
it('should add colon to prefix if not present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should not add colon to prefix if already present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit:');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should pass sendCommand function that calls ioredisClient.call', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
const sendCommandCall = mockRedisStore.mock.calls[0][0];
const sendCommand = sendCommandCall.sendCommand;
// Test that sendCommand properly delegates to ioredisClient.call
const args = ['GET', 'test-key'];
sendCommand(...args);
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
});
it('should handle undefined prefix', () => {
cacheConfig.USE_REDIS = true;
expect(() => limiterCache()).toThrow('prefix is required');
});
});
});

View File

@@ -1,52 +1,113 @@
const { cacheConfig } = require('./cacheConfig');
const { Keyv } = require('keyv');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile } = require('./keyvFiles');
const { logFile, violationFile } = require('./keyvFiles');
const { isEnabled, math } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo');
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const isRedisEnabled = isEnabled(USE_REDIS);
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
const createViolationInstance = (namespace) => {
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};
// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
const config = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const roles = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const mcpTools = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MCP_TOOLS });
const audioRuns = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const messages = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
const flows = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
const tokenConfig = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const s3ExpiryInterval = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
const abortKeys = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
const openIdExchangedTokensCache = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.OPENID_EXCHANGED_TOKENS, ttl: Time.TEN_MINUTES });
const namespaces = {
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }),
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS),
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT),
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER),
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT),
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS),
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
[ViolationTypes.BAN]: new Keyv({
[CacheKeys.ROLES]: roles,
[CacheKeys.MCP_TOOLS]: mcpTools,
[CacheKeys.CONFIG_STORE]: config,
[CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({
store: keyvMongo,
namespace: CacheKeys.BANS,
ttl: cacheConfig.BAN_DURATION,
namespace: CacheKeys.ENCODED_DOMAINS,
ttl: 0,
}),
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION),
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES),
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES),
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES),
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
CacheKeys.OPENID_EXCHANGED_TOKENS,
Time.TEN_MINUTES,
general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
registrations: createViolationInstance('registrations'),
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
ViolationTypes.RESET_PASSWORD_LIMIT,
),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
[CacheKeys.FLOWS]: flows,
[CacheKeys.OPENID_EXCHANGED_TOKENS]: openIdExchangedTokensCache,
};
/**
@@ -55,10 +116,7 @@ const namespaces = {
*/
function getTTLStores() {
return Object.values(namespaces).filter(
(store) =>
store instanceof Keyv &&
parseInt(store.opts?.ttl ?? '0') > 0 &&
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
(store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
);
}
@@ -94,18 +152,18 @@ async function clearExpiredFromCache(cache) {
if (data?.expires && data.expires <= expiryTime) {
const deleted = await cache.opts.store.delete(key);
if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue;
}
cleared++;
}
} catch (error) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
const deleted = await cache.opts.store.delete(key);
if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue;
}
@@ -114,7 +172,7 @@ async function clearExpiredFromCache(cache) {
}
if (cleared > 0) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.log(
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
);
@@ -155,7 +213,7 @@ async function clearAllExpiredFromCache() {
}
}
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
if (!isRedisEnabled && !isEnabled(CI)) {
/** @type {Set<NodeJS.Timeout>} */
const cleanupIntervals = new Set();
@@ -166,7 +224,7 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
cleanupIntervals.add(cleanup);
if (cacheConfig.DEBUG_MEMORY_CACHE) {
if (debugMemoryCache) {
const monitor = setInterval(() => {
const ttlStores = getTTLStores();
const memory = process.memoryUsage();
@@ -187,13 +245,13 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
}
const dispose = () => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...');
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
cleanupIntervals.forEach((interval) => clearInterval(interval));
cleanupIntervals.clear();
// One final cleanup before exit
clearAllExpiredFromCache().then(() => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed');
debugMemoryCache && console.log('[Cache] Final cleanup completed');
process.exit(0);
});
};

92
api/cache/ioredisClient.js vendored Normal file
View File

@@ -0,0 +1,92 @@
const fs = require('fs');
const Redis = require('ioredis');
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
let ioredisClient;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
} else {
ioredisClient = new Redis(REDIS_URI, redisOptions);
}
ioredisClient.on('ready', () => {
logger.info('IoRedis connection ready');
});
ioredisClient.on('reconnecting', () => {
logger.info('IoRedis connection reconnecting');
});
ioredisClient.on('end', () => {
logger.info('IoRedis connection ended');
});
ioredisClient.on('close', () => {
logger.info('IoRedis connection closed');
});
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
ioredisClient.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
);
} else {
logger.info('[Optional] IoRedis not initialized for rate limiters.');
}
module.exports = ioredisClient;

109
api/cache/keyvRedis.js vendored Normal file
View File

@@ -0,0 +1,109 @@
const fs = require('fs');
const ioredis = require('ioredis');
const KeyvRedis = require('@keyv/redis').default;
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
process.env;
let keyvRedis;
const redis_prefix = REDIS_KEY_PREFIX || '';
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
/** @type {import('@keyv/redis').KeyvRedisOptions} */
let keyvOpts = {
useRedisSets: false,
keyPrefix: redis_prefix,
};
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
const cluster = new ioredis.Cluster(hosts, { redisOptions });
keyvRedis = new KeyvRedis(cluster, keyvOpts);
} else {
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
}
const pingInterval = setInterval(
() => {
logger.debug('KeyvRedis ping');
keyvRedis.client.ping().catch((err) => logger.error('Redis keep-alive ping failed:', err));
},
5 * 60 * 1000,
);
keyvRedis.on('ready', () => {
logger.info('KeyvRedis connection ready');
});
keyvRedis.on('reconnecting', () => {
logger.info('KeyvRedis connection reconnecting');
});
keyvRedis.on('end', () => {
logger.info('KeyvRedis connection ended');
});
keyvRedis.on('close', () => {
clearInterval(pingInterval);
logger.info('KeyvRedis connection closed');
});
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
);
} else {
logger.info('[Optional] Redis not initialized.');
}
module.exports = keyvRedis;

View File

@@ -1,5 +1,4 @@
const { isEnabled } = require('~/server/utils');
const { ViolationTypes } = require('librechat-data-provider');
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
@@ -10,14 +9,14 @@ const banViolation = require('./banViolation');
* @param {Object} res - Express response object.
* @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log.
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
* @param {number} [score=1] - The severity of the violation. Defaults to 1
*/
const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id;
if (!userId) {
return;
}
const logs = getLogStores(ViolationTypes.GENERAL);
const logs = getLogStores('general');
const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;

View File

@@ -1,57 +0,0 @@
const IoRedis = require('ioredis');
const { cacheConfig } = require('./cacheConfig');
const { createClient, createCluster } = require('@keyv/redis');
const GLOBAL_PREFIX_SEPARATOR = '::';
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
const ca = cacheConfig.REDIS_CA;
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
let ioredisClient = null;
if (cacheConfig.USE_REDIS) {
const redisOptions = {
username: username,
password: password,
tls: ca ? { ca } : undefined,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
};
ioredisClient =
urls.length === 1
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
: new IoRedis.Cluster(cacheConfig.REDIS_URI, { redisOptions });
// Pinging the Redis server every 5 minutes to keep the connection alive
const pingInterval = setInterval(() => ioredisClient.ping(), 5 * 60 * 1000);
ioredisClient.on('close', () => clearInterval(pingInterval));
ioredisClient.on('end', () => clearInterval(pingInterval));
}
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
let keyvRedisClient = null;
if (cacheConfig.USE_REDIS) {
// ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
// The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
const redisOptions = { username, password, socket: { tls: ca != null, ca } };
keyvRedisClient =
urls.length === 1
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
: createCluster({
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
defaults: redisOptions,
});
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
// Pinging the Redis server every 5 minutes to keep the connection alive
const keyvPingInterval = setInterval(() => keyvRedisClient.ping(), 5 * 60 * 1000);
keyvRedisClient.on('disconnect', () => clearInterval(keyvPingInterval));
keyvRedisClient.on('end', () => clearInterval(keyvPingInterval));
}
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };

View File

@@ -70,9 +70,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
if (ephemeralAgent?.execute_code === true) {
tools.push(Tools.execute_code);
}
if (ephemeralAgent?.file_search === true) {
tools.push(Tools.file_search);
}
if (ephemeralAgent?.web_search === true) {
tools.push(Tools.web_search);
}
@@ -90,7 +87,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
}
const instructions = req.body.promptPrefix;
const result = {
return {
id: agent_id,
instructions,
provider: endpoint,
@@ -98,11 +95,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
model,
tools,
};
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
result.artifacts = ephemeralAgent.artifacts;
}
return result;
};
/**

View File

@@ -43,7 +43,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -413,7 +413,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -670,7 +670,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -1332,7 +1332,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -1514,7 +1514,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -1798,7 +1798,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -2350,7 +2350,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();

View File

@@ -1,6 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models');
@@ -100,15 +98,10 @@ module.exports = {
update.conversationId = newConversationId;
}
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
update.expiredAt = null;
}
if (req.body.isTemporary) {
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
} else {
update.expiredAt = null;
}

View File

@@ -1,7 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
const { getProjectByName } = require('./Project');
const { getAgent } = require('./Agent');
const { EToolResources } = require('librechat-data-provider');
const { File } = require('~/db/models');
/**
@@ -14,124 +12,17 @@ const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean();
};
/**
* Checks if a user has access to multiple files through a shared agent (batch operation)
* @param {string} userId - The user ID to check access for
* @param {string[]} fileIds - Array of file IDs to check
* @param {string} agentId - The agent ID that might grant access
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false));
try {
const agent = await getAgent({ id: agentId });
if (!agent) {
return accessMap;
}
// Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId) {
fileIds.forEach((fileId) => accessMap.set(fileId, true));
return accessMap;
}
// Check if agent is shared with the user via projects
if (!agent.projectIds || agent.projectIds.length === 0) {
return accessMap;
}
// Check if agent is in global project
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
) {
return accessMap;
}
// Agent is globally shared - check if it's collaborative
if (checkCollaborative && !agent.isCollaborative) {
return accessMap;
}
// Check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap;
} catch (error) {
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
return accessMap;
}
};
/**
* Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field.
* @param {Object} [options] - Additional options
* @param {string} [options.userId] - User ID for access control
* @param {string} [options.agentId] - Agent ID that might grant access to files
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
*/
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions };
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
// If userId and agentId are provided, filter files based on access
if (options.userId && options.agentId) {
// Collect file IDs that need access check
const filesToCheck = [];
const ownedFiles = [];
for (const file of files) {
if (file.user && file.user.toString() === options.userId) {
ownedFiles.push(file);
} else {
filesToCheck.push(file);
}
}
if (filesToCheck.length === 0) {
return ownedFiles;
}
// Batch check access for all non-owned files
const fileIds = filesToCheck.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
options.userId,
fileIds,
options.agentId,
false,
);
// Filter files based on access
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
return [...ownedFiles, ...accessibleFiles];
}
return files;
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
};
/**
@@ -141,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, option
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
*/
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
if (!fileIds || !fileIds.length) {
return [];
}
try {
const filter = {
file_id: { $in: fileIds },
$or: [],
};
if (toolResourceSet.has(EToolResources.ocr)) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
if (toolResourceSet.size) {
filter.$or = [];
}
if (toolResourceSet.has(EToolResources.file_search)) {
filter.$or.push({ embedded: true });
}
@@ -285,5 +176,4 @@ module.exports = {
deleteFiles,
deleteFileByFilter,
batchUpdateFiles,
hasAccessToFilesViaAgent,
};

View File

@@ -1,264 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { fileSchema } = require('@librechat/data-schemas');
const { agentSchema } = require('@librechat/data-schemas');
const { projectSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { getFiles, createFile } = require('./File');
const { getProjectByName } = require('./Project');
const { createAgent } = require('./Agent');
let File;
let Agent;
let Project;
describe('File Access Control', () => {
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
File = mongoose.models.File || mongoose.model('File', fileSchema);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await File.deleteMany({});
await Agent.deleteMany({});
await Project.deleteMany({});
});
describe('hasAccessToFilesViaAgent', () => {
it('should efficiently check access for multiple files at once', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
// Create files
for (const fileId of fileIds) {
await createFile({
user: authorId,
file_id: fileId,
filename: `file-${fileId}.txt`,
filepath: `/uploads/${fileId}`,
});
}
// Create agent with only first two files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileIds[0], fileIds[1]],
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for all files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have access only to the first two files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false);
expect(accessMap.get(fileIds[3])).toBe(false);
});
it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create agent
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]], // Only one file attached
},
},
});
// Check access as the author
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should handle non-existent agent gracefully', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const fileIds = [uuidv4(), uuidv4()];
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
// Should have no access to any files
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when agent is not collaborative', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create agent with files but isCollaborative: false
await createAgent({
id: agentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have no access to any files when isCollaborative is false
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
});
describe('getFiles with agent access control', () => {
test('should return files owned by user and files accessible through agent', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
const ownedFileId = `file_${uuidv4()}`;
const sharedFileId = `file_${uuidv4()}`;
const inaccessibleFileId = `file_${uuidv4()}`;
// Create/get global project using getProjectByName which will upsert
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
// Create agent with shared file
await createAgent({
id: agentId,
name: 'Shared Agent',
provider: 'test',
model: 'test-model',
author: authorId,
projectIds: [globalProject._id],
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [sharedFileId],
},
},
});
// Create files
await createFile({
file_id: ownedFileId,
user: userId,
filename: 'owned.txt',
filepath: '/uploads/owned.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: sharedFileId,
user: authorId,
filename: 'shared.txt',
filepath: '/uploads/shared.txt',
type: 'text/plain',
bytes: 200,
embedded: true,
});
await createFile({
file_id: inaccessibleFileId,
user: authorId,
filename: 'inaccessible.txt',
filepath: '/uploads/inaccessible.txt',
type: 'text/plain',
bytes: 300,
});
// Get files with access control
const files = await getFiles(
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
null,
{ text: 0 },
{ userId: userId.toString(), agentId },
);
expect(files).toHaveLength(2);
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
});
test('should return all files when no userId/agentId provided', async () => {
const userId = new mongoose.Types.ObjectId();
const fileId1 = `file_${uuidv4()}`;
const fileId2 = `file_${uuidv4()}`;
await createFile({
file_id: fileId1,
user: userId,
filename: 'file1.txt',
filepath: '/uploads/file1.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: fileId2,
user: new mongoose.Types.ObjectId(),
filename: 'file2.txt',
filepath: '/uploads/file2.txt',
type: 'text/plain',
bytes: 200,
});
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
expect(files).toHaveLength(2);
});
});
});

View File

@@ -1,7 +1,5 @@
const { z } = require('zod');
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { Message } = require('~/db/models');
const idSchema = z.string().uuid();
@@ -56,14 +54,9 @@ async function saveMessage(req, params, metadata) {
};
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.expiredAt = null;
}
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
} else {
update.expiredAt = null;
}

View File

@@ -135,11 +135,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-4': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 },

View File

@@ -636,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
);
});
test('should return correct prompt and completion rates for Grok 4 model', () => {
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
tokenValues['grok-3'].prompt,
@@ -671,15 +662,6 @@ describe('Grok Model Tests - Pricing', () => {
tokenValues['grok-3-mini-fast'].completion,
);
});
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
});
});

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "v0.7.9-rc1",
"version": "v0.7.8",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -44,20 +44,20 @@
"@googleapis/youtube": "^20.0.0",
"@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.47",
"@langchain/core": "^0.3.62",
"@langchain/core": "^0.3.60",
"@langchain/google-genai": "^0.2.13",
"@langchain/google-vertexai": "^0.2.13",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.63",
"@librechat/agents": "^2.4.41",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
"connect-redis": "^8.1.0",
"cohere-ai": "^7.9.1",
"compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.7.2",
"cookie-parser": "^1.4.7",
"cors": "^2.8.5",
@@ -67,7 +67,7 @@
"express": "^4.21.2",
"express-mongo-sanitize": "^2.2.0",
"express-rate-limit": "^7.4.1",
"express-session": "^1.18.2",
"express-session": "^1.18.1",
"express-static-gzip": "^2.2.0",
"file-type": "^18.7.0",
"firebase": "^11.0.2",
@@ -88,7 +88,7 @@
"mime": "^3.0.0",
"module-alias": "^2.2.3",
"mongoose": "^8.12.1",
"multer": "^2.0.2",
"multer": "^2.0.1",
"nanoid": "^3.3.7",
"node-fetch": "^2.7.0",
"nodemailer": "^6.9.15",

View File

@@ -169,6 +169,9 @@ function disposeClient(client) {
client.isGenerativeModel = null;
}
// Properties specific to OpenAIClient
if (client.ChatGPTClient) {
client.ChatGPTClient = null;
}
if (client.completionsUrl) {
client.completionsUrl = null;
}

View File

@@ -0,0 +1,282 @@
const { getResponseSender, Constants } = require('librechat-data-provider');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
endpointOption,
conversationId,
modelDisplayLabel,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null;
logger.debug('[AskController]', {
text,
conversationId,
...endpointOption,
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
});
let userMessage = null;
let userMessagePromise = null;
let promptTokens = null;
let userMessageId = null;
let responseMessageId = null;
let getAbortData = null;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
modelDisplayLabel,
});
const initialConversationId = conversationId;
const newConvo = !initialConversationId;
const userId = req.user.id;
let reqDataContext = {
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
};
const updateReqData = (data = {}) => {
reqDataContext = processReqData(data, reqDataContext);
abortKey = reqDataContext.abortKey;
userMessage = reqDataContext.userMessage;
userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = reqDataContext.responseMessageId;
promptTokens = reqDataContext.promptTokens;
conversationId = reqDataContext.conversationId;
userMessageId = reqDataContext.userMessageId;
};
let { onProgress: progressCallback, getPartialText } = createOnProgress();
const performCleanup = () => {
logger.debug('[AskController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore
}
}
}
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
addTitle = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AskController] Cleanup completed');
};
try {
({ client } = await initializeClient({ req, res, endpointOption }));
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
if (client) {
requestDataMap.set(req, { client });
}
clientRef = new WeakRef(client);
getAbortData = () => {
const currentClient = clientRef?.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[AskController] Request closed');
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AskController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
});
const messageOptions = {
user: userId,
parentMessageId,
conversationId: reqDataContext.conversationId,
overrideParentMessageId,
getReqData: updateReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
res,
},
};
/** @type {TMessage} */
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;
const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
const latestUserMessage = reqDataContext.userMessage;
if (client?.options?.attachments && latestUserMessage) {
latestUserMessage.files = client.options.attachments;
if (endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model;
}
delete latestUserMessage.image_urls;
}
if (!abortController.signal.aborted) {
const finalResponseMessage = { ...response };
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: latestUserMessage,
responseMessage: finalResponseMessage,
});
res.end();
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
await saveMessage(
req,
{ ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/AskController.js - response end' },
);
}
}
if (!client?.skipSaveUserMessage && latestUserMessage) {
await saveMessage(req, latestUserMessage, {
context: "api/server/controllers/AskController.js - don't skip saving user message",
});
}
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response: { ...response },
client,
})
.then(() => {
logger.debug('[AskController] Title generation started');
})
.catch((err) => {
logger.error('[AskController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AskController] Title generation completed');
performCleanup();
});
} else {
performCleanup();
}
} catch (error) {
logger.error('[AskController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef?.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[AskController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, {
sender,
partialText,
conversationId: reqDataContext.conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
userMessageId: reqDataContext.userMessageId,
})
.catch((err) => {
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
});
}
};
module.exports = AskController;

View File

@@ -1,17 +1,17 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const openIdClient = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
requestPasswordReset,
setOpenIDAuthTokens,
registerUser,
resetPassword,
setAuthTokens,
registerUser,
requestPasswordReset,
setOpenIDAuthTokens,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getOpenIdConfig } = require('~/strategies');
const { isEnabled } = require('~/server/utils');
const registrationController = async (req, res) => {
try {

View File

@@ -1,5 +1,3 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { getResponseSender } = require('librechat-data-provider');
const {
handleAbortError,
@@ -12,8 +10,9 @@ const {
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { createOnProgress } = require('~/server/utils');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
let {
@@ -85,7 +84,7 @@ const EditController = async (req, res, next, initializeClient) => {
}
if (abortKey) {
logger.debug('[EditController] Cleaning up abort controller');
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
@@ -199,7 +198,7 @@ const EditController = async (req, res, next, initializeClient) => {
const finalUserMessage = reqDataContext.userMessage;
const finalResponseMessage = { ...response };
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -24,23 +24,17 @@ const handleValidationError = (err, res) => {
}
};
module.exports = (err, _req, res, _next) => {
// eslint-disable-next-line no-unused-vars
module.exports = (err, req, res, next) => {
try {
if (err.name === 'ValidationError') {
return handleValidationError(err, res);
return (err = handleValidationError(err, res));
}
if (err.code && err.code == 11000) {
return handleDuplicateKeyError(err, res);
return (err = handleDuplicateKeyError(err, res));
}
// Special handling for errors like SyntaxError
if (err.statusCode && err.body) {
return res.status(err.statusCode).send(err.body);
}
logger.error('ErrorController => error', err);
return res.status(500).send('An unknown error occurred.');
} catch (err) {
logger.error('ErrorController => processing error', err);
return res.status(500).send('Processing error in ErrorController.');
logger.error('ErrorController => error', err);
res.status(500).send('An unknown error occurred.');
}
};

View File

@@ -1,241 +0,0 @@
const errorController = require('./ErrorController');
const { logger } = require('~/config');
// Mock the logger
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
},
}));
describe('ErrorController', () => {
let mockReq, mockRes, mockNext;
beforeEach(() => {
mockReq = {};
mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
mockNext = jest.fn();
logger.error.mockClear();
});
describe('ValidationError handling', () => {
it('should handle ValidationError with single error', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '["Email is required"]',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with multiple errors', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
password: { message: 'Password is required', path: 'password' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '"Email is required Password is required"',
fields: '["email","password"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with empty errors object', () => {
const validationError = {
name: 'ValidationError',
errors: {},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '[]',
fields: '[]',
});
});
});
describe('Duplicate key error handling', () => {
it('should handle duplicate key error (code 11000)', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle duplicate key error with multiple fields', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com', username: 'testuser' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email","username"] already exists.',
fields: '["email","username"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle error with code 11000 as string', () => {
const duplicateKeyError = {
code: '11000',
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
});
});
describe('SyntaxError handling', () => {
it('should handle errors with statusCode and body', () => {
const syntaxError = {
statusCode: 400,
body: 'Invalid JSON syntax',
};
errorController(syntaxError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
});
it('should handle errors with different statusCode and body', () => {
const customError = {
statusCode: 422,
body: { error: 'Unprocessable entity' },
};
errorController(customError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(422);
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
});
it('should handle error with statusCode but no body', () => {
const partialError = {
statusCode: 400,
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
it('should handle error with body but no statusCode', () => {
const partialError = {
body: 'Some error message',
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
});
describe('Unknown error handling', () => {
it('should handle unknown errors', () => {
const unknownError = new Error('Some unknown error');
errorController(unknownError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
});
it('should handle errors with code other than 11000', () => {
const mongoError = {
code: 11100,
message: 'Some MongoDB error',
};
errorController(mongoError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
});
it('should handle null/undefined errors', () => {
errorController(null, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledWith(
'ErrorController => processing error',
expect.any(Error),
);
});
});
describe('Catch block handling', () => {
beforeEach(() => {
// Restore logger mock to normal behavior for these tests
logger.error.mockRestore();
logger.error = jest.fn();
});
it('should handle errors when logger.error throws', () => {
// Create fresh mocks for this test
const freshMockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
// Mock logger to throw on the first call, succeed on the second
logger.error
.mockImplementationOnce(() => {
throw new Error('Logger error');
})
.mockImplementation(() => {});
const testError = new Error('Test error');
errorController(testError, mockReq, freshMockRes, mockNext);
expect(freshMockRes.status).toHaveBeenCalledWith(500);
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledTimes(2);
});
});
});

View File

@@ -1,10 +1,11 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
const { CacheKeys, AuthType } = require('librechat-data-provider');
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
const { getToolkitKey } = require('~/server/services/ToolService');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { availableTools } = require('~/app/clients/tools');
const { getLogStores } = require('~/cache');
const { Constants } = require('librechat-data-provider');
/**
* Filters out duplicate plugins from the list of plugins.
@@ -139,9 +140,9 @@ function createGetServerTools() {
const getAvailableTools = async (req, res) => {
try {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
if (cachedToolsArray) {
res.status(200).json(cachedToolsArray);
const cachedTools = await cache.get(CacheKeys.TOOLS);
if (cachedTools) {
res.status(200).json(cachedTools);
return;
}
@@ -172,7 +173,7 @@ const getAvailableTools = async (req, res) => {
}
});
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
const toolDefinitions = await getCachedTools({ includeGlobal: true });
const toolsOutput = [];
for (const plugin of authenticatedPlugins) {

View File

@@ -1,5 +1,11 @@
const {
Tools,
Constants,
FileSources,
webSearchKeys,
extractWebSearchEnvVars,
} = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const { webSearchKeys, extractWebSearchEnvVars } = require('@librechat/api');
const {
getFiles,
updateUser,
@@ -14,7 +20,6 @@ const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/service
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { Tools, Constants, FileSources } = require('librechat-data-provider');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User } = require('~/db/models');
const { deleteToolCalls } = require('~/models/ToolCall');

View File

@@ -1,195 +0,0 @@
const { duplicateAgent } = require('../v1');
const { getAgent, createAgent } = require('~/models/Agent');
const { getActions } = require('~/models/Action');
const { nanoid } = require('nanoid');
jest.mock('~/models/Agent');
jest.mock('~/models/Action');
jest.mock('nanoid');
describe('duplicateAgent', () => {
let req, res;
beforeEach(() => {
req = {
params: { id: 'agent_123' },
user: { id: 'user_456' },
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
jest.clearAllMocks();
});
it('should duplicate an agent successfully', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_789',
versions: [{ name: 'Test Agent', version: 1 }],
__v: 0,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_456',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
id: 'agent_new_123',
author: 'user_456',
name: expect.stringContaining('Test Agent ('),
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
}),
);
expect(createAgent).toHaveBeenCalledWith(
expect.not.objectContaining({
versions: expect.anything(),
__v: expect.anything(),
}),
);
expect(res.status).toHaveBeenCalledWith(201);
expect(res.json).toHaveBeenCalledWith({
agent: mockNewAgent,
actions: [],
});
});
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
versions: [
{
name: 'Test Agent',
versions: [{ name: 'Nested' }],
__v: 1,
},
],
__v: 2,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(mockNewAgent.versions).toHaveLength(1);
const firstVersion = mockNewAgent.versions[0];
expect(firstVersion).not.toHaveProperty('versions');
expect(firstVersion).not.toHaveProperty('__v');
expect(mockNewAgent).not.toHaveProperty('__v');
expect(res.status).toHaveBeenCalledWith(201);
});
it('should return 404 if agent not found', async () => {
getAgent.mockResolvedValue(null);
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Agent not found',
status: 'error',
});
});
it('should handle tool_resources.ocr correctly', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
tool_resources: {
ocr: { enabled: true, config: 'test' },
other: { should: 'not be copied' },
},
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue({ id: 'agent_new_123' });
await duplicateAgent(req, res);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
tool_resources: {
ocr: { enabled: true, config: 'test' },
},
}),
);
});
it('should handle errors gracefully', async () => {
getAgent.mockRejectedValue(new Error('Database error'));
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(500);
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
});
});

View File

@@ -1,22 +1,18 @@
require('events').EventEmitter.defaultMaxListeners = 100;
const { logger } = require('@librechat/data-schemas');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const {
sendEvent,
createRun,
Tokenizer,
checkAccess,
memoryInstructions,
formatContentStrings,
createMemoryProcessor,
} = require('@librechat/api');
const {
Callback,
Providers,
GraphEvents,
formatMessage,
formatAgentMessages,
formatContentStrings,
getTokenCountForMessage,
createMetadataAggregator,
} = require('@librechat/agents');
@@ -26,41 +22,31 @@ const {
VisionModes,
ContentTypes,
EModelEndpoint,
KnownEndpoints,
PermissionTypes,
isAgentsEndpoint,
AgentCapabilities,
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const {
findPluginAuthsByKeys,
getFormattedMemories,
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
getCustomEndpointConfig,
createGetMCPAuthMap,
checkCapability,
} = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const { checkAccess } = require('~/server/middleware/roles/access');
const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent');
const { getMCPManager } = require('~/config');
const omitTitleOptions = new Set([
'stream',
'thinking',
'streaming',
'clientOptions',
'thinkingConfig',
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
'additionalModelRequestFields',
]);
/**
* @param {ServerRequest} req
* @param {Agent} agent
@@ -75,6 +61,8 @@ const payloadParser = ({ req, agent, endpoint }) => {
return req.body.endpointOption.model_parameters;
};
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
function createTokenCounter(encoding) {
@@ -405,12 +393,7 @@ class AgentClient extends BaseClient {
if (user.personalization?.memories === false) {
return;
}
const hasAccess = await checkAccess({
user,
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE],
getRoleByName,
});
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
if (!hasAccess) {
logger.debug(
@@ -455,12 +438,6 @@ class AgentClient extends BaseClient {
res: this.options.res,
agent: prelimAgent,
allowedProviders,
endpointOption: {
endpoint:
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
? EModelEndpoint.agents
: memoryConfig.agent?.provider,
},
});
if (!agent) {
@@ -534,10 +511,7 @@ class AgentClient extends BaseClient {
messagesToProcess = [...messages.slice(-messageWindowSize)];
}
}
const bufferString = getBufferString(messagesToProcess);
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
return await this.processMemory([bufferMessage]);
return await this.processMemory(messagesToProcess);
} catch (error) {
logger.error('Memory Agent failed to process memory', error);
}
@@ -703,18 +677,23 @@ class AgentClient extends BaseClient {
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
user: this.options.req.user,
},
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
recursionLimit: agentsEConfig?.recursionLimit,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',
};
const getUserMCPAuthMap = await createGetMCPAuthMap();
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
payload,
this.indexTokenCountMap,
toolSet,
);
if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) {
initialMessages = formatContentStrings(initialMessages);
}
/**
*
@@ -778,9 +757,6 @@ class AgentClient extends BaseClient {
}
let messages = _messages;
if (agent.useLegacyContent === true) {
messages = formatContentStrings(messages);
}
if (
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
'prompt-caching',
@@ -829,11 +805,10 @@ class AgentClient extends BaseClient {
}
try {
if (await hasCustomUserVars()) {
config.configurable.userMCPAuthMap = await getMCPAuthMap({
if (getUserMCPAuthMap) {
config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
tools: agent.tools,
userId: this.options.req.user.id,
findPluginAuthsByKeys,
});
}
} catch (err) {
@@ -1008,26 +983,23 @@ class AgentClient extends BaseClient {
throw new Error('Run not initialized');
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const { req, res, agent } = this.options;
const endpoint = agent.endpoint;
const endpoint = this.options.agent.endpoint;
const { req, res } = this.options;
/** @type {import('@librechat/agents').ClientOptions} */
let clientOptions = {
maxTokens: 75,
model: agent.model_parameters.model,
};
const { getOptions, overrideProvider, customEndpointConfig } =
await getProviderConfig(endpoint);
/** @type {TEndpoint | undefined} */
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
let endpointConfig = req.app.locals[endpoint];
if (!endpointConfig) {
logger.warn(
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
);
try {
endpointConfig = await getCustomEndpointConfig(endpoint);
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
err,
);
}
}
if (
endpointConfig &&
endpointConfig.titleModel &&
@@ -1035,56 +1007,30 @@ class AgentClient extends BaseClient {
) {
clientOptions.model = endpointConfig.titleModel;
}
const options = await getOptions({
req,
res,
optionsOnly: true,
overrideEndpoint: endpoint,
overrideModel: clientOptions.model,
endpointOption: { model_parameters: clientOptions },
});
let provider = options.provider ?? overrideProvider ?? agent.provider;
if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName == null
clientOptions.model &&
this.options.agent.model_parameters.model !== clientOptions.model
) {
provider = Providers.OPENAI;
} else if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName != null &&
provider !== Providers.AZURE
) {
provider = Providers.AZURE;
clientOptions =
(
await initOpenAI({
req,
res,
optionsOnly: true,
overrideModel: clientOptions.model,
overrideEndpoint: endpoint,
endpointOption: {
model_parameters: clientOptions,
},
})
)?.llmConfig ?? clientOptions;
}
/** @type {import('@librechat/agents').ClientOptions} */
clientOptions = { ...options.llmConfig };
if (options.configOptions) {
clientOptions.configuration = options.configOptions;
}
// Ensure maxTokens is set for non-o1 models
if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) {
clientOptions.maxTokens = 75;
} else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens;
}
clientOptions = Object.assign(
Object.fromEntries(
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
),
);
if (provider === Providers.GOOGLE) {
clientOptions.json = true;
}
try {
const titleResult = await this.run.generateTitle({
provider,
inputText: text,
contentParts: this.contentParts,
clientOptions,
@@ -1102,10 +1048,8 @@ class AgentClient extends BaseClient {
let input_tokens, output_tokens;
if (item.usage) {
input_tokens =
item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens;
output_tokens =
item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens;
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
} else if (item.tokenUsage) {
input_tokens = item.tokenUsage.promptTokens;
output_tokens = item.tokenUsage.completionTokens;
@@ -1135,52 +1079,8 @@ class AgentClient extends BaseClient {
}
}
/**
* @param {object} params
* @param {number} params.promptTokens
* @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model]
* @param {string} [params.context='message']
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
try {
await spendTokens(
{
model,
context,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ promptTokens, completionTokens },
);
if (
usage &&
typeof usage === 'object' &&
'reasoning_tokens' in usage &&
typeof usage.reasoning_tokens === 'number'
) {
await spendTokens(
{
model,
context: 'reasoning',
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ completionTokens: usage.reasoning_tokens },
);
}
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #recordTokenUsage] Error recording token usage',
error,
);
}
}
/** Silent method, as `recordCollectedUsage` is used instead */
async recordTokenUsage() {}
getEncoding() {
return 'o200k_base';

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { sendResponse } = require('~/server/middleware/error');
const { recordUsage } = require('~/server/services/Threads');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendResponse } = require('~/server/utils');
/**
* @typedef {Object} ErrorHandlerContext
@@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -0,0 +1,106 @@
const { HttpsProxyAgent } = require('https-proxy-agent');
const { resolveHeaders } = require('librechat-data-provider');
const { createLLM } = require('~/app/clients/llm');
/**
* Initializes and returns a Language Learning Model (LLM) instance.
*
* @param {Object} options - Configuration options for the LLM.
* @param {string} options.model - The model identifier.
* @param {string} options.modelName - The specific name of the model.
* @param {number} options.temperature - The temperature setting for the model.
* @param {number} options.presence_penalty - The presence penalty for the model.
* @param {number} options.frequency_penalty - The frequency penalty for the model.
* @param {number} options.max_tokens - The maximum number of tokens for the model output.
* @param {boolean} options.streaming - Whether to use streaming for the model output.
* @param {Object} options.context - The context for the conversation.
* @param {number} options.tokenBuffer - The token buffer size.
* @param {number} options.initialMessageCount - The initial message count.
* @param {string} options.conversationId - The ID of the conversation.
* @param {string} options.user - The user identifier.
* @param {string} options.langchainProxy - The langchain proxy URL.
* @param {boolean} options.useOpenRouter - Whether to use OpenRouter.
* @param {Object} options.options - Additional options.
* @param {Object} options.options.headers - Custom headers for the request.
* @param {string} options.options.proxy - Proxy URL.
* @param {Object} options.options.req - The request object.
* @param {Object} options.options.res - The response object.
* @param {boolean} options.options.debug - Whether to enable debug mode.
* @param {string} options.apiKey - The API key for authentication.
* @param {Object} options.azure - Azure-specific configuration.
* @param {Object} options.abortController - The AbortController instance.
* @returns {Object} The initialized LLM instance.
*/
function initializeLLM(options) {
const {
model,
modelName,
temperature,
presence_penalty,
frequency_penalty,
max_tokens,
streaming,
user,
langchainProxy,
useOpenRouter,
options: { headers, proxy },
apiKey,
azure,
} = options;
const modelOptions = {
modelName: modelName || model,
temperature,
presence_penalty,
frequency_penalty,
user,
};
if (max_tokens) {
modelOptions.max_tokens = max_tokens;
}
const configOptions = {};
if (langchainProxy) {
configOptions.basePath = langchainProxy;
}
if (useOpenRouter) {
configOptions.basePath = 'https://openrouter.ai/api/v1';
configOptions.baseOptions = {
headers: {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
};
}
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
configOptions.baseOptions = {
headers: resolveHeaders({
...headers,
...configOptions?.baseOptions?.headers,
}),
};
}
if (proxy) {
configOptions.httpAgent = new HttpsProxyAgent(proxy);
configOptions.httpsAgent = new HttpsProxyAgent(proxy);
}
const llm = createLLM({
modelOptions,
configOptions,
openAIApiKey: apiKey,
azure,
streaming,
});
return llm;
}
module.exports = {
initializeLLM,
};

View File

@@ -1,5 +1,3 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
handleAbortError,
@@ -7,19 +5,17 @@ const {
cleanupAbortController,
} = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
isRegenerate,
endpointOption,
conversationId,
isContinued = false,
editedContent = null,
parentMessageId = null,
overrideParentMessageId = null,
responseMessageId: editedResponseMessageId = null,
} = req.body;
let sender;
@@ -71,7 +67,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
handler();
}
} catch (e) {
logger.error('[AgentController] Error in cleanup handler', e);
// Ignore cleanup errors
}
}
}
@@ -159,7 +155,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
logger.error('[AgentController] Error removing close listener', e);
// Ignore
}
});
@@ -167,15 +163,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
user: userId,
onStart,
getReqData,
isContinued,
isRegenerate,
editedContent,
conversationId,
parentMessageId,
abortController,
overrideParentMessageId,
isEdited: !!editedContent,
responseMessageId: editedResponseMessageId,
progressOptions: {
res,
},
@@ -215,7 +206,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Create a new response object with minimal copies
const finalResponse = { ...response };
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -1,8 +1,6 @@
const { z } = require('zod');
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const {
Tools,
Constants,
@@ -10,7 +8,6 @@ const {
SystemRoles,
EToolResources,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const {
getAgent,
@@ -33,7 +30,6 @@ const { deleteFileByFilter } = require('~/models/File');
const systemTools = {
[Tools.execute_code]: true,
[Tools.file_search]: true,
[Tools.web_search]: true,
};
/**
@@ -46,13 +42,9 @@ const systemTools = {
*/
const createAgentHandler = async (req, res) => {
try {
const validatedData = agentCreateSchema.parse(req.body);
const { tools = [], ...agentData } = removeNullishValues(validatedData);
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
const { id: userId } = req.user;
agentData.id = `agent_${nanoid()}`;
agentData.author = userId;
agentData.tools = [];
const availableTools = await getCachedTools({ includeGlobal: true });
@@ -66,13 +58,19 @@ const createAgentHandler = async (req, res) => {
}
}
Object.assign(agentData, {
author: userId,
name,
description,
instructions,
provider,
model,
});
agentData.id = `agent_${nanoid()}`;
const agent = await createAgent(agentData);
res.status(201).json(agent);
} catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents] Error creating agent', error);
res.status(500).json({ error: error.message });
}
@@ -156,16 +154,14 @@ const getAgentHandler = async (req, res) => {
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const validatedData = agentUpdateSchema.parse(req.body);
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
const { projectIds, removeProjectIds, ...updateData } = req.body;
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id });
const isAuthor = existingAgent.author.toString() === req.user.id;
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
@@ -204,11 +200,6 @@ const updateAgentHandler = async (req, res) => {
return res.json(updatedAgent);
} catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents/:id] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents/:id] Error updating Agent', error);
if (error.statusCode === 409) {
@@ -251,8 +242,6 @@ const duplicateAgentHandler = async (req, res) => {
createdAt: _createdAt,
updatedAt: _updatedAt,
tool_resources: _tool_resources = {},
versions: _versions,
__v: _v,
...cloneData
} = agent;
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
@@ -391,22 +380,6 @@ const uploadAgentAvatarHandler = async (req, res) => {
return res.status(400).json({ message: 'Agent ID is required' });
}
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id: agent_id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
return res.status(403).json({
error: 'You do not have permission to modify this non-collaborative agent',
});
}
const buffer = await fs.readFile(req.file.path);
const fileStrategy = req.app.locals.fileStrategy;
@@ -429,7 +402,14 @@ const uploadAgentAvatarHandler = async (req, res) => {
source: fileStrategy,
};
let _avatar = existingAgent.avatar;
let _avatar;
try {
const agent = await getAgent({ id: agent_id });
_avatar = agent.avatar;
} catch (error) {
logger.error('[/:agent_id/avatar] Error fetching agent', error);
_avatar = {};
}
if (_avatar && _avatar.source) {
const { deleteFile } = getStrategyFunctions(_avatar.source);
@@ -451,7 +431,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
};
promises.push(
await updateAgent({ id: agent_id }, data, {
await updateAgent({ id: agent_id, author: req.user.id }, data, {
updatingUserId: req.user.id,
}),
);

View File

@@ -1,659 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { agentSchema } = require('@librechat/data-schemas');
// Only mock the dependencies that are not database-related
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn().mockResolvedValue({
web_search: true,
execute_code: true,
file_search: true,
}),
}));
jest.mock('~/models/Project', () => ({
getProjectByName: jest.fn().mockResolvedValue(null),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/services/Files/images/avatar', () => ({
resizeAvatar: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3Url: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
filterFile: jest.fn(),
}));
jest.mock('~/models/Action', () => ({
updateAction: jest.fn(),
getActions: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/models/File', () => ({
deleteFileByFilter: jest.fn(),
}));
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
*/
let Agent;
describe('Agent Controllers - Mass Assignment Protection', () => {
let mongoServer;
let mockReq;
let mockRes;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
// Reset all mocks
jest.clearAllMocks();
// Setup mock request and response objects
mockReq = {
user: {
id: new mongoose.Types.ObjectId().toString(),
role: 'USER',
},
body: {},
params: {},
app: {
locals: {
fileStrategy: 'local',
},
},
};
mockRes = {
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
};
});
describe('createAgentHandler', () => {
test('should create agent with allowed fields only', async () => {
const validData = {
name: 'Test Agent',
description: 'A test agent',
instructions: 'Be helpful',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search'],
model_parameters: { temperature: 0.7 },
tool_resources: {
file_search: { file_ids: ['file1', 'file2'] },
},
};
mockReq.body = validData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
expect(mockRes.json).toHaveBeenCalled();
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.name).toBe('Test Agent');
expect(createdAgent.description).toBe('A test agent');
expect(createdAgent.provider).toBe('openai');
expect(createdAgent.model).toBe('gpt-4');
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
expect(createdAgent.tools).toContain('web_search');
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb).toBeDefined();
expect(agentInDb.name).toBe('Test Agent');
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
});
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
const maliciousData = {
// Required fields
provider: 'openai',
model: 'gpt-4',
name: 'Malicious Agent',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
authorName: 'Hacker', // Should be stripped
isCollaborative: true, // Should be stripped on creation
versions: [], // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
id: 'custom_agent_id', // Should be overridden
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
mockReq.body = maliciousData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not set
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
expect(createdAgent.authorName).toBeUndefined();
expect(createdAgent.isCollaborative).toBeFalsy();
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
// Verify timestamps are recent (not the malicious dates)
const createdTime = new Date(createdAgent.createdAt).getTime();
const now = Date.now();
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
expect(agentInDb.authorName).toBeUndefined();
});
test('should validate required fields', async () => {
const invalidData = {
name: 'Missing Required Fields',
// Missing provider and model
};
mockReq.body = invalidData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
// Verify nothing was created in database
const count = await Agent.countDocuments();
expect(count).toBe(0);
});
test('should handle tool_resources validation', async () => {
const dataWithInvalidToolResources = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Tool Resources',
tool_resources: {
// Valid resources
file_search: {
file_ids: ['file1', 'file2'],
vector_store_ids: ['vs1'],
},
execute_code: {
file_ids: ['file3'],
},
// Invalid resource (should be stripped by schema)
invalid_resource: {
file_ids: ['file4'],
},
},
};
mockReq.body = dataWithInvalidToolResources;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.tool_resources).toBeDefined();
expect(createdAgent.tool_resources.file_search).toBeDefined();
expect(createdAgent.tool_resources.execute_code).toBeDefined();
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
});
test('should handle avatar validation', async () => {
const dataWithAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Avatar',
avatar: {
filepath: 'https://example.com/avatar.png',
source: 's3',
},
};
mockReq.body = dataWithAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.avatar).toEqual({
filepath: 'https://example.com/avatar.png',
source: 's3',
});
});
test('should handle invalid avatar format', async () => {
const dataWithInvalidAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Invalid Avatar',
avatar: 'just-a-string', // Invalid format
};
mockReq.body = dataWithInvalidAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
}),
);
});
});
describe('updateAgentHandler', () => {
let existingAgentId;
let existingAgentAuthorId;
beforeEach(async () => {
// Create an existing agent for update tests
existingAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
author: existingAgentAuthorId,
description: 'Original description',
isCollaborative: false,
versions: [
{
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
description: 'Original description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
existingAgentId = agent.id;
});
test('should update agent with allowed fields only', async () => {
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Agent',
description: 'Updated description',
model: 'gpt-4',
isCollaborative: true, // This IS allowed in updates
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(400);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Updated Agent');
expect(updatedAgent.description).toBe('Updated description');
expect(updatedAgent.model).toBe('gpt-4');
expect(updatedAgent.isCollaborative).toBe(true);
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Updated Agent');
expect(agentInDb.isCollaborative).toBe(true);
});
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Name',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
authorName: 'Hacker', // Should be stripped
id: 'different_agent_id', // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
versions: [], // Should be stripped
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not changed
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
expect(updatedAgent.authorName).toBeUndefined();
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
expect(agentInDb.id).toBe(existingAgentId);
});
test('should reject update from non-author when not collaborative', async () => {
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Unauthorized Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalledWith({
error: 'You do not have permission to modify this non-collaborative agent',
});
// Verify agent was not modified in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Original Agent');
});
test('should allow update from non-author when collaborative', async () => {
// First make the agent collaborative
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Collaborative Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Collaborative Update');
// Author field should be removed for non-author
expect(updatedAgent.author).toBeUndefined();
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Collaborative Update');
});
test('should allow admin to update any agent', async () => {
const adminUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = adminUserId;
mockReq.user.role = 'ADMIN'; // Set as admin
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Admin Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Admin Update');
});
test('should handle projectIds updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
const projectId1 = new mongoose.Types.ObjectId().toString();
const projectId2 = new mongoose.Types.ObjectId().toString();
mockReq.body = {
projectIds: [projectId1, projectId2],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent).toBeDefined();
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
});
test('should validate tool_resources in updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tool_resources: {
ocr: {
file_ids: ['ocr1', 'ocr2'],
},
execute_code: {
file_ids: ['img1'],
},
// Invalid tool resource
invalid_tool: {
file_ids: ['invalid'],
},
},
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tool_resources).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeDefined();
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
});
test('should return 404 for non-existent agent', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
mockReq.body = {
name: 'Update Non-existent',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(404);
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
});
test('should handle validation errors properly', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
model_parameters: 'invalid-not-an-object', // Should be an object
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
});
});
describe('Mass Assignment Attack Scenarios', () => {
test('should prevent setting system fields during creation', async () => {
const systemFields = {
provider: 'openai',
model: 'gpt-4',
name: 'System Fields Test',
// System fields that should never be settable by users
__v: 99,
_id: new mongoose.Types.ObjectId(),
versions: [
{
name: 'Fake Version',
provider: 'fake',
model: 'fake-model',
},
],
};
mockReq.body = systemFields;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify system fields were not affected
expect(createdAgent.__v).not.toBe(99);
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.__v).not.toBe(99);
});
test('should prevent privilege escalation through isCollaborative', async () => {
// Create a non-collaborative agent
const authorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
isCollaborative: false,
versions: [
{
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
// Try to make it collaborative as a different user
const attackerId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = attackerId;
mockReq.params.id = agent.id;
mockReq.body = {
isCollaborative: true, // Trying to escalate privileges
};
await updateAgentHandler(mockReq, mockRes);
// Should be rejected
expect(mockRes.status).toHaveBeenCalledWith(403);
// Verify in database that it's still not collaborative
const agentInDb = await Agent.findOne({ id: agent.id });
expect(agentInDb.isCollaborative).toBe(false);
});
test('should prevent author hijacking', async () => {
const originalAuthorId = new mongoose.Types.ObjectId();
const attackerId = new mongoose.Types.ObjectId();
// Admin creates an agent
mockReq.user.id = originalAuthorId.toString();
mockReq.user.role = 'ADMIN';
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Admin Agent',
author: attackerId.toString(), // Trying to set different author
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Author should be the actual user, not the attempted value
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
});
test('should strip unknown fields to prevent future vulnerabilities', async () => {
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Future Proof Test',
// Unknown fields that might be added in future
superAdminAccess: true,
bypassAllChecks: true,
internalFlag: 'secret',
futureFeature: 'exploit',
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unknown fields were stripped
expect(createdAgent.superAdminAccess).toBeUndefined();
expect(createdAgent.bypassAllChecks).toBeUndefined();
expect(createdAgent.internalFlag).toBeUndefined();
expect(createdAgent.futureFeature).toBeUndefined();
// Also check in database
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
expect(agentInDb.superAdminAccess).toBeUndefined();
expect(agentInDb.bypassAllChecks).toBeUndefined();
expect(agentInDb.internalFlag).toBeUndefined();
expect(agentInDb.futureFeature).toBeUndefined();
});
});
});

View File

@@ -1,7 +1,4 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -22,20 +19,20 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { createRunBody } = require('~/server/services/createRunBody');
const { sendResponse } = require('~/server/middleware/error');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -474,7 +471,7 @@ const chatV1 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendEvent(res, {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -590,7 +587,7 @@ const chatV1 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,7 +1,4 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -25,14 +22,15 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { sendMessage, sleep, countTokens } = require('~/server/utils');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -311,7 +309,7 @@ const chatV2 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendEvent(res, {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -434,7 +432,7 @@ const chatV2 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
const { sendResponse } = require('~/server/middleware/error');
const { getConvo } = require('~/models/Conversation');
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
/**
* @typedef {Object} ErrorHandlerContext
@@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -1,21 +1,21 @@
const { nanoid } = require('nanoid');
const { EnvVar } = require('@librechat/agents');
const { logger } = require('@librechat/data-schemas');
const { checkAccess, loadWebSearchAuth } = require('@librechat/api');
const {
Tools,
AuthType,
Permissions,
ToolCallTypes,
PermissionTypes,
loadWebSearchAuth,
} = require('librechat-data-provider');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { loadTools } = require('~/app/clients/tools/util');
const { getRoleByName } = require('~/models/Role');
const { checkAccess } = require('~/server/middleware');
const { getMessage } = require('~/models/Message');
const { logger } = require('~/config');
const fieldsMap = {
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
@@ -79,7 +79,6 @@ const verifyToolAuth = async (req, res) => {
throwError: false,
});
} catch (error) {
logger.error('Error loading auth values', error);
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
return;
}
@@ -133,12 +132,7 @@ const callTool = async (req, res) => {
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
let hasAccess = true;
if (toolAccessPermType[toolId]) {
hasAccess = await checkAccess({
user: req.user,
permissionType: toolAccessPermType[toolId],
permissions: [Permissions.USE],
getRoleByName,
});
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
}
if (!hasAccess) {
logger.warn(

View File

@@ -55,6 +55,7 @@ const startServer = async () => {
/* Middleware */
app.use(noIndex);
app.use(errorController);
app.use(express.json({ limit: '3mb' }));
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
app.use(mongoSanitize());
@@ -96,6 +97,7 @@ const startServer = async () => {
app.use('/api/actions', routes.actions);
app.use('/api/keys', routes.keys);
app.use('/api/user', routes.user);
app.use('/api/ask', routes.ask);
app.use('/api/search', routes.search);
app.use('/api/edit', routes.edit);
app.use('/api/messages', routes.messages);
@@ -116,13 +118,11 @@ const startServer = async () => {
app.use('/api/roles', routes.roles);
app.use('/api/agents', routes.agents);
app.use('/api/banner', routes.banner);
app.use('/api/bedrock', routes.bedrock);
app.use('/api/memories', routes.memories);
app.use('/api/tags', routes.tags);
app.use('/api/mcp', routes.mcp);
// Add the error controller one more time after all routes
app.use(errorController);
app.use((req, res) => {
res.set({
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',

View File

@@ -1,4 +1,5 @@
const fs = require('fs');
const path = require('path');
const request = require('supertest');
const { MongoMemoryServer } = require('mongodb-memory-server');
const mongoose = require('mongoose');
@@ -58,30 +59,6 @@ describe('Server Configuration', () => {
expect(response.headers['pragma']).toBe('no-cache');
expect(response.headers['expires']).toBe('0');
});
it('should return 500 for unknown errors via ErrorController', async () => {
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
// Mock MongoDB operations to fail
const originalFindOne = mongoose.models.User.findOne;
const mockError = new Error('MongoDB operation failed');
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
throw mockError;
});
try {
const response = await request(app).post('/api/auth/login').send({
email: 'test@example.com',
password: 'password123',
});
expect(response.status).toBe(500);
expect(response.text).toBe('An unknown error occurred.');
} finally {
// Restore original function
mongoose.models.User.findOne = originalFindOne;
}
});
});
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely

View File

@@ -1,13 +1,13 @@
const { logger } = require('@librechat/data-schemas');
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
// abortMiddleware.js
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const clearPendingReq = require('~/cache/clearPendingReq');
const { sendError } = require('~/server/middleware/error');
const { spendTokens } = require('~/models/spendTokens');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
const abortDataMap = new WeakMap();
@@ -101,7 +101,7 @@ async function abortMessage(req, res) {
cleanupAbortController(abortKey);
if (res.headersSent && finalEvent) {
return sendEvent(res, finalEvent);
return sendMessage(res, finalEvent);
}
res.setHeader('Content-Type', 'application/json');
@@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
* @param {string} responseMessageId
*/
const onStart = (userMessage, responseMessageId) => {
sendEvent(res, { message: userMessage, created: true });
sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
getReqData({ abortKey });

View File

@@ -1,11 +1,11 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { deleteMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils');
const { logger } = require('~/config');
const three_minutes = 1000 * 60 * 3;
@@ -34,7 +34,7 @@ async function abortRun(req, res) {
const [thread_id, run_id] = runValues.split(':');
if (!run_id) {
logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id });
logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id });
return res.status(204).send({ message: 'Run not found' });
} else if (run_id === 'cancelled') {
logger.warn('[abortRun] Run already cancelled', { thread_id });
@@ -93,7 +93,7 @@ async function abortRun(req, res) {
};
if (res.headersSent && finalEvent) {
return sendEvent(res, finalEvent);
return sendMessage(res, finalEvent);
}
res.json(finalEvent);

View File

@@ -1,13 +1,13 @@
const { handleError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
EndpointURLs,
parseCompactConvo,
EModelEndpoint,
isAgentsEndpoint,
parseCompactConvo,
EndpointURLs,
} = require('librechat-data-provider');
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const assistants = require('~/server/services/Endpoints/assistants');
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
const { processFiles } = require('~/server/services/Files/process');
const anthropic = require('~/server/services/Endpoints/anthropic');
const bedrock = require('~/server/services/Endpoints/bedrock');
@@ -15,6 +15,7 @@ const openAI = require('~/server/services/Endpoints/openAI');
const agents = require('~/server/services/Endpoints/agents');
const custom = require('~/server/services/Endpoints/custom');
const google = require('~/server/services/Endpoints/google');
const { handleError } = require('~/server/utils');
const buildFunction = {
[EModelEndpoint.openAI]: openAI.buildOptions,
@@ -24,6 +25,7 @@ const buildFunction = {
[EModelEndpoint.bedrock]: bedrock.buildOptions,
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
[EModelEndpoint.anthropic]: anthropic.buildOptions,
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
[EModelEndpoint.assistants]: assistants.buildOptions,
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
};
@@ -34,9 +36,6 @@ async function buildEndpointOption(req, res, next) {
try {
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
} catch (error) {
logger.warn(
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
);
return handleError(res, { text: 'Error parsing conversation' });
}
@@ -58,6 +57,15 @@ async function buildEndpointOption(req, res, next) {
return handleError(res, { text: 'Model spec mismatch' });
}
if (
currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins &&
currentModelSpec.preset.tools
) {
return handleError(res, {
text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`,
});
}
try {
currentModelSpec.preset.spec = spec;
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
@@ -69,7 +77,6 @@ async function buildEndpointOption(req, res, next) {
conversation: currentModelSpec.preset,
});
} catch (error) {
logger.error(`Error parsing model spec for endpoint ${endpoint}`, error);
return handleError(res, { text: 'Error parsing model spec' });
}
}
@@ -77,23 +84,20 @@ async function buildEndpointOption(req, res, next) {
try {
const isAgents =
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
const builder = isAgents
? (...args) => buildFunction[EModelEndpoint.agents](req, ...args)
: buildFunction[endpointType ?? endpoint];
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
// TODO: use object params
req.body.endpointOption = await builder(endpoint, parsedBody, endpointType);
// TODO: use `getModelsConfig` only when necessary
const modelsConfig = await getModelsConfig(req);
req.body.endpointOption.modelsConfig = modelsConfig;
if (req.body.files && !isAgents) {
req.body.endpointOption.attachments = processFiles(req.body.files);
}
next();
} catch (error) {
logger.error(
`Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`,
error,
);
return handleError(res, { text: 'Error building endpoint option' });
}
}

View File

@@ -18,6 +18,7 @@ const message = 'Your account has been temporarily banned due to violations of o
* @function
* @param {Object} req - Express Request object.
* @param {Object} res - Express Response object.
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
*
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
*/
@@ -134,7 +135,6 @@ const checkBan = async (req, res, next = () => {}) => {
return await banResponse(req, res);
} catch (error) {
logger.error('Error in checkBan middleware:', error);
return next(error);
}
};

View File

@@ -1,4 +1,4 @@
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { Time, CacheKeys } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils');
@@ -37,7 +37,7 @@ const concurrentLimiter = async (req, res, next) => {
const userId = req.user?.id ?? req.user?._id ?? '';
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = ViolationTypes.CONCURRENT;
const type = 'concurrent';
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
const pendingRequests = +((await cache.get(key)) ?? 0);

View File

@@ -1,7 +1,6 @@
const crypto = require('crypto');
const { sendEvent } = require('@librechat/api');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { sendError } = require('~/server/middleware/error');
const { sendMessage, sendError } = require('~/server/utils');
const { saveMessage } = require('~/models');
/**
@@ -37,7 +36,7 @@ const denyRequest = async (req, res, errorMessage) => {
isCreatedByUser: true,
text,
};
sendEvent(res, { message: userMessage, created: true });
sendMessage(res, { message: userMessage, created: true });
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;

View File

@@ -1,79 +0,0 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
const forkIpMax = FORK_IP_MAX;
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
const forkUserMax = FORK_USER_MAX;
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
return {
forkIpWindowMs,
forkIpMax,
forkIpWindowInMinutes,
forkUserWindowMs,
forkUserMax,
forkUserWindowInMinutes,
forkViolationScore: FORK_VIOLATION_SCORE,
};
};
const createForkHandler = (ip = true) => {
const {
forkIpMax,
forkUserMax,
forkViolationScore,
forkIpWindowInMinutes,
forkUserWindowInMinutes,
} = getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
const errorMessage = {
type,
max: ip ? forkIpMax : forkUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, forkViolationScore);
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
};
};
const createForkLimiters = () => {
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
const ipLimiterOptions = {
windowMs: forkIpWindowMs,
max: forkIpMax,
handler: createForkHandler(),
store: limiterCache('fork_ip_limiter'),
};
const userLimiterOptions = {
windowMs: forkUserWindowMs,
max: forkUserMax,
handler: createForkHandler(false),
keyGenerator: function (req) {
return req.user?.id;
},
store: limiterCache('fork_user_limiter'),
};
const forkIpLimiter = rateLimit(ipLimiterOptions);
const forkUserLimiter = rateLimit(userLimiterOptions);
return { forkIpLimiter, forkUserLimiter };
};
module.exports = { createForkLimiters };

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
const importIpMax = IMPORT_IP_MAX;
@@ -25,18 +27,12 @@ const getEnvironmentVariables = () => {
importUserWindowMs,
importUserMax,
importUserWindowInMinutes,
importViolationScore: IMPORT_VIOLATION_SCORE,
};
};
const createImportHandler = (ip = true) => {
const {
importIpMax,
importUserMax,
importViolationScore,
importIpWindowInMinutes,
importUserWindowInMinutes,
} = getEnvironmentVariables();
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
@@ -47,7 +43,7 @@ const createImportHandler = (ip = true) => {
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, importViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
};
};
@@ -60,7 +56,6 @@ const createImportLimiters = () => {
windowMs: importIpWindowMs,
max: importIpMax,
handler: createImportHandler(),
store: limiterCache('import_ip_limiter'),
};
const userLimiterOptions = {
windowMs: importUserWindowMs,
@@ -69,9 +64,23 @@ const createImportLimiters = () => {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('import_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for import rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'import_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'import_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const importIpLimiter = rateLimit(ipLimiterOptions);
const importUserLimiter = rateLimit(userLimiterOptions);
return { importIpLimiter, importUserLimiter };

View File

@@ -4,7 +4,6 @@ const createSTTLimiters = require('./sttLimiters');
const loginLimiter = require('./loginLimiter');
const importLimiters = require('./importLimiters');
const uploadLimiters = require('./uploadLimiters');
const forkLimiters = require('./forkLimiters');
const registerLimiter = require('./registerLimiter');
const toolCallLimiter = require('./toolCallLimiter');
const messageLimiters = require('./messageLimiters');
@@ -15,7 +14,6 @@ module.exports = {
...uploadLimiters,
...importLimiters,
...messageLimiters,
...forkLimiters,
loginLimiter,
registerLimiter,
toolCallLimiter,

View File

@@ -1,8 +1,9 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
const windowMs = LOGIN_WINDOW * 60 * 1000;
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`;
const handler = async (req, res) => {
const type = ViolationTypes.LOGINS;
const type = 'logins';
const errorMessage = {
type,
max,
@@ -27,9 +28,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('login_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for login rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'login_limiter:',
});
limiterOptions.store = store;
}
const loginLimiter = rateLimit(limiterOptions);
module.exports = loginLimiter;

View File

@@ -1,15 +1,16 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { RedisStore } = require('rate-limit-redis');
const denyRequest = require('~/server/middleware/denyRequest');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const { isEnabled } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
MESSAGE_IP_MAX = 40,
MESSAGE_IP_WINDOW = 1,
MESSAGE_USER_MAX = 40,
MESSAGE_USER_WINDOW = 1,
MESSAGE_VIOLATION_SCORE: score,
} = process.env;
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
@@ -30,7 +31,7 @@ const userWindowInMinutes = userWindowMs / 60000;
*/
const createHandler = (ip = true) => {
return async (req, res) => {
const type = ViolationTypes.MESSAGE_LIMIT;
const type = 'message_limit';
const errorMessage = {
type,
max: ip ? ipMax : userMax,
@@ -38,7 +39,7 @@ const createHandler = (ip = true) => {
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
await logViolation(req, res, type, errorMessage);
return await denyRequest(req, res, errorMessage);
};
};
@@ -50,7 +51,6 @@ const ipLimiterOptions = {
windowMs: ipWindowMs,
max: ipMax,
handler: createHandler(),
store: limiterCache('message_ip_limiter'),
};
const userLimiterOptions = {
@@ -60,9 +60,23 @@ const userLimiterOptions = {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('message_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for message rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'message_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'message_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
/**
* Message request rate limiter by IP
*/

View File

@@ -1,8 +1,9 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
const windowMs = REGISTER_WINDOW * 60 * 1000;
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`;
const handler = async (req, res) => {
const type = ViolationTypes.REGISTRATIONS;
const type = 'registrations';
const errorMessage = {
type,
max,
@@ -27,9 +28,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('register_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for register rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'register_limiter:',
});
limiterOptions.store = store;
}
const registerLimiter = rateLimit(limiterOptions);
module.exports = registerLimiter;

View File

@@ -1,8 +1,10 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
RESET_PASSWORD_WINDOW = 2,
@@ -31,9 +33,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('reset_password_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for reset password rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'reset_password_limiter:',
});
limiterOptions.store = store;
}
const resetPasswordLimiter = rateLimit(limiterOptions);
module.exports = resetPasswordLimiter;

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
const sttIpMax = STT_IP_MAX;
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
sttUserWindowMs,
sttUserMax,
sttUserWindowInMinutes,
sttViolationScore: STT_VIOLATION_SCORE,
};
};
const createSTTHandler = (ip = true) => {
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
@@ -42,7 +43,7 @@ const createSTTHandler = (ip = true) => {
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, sttViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many STT requests. Try again later' });
};
};
@@ -54,7 +55,6 @@ const createSTTLimiters = () => {
windowMs: sttIpWindowMs,
max: sttIpMax,
handler: createSTTHandler(),
store: limiterCache('stt_ip_limiter'),
};
const userLimiterOptions = {
@@ -64,9 +64,23 @@ const createSTTLimiters = () => {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('stt_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for STT rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'stt_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'stt_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const sttIpLimiter = rateLimit(ipLimiterOptions);
const sttUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,9 +1,10 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const handler = async (req, res) => {
const type = ViolationTypes.TOOL_CALL_LIMIT;
@@ -14,7 +15,7 @@ const handler = async (req, res) => {
windowInMinutes: 1,
};
await logViolation(req, res, type, errorMessage, score);
await logViolation(req, res, type, errorMessage, 0);
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
};
@@ -25,9 +26,17 @@ const limiterOptions = {
keyGenerator: function (req) {
return req.user?.id;
},
store: limiterCache('tool_call_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for tool call rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'tool_call_limiter:',
});
limiterOptions.store = store;
}
const toolCallLimiter = rateLimit(limiterOptions);
module.exports = toolCallLimiter;

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { limiterCache } = require('~/cache/cacheFactory');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
const ttsIpMax = TTS_IP_MAX;
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
ttsUserWindowMs,
ttsUserMax,
ttsUserWindowInMinutes,
ttsViolationScore: TTS_VIOLATION_SCORE,
};
};
const createTTSHandler = (ip = true) => {
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
@@ -42,7 +43,7 @@ const createTTSHandler = (ip = true) => {
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, ttsViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
};
};
@@ -54,19 +55,32 @@ const createTTSLimiters = () => {
windowMs: ttsIpWindowMs,
max: ttsIpMax,
handler: createTTSHandler(),
store: limiterCache('tts_ip_limiter'),
};
const userLimiterOptions = {
windowMs: ttsUserWindowMs,
max: ttsUserMax,
handler: createTTSHandler(false),
store: limiterCache('tts_user_limiter'),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for TTS rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'tts_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'tts_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const ttsIpLimiter = rateLimit(ipLimiterOptions);
const ttsUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
@@ -25,7 +27,6 @@ const getEnvironmentVariables = () => {
fileUploadUserWindowMs,
fileUploadUserMax,
fileUploadUserWindowInMinutes,
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
};
};
@@ -35,7 +36,6 @@ const createFileUploadHandler = (ip = true) => {
fileUploadIpWindowInMinutes,
fileUploadUserMax,
fileUploadUserWindowInMinutes,
fileUploadViolationScore,
} = getEnvironmentVariables();
return async (req, res) => {
@@ -47,7 +47,7 @@ const createFileUploadHandler = (ip = true) => {
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
};
};
@@ -60,7 +60,6 @@ const createFileLimiters = () => {
windowMs: fileUploadIpWindowMs,
max: fileUploadIpMax,
handler: createFileUploadHandler(),
store: limiterCache('file_upload_ip_limiter'),
};
const userLimiterOptions = {
@@ -70,9 +69,23 @@ const createFileLimiters = () => {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('file_upload_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for file upload rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'file_upload_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'file_upload_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
const fileUploadUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,8 +1,10 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
VERIFY_EMAIL_WINDOW = 2,
@@ -31,9 +33,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('verify_email_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for verify email rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'verify_email_limiter:',
});
limiterOptions.store = store;
}
const verifyEmailLimiter = rateLimit(limiterOptions);
module.exports = verifyEmailLimiter;

View File

@@ -0,0 +1,78 @@
const { getRoleByName } = require('~/models/Role');
const { logger } = require('~/config');
/**
* Core function to check if a user has one or more required permissions
*
* @param {object} user - The user object
* @param {PermissionTypes} permissionType - The type of permission to check
* @param {Permissions[]} permissions - The list of specific permissions to check
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
* @param {object} [checkObject] - The object to check properties against
* @returns {Promise<boolean>} Whether the user has the required permissions
*/
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
if (!user) {
return false;
}
const role = await getRoleByName(user.role);
if (role && role.permissions && role.permissions[permissionType]) {
const hasAnyPermission = permissions.some((permission) => {
if (role.permissions[permissionType][permission]) {
return true;
}
if (bodyProps[permission] && checkObject) {
return bodyProps[permission].some((prop) =>
Object.prototype.hasOwnProperty.call(checkObject, prop),
);
}
return false;
});
return hasAnyPermission;
}
return false;
};
/**
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
*
* @param {PermissionTypes} permissionType - The type of permission to check.
* @param {Permissions[]} permissions - The list of specific permissions to check.
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
*/
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
return async (req, res, next) => {
try {
const hasAccess = await checkAccess(
req.user,
permissionType,
permissions,
bodyProps,
req.body,
);
if (hasAccess) {
return next();
}
logger.warn(
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
);
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
} catch (error) {
logger.error(error);
return res.status(500).json({ message: `Server error: ${error.message}` });
}
};
};
module.exports = {
checkAccess,
generateCheckAccess,
};

View File

@@ -1,5 +1,8 @@
const checkAdmin = require('./admin');
const { checkAccess, generateCheckAccess } = require('./access');
module.exports = {
checkAdmin,
checkAccess,
generateCheckAccess,
};

View File

@@ -1,6 +1,5 @@
const uap = require('ua-parser-js');
const { ViolationTypes } = require('librechat-data-provider');
const { handleError } = require('@librechat/api');
const { handleError } = require('../utils');
const { logViolation } = require('../../cache');
/**
@@ -22,7 +21,7 @@ async function uaParser(req, res, next) {
const ua = uap(req.headers['user-agent']);
if (!ua.browser.name) {
const type = ViolationTypes.NON_BROWSER;
const type = 'non_browser';
await logViolation(req, res, type, { type }, score);
return handleError(res, { message: 'Illegal request' });
}

View File

@@ -1,4 +1,4 @@
const { handleError } = require('@librechat/api');
const { handleError } = require('../utils');
function validateEndpoint(req, res, next) {
const { endpoint: _endpoint, endpointType } = req.body;

View File

@@ -1,6 +1,6 @@
const { handleError } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const { handleError } = require('~/server/utils');
const { logViolation } = require('~/cache');
/**
* Validates the model of the request.

View File

@@ -1,162 +0,0 @@
const fs = require('fs');
const path = require('path');
const express = require('express');
const request = require('supertest');
const zlib = require('zlib');
// Create test setup
const mockTestDir = path.join(__dirname, 'test-static-route');
// Mock the paths module to point to our test directory
jest.mock('~/config/paths', () => ({
imageOutput: mockTestDir,
}));
describe('Static Route Integration', () => {
let app;
let staticRoute;
let testDir;
let testImagePath;
beforeAll(() => {
// Create a test directory and files
testDir = mockTestDir;
testImagePath = path.join(testDir, 'test-image.jpg');
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir, { recursive: true });
}
// Create a test image file
fs.writeFileSync(testImagePath, 'fake-image-data');
// Create a gzipped version of the test image (for gzip scanning tests)
fs.writeFileSync(testImagePath + '.gz', zlib.gzipSync('fake-image-data'));
});
afterAll(() => {
// Clean up test files
if (fs.existsSync(testDir)) {
fs.rmSync(testDir, { recursive: true, force: true });
}
});
// Helper function to set up static route with specific config
const setupStaticRoute = (skipGzipScan = false) => {
if (skipGzipScan) {
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
} else {
process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN = 'true';
}
staticRoute = require('../static');
app.use('/images', staticRoute);
};
beforeEach(() => {
// Clear the module cache to get fresh imports
jest.resetModules();
app = express();
// Clear environment variables
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
delete process.env.NODE_ENV;
});
describe('route functionality', () => {
it('should serve static image files', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should return 404 for non-existent files', async () => {
setupStaticRoute();
const response = await request(app).get('/images/nonexistent.jpg');
expect(response.status).toBe(404);
});
});
describe('cache behavior', () => {
it('should set cache headers for images in production', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
});
it('should not set cache headers in development', async () => {
process.env.NODE_ENV = 'development';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
// Our middleware should not set the production cache-control header in development
expect(response.headers['cache-control']).not.toBe('public, max-age=172800, s-maxage=86400');
});
});
describe('gzip compression behavior', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve gzipped files when gzip scanning is enabled', async () => {
setupStaticRoute(false); // Enable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.body.toString()).toBe('fake-image-data');
});
it('should not serve gzipped files when gzip scanning is disabled', async () => {
setupStaticRoute(true); // Disable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.body.toString()).toBe('fake-image-data');
});
});
describe('path configuration', () => {
it('should use the configured imageOutput path', async () => {
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should serve from subdirectories', async () => {
// Create a subdirectory with a file
const subDir = path.join(testDir, 'thumbs');
fs.mkdirSync(subDir, { recursive: true });
const thumbPath = path.join(subDir, 'thumb.jpg');
fs.writeFileSync(thumbPath, 'thumbnail-data');
setupStaticRoute();
const response = await request(app).get('/images/thumbs/thumb.jpg').expect(200);
expect(response.body.toString()).toBe('thumbnail-data');
// Clean up
fs.rmSync(subDir, { recursive: true, force: true });
});
});
});

View File

@@ -1,28 +1,14 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { generateCheckAccess } = require('@librechat/api');
const {
SystemRoles,
Permissions,
PermissionTypes,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getAgent, updateAgent } = require('~/models/Agent');
const { getRoleByName } = require('~/models/Role');
const { logger } = require('~/config');
const router = express.Router();
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
// If the user has ADMIN role
// then action edition is possible even if not owner of the assistant
const isAdmin = (req) => {
@@ -55,7 +41,7 @@ router.get('/', async (req, res) => {
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
* @returns {Object} 200 - success response - application/json
*/
router.post('/:agent_id', checkAgentCreate, async (req, res) => {
router.post('/:agent_id', async (req, res) => {
try {
const { agent_id } = req.params;
@@ -163,7 +149,7 @@ router.post('/:agent_id', checkAgentCreate, async (req, res) => {
* @param {string} req.params.action_id - The ID of the action to delete.
* @returns {Object} 200 - success response - application/json
*/
router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => {
router.delete('/:agent_id/:action_id', async (req, res) => {
try {
const { agent_id, action_id } = req.params;
const admin = isAdmin(req);

View File

@@ -1,28 +1,22 @@
const express = require('express');
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
// validateModel,
generateCheckAccess,
validateConvoAccess,
buildEndpointOption,
} = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/agents');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
router.use(moderateText);
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
skipCheck: skipAgentCheck,
getRoleByName,
});
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
router.use(checkAgentAccess);
router.use(validateConvoAccess);

View File

@@ -1,36 +1,29 @@
const express = require('express');
const { generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { requireJwtAuth } = require('~/server/middleware');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const v1 = require('~/server/controllers/agents/v1');
const { getRoleByName } = require('~/models/Role');
const actions = require('./actions');
const tools = require('./tools');
const router = express.Router();
const avatar = express.Router();
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
Permissions.USE,
Permissions.CREATE,
]);
const checkGlobalAgentShare = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
bodyProps: {
const checkGlobalAgentShare = generateCheckAccess(
PermissionTypes.AGENTS,
[Permissions.USE, Permissions.CREATE],
{
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
},
getRoleByName,
});
);
router.use(requireJwtAuth);
router.use(checkAgentAccess);
/**
* Agent actions route.

View File

@@ -0,0 +1,63 @@
const { Keyv } = require('keyv');
const { KeyvFile } = require('keyv-file');
const { logger } = require('~/config');
const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => {
try {
const conversationsCache = new Keyv({
store: new KeyvFile({ filename: './data/cache.json' }),
namespace: 'chatgpt', // should be 'bing' for bing/sydney
});
const {
conversationId,
messageId: userMessageId,
parentMessageId: userParentMessageId,
text: userText,
} = userMessage;
const {
messageId: responseMessageId,
parentMessageId: responseParentMessageId,
text: responseText,
} = responseMessage;
let conversation = await conversationsCache.get(conversationId);
// used to generate a title for the conversation if none exists
// let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
// isNewConversation = true;
}
const roles = (options) => {
if (endpoint === 'openAI') {
return options?.chatGptLabel || 'ChatGPT';
}
};
let _userMessage = {
id: userMessageId,
parentMessageId: userParentMessageId,
role: 'User',
message: userText,
};
let _responseMessage = {
id: responseMessageId,
parentMessageId: responseParentMessageId,
role: roles(endpointOption),
message: responseText,
};
conversation.messages.push(_userMessage, _responseMessage);
await conversationsCache.set(conversationId, conversation);
} catch (error) {
logger.error('[addToCache] Error adding conversation to cache', error);
}
};
module.exports = addToCache;

View File

@@ -0,0 +1,25 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,25 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/custom');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const {
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,24 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const {
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,241 @@
const express = require('express');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage, updateMessage } = require('~/models');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
moderateText,
} = require('~/server/middleware');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const newConvo = !conversationId;
const user = req.user.id;
const plugins = [];
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
};
let streaming = null;
let timer = null;
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
onProgress: () => {
if (timer) {
clearTimeout(timer);
}
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
resolve();
}, 250);
});
},
});
const pluginMap = new Map();
const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};
const onToolStart = async (tool, input, runId, parentRunId) => {
const pluginName = pluginMap.get(parentRunId);
const latestPlugin = {
runId,
loading: true,
inputs: [input],
latest: pluginName,
outputs: null,
};
if (streaming) {
await streaming;
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(
res,
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
extraTokens,
);
};
const onToolEnd = async (output, runId) => {
if (streaming) {
await streaming;
}
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
if (pluginIndex !== -1) {
plugins[pluginIndex].loading = false;
plugins[pluginIndex].outputs = output;
}
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugins: plugins.map((p) => ({ ...p, loading: false })),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onChainEnd = () => {
if (!client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
}
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};
let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
onStart,
getPartialText,
...endpointOption,
progressCallback,
progressOptions: {
res,
// parentMessageId: overrideParentMessageId || userMessageId,
plugins,
},
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
logger.debug('[/ask/gptPlugins]', response);
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
final: true,
conversation,
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response,
client,
});
}
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
if (response.plugins?.length > 0) {
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/ask/gptPlugins.js - save plugins used' },
);
}
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
},
);
module.exports = router;

View File

@@ -0,0 +1,47 @@
const express = require('express');
const { EModelEndpoint } = require('librechat-data-provider');
const {
uaParser,
checkBan,
requireJwtAuth,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
validateConvoAccess,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const gptPlugins = require('./gptPlugins');
const anthropic = require('./anthropic');
const custom = require('./custom');
const google = require('./google');
const openAI = require('./openAI');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.custom}`, custom);
module.exports = router;

View File

@@ -0,0 +1,27 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
moderateText,
} = require('~/server/middleware');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,37 @@
const express = require('express');
const router = express.Router();
const {
setHeaders,
handleAbort,
moderateText,
// validateModel,
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/bedrock');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title');
router.use(moderateText);
/**
* @route POST /
* @desc Chat with an assistant
* @access Public
* @param {express.Request} req - The request object, containing the request data.
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post(
'/',
// validateModel,
// validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AgentController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,35 @@
const express = require('express');
const {
uaParser,
checkBan,
requireJwtAuth,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const chat = require('./chat');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use('/chat', chat);
module.exports = router;

View File

@@ -1,17 +1,16 @@
const multer = require('multer');
const express = require('express');
const { sleep } = require('@librechat/agents');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { importConversations } = require('~/server/utils/import');
const { createImportLimiters } = require('~/server/middleware');
const { deleteToolCalls } = require('~/models/ToolCall');
const { isEnabled, sleep } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { logger } = require('~/config');
const assistantClients = {
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
@@ -44,7 +43,6 @@ router.get('/', async (req, res) => {
});
res.status(200).json(result);
} catch (error) {
logger.error('Error fetching conversations', error);
res.status(500).json({ error: 'Error fetching conversations' });
}
});
@@ -158,7 +156,6 @@ router.post('/update', async (req, res) => {
});
const { importIpLimiter, importUserLimiter } = createImportLimiters();
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
const upload = multer({ storage: storage, fileFilter: importFileFilter });
/**
@@ -192,7 +189,7 @@ router.post(
* @param {express.Response<TForkConvoResponse>} res - Express response object.
* @returns {Promise<void>} - The response after forking the conversation.
*/
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
router.post('/fork', async (req, res) => {
try {
/** @type {TForkConvoRequest} */
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;

View File

@@ -0,0 +1,207 @@
const express = require('express');
const { getResponseSender } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
validateModel,
handleAbortError,
validateEndpoint,
buildEndpointOption,
createAbortController,
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, updateMessage } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
logger.debug('[/edit/gptPlugins]', {
text,
generation,
isContinued,
conversationId,
...endpointOption,
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: () => {
if (plugin.loading === true) {
plugin.loading = false;
}
},
});
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' },
);
}
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction);
};
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onStart,
...endpointOption,
progressCallback,
progressOptions: {
res,
plugin,
// parentMessageId: overrideParentMessageId || userMessageId,
},
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
final: true,
conversation,
requestMessage: userMessage,
responseMessage: response,
});
res.end();
response.plugin = { ...plugin, loading: false };
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/edit/gptPlugins.js' },
);
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
},
);
module.exports = router;

View File

@@ -3,6 +3,7 @@ const openAI = require('./openAI');
const custom = require('./custom');
const google = require('./google');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('librechat-data-provider');
const {
@@ -38,6 +39,7 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.custom}`, custom);

Some files were not shown because too many files have changed in this diff Show More