Compare commits

..

8 Commits

Author SHA1 Message Date
Ruben Talstra
13d3257801 🔧 fix: Add support for configurable MongoDB database name in connection logic 2025-06-15 12:50:08 +02:00
Ruben Talstra
cde930e0d6 🔧 fix: Make MongoDB database name configurable via environment variable 2025-06-15 12:07:02 +02:00
Ruben Talstra
c94d4d9ae1 Merge branch 'dev' into explicit-mongo-dbname 2025-06-15 12:05:18 +02:00
Ruben Talstra
6a811c708b 🔧 refactor: Move MongoDB connection logic to a separate module and make database name optional 2025-06-15 12:04:55 +02:00
Ruben Talstra
cffd66054e 🔧 refactor: Make MongoDB name optional in connection options 2025-05-29 16:21:52 +02:00
Jason Koh
d7f4f8ef4a update env vars for testing 2025-05-27 16:25:57 -07:00
Jason Koh
50e36e67f0 clean up 2025-05-27 14:51:31 -07:00
Jason Koh
12ac302a68 define mongodb name 2025-05-27 14:32:41 -07:00
441 changed files with 8385 additions and 24060 deletions

View File

@@ -15,6 +15,7 @@ HOST=localhost
PORT=3080
MONGO_URI=mongodb://127.0.0.1:27017/LibreChat
# MONGO_DB_NAME=LibreChat
DOMAIN_CLIENT=http://localhost:3080
DOMAIN_SERVER=http://localhost:3080
@@ -58,7 +59,7 @@ DEBUG_CONSOLE=false
# Endpoints #
#===================================================#
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic
# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
PROXY=
@@ -142,10 +143,10 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# Vertex AI
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
@@ -349,11 +350,6 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20
TTS_VIOLATION_SCORE=0
STT_VIOLATION_SCORE=0
FORK_VIOLATION_SCORE=0
IMPORT_VIOLATION_SCORE=0
FILE_UPLOAD_VIOLATION_SCORE=0
LOGIN_MAX=7
LOGIN_WINDOW=5
@@ -458,8 +454,8 @@ OPENID_REUSE_TOKENS=
OPENID_JWKS_URL_CACHE_ENABLED=
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
# Set to true to use the OpenID Connect end session endpoint for logout
OPENID_USE_END_SESSION_ENDPOINT=
@@ -580,10 +576,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
# If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
#===================================================#
# UI #
#===================================================#
@@ -666,4 +658,4 @@ OPENWEATHER_API_KEY=
# Reranker (Required)
# JINA_API_KEY=your_jina_api_key
# or
# COHERE_API_KEY=your_cohere_api_key
# COHERE_API_KEY=your_cohere_api_key

View File

@@ -98,8 +98,6 @@ jobs:
cd client
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "")
# Filter out false positives
UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "")
echo "CLIENT_UNUSED<<EOF" >> $GITHUB_ENV
echo "$UNUSED" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV

4
.gitignore vendored
View File

@@ -55,8 +55,6 @@ bower_components/
# AI
.clineignore
.cursor
.aider*
CLAUDE.md
# Floobits
.floo
@@ -125,4 +123,4 @@ helm/**/.values.yaml
!/client/src/@types/i18next.d.ts
# SAML Idp cert
*.cert
*.cert

View File

@@ -1,4 +1,4 @@
# v0.7.9-rc1
# v0.7.8
# Base node image
FROM node:20-alpine AS node

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi
# v0.7.9-rc1
# v0.7.8
# Base for all builds
FROM node:20-alpine AS base-min

View File

@@ -52,7 +52,7 @@
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
- 🤖 **AI Model Selection**:
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
@@ -66,9 +66,10 @@
- 🔦 **Agents & Tools Integration**:
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
- 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context

View File

@@ -190,11 +190,10 @@ class AnthropicClient extends BaseClient {
reverseProxyUrl: this.options.reverseProxyUrl,
}),
apiKey: this.apiKey,
fetchOptions: {},
};
if (this.options.proxy) {
options.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy);
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
if (this.options.reverseProxyUrl) {

View File

@@ -13,6 +13,7 @@ const {
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { checkBalance } = require('~/models/balanceMethods');
const { truncateToolCallOutputs } = require('./prompts');
const { addSpaceIfNeeded } = require('~/server/utils');
const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream');
const { logger } = require('~/config');
@@ -571,7 +572,7 @@ class BaseClient {
});
}
const { editedContent } = opts;
const { generation = '' } = opts;
// It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages
@@ -586,21 +587,11 @@ class BaseClient {
isCreatedByUser: false,
model: this.modelOptions?.model ?? this.model,
sender: this.sender,
text: generation,
};
this.currentMessages.push(userMessage, latestMessage);
} else if (editedContent != null) {
// Handle editedContent for content parts
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
const { index, text, type } = editedContent;
if (index >= 0 && index < latestMessage.content.length) {
const contentPart = latestMessage.content[index];
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
contentPart[ContentTypes.THINK] = text;
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
contentPart[ContentTypes.TEXT] = text;
}
}
}
} else {
latestMessage.text = generation;
}
this.continued = true;
} else {
@@ -681,32 +672,16 @@ class BaseClient {
};
if (typeof completion === 'string') {
responseMessage.text = completion;
responseMessage.text = addSpaceIfNeeded(generation) + completion;
} else if (
Array.isArray(completion) &&
(this.clientName === EModelEndpoint.agents ||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
) {
responseMessage.text = '';
if (!opts.editedContent || this.currentMessages.length === 0) {
responseMessage.content = completion;
} else {
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
if (!latestMessage?.content) {
responseMessage.content = completion;
} else {
const existingContent = [...latestMessage.content];
const { type: editedType } = opts.editedContent;
responseMessage.content = this.mergeEditedContent(
existingContent,
completion,
editedType,
);
}
}
responseMessage.content = completion;
} else if (Array.isArray(completion)) {
responseMessage.text = completion.join('');
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
}
if (
@@ -817,8 +792,7 @@ class BaseClient {
userMessage.tokenCount = userMessageTokenCount;
/*
Note: `AgentController` saves the user message if not saved here
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
*/
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
@@ -827,8 +801,7 @@ class BaseClient {
}
/*
Note: we update the user message to be sure it gets the calculated token count;
though `AgentController` saves the user message if not saved here
(noted by `savedMessageIds`), EditController does not
though `AskController` saves the user message, EditController does not
*/
await userMessagePromise;
await this.updateMessageInDatabase({
@@ -1120,50 +1093,6 @@ class BaseClient {
return numTokens;
}
/**
* Merges completion content with existing content when editing TEXT or THINK types
* @param {Array} existingContent - The existing content array
* @param {Array} newCompletion - The new completion content
* @param {string} editedType - The type of content being edited
* @returns {Array} The merged content array
*/
mergeEditedContent(existingContent, newCompletion, editedType) {
if (!newCompletion.length) {
return existingContent.concat(newCompletion);
}
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
return existingContent.concat(newCompletion);
}
const lastIndex = existingContent.length - 1;
const lastExisting = existingContent[lastIndex];
const firstNew = newCompletion[0];
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
return existingContent.concat(newCompletion);
}
const mergedContent = [...existingContent];
if (editedType === ContentTypes.TEXT) {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.TEXT]:
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
};
} else {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.THINK]:
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
(firstNew[ContentTypes.THINK] || ''),
};
}
// Add remaining completion items
return mergedContent.concat(newCompletion.slice(1));
}
async sendPayload(payload, opts = {}) {
if (opts && typeof opts === 'object') {
this.setOptions(opts);

View File

@@ -0,0 +1,804 @@
const { Keyv } = require('keyv');
const crypto = require('crypto');
const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
ImageDetail,
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { createContextHandlers } = require('./prompts');
const { createCoherePayload } = require('./llm');
const { extractBaseURL } = require('~/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {};
class ChatGPTClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
this.conversationsCache = new Keyv(cacheOptions);
this.setOptions(options);
}
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = {
...this.options.modelOptions,
...options.modelOptions,
};
delete options.modelOptions;
// now we can merge options
this.options = {
...this.options,
...options,
};
} else {
this.options = options;
}
if (this.options.openaiApiKey) {
this.apiKey = this.options.openaiApiKey;
}
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || CHATGPT_MODEL,
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop,
};
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
const { isChatGptModel } = this;
this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
const { isUnofficialChatGptModel } = this;
// Davinci models have a max context length of 4097 tokens.
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
// I decided to reserve 1024 tokens for the response.
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
}
this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
if (isChatGptModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isUnofficialChatGptModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
});
} else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead.
this.startToken = '||>';
this.endToken = '';
try {
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch {
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
}
}
if (!this.modelOptions.stop) {
const stopTokens = [this.startToken];
if (this.endToken && this.endToken !== this.startToken) {
stopTokens.push(this.endToken);
}
stopTokens.push(`\n${this.userLabel}:`);
stopTokens.push('<|diff_marker|>');
// I chose not to do one for `chatGptLabel` because I've never seen it happen
this.modelOptions.stop = stopTokens;
}
if (this.options.reverseProxyUrl) {
this.completionsUrl = this.options.reverseProxyUrl;
} else if (isChatGptModel) {
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
} else {
this.completionsUrl = 'https://api.openai.com/v1/completions';
}
return this;
}
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
}
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
let modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = input;
} else {
modelOptions.prompt = input;
}
if (this.useOpenRouter && modelOptions.prompt) {
delete modelOptions.stop;
}
const { debug } = this.options;
let baseURL = this.completionsUrl;
if (debug) {
console.debug();
console.debug(baseURL);
console.debug(modelOptions);
console.debug();
}
const opts = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
};
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}
/** @type {TAzureConfig | undefined} */
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
const isAzure = this.azure || this.options.azure;
if (
(isAzure && this.isVisionModel && azureConfig) ||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName: modelOptions.model,
modelGroupMap,
groupMap,
});
opts.headers = resolveHeaders(headers);
this.langchainProxy = extractBaseURL(baseURL);
this.apiKey = azureOptions.azureOpenAIApiKey;
const groupName = modelGroupMap[modelOptions.model].group;
this.options.addParams = azureConfig.groupMap[groupName].addParams;
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
// Note: `forcePrompt` not re-assigned as only chat models are vision models
this.azure = !serverless && azureOptions;
this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
}
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
}
if (this.options.headers) {
opts.headers = { ...opts.headers, ...this.options.headers };
}
if (isAzure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
baseURL = this.langchainProxy
? constructAzureURL({
baseURL: this.langchainProxy,
azureOptions: this.azure,
})
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
if (this.options.forcePrompt) {
baseURL += '/completions';
} else {
baseURL += '/chat/completions';
}
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
} else if (this.apiKey) {
opts.headers.Authorization = `Bearer ${this.apiKey}`;
}
if (process.env.OPENAI_ORGANIZATION) {
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
}
if (this.useOpenRouter) {
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
opts.headers['X-Title'] = 'LibreChat';
}
/* hacky fixes for Mistral AI API:
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
- If there is only one message and it's a system message, change the role to user
*/
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
const { messages } = modelOptions;
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
if (systemMessageIndex > 0) {
const [systemMessage] = messages.splice(systemMessageIndex, 1);
messages.unshift(systemMessage);
}
modelOptions.messages = messages;
if (messages.length === 1 && messages[0].role === 'system') {
modelOptions.messages[0].role = 'user';
}
}
if (this.options.addParams && typeof this.options.addParams === 'object') {
modelOptions = {
...modelOptions,
...this.options.addParams,
};
logger.debug('[ChatGPTClient] chatCompletion: added params', {
addParams: this.options.addParams,
modelOptions,
});
}
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => {
delete modelOptions[param];
});
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
dropParams: this.options.dropParams,
modelOptions,
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
baseURL.includes('v1') &&
!baseURL.includes('/chat/completions') &&
this.isChatCompletion
) {
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
}
const BASE_URL = new URL(baseURL);
if (opts.defaultQuery) {
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
BASE_URL.searchParams.append(key, value);
});
delete opts.defaultQuery;
}
const completionsURL = BASE_URL.toString();
opts.body = JSON.stringify(modelOptions);
if (modelOptions.stream) {
return new Promise(async (resolve, reject) => {
try {
let done = false;
await fetchEventSource(completionsURL, {
...opts,
signal: abortController.signal,
async onopen(response) {
if (response.status === 200) {
return;
}
if (debug) {
console.debug(response);
}
let error;
try {
const body = await response.text();
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
error.json = JSON.parse(body);
} catch {
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
}
throw error;
},
onclose() {
if (debug) {
console.debug('Server closed the connection unexpectedly, returning...');
}
// workaround for private API not sending [DONE] event
if (!done) {
onProgress('[DONE]');
resolve();
}
},
onerror(err) {
if (debug) {
console.debug(err);
}
// rethrow to stop the operation
throw err;
},
onmessage(message) {
if (debug) {
console.debug(message);
}
if (!message.data || message.event === 'ping') {
return;
}
if (message.data === '[DONE]') {
onProgress('[DONE]');
resolve();
done = true;
return;
}
onProgress(JSON.parse(message.data));
},
});
} catch (err) {
reject(err);
}
});
}
const response = await fetch(completionsURL, {
...opts,
signal: abortController.signal,
});
if (response.status !== 200) {
const body = await response.text();
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
try {
error.json = JSON.parse(body);
} catch {
error.body = body;
}
throw error;
}
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
reply += message.text;
}
/*
Cohere API Chinese Unicode character replacement hotfix.
Should be un-commented when the following issue is resolved:
https://github.com/cohere-ai/cohere-typescript/issues/151
else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
*/
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
||>Message:
${userMessage.message}
||>Response:
${botMessage.message}
||>Title:`,
};
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
titleGenClientOptions.modelOptions = {
model: 'gpt-3.5-turbo',
temperature: 0,
presence_penalty: 0,
frequency_penalty: 0,
};
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
const result = await titleGenClient.getCompletion([instructionsPayload], null);
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
return result.choices[0].message.content
.replace(/[^a-zA-Z0-9' ]/g, '')
.replace(/\s+/g, ' ')
.trim();
}
async sendMessage(message, opts = {}) {
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
this.setOptions(opts.clientOptions);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
let conversation =
typeof opts.conversation === 'object'
? opts.conversation
: await this.conversationsCache.get(conversationId);
let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
isNewConversation = true;
}
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const userMessage = {
id: crypto.randomUUID(),
parentMessageId,
role: 'User',
message,
};
conversation.messages.push(userMessage);
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
const { prompt: payload, context } = await this.buildPrompt(
conversation.messages,
userMessage.id,
{
isChatGptModel: this.isChatGptModel,
promptPrefix: opts.promptPrefix,
},
);
if (this.options.keepNecessaryMessagesOnly) {
conversation.messages = context;
}
let reply = '';
let result = null;
if (typeof opts.onProgress === 'function') {
await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
return;
}
const token = this.isChatGptModel
? progressMessage.choices[0].delta.content
: progressMessage.choices[0].text;
// first event's delta content is always undefined
if (!token) {
return;
}
if (this.options.debug) {
console.debug(token);
}
if (token === this.endToken) {
return;
}
opts.onProgress(token);
reply += token;
},
opts.abortController || new AbortController(),
);
} else {
result = await this.getCompletion(
payload,
null,
opts.abortController || new AbortController(),
);
if (this.options.debug) {
console.debug(JSON.stringify(result));
}
if (this.isChatGptModel) {
reply = result.choices[0].message.content;
} else {
reply = result.choices[0].text.replace(this.endToken, '');
}
}
// avoids some rendering issues when using the CLI app
if (this.options.debug) {
console.debug();
}
reply = reply.trim();
const replyMessage = {
id: crypto.randomUUID(),
parentMessageId: userMessage.id,
role: 'ChatGPT',
message: reply,
};
conversation.messages.push(replyMessage);
const returnData = {
response: replyMessage.message,
conversationId,
parentMessageId: replyMessage.parentMessageId,
messageId: replyMessage.id,
details: result || {},
};
if (shouldGenerateTitle) {
conversation.title = await this.generateTitle(userMessage, replyMessage);
returnData.title = conversation.title;
}
await this.conversationsCache.set(conversationId, conversation);
if (this.options.returnConversation) {
returnData.conversation = conversation;
}
return returnData;
}
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
// Handle attachments and create augmentedPrompt
if (this.options.attachments) {
const attachments = await this.options.attachments;
const lastMessage = messages[messages.length - 1];
if (this.message_file_map) {
this.message_file_map[lastMessage.messageId] = attachments;
} else {
this.message_file_map = {
[lastMessage.messageId]: attachments,
};
}
const files = await this.addImageURLs(lastMessage, attachments);
this.options.attachments = files;
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
messages[messages.length - 1].text,
);
}
// Calculate image token cost and process embedded files
messages.forEach((message, i) => {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
messages[i].tokenCount =
(messages[i].tokenCount || 0) +
this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
}
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
let currentTokenCount;
if (isChatGptModel) {
currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
} else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
}
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
const context = [];
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && messages.length > 0) {
const message = messages.pop();
const roleLabel =
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
? this.userLabel
: this.chatGptLabel;
const messageString = `${this.startToken}${roleLabel}:\n${
message?.text ?? message?.message
}${this.endToken}\n`;
let newPromptBody;
if (promptBody || isChatGptModel) {
newPromptBody = `${messageString}${promptBody}`;
} else {
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
// like "what's the last thing I wrote?".
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}
context.unshift(message);
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = `${promptBody}${promptSuffix}`;
if (isChatGptModel) {
messagePayload.content = prompt;
// Add 3 tokens for Assistant Label priming after all messages have been counted.
currentTokenCount += 3;
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (isChatGptModel) {
return { prompt: [instructionsPayload, messagePayload], context };
}
return { prompt, context, promptTokens: currentTokenCount };
}
getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length;
}
/**
* Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
*
* @param {Object} message
*/
getTokenCountForMessage(message) {
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let tokensPerMessage = 3;
let tokensPerName = 1;
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
tokensPerMessage = 4;
tokensPerName = -1;
}
let numTokens = tokensPerMessage;
for (let [key, value] of Object.entries(message)) {
numTokens += this.getTokenCount(value);
if (key === 'name') {
numTokens += tokensPerName;
}
}
return numTokens;
}
}
module.exports = ChatGPTClient;

View File

@@ -1,7 +1,7 @@
const { google } = require('googleapis');
const { Tokenizer } = require('@librechat/api');
const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
@@ -12,13 +12,13 @@ const {
endpointSettings,
parseTextParts,
EModelEndpoint,
googleSettings,
ContentTypes,
VisionModes,
ErrorTypes,
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
@@ -166,16 +166,6 @@ class GoogleClient extends BaseClient {
);
}
// Add thinking configuration
this.modelOptions.thinkingConfig = {
thinkingBudget:
(this.modelOptions.thinking ?? googleSettings.thinking.default)
? this.modelOptions.thinkingBudget
: 0,
};
delete this.modelOptions.thinking;
delete this.modelOptions.thinkingBudget;
this.sender =
this.options.sender ??
getResponseSender({

View File

@@ -5,7 +5,6 @@ const {
isEnabled,
Tokenizer,
createFetch,
resolveHeaders,
constructAzureURL,
genAzureChatCompletion,
createStreamEventHandlers,
@@ -16,6 +15,7 @@ const {
ContentTypes,
parseTextParts,
EModelEndpoint,
resolveHeaders,
KnownEndpoints,
openAISettings,
ImageDetailCost,
@@ -37,6 +37,7 @@ const { addSpaceIfNeeded, sleep } = require('~/server/utils');
const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
@@ -46,6 +47,12 @@ const { logger } = require('~/config');
class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
@@ -372,12 +379,23 @@ class OpenAIClient extends BaseClient {
return files;
}
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
async buildMessages(
messages,
parentMessageId,
{ isChatCompletion = false, promptPrefix = null },
opts,
) {
let orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
summary: this.shouldSummarize,
});
if (!isChatCompletion) {
return await this.buildPrompt(orderedMessages, {
isChatGptModel: isChatCompletion,
promptPrefix,
});
}
let payload;
let instructions;
@@ -1141,7 +1159,6 @@ ${convo}
logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions });
const opts = {
baseURL,
fetchOptions: {},
};
if (this.useOpenRouter) {
@@ -1160,7 +1177,7 @@ ${convo}
}
if (this.options.proxy) {
opts.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy);
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
/** @type {TAzureConfig | undefined} */
@@ -1378,7 +1395,7 @@ ${convo}
...modelOptions,
stream: true,
};
const stream = await openai.chat.completions
const stream = await openai.beta.chat.completions
.stream(params)
.on('abort', () => {
/* Do nothing here */

View File

@@ -0,0 +1,542 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { processFileURL } = require('~/server/services/Files/process');
const { EModelEndpoint } = require('librechat-data-provider');
const { checkBalance } = require('~/models/balanceMethods');
const { formatLangChainMessages } = require('./prompts');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { logger } = require('~/config');
class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.sender = options.sender ?? 'Assistant';
this.tools = [];
this.actions = [];
this.setOptions(options);
this.openAIApiKey = this.apiKey;
this.executor = null;
}
setOptions(options) {
this.agentOptions = { ...options.agentOptions };
this.functionsAgent = this.agentOptions?.agent === 'functions';
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
super.setOptions(options);
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) {
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
}
}
getSaveOptions() {
return {
artifacts: this.options.artifacts,
chatGptLabel: this.options.chatGptLabel,
modelLabel: this.options.modelLabel,
promptPrefix: this.options.promptPrefix,
tools: this.options.tools,
...this.modelOptions,
agentOptions: this.agentOptions,
iconURL: this.options.iconURL,
greeting: this.options.greeting,
spec: this.options.spec,
};
}
saveLatestAction(action) {
this.actions.push(action);
}
getFunctionModelName(input) {
if (/-(?!0314)\d{4}/.test(input)) {
return input;
} else if (input.includes('gpt-3.5-turbo')) {
return 'gpt-3.5-turbo';
} else if (input.includes('gpt-4')) {
return 'gpt-4';
} else {
return 'gpt-3.5-turbo';
}
}
getBuildMessagesOptions(opts) {
return {
isChatCompletion: true,
promptPrefix: opts.promptPrefix,
abortController: opts.abortController,
};
}
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = {
modelName: this.agentOptions.model,
temperature: this.agentOptions.temperature,
};
const model = this.initializeLLM({
...modelOptions,
context: 'plugins',
initialMessageCount: this.currentMessages.length + 1,
});
logger.debug(
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
);
// Map Messages to Langchain format
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
userName: this.options?.name,
});
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
const memory = new BufferMemory({
llm: model,
chatHistory: new ChatMessageHistory(pastMessages),
});
const { loadedTools } = await loadTools({
user,
model,
tools: this.options.tools,
functions: this.functionsAgent,
options: {
memory,
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
fileStrategy: this.options.req.app.locals.fileStrategy,
processFileURL,
message,
},
useSpecs: true,
});
if (loadedTools.length === 0) {
return;
}
this.tools = loadedTools;
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
logger.debug(
'[PluginsClient] Loaded Tools',
this.tools.map((tool) => tool.name),
);
const handleAction = (action, runId, callback = null) => {
this.saveLatestAction(action);
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
if (typeof callback === 'function') {
callback(action, runId);
}
};
// initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
let customInstructions = (this.options.promptPrefix ?? '').trim();
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
}
this.executor = await initializer({
model,
signal,
pastMessages,
tools: this.tools,
customInstructions,
verbose: this.options.debug,
returnIntermediateSteps: true,
customName: this.options.chatGptLabel,
currentDateString: this.currentDateString,
callbackManager: CallbackManager.fromHandlers({
async handleAgentAction(action, runId) {
handleAction(action, runId, onAgentAction);
},
async handleChainEnd(action) {
if (typeof onChainEnd === 'function') {
onChainEnd(action);
}
},
}),
});
logger.debug('[PluginsClient] Loaded agent.');
}
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
let errorMessage = '';
const maxAttempts = 1;
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
const errorInput = buildErrorInput({
message,
errorMessage,
actions: this.actions,
functionsAgent: this.functionsAgent,
});
const input = attempts > 1 ? errorInput : message;
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
if (errorMessage.length > 0) {
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
}
try {
this.result = await this.executor.call({ input, signal }, [
{
async handleToolStart(...args) {
await onToolStart(...args);
},
async handleToolEnd(...args) {
await onToolEnd(...args);
},
async handleLLMEnd(output) {
const { generations } = output;
const { text } = generations[0][0];
if (text && typeof stream === 'function') {
await stream(text);
}
},
},
]);
break; // Exit the loop if the function call is successful
} catch (err) {
logger.error('[PluginsClient] executorCall error:', err);
if (attempts === maxAttempts) {
const { run } = this.runManager.getRunByConversationId(this.conversationId);
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
this.result.output = run && run.error ? run.error : defaultOutput;
this.result.errorMessage = run && run.error ? run.error : err.message;
this.result.intermediateSteps = this.actions;
break;
}
}
}
}
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
output,
errorMessage,
...result,
});
const { error } = responseMessage;
if (!error) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
}
// Record usage only when completion is skipped as it is already recorded in the agent phase.
if (!this.agentOptions.skipCompletion && !error) {
await this.recordTokenUsage(responseMessage);
}
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result, databasePromise };
}
async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);
}
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
} = await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
this.currentMessages.push(userMessage);
let {
prompt: payload,
tokenCountMap,
promptTokens,
} = await this.buildMessages(
this.currentMessages,
userMessage.messageId,
this.getBuildMessagesOptions({
promptPrefix: null,
abortController: this.abortController,
}),
);
if (tokenCountMap) {
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
}
this.handleTokenCountMap(tokenCountMap);
}
this.result = {};
if (payload) {
this.currentMessages = payload;
}
if (!this.skipSaveUserMessage) {
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise,
});
}
}
const balance = this.options.req?.app?.locals?.balance;
if (balance?.enabled) {
await checkBalance({
req: this.options.req,
res: this.options.res,
txData: {
user: this.user,
tokenType: 'prompt',
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: EModelEndpoint.openAI,
},
});
}
const responseMessage = {
endpoint: EModelEndpoint.gptPlugins,
iconURL: this.options.iconURL,
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
isCreatedByUser: false,
model: this.modelOptions.model,
sender: this.sender,
promptTokens,
};
await this.initialize({
user,
message,
onAgentAction,
onChainEnd,
signal: this.abortController.signal,
onProgress: opts.onProgress,
});
// const stream = async (text) => {
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
// };
await this.executorCall(message, {
signal: this.abortController.signal,
// stream,
onToolStart,
onToolEnd,
});
// If message was aborted mid-generation
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
responseMessage.text = 'Cancelled.';
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
// If error occurred during generation (likely token_balance)
if (this.result?.errorMessage?.length > 0) {
responseMessage.error = true;
responseMessage.text = this.result.output;
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
const partialText = opts.getPartialText();
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
responseMessage.text =
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output) {
responseMessage.text = this.result.output;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
const promptPrefix = buildPromptPrefix({
result: this.result,
message,
functionsAgent: this.functionsAgent,
});
logger.debug('[PluginsClient]', { promptPrefix });
payload = await this.buildCompletionPrompt({
messages: this.currentMessages,
promptPrefix,
});
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
responseMessage.text = await this.sendCompletion(payload, opts);
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
const orderedMessages = messages;
let promptPrefix = _promptPrefix.trim();
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
if (this.isGpt3) {
instructionsPayload.role = 'user';
messagePayload.role = 'user';
instructionsPayload.content += `\n${promptSuffix}`;
}
// testing if this works with browser endpoint
if (!this.isGpt3 && this.options.reverseProxyUrl) {
instructionsPayload.role = 'user';
}
let currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop();
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${
message.text ?? message.content ?? ''
}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setTimeout(resolve, 0));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = promptBody;
messagePayload.content = prompt;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
if (this.isGpt3 && messagePayload.content.length > 0) {
const context = 'Chat History:\n';
messagePayload.content = `${context}${prompt}`;
currentTokenCount += this.getTokenCount(context);
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (this.isGpt3) {
messagePayload.content += promptSuffix;
return [instructionsPayload, messagePayload];
}
const result = [messagePayload, instructionsPayload];
if (this.functionsAgent && !this.isGpt3) {
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
}
return result.filter((message) => message.content.length > 0);
}
}
module.exports = PluginsClient;

View File

@@ -1,11 +1,15 @@
const ChatGPTClient = require('./ChatGPTClient');
const OpenAIClient = require('./OpenAIClient');
const PluginsClient = require('./PluginsClient');
const GoogleClient = require('./GoogleClient');
const TextStream = require('./TextStream');
const AnthropicClient = require('./AnthropicClient');
const toolUtils = require('./tools/util');
module.exports = {
ChatGPTClient,
OpenAIClient,
PluginsClient,
GoogleClient,
TextStream,
AnthropicClient,

View File

@@ -1,7 +1,6 @@
const axios = require('axios');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const footer = `Use the context as your learned knowledge to better answer the user.
@@ -19,7 +18,7 @@ function createContextHandlers(req, userMessageContent) {
const queryPromises = [];
const processedFiles = [];
const processedIds = new Set();
const jwtToken = generateShortLivedToken(req.user.id);
const jwtToken = req.headers.authorization.split(' ')[1];
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
const query = async (file) => {
@@ -97,35 +96,35 @@ function createContextHandlers(req, userMessageContent) {
resolvedQueries.length === 0
? '\n\tThe semantic search did not return any results.'
: resolvedQueries
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
const generateContext = (currentContext) =>
`
const generateContext = (currentContext) =>
`
<file>
<filename>${file.filename}</filename>
<context>${currentContext}
</context>
</file>`;
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
<contextItem>
<![CDATA[${pageContent?.trim()}]]>
</contextItem>`;
})
.join('');
})
.join('');
return generateContext(contextItems);
})
.join('');
return generateContext(contextItems);
})
.join('');
if (useFullContext) {
const prompt = `${header}

View File

@@ -309,7 +309,7 @@ describe('AnthropicClient', () => {
};
client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeUndefined();
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
});
it('should not add beta header for other models', () => {
@@ -320,7 +320,7 @@ describe('AnthropicClient', () => {
},
});
const anthropicClient = client.getClient();
expect(anthropicClient._options.defaultHeaders).toBeUndefined();
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
});
});

View File

@@ -531,6 +531,44 @@ describe('OpenAIClient', () => {
});
});
describe('sendMessage/getCompletion/chatCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
const model = 'text-davinci-003';
const onProgress = jest.fn().mockImplementation(() => ({}));
const testClient = new OpenAIClient('test-api-key', {
...defaultOptions,
modelOptions: { model },
});
const getCompletion = jest.spyOn(testClient, 'getCompletion');
await testClient.sendMessage('Hi mom!', { onProgress });
expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);
// Check if the first argument (url) is correct
const firstCallArgs = fetchEventSource.mock.calls[0];
const expectedURL = 'https://api.openai.com/v1/completions';
expect(firstCallArgs[0]).toBe(expectedURL);
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).toHaveProperty('model');
expect(requestBody.model).toBe(model);
});
});
describe('checkVisionRequest functionality', () => {
let client;
const attachments = [{ type: 'image/png' }];

View File

@@ -0,0 +1,314 @@
const crypto = require('crypto');
const { Constants } = require('librechat-data-provider');
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
const PluginsClient = require('../PluginsClient');
jest.mock('~/db/connect');
jest.mock('~/models/Conversation', () => {
return function () {
return {
save: jest.fn(),
deleteConvos: jest.fn(),
};
};
});
const defaultAzureOptions = {
azureOpenAIApiInstanceName: 'your-instance-name',
azureOpenAIApiDeploymentName: 'your-deployment-name',
azureOpenAIApiVersion: '2020-07-01-preview',
};
describe('PluginsClient', () => {
let TestAgent;
let options = {
tools: [],
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
};
let parentMessageId;
let conversationId;
const fakeMessages = [];
const userMessage = 'Hello, ChatGPT!';
const apiKey = 'fake-api-key';
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, options);
TestAgent.loadHistory = jest
.fn()
.mockImplementation((conversationId, parentMessageId = null) => {
if (!conversationId) {
TestAgent.currentMessages = [];
return Promise.resolve([]);
}
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
messages: fakeMessages,
parentMessageId,
});
const chatMessages = orderedMessages.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanMessage(msg.text)
: new AIMessage(msg.text),
);
TestAgent.currentMessages = orderedMessages;
return Promise.resolve(chatMessages);
});
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
if (opts && typeof opts === 'object') {
TestAgent.setOptions(opts);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
this.pastMessages = await TestAgent.loadHistory(
conversationId,
TestAgent.options?.parentMessageId,
);
const userMessage = {
text: message,
sender: 'ChatGPT',
isCreatedByUser: true,
messageId: userMessageId,
parentMessageId,
conversationId,
};
const response = {
sender: 'ChatGPT',
text: 'Hello, User!',
isCreatedByUser: false,
messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId,
conversationId,
};
fakeMessages.push(userMessage);
fakeMessages.push(response);
return response;
});
});
test('initializes PluginsClient without crashing', () => {
expect(TestAgent).toBeInstanceOf(PluginsClient);
});
test('check setOptions function', () => {
expect(TestAgent.agentIsGpt3).toBe(true);
});
describe('sendMessage', () => {
test('sendMessage should return a response message', async () => {
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: expect.any(String),
});
const response = await TestAgent.sendMessage(userMessage);
parentMessageId = response.messageId;
conversationId = response.conversationId;
expect(response).toEqual(expectedResult);
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {
conversationId,
parentMessageId,
};
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: opts.conversationId,
});
const response = await TestAgent.sendMessage(userMessage, opts);
parentMessageId = response.messageId;
expect(response.conversationId).toEqual(conversationId);
expect(response).toEqual(expectedResult);
});
test('should return chat history', async () => {
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
expect(TestAgent.currentMessages).toHaveLength(4);
expect(chatMessages[0].text).toEqual(userMessage);
});
});
describe('getFunctionModelName', () => {
let client;
beforeEach(() => {
client = new PluginsClient('dummy_api_key');
});
test('should return the input when it includes a dash followed by four digits', () => {
expect(client.getFunctionModelName('-1234')).toBe('-1234');
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
});
test('should return the input for all function-capable models (`0613` models and above)', () => {
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
});
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-4" when the input includes "gpt-4"', () => {
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
});
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
});
});
describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins
// let client;
// beforeEach(() => {
// client = new PluginsClient('dummy_api_key');
// });
test('should not call getFunctionModelName when azure options are set', () => {
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
const model = 'gpt-4-turbo';
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
const testClient = new PluginsClient('dummy_api_key', {
agentOptions: {
model,
agent: 'functions',
},
azure: defaultAzureOptions,
});
expect(spy).not.toHaveBeenCalled();
expect(testClient.agentOptions.model).toBe(model);
spy.mockRestore();
});
});
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
});

View File

@@ -107,12 +107,6 @@ const getImageEditPromptDescription = () => {
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
};
function createAbortHandler() {
return function () {
logger.debug('[ImageGenOAI] Image generation aborted');
};
}
/**
* Creates OpenAI Image tools (generation and editing)
* @param {Object} fields - Configuration fields
@@ -207,18 +201,10 @@ function createOpenAIImageTools(fields = {}) {
}
let resp;
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
resp = await openai.images.generate(
{
model: 'gpt-image-1',
@@ -242,10 +228,6 @@ function createOpenAIImageTools(fields = {}) {
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
Error Message: ${error.message}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
if (!resp) {
@@ -427,17 +409,10 @@ Error Message: ${error.message}`);
headers['Authorization'] = `Bearer ${apiKey}`;
}
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
/** @type {import('axios').AxiosRequestConfig} */
const axiosConfig = {
@@ -492,10 +467,6 @@ Error Message: ${error.message}`);
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
Error Message: ${error.message || 'Unknown error'}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
},
{

View File

@@ -1,35 +1,26 @@
const { z } = require('zod');
const axios = require('axios');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Tools, EToolResources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/**
*
* @param {Object} options
* @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @returns {Promise<{
* files: Array<{ file_id: string; filename: string }>,
* toolContext: string
* }>}
*/
const primeFiles = async (options) => {
const { tool_resources, req, agentId } = options;
const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
const dbFiles = (
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
@@ -68,7 +59,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.';
}
const jwtToken = generateShortLivedToken(req.user.id);
const jwtToken = req.headers.authorization.split(' ')[1];
if (!jwtToken) {
return 'There was an error authenticating the file search request.';
}

View File

@@ -1,14 +1,14 @@
const { mcpToolPattern } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const {
Tools,
Constants,
EToolResources,
loadWebSearchAuth,
replaceSpecialVars,
} = require('librechat-data-provider');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const {
availableTools,
manifestToolMap,
@@ -28,10 +28,11 @@ const {
} = require('../');
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getCachedTools } = require('~/server/services/Config');
const { createMCPTool } = require('~/server/services/MCP');
const { logger } = require('~/config');
const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
/**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
@@ -92,7 +93,7 @@ const validateTools = async (user, tools = []) => {
return Array.from(validToolsSet.values());
} catch (err) {
logger.error('[validateTools] There was a problem validating tools', err);
throw new Error(err);
throw new Error('There was a problem validating tools');
}
};
@@ -235,7 +236,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */
const toolContextMap = {};
const appTools = (await getCachedTools({ includeGlobal: true })) ?? {};
const appTools = options.req?.app?.locals?.availableTools ?? {};
for (const tool of tools) {
if (tool === Tools.execute_code) {
@@ -245,13 +246,7 @@ const loadTools = async ({
authFields: [EnvVar.CODE_API_KEY],
});
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
const { files, toolContext } = await primeCodeFiles(
{
...options,
agentId: agent?.id,
},
codeApiKey,
);
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -266,10 +261,7 @@ const loadTools = async ({
continue;
} else if (tool === Tools.file_search) {
requestedTools[tool] = async () => {
const { files, toolContext } = await primeSearchFiles({
...options,
agentId: agent?.id,
});
const { files, toolContext } = await primeSearchFiles(options);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -307,7 +299,6 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
requestedTools[tool] = async () =>
createMCPTool({
req: options.req,
res: options.res,
toolKey: tool,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,

View File

@@ -1,8 +1,7 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, math } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, math, removePorts } = require('~/server/utils');
const { deleteAllUserSessions } = require('~/models');
const { removePorts } = require('~/server/utils');
const getLogStores = require('./getLogStores');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};

View File

@@ -1,7 +1,7 @@
const { Keyv } = require('keyv');
const { isEnabled, math } = require('@librechat/api');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles');
const { isEnabled, math } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo');
@@ -29,10 +29,6 @@ const roles = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const mcpTools = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MCP_TOOLS });
const audioRuns = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
@@ -71,7 +67,6 @@ const openIdExchangedTokensCache = isRedisEnabled
const namespaces = {
[CacheKeys.ROLES]: roles,
[CacheKeys.MCP_TOOLS]: mcpTools,
[CacheKeys.CONFIG_STORE]: config,
[CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),

View File

@@ -9,7 +9,7 @@ const banViolation = require('./banViolation');
* @param {Object} res - Express response object.
* @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log.
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
* @param {number} [score=1] - The severity of the violation. Defaults to 1
*/
const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id;

View File

@@ -15,7 +15,7 @@ let flowManager = null;
*/
function getMCPManager(userId) {
if (!mcpManager) {
mcpManager = MCPManager.getInstance();
mcpManager = MCPManager.getInstance(logger);
} else {
mcpManager.checkIdleConnections(userId);
}
@@ -30,6 +30,7 @@ function getFlowStateManager(flowsCache) {
if (!flowManager) {
flowManager = new FlowStateManager(flowsCache, {
ttl: Time.ONE_MINUTE * 3,
logger,
});
}
return flowManager;

View File

@@ -1,6 +1,7 @@
require('dotenv').config();
const mongoose = require('mongoose');
const MONGO_URI = process.env.MONGO_URI;
const MONGO_DB_NAME = process.env.MONGO_DB_NAME;
if (!MONGO_URI) {
throw new Error('Please define the MONGO_URI environment variable');
@@ -33,6 +34,10 @@ async function connectDb() {
// useCreateIndex: true
};
if (MONGO_DB_NAME) {
opts.dbName = MONGO_DB_NAME;
}
mongoose.set('strictQuery', true);
cached.promise = mongoose.connect(MONGO_URI, opts).then((mongoose) => {
return mongoose;
@@ -43,6 +48,4 @@ async function connectDb() {
return cached.conn;
}
module.exports = {
connectDb,
};
module.exports = { connectDb };

102
api/db/connect.spec.js Normal file
View File

@@ -0,0 +1,102 @@
jest.mock('dotenv');
require('dotenv').config();
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
// Helper to clear the connect module from require cache
function clearConnectModule() {
delete require.cache[require.resolve('./connect')];
}
beforeEach(() => {
// Reset environment and global state
clearConnectModule();
delete process.env.MONGO_URI;
delete process.env.MONGO_DB_NAME;
global.mongoose = undefined;
});
describe('connectDb Module', () => {
it('throws if MONGO_URI is not defined', () => {
expect(() => require('./connect')).toThrow('Please define the MONGO_URI environment variable');
});
describe('with in-memory MongoDB', () => {
let mongoServer;
let uri;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
uri = mongoServer.getUri();
});
afterAll(async () => {
// Ensure real mongoose connection is torn down
await mongoose.disconnect();
await mongoServer.stop();
});
it('connects using only MONGO_URI (no dbName override)', async () => {
process.env.MONGO_URI = uri;
clearConnectModule();
const { connectDb } = require('./connect');
const mongooseInstance = await connectDb();
expect(mongooseInstance).toBe(mongoose);
// The default DB is whatever follows the last slash in the URI
const defaultDbNameMatch = uri.match(/\/([\w-]+)(\?.*)?$/);
const defaultDbName = defaultDbNameMatch ? defaultDbNameMatch[1] : mongoose.connection.name;
expect(mongooseInstance.connection.name).toBe(defaultDbName);
});
it('strips any DB name in URI when MONGO_DB_NAME is set', async () => {
process.env.MONGO_URI = uri;
process.env.MONGO_DB_NAME = 'ignoredDbName';
clearConnectModule();
const connectSpy = jest.spyOn(mongoose, 'connect');
const { connectDb } = require('./connect');
await connectDb();
const [calledUri, calledOpts] = connectSpy.mock.calls[0];
// Should have removed the trailing "/<dbname>"
expect(calledUri).not.toMatch(/\/[\w-]+$/);
// And since we strip it rather than passing dbName, opts.dbName remains unset
expect(calledOpts).not.toHaveProperty('dbName');
connectSpy.mockRestore();
});
it('caches the connection between calls', async () => {
process.env.MONGO_URI = uri;
clearConnectModule();
const { connectDb } = require('./connect');
const first = await connectDb();
// simulate still-connected
mongoose.connection.readyState = 1;
const second = await connectDb();
expect(second).toBe(first);
});
it('reconnects if cached conn is disconnected', async () => {
process.env.MONGO_URI = uri;
clearConnectModule();
const { connectDb } = require('./connect');
// First, establish a connection
await connectDb();
expect(mongoose.connection.readyState).toBe(1);
// Now actually tear it down
await mongoose.disconnect();
expect(mongoose.connection.readyState).toBe(0);
// Calling again should reconnect
await connectDb();
expect(mongoose.connection.readyState).toBe(1);
});
});
});

View File

@@ -1,11 +1,8 @@
const mongoose = require('mongoose');
const { MeiliSearch } = require('meilisearch');
const { logger } = require('@librechat/data-schemas');
const { FlowStateManager } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const Conversation = mongoose.models.Conversation;
const Message = mongoose.models.Message;
@@ -31,123 +28,43 @@ class MeiliSearchClient {
}
}
/**
* Performs the actual sync operations for messages and conversations
*/
async function performSync() {
const client = MeiliSearchClient.getInstance();
const { status } = await client.health();
if (status !== 'available') {
throw new Error('Meilisearch not available');
}
if (indexingDisabled === true) {
logger.info('[indexSync] Indexing is disabled, skipping...');
return { messagesSync: false, convosSync: false };
}
let messagesSync = false;
let convosSync = false;
// Check if we need to sync messages
const messageProgress = await Message.getSyncProgress();
if (!messageProgress.isComplete) {
logger.info(
`[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`,
);
// Check if we should do a full sync or incremental
const messageCount = await Message.countDocuments();
const messagesIndexed = messageProgress.totalProcessed;
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
if (messageCount - messagesIndexed > syncThreshold) {
logger.info('[indexSync] Starting full message sync due to large difference');
await Message.syncWithMeili();
messagesSync = true;
} else if (messageCount !== messagesIndexed) {
logger.warn('[indexSync] Messages out of sync, performing incremental sync');
await Message.syncWithMeili();
messagesSync = true;
}
} else {
logger.info(
`[indexSync] Messages are fully synced: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments}`,
);
}
// Check if we need to sync conversations
const convoProgress = await Conversation.getSyncProgress();
if (!convoProgress.isComplete) {
logger.info(
`[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`,
);
const convoCount = await Conversation.countDocuments();
const convosIndexed = convoProgress.totalProcessed;
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
if (convoCount - convosIndexed > syncThreshold) {
logger.info('[indexSync] Starting full conversation sync due to large difference');
await Conversation.syncWithMeili();
convosSync = true;
} else if (convoCount !== convosIndexed) {
logger.warn('[indexSync] Convos out of sync, performing incremental sync');
await Conversation.syncWithMeili();
convosSync = true;
}
} else {
logger.info(
`[indexSync] Conversations are fully synced: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments}`,
);
}
return { messagesSync, convosSync };
}
/**
* Main index sync function that uses FlowStateManager to prevent concurrent execution
*/
async function indexSync() {
if (!searchEnabled) {
return;
}
logger.info('[indexSync] Starting index synchronization check...');
try {
// Get or create FlowStateManager instance
const flowsCache = getLogStores(CacheKeys.FLOWS);
if (!flowsCache) {
logger.warn('[indexSync] Flows cache not available, falling back to direct sync');
return await performSync();
const client = MeiliSearchClient.getInstance();
const { status } = await client.health();
if (status !== 'available') {
throw new Error('Meilisearch not available');
}
const flowManager = new FlowStateManager(flowsCache, {
ttl: 60000 * 10, // 10 minutes TTL for sync operations
});
// Use a unique flow ID for the sync operation
const flowId = 'meili-index-sync';
const flowType = 'MEILI_SYNC';
// This will only execute the handler if no other instance is running the sync
const result = await flowManager.createFlowWithHandler(flowId, flowType, performSync);
if (result.messagesSync || result.convosSync) {
logger.info('[indexSync] Sync completed successfully');
} else {
logger.debug('[indexSync] No sync was needed');
}
return result;
} catch (err) {
if (err.message.includes('flow already exists')) {
logger.info('[indexSync] Sync already running on another instance');
if (indexingDisabled === true) {
logger.info('[indexSync] Indexing is disabled, skipping...');
return;
}
const messageCount = await Message.countDocuments();
const convoCount = await Conversation.countDocuments();
const messages = await client.index('messages').getStats();
const convos = await client.index('convos').getStats();
const messagesIndexed = messages.numberOfDocuments;
const convosIndexed = convos.numberOfDocuments;
logger.debug(`[indexSync] There are ${messageCount} messages and ${messagesIndexed} indexed`);
logger.debug(`[indexSync] There are ${convoCount} convos and ${convosIndexed} indexed`);
if (messageCount !== messagesIndexed) {
logger.debug('[indexSync] Messages out of sync, indexing');
Message.syncWithMeili();
}
if (convoCount !== convosIndexed) {
logger.debug('[indexSync] Convos out of sync, indexing');
Conversation.syncWithMeili();
}
} catch (err) {
if (err.message.includes('not found')) {
logger.debug('[indexSync] Creating indices...');
currentTimeout = setTimeout(async () => {

View File

@@ -11,7 +11,6 @@ const {
removeAgentIdsFromProject,
removeAgentFromAllProjects,
} = require('./Project');
const { getCachedTools } = require('~/server/services/Config');
const getLogStores = require('~/cache/getLogStores');
const { getActions } = require('./Action');
const { Agent } = require('~/db/models');
@@ -56,12 +55,12 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
* @param {string} params.agent_id
* @param {string} params.endpoint
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
* @returns {Agent|null} The agent document as a plain object, or null if not found.
*/
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => {
const { model, ...model_parameters } = _m;
/** @type {Record<string, FunctionTool>} */
const availableTools = await getCachedTools({ includeGlobal: true });
const availableTools = req.app.locals.availableTools;
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
@@ -70,9 +69,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
if (ephemeralAgent?.execute_code === true) {
tools.push(Tools.execute_code);
}
if (ephemeralAgent?.file_search === true) {
tools.push(Tools.file_search);
}
if (ephemeralAgent?.web_search === true) {
tools.push(Tools.web_search);
}
@@ -90,7 +86,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
}
const instructions = req.body.promptPrefix;
const result = {
return {
id: agent_id,
instructions,
provider: endpoint,
@@ -98,11 +94,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
model,
tools,
};
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
result.artifacts = ephemeralAgent.artifacts;
}
return result;
};
/**
@@ -120,7 +111,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
return null;
}
if (agent_id === EPHEMERAL_AGENT_ID) {
return await loadEphemeralAgent({ req, agent_id, endpoint, model_parameters });
return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters });
}
const agent = await getAgent({
id: agent_id,

View File

@@ -6,10 +6,6 @@ const originalEnv = {
process.env.CREDS_KEY = '0123456789abcdef0123456789abcdef';
process.env.CREDS_IV = '0123456789abcdef';
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn(),
}));
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { agentSchema } = require('@librechat/data-schemas');
@@ -27,7 +23,6 @@ const {
generateActionMetadataHash,
revertAgentVersion,
} = require('./Agent');
const { getCachedTools } = require('~/server/services/Config');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
@@ -43,7 +38,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -411,9 +406,8 @@ describe('models/Agent', () => {
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -670,7 +664,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -1332,7 +1326,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -1514,7 +1508,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -1552,12 +1546,6 @@ describe('models/Agent', () => {
test('should test ephemeral agent loading logic', async () => {
const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants;
getCachedTools.mockResolvedValue({
tool1_mcp_server1: {},
tool2_mcp_server2: {},
another_tool: {},
});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -1568,6 +1556,15 @@ describe('models/Agent', () => {
mcp: ['server1', 'server2'],
},
},
app: {
locals: {
availableTools: {
tool1_mcp_server1: {},
tool2_mcp_server2: {},
another_tool: {},
},
},
},
};
const result = await loadAgent({
@@ -1660,8 +1657,6 @@ describe('models/Agent', () => {
test('should handle ephemeral agent with no MCP servers', async () => {
const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants;
getCachedTools.mockResolvedValue({});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -1672,6 +1667,11 @@ describe('models/Agent', () => {
mcp: [],
},
},
app: {
locals: {
availableTools: {},
},
},
};
const result = await loadAgent({
@@ -1692,13 +1692,16 @@ describe('models/Agent', () => {
test('should handle ephemeral agent with undefined ephemeralAgent in body', async () => {
const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants;
getCachedTools.mockResolvedValue({});
const mockReq = {
user: { id: 'user123' },
body: {
promptPrefix: 'Basic instructions',
},
app: {
locals: {
availableTools: {},
},
},
};
const result = await loadAgent({
@@ -1731,13 +1734,6 @@ describe('models/Agent', () => {
const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants;
const largeToolList = Array.from({ length: 100 }, (_, i) => `tool_${i}_mcp_server1`);
const availableTools = largeToolList.reduce((acc, tool) => {
acc[tool] = {};
return acc;
}, {});
getCachedTools.mockResolvedValue(availableTools);
const mockReq = {
user: { id: 'user123' },
body: {
@@ -1748,6 +1744,14 @@ describe('models/Agent', () => {
mcp: ['server1'],
},
},
app: {
locals: {
availableTools: largeToolList.reduce((acc, tool) => {
acc[tool] = {};
return acc;
}, {}),
},
},
};
const result = await loadAgent({
@@ -1798,7 +1802,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();
@@ -2268,13 +2272,6 @@ describe('models/Agent', () => {
test('should handle loadEphemeralAgent with malformed MCP tool names', async () => {
const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants;
getCachedTools.mockResolvedValue({
malformed_tool_name: {}, // No mcp delimiter
tool__server1: {}, // Wrong delimiter
tool_mcp_server1: {}, // Correct format
tool_mcp_server2: {}, // Different server
});
const mockReq = {
user: { id: 'user123' },
body: {
@@ -2285,6 +2282,16 @@ describe('models/Agent', () => {
mcp: ['server1'],
},
},
app: {
locals: {
availableTools: {
malformed_tool_name: {}, // No mcp delimiter
tool__server1: {}, // Wrong delimiter
tool_mcp_server1: {}, // Correct format
tool_mcp_server2: {}, // Different server
},
},
},
};
const result = await loadAgent({
@@ -2350,7 +2357,7 @@ describe('models/Agent', () => {
const mongoUri = mongoServer.getUri();
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
await mongoose.connect(mongoUri);
}, 20000);
});
afterAll(async () => {
await mongoose.disconnect();

View File

@@ -1,6 +1,4 @@
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models');
@@ -100,15 +98,10 @@ module.exports = {
update.conversationId = newConversationId;
}
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
update.expiredAt = null;
}
if (req.body.isTemporary) {
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
} else {
update.expiredAt = null;
}

View File

@@ -1,7 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
const { getProjectByName } = require('./Project');
const { getAgent } = require('./Agent');
const { EToolResources } = require('librechat-data-provider');
const { File } = require('~/db/models');
/**
@@ -14,119 +12,17 @@ const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean();
};
/**
* Checks if a user has access to multiple files through a shared agent (batch operation)
* @param {string} userId - The user ID to check access for
* @param {string[]} fileIds - Array of file IDs to check
* @param {string} agentId - The agent ID that might grant access
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId) => {
const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false));
try {
const agent = await getAgent({ id: agentId });
if (!agent) {
return accessMap;
}
// Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId) {
fileIds.forEach((fileId) => accessMap.set(fileId, true));
return accessMap;
}
// Check if agent is shared with the user via projects
if (!agent.projectIds || agent.projectIds.length === 0) {
return accessMap;
}
// Check if agent is in global project
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
) {
return accessMap;
}
// Agent is globally shared - check if it's collaborative
if (!agent.isCollaborative) {
return accessMap;
}
// Agent is globally shared and collaborative - check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap;
} catch (error) {
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
return accessMap;
}
};
/**
* Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field.
* @param {Object} [options] - Additional options
* @param {string} [options.userId] - User ID for access control
* @param {string} [options.agentId] - Agent ID that might grant access to files
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
*/
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions };
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
// If userId and agentId are provided, filter files based on access
if (options.userId && options.agentId) {
// Collect file IDs that need access check
const filesToCheck = [];
const ownedFiles = [];
for (const file of files) {
if (file.user && file.user.toString() === options.userId) {
ownedFiles.push(file);
} else {
filesToCheck.push(file);
}
}
if (filesToCheck.length === 0) {
return ownedFiles;
}
// Batch check access for all non-owned files
const fileIds = filesToCheck.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(options.userId, fileIds, options.agentId);
// Filter files based on access
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
return [...ownedFiles, ...accessibleFiles];
}
return files;
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
};
/**
@@ -136,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, option
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
*/
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
if (!fileIds || !fileIds.length) {
return [];
}
try {
const filter = {
file_id: { $in: fileIds },
$or: [],
};
if (toolResourceSet.has(EToolResources.ocr)) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
if (toolResourceSet.size) {
filter.$or = [];
}
if (toolResourceSet.has(EToolResources.file_search)) {
filter.$or.push({ embedded: true });
}
@@ -280,5 +176,4 @@ module.exports = {
deleteFiles,
deleteFileByFilter,
batchUpdateFiles,
hasAccessToFilesViaAgent,
};

View File

@@ -1,264 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { fileSchema } = require('@librechat/data-schemas');
const { agentSchema } = require('@librechat/data-schemas');
const { projectSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { getFiles, createFile } = require('./File');
const { getProjectByName } = require('./Project');
const { createAgent } = require('./Agent');
let File;
let Agent;
let Project;
describe('File Access Control', () => {
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
File = mongoose.models.File || mongoose.model('File', fileSchema);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await File.deleteMany({});
await Agent.deleteMany({});
await Project.deleteMany({});
});
describe('hasAccessToFilesViaAgent', () => {
it('should efficiently check access for multiple files at once', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
// Create files
for (const fileId of fileIds) {
await createFile({
user: authorId,
file_id: fileId,
filename: `file-${fileId}.txt`,
filepath: `/uploads/${fileId}`,
});
}
// Create agent with only first two files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileIds[0], fileIds[1]],
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for all files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have access only to the first two files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false);
expect(accessMap.get(fileIds[3])).toBe(false);
});
it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create agent
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]], // Only one file attached
},
},
});
// Check access as the author
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should handle non-existent agent gracefully', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const fileIds = [uuidv4(), uuidv4()];
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
// Should have no access to any files
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when agent is not collaborative', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create agent with files but isCollaborative: false
await createAgent({
id: agentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have no access to any files when isCollaborative is false
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
});
describe('getFiles with agent access control', () => {
test('should return files owned by user and files accessible through agent', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
const ownedFileId = `file_${uuidv4()}`;
const sharedFileId = `file_${uuidv4()}`;
const inaccessibleFileId = `file_${uuidv4()}`;
// Create/get global project using getProjectByName which will upsert
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
// Create agent with shared file
await createAgent({
id: agentId,
name: 'Shared Agent',
provider: 'test',
model: 'test-model',
author: authorId,
projectIds: [globalProject._id],
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [sharedFileId],
},
},
});
// Create files
await createFile({
file_id: ownedFileId,
user: userId,
filename: 'owned.txt',
filepath: '/uploads/owned.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: sharedFileId,
user: authorId,
filename: 'shared.txt',
filepath: '/uploads/shared.txt',
type: 'text/plain',
bytes: 200,
embedded: true,
});
await createFile({
file_id: inaccessibleFileId,
user: authorId,
filename: 'inaccessible.txt',
filepath: '/uploads/inaccessible.txt',
type: 'text/plain',
bytes: 300,
});
// Get files with access control
const files = await getFiles(
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
null,
{ text: 0 },
{ userId: userId.toString(), agentId },
);
expect(files).toHaveLength(2);
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
});
test('should return all files when no userId/agentId provided', async () => {
const userId = new mongoose.Types.ObjectId();
const fileId1 = `file_${uuidv4()}`;
const fileId2 = `file_${uuidv4()}`;
await createFile({
file_id: fileId1,
user: userId,
filename: 'file1.txt',
filepath: '/uploads/file1.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: fileId2,
user: new mongoose.Types.ObjectId(),
filename: 'file2.txt',
filepath: '/uploads/file2.txt',
type: 'text/plain',
bytes: 200,
});
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
expect(files).toHaveLength(2);
});
});
});

View File

@@ -1,7 +1,5 @@
const { z } = require('zod');
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { Message } = require('~/db/models');
const idSchema = z.string().uuid();
@@ -56,14 +54,9 @@ async function saveMessage(req, params, metadata) {
};
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.expiredAt = null;
}
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
} else {
update.expiredAt = null;
}

42
api/models/Token.js Normal file
View File

@@ -0,0 +1,42 @@
const { findToken, updateToken, createToken } = require('~/models');
const { encryptV2 } = require('~/server/utils/crypto');
/**
* Handles the OAuth token by creating or updating the token.
* @param {object} fields
* @param {string} fields.userId - The user's ID.
* @param {string} fields.token - The full token to store.
* @param {string} fields.identifier - Unique, alternative identifier for the token.
* @param {number} fields.expiresIn - The number of seconds until the token expires.
* @param {object} fields.metadata - Additional metadata to store with the token.
* @param {string} [fields.type="oauth"] - The type of token. Default is 'oauth'.
*/
async function handleOAuthToken({
token,
userId,
identifier,
expiresIn,
metadata,
type = 'oauth',
}) {
const encrypedToken = await encryptV2(token);
const tokenData = {
type,
userId,
metadata,
identifier,
token: encrypedToken,
expiresIn: parseInt(expiresIn, 10) || 3600,
};
const existingToken = await findToken({ userId, identifier });
if (existingToken) {
return await updateToken({ identifier }, tokenData);
} else {
return await createToken(tokenData);
}
}
module.exports = {
handleOAuthToken,
};

View File

@@ -1,6 +1,6 @@
const mongoose = require('mongoose');
const { getRandomValues } = require('@librechat/api');
const { logger, hashToken } = require('@librechat/data-schemas');
const { getRandomValues } = require('~/server/utils/crypto');
const { createToken, findToken } = require('~/models');
/**

View File

@@ -78,7 +78,7 @@ const tokenValues = Object.assign(
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'o4-mini': { prompt: 1.1, completion: 4.4 },
'o3-mini': { prompt: 1.1, completion: 4.4 },
o3: { prompt: 2, completion: 8 },
o3: { prompt: 10, completion: 40 },
'o1-mini': { prompt: 1.1, completion: 4.4 },
'o1-preview': { prompt: 15, completion: 60 },
o1: { prompt: 15, completion: 60 },
@@ -135,11 +135,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-4': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 },

View File

@@ -636,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
);
});
test('should return correct prompt and completion rates for Grok 4 model', () => {
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
tokenValues['grok-3'].prompt,
@@ -671,15 +662,6 @@ describe('Grok Model Tests - Pricing', () => {
tokenValues['grok-3-mini-fast'].completion,
);
});
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
});
});

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "v0.7.9-rc1",
"version": "v0.7.8",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -34,27 +34,28 @@
},
"homepage": "https://librechat.ai",
"dependencies": {
"@anthropic-ai/sdk": "^0.52.0",
"@anthropic-ai/sdk": "^0.37.0",
"@aws-sdk/client-s3": "^3.758.0",
"@aws-sdk/s3-request-presigner": "^3.758.0",
"@azure/identity": "^4.7.0",
"@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.27.0",
"@google/generative-ai": "^0.24.0",
"@google/generative-ai": "^0.23.0",
"@googleapis/youtube": "^20.0.0",
"@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.47",
"@langchain/core": "^0.3.60",
"@langchain/google-genai": "^0.2.13",
"@langchain/google-vertexai": "^0.2.13",
"@langchain/community": "^0.3.44",
"@langchain/core": "^0.3.57",
"@langchain/google-genai": "^0.2.9",
"@langchain/google-vertexai": "^0.2.9",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.59",
"@librechat/agents": "^2.4.38",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"cohere-ai": "^7.9.1",
"compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.7.2",

View File

@@ -169,6 +169,9 @@ function disposeClient(client) {
client.isGenerativeModel = null;
}
// Properties specific to OpenAIClient
if (client.ChatGPTClient) {
client.ChatGPTClient = null;
}
if (client.completionsUrl) {
client.completionsUrl = null;
}

View File

@@ -0,0 +1,282 @@
const { getResponseSender, Constants } = require('librechat-data-provider');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
endpointOption,
conversationId,
modelDisplayLabel,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null;
logger.debug('[AskController]', {
text,
conversationId,
...endpointOption,
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
});
let userMessage = null;
let userMessagePromise = null;
let promptTokens = null;
let userMessageId = null;
let responseMessageId = null;
let getAbortData = null;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
modelDisplayLabel,
});
const initialConversationId = conversationId;
const newConvo = !initialConversationId;
const userId = req.user.id;
let reqDataContext = {
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
};
const updateReqData = (data = {}) => {
reqDataContext = processReqData(data, reqDataContext);
abortKey = reqDataContext.abortKey;
userMessage = reqDataContext.userMessage;
userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = reqDataContext.responseMessageId;
promptTokens = reqDataContext.promptTokens;
conversationId = reqDataContext.conversationId;
userMessageId = reqDataContext.userMessageId;
};
let { onProgress: progressCallback, getPartialText } = createOnProgress();
const performCleanup = () => {
logger.debug('[AskController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore
}
}
}
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
addTitle = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AskController] Cleanup completed');
};
try {
({ client } = await initializeClient({ req, res, endpointOption }));
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
if (client) {
requestDataMap.set(req, { client });
}
clientRef = new WeakRef(client);
getAbortData = () => {
const currentClient = clientRef?.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[AskController] Request closed');
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AskController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
});
const messageOptions = {
user: userId,
parentMessageId,
conversationId: reqDataContext.conversationId,
overrideParentMessageId,
getReqData: updateReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
res,
},
};
/** @type {TMessage} */
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;
const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
const latestUserMessage = reqDataContext.userMessage;
if (client?.options?.attachments && latestUserMessage) {
latestUserMessage.files = client.options.attachments;
if (endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model;
}
delete latestUserMessage.image_urls;
}
if (!abortController.signal.aborted) {
const finalResponseMessage = { ...response };
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: latestUserMessage,
responseMessage: finalResponseMessage,
});
res.end();
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
await saveMessage(
req,
{ ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/AskController.js - response end' },
);
}
}
if (!client?.skipSaveUserMessage && latestUserMessage) {
await saveMessage(req, latestUserMessage, {
context: "api/server/controllers/AskController.js - don't skip saving user message",
});
}
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response: { ...response },
client,
})
.then(() => {
logger.debug('[AskController] Title generation started');
})
.catch((err) => {
logger.error('[AskController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AskController] Title generation completed');
performCleanup();
});
} else {
performCleanup();
}
} catch (error) {
logger.error('[AskController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef?.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[AskController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, {
sender,
partialText,
conversationId: reqDataContext.conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
userMessageId: reqDataContext.userMessageId,
})
.catch((err) => {
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
});
}
};
module.exports = AskController;

View File

@@ -1,17 +1,17 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const openIdClient = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
requestPasswordReset,
setOpenIDAuthTokens,
registerUser,
resetPassword,
setAuthTokens,
registerUser,
requestPasswordReset,
setOpenIDAuthTokens,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getOpenIdConfig } = require('~/strategies');
const { isEnabled } = require('~/server/utils');
const registrationController = async (req, res) => {
try {

View File

@@ -1,5 +1,3 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { getResponseSender } = require('librechat-data-provider');
const {
handleAbortError,
@@ -12,8 +10,9 @@ const {
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { createOnProgress } = require('~/server/utils');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
let {
@@ -85,7 +84,7 @@ const EditController = async (req, res, next, initializeClient) => {
}
if (abortKey) {
logger.debug('[EditController] Cleaning up abort controller');
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
@@ -199,7 +198,7 @@ const EditController = async (req, res, next, initializeClient) => {
const finalUserMessage = reqDataContext.userMessage;
const finalResponseMessage = { ...response };
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -24,23 +24,17 @@ const handleValidationError = (err, res) => {
}
};
module.exports = (err, _req, res, _next) => {
// eslint-disable-next-line no-unused-vars
module.exports = (err, req, res, next) => {
try {
if (err.name === 'ValidationError') {
return handleValidationError(err, res);
return (err = handleValidationError(err, res));
}
if (err.code && err.code == 11000) {
return handleDuplicateKeyError(err, res);
return (err = handleDuplicateKeyError(err, res));
}
// Special handling for errors like SyntaxError
if (err.statusCode && err.body) {
return res.status(err.statusCode).send(err.body);
}
logger.error('ErrorController => error', err);
return res.status(500).send('An unknown error occurred.');
} catch (err) {
logger.error('ErrorController => processing error', err);
return res.status(500).send('Processing error in ErrorController.');
logger.error('ErrorController => error', err);
res.status(500).send('An unknown error occurred.');
}
};

View File

@@ -1,241 +0,0 @@
const errorController = require('./ErrorController');
const { logger } = require('~/config');
// Mock the logger
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
},
}));
describe('ErrorController', () => {
let mockReq, mockRes, mockNext;
beforeEach(() => {
mockReq = {};
mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
mockNext = jest.fn();
logger.error.mockClear();
});
describe('ValidationError handling', () => {
it('should handle ValidationError with single error', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '["Email is required"]',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with multiple errors', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
password: { message: 'Password is required', path: 'password' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '"Email is required Password is required"',
fields: '["email","password"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with empty errors object', () => {
const validationError = {
name: 'ValidationError',
errors: {},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '[]',
fields: '[]',
});
});
});
describe('Duplicate key error handling', () => {
it('should handle duplicate key error (code 11000)', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle duplicate key error with multiple fields', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com', username: 'testuser' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email","username"] already exists.',
fields: '["email","username"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle error with code 11000 as string', () => {
const duplicateKeyError = {
code: '11000',
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
});
});
describe('SyntaxError handling', () => {
it('should handle errors with statusCode and body', () => {
const syntaxError = {
statusCode: 400,
body: 'Invalid JSON syntax',
};
errorController(syntaxError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
});
it('should handle errors with different statusCode and body', () => {
const customError = {
statusCode: 422,
body: { error: 'Unprocessable entity' },
};
errorController(customError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(422);
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
});
it('should handle error with statusCode but no body', () => {
const partialError = {
statusCode: 400,
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
it('should handle error with body but no statusCode', () => {
const partialError = {
body: 'Some error message',
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
});
describe('Unknown error handling', () => {
it('should handle unknown errors', () => {
const unknownError = new Error('Some unknown error');
errorController(unknownError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
});
it('should handle errors with code other than 11000', () => {
const mongoError = {
code: 11100,
message: 'Some MongoDB error',
};
errorController(mongoError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
});
it('should handle null/undefined errors', () => {
errorController(null, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledWith(
'ErrorController => processing error',
expect.any(Error),
);
});
});
describe('Catch block handling', () => {
beforeEach(() => {
// Restore logger mock to normal behavior for these tests
logger.error.mockRestore();
logger.error = jest.fn();
});
it('should handle errors when logger.error throws', () => {
// Create fresh mocks for this test
const freshMockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
// Mock logger to throw on the first call, succeed on the second
logger.error
.mockImplementationOnce(() => {
throw new Error('Logger error');
})
.mockImplementation(() => {});
const testError = new Error('Test error');
errorController(testError, mockReq, freshMockRes, mockNext);
expect(freshMockRes.status).toHaveBeenCalledWith(500);
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledTimes(2);
});
});
});

View File

@@ -1,9 +1,8 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
const { CacheKeys, AuthType } = require('librechat-data-provider');
const { getToolkitKey } = require('~/server/services/ToolService');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCustomConfig } = require('~/server/services/Config');
const { availableTools } = require('~/app/clients/tools');
const { getMCPManager } = require('~/config');
const { getLogStores } = require('~/cache');
/**
@@ -85,45 +84,6 @@ const getAvailablePluginsController = async (req, res) => {
}
};
function createServerToolsCallback() {
/**
* @param {string} serverName
* @param {TPlugin[] | null} serverTools
*/
return async function (serverName, serverTools) {
try {
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
if (!serverName || !mcpToolsCache) {
return;
}
await mcpToolsCache.set(serverName, serverTools);
logger.debug(`MCP tools for ${serverName} added to cache.`);
} catch (error) {
logger.error('Error retrieving MCP tools from cache:', error);
}
};
}
function createGetServerTools() {
/**
* Retrieves cached server tools
* @param {string} serverName
* @returns {Promise<TPlugin[] | null>}
*/
return async function (serverName) {
try {
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
if (!mcpToolsCache) {
return null;
}
return await mcpToolsCache.get(serverName);
} catch (error) {
logger.error('Error retrieving MCP tools from cache:', error);
return null;
}
};
}
/**
* Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file.
*
@@ -139,9 +99,9 @@ function createGetServerTools() {
const getAvailableTools = async (req, res) => {
try {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
if (cachedToolsArray) {
res.status(200).json(cachedToolsArray);
const cachedTools = await cache.get(CacheKeys.TOOLS);
if (cachedTools) {
res.status(200).json(cachedTools);
return;
}
@@ -149,16 +109,7 @@ const getAvailableTools = async (req, res) => {
const customConfig = await getCustomConfig();
if (customConfig?.mcpServers != null) {
const mcpManager = getMCPManager();
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
const serverToolsCallback = createServerToolsCallback();
const getServerTools = createGetServerTools();
const mcpTools = await mcpManager.loadManifestTools({
flowManager,
serverToolsCallback,
getServerTools,
});
pluginManifest = [...mcpTools, ...pluginManifest];
pluginManifest = await mcpManager.loadManifestTools(pluginManifest);
}
/** @type {TPlugin[]} */
@@ -172,57 +123,17 @@ const getAvailableTools = async (req, res) => {
}
});
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
const toolDefinitions = req.app.locals.availableTools;
const tools = authenticatedPlugins.filter(
(plugin) =>
toolDefinitions[plugin.pluginKey] !== undefined ||
(plugin.toolkit === true &&
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)),
);
const toolsOutput = [];
for (const plugin of authenticatedPlugins) {
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
const isToolkit =
plugin.toolkit === true &&
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey);
if (!isToolDefined && !isToolkit) {
continue;
}
const toolToAdd = { ...plugin };
if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) {
toolsOutput.push(toolToAdd);
continue;
}
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
const serverName = parts[parts.length - 1];
const serverConfig = customConfig?.mcpServers?.[serverName];
if (!serverConfig?.customUserVars) {
toolsOutput.push(toolToAdd);
continue;
}
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length === 0) {
toolToAdd.authConfig = [];
toolToAdd.authenticated = true;
} else {
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
toolToAdd.authenticated = false;
}
toolsOutput.push(toolToAdd);
}
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
res.status(200).json(finalTools);
await cache.set(CacheKeys.TOOLS, tools);
res.status(200).json(tools);
} catch (error) {
logger.error('[getAvailableTools]', error);
res.status(500).json({ message: error.message });
}
};

View File

@@ -1,4 +1,3 @@
const { encryptV3 } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
verifyTOTP,
@@ -8,6 +7,7 @@ const {
generateBackupCodes,
} = require('~/server/services/twoFactorService');
const { getUserById, updateUser } = require('~/models');
const { encryptV3 } = require('~/server/utils/crypto');
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');

View File

@@ -1,6 +1,5 @@
const {
Tools,
Constants,
FileSources,
webSearchKeys,
extractWebSearchEnvVars,
@@ -24,7 +23,6 @@ const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User } = require('~/db/models');
const { deleteToolCalls } = require('~/models/ToolCall');
const { deleteAllSharedLinks } = require('~/models');
const { getMCPManager } = require('~/config');
const getUserController = async (req, res) => {
/** @type {MongoUser} */
@@ -104,22 +102,10 @@ const updateUserPluginsController = async (req, res) => {
}
let keys = Object.keys(auth);
const values = Object.values(auth); // Used in 'install' block
const isMCPTool = pluginKey.startsWith('mcp_') || pluginKey.includes(Constants.mcp_delimiter);
// Early exit condition:
// If keys are empty (meaning auth: {} was likely sent for uninstall, or auth was empty for install)
// AND it's not web_search (which has special key handling to populate `keys` for uninstall)
// AND it's NOT (an uninstall action FOR an MCP tool - we need to proceed for this case to clear all its auth)
// THEN return.
if (
keys.length === 0 &&
pluginKey !== Tools.web_search &&
!(action === 'uninstall' && isMCPTool)
) {
if (keys.length === 0 && pluginKey !== Tools.web_search) {
return res.status(200).send();
}
const values = Object.values(auth);
/** @type {number} */
let status = 200;
@@ -146,53 +132,16 @@ const updateUserPluginsController = async (req, res) => {
}
}
} else if (action === 'uninstall') {
// const isMCPTool was defined earlier
if (isMCPTool && keys.length === 0) {
// This handles the case where auth: {} is sent for an MCP tool uninstall.
// It means "delete all credentials associated with this MCP pluginKey".
authService = await deleteUserPluginAuth(user.id, null, true, pluginKey);
for (let i = 0; i < keys.length; i++) {
authService = await deleteUserPluginAuth(user.id, keys[i]);
if (authService instanceof Error) {
logger.error(
`[authService] Error deleting all auth for MCP tool ${pluginKey}:`,
authService,
);
logger.error('[authService]', authService);
({ status, message } = authService);
}
} else {
// This handles:
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
// 2. Other tools uninstall (if keys were provided).
// 3. MCP tool uninstall if specific keys were provided in `auth` (not current frontend behavior).
// If keys is empty for non-MCP tools (and not web_search), this loop won't run, and nothing is deleted.
for (let i = 0; i < keys.length; i++) {
authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name
if (authService instanceof Error) {
logger.error('[authService] Error deleting specific auth key:', authService);
({ status, message } = authService);
}
}
}
}
if (status === 200) {
// If auth was updated successfully, disconnect MCP sessions as they might use these credentials
if (pluginKey.startsWith(Constants.mcp_prefix)) {
try {
const mcpManager = getMCPManager(user.id);
if (mcpManager) {
logger.info(
`[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`,
);
await mcpManager.disconnectUserConnections(user.id);
}
} catch (disconnectError) {
logger.error(
`[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`,
disconnectError,
);
// Do not fail the request for this, but log it.
}
}
return res.status(status).send();
}

View File

@@ -1,195 +0,0 @@
const { duplicateAgent } = require('../v1');
const { getAgent, createAgent } = require('~/models/Agent');
const { getActions } = require('~/models/Action');
const { nanoid } = require('nanoid');
jest.mock('~/models/Agent');
jest.mock('~/models/Action');
jest.mock('nanoid');
describe('duplicateAgent', () => {
let req, res;
beforeEach(() => {
req = {
params: { id: 'agent_123' },
user: { id: 'user_456' },
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
jest.clearAllMocks();
});
it('should duplicate an agent successfully', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_789',
versions: [{ name: 'Test Agent', version: 1 }],
__v: 0,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_456',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
id: 'agent_new_123',
author: 'user_456',
name: expect.stringContaining('Test Agent ('),
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
}),
);
expect(createAgent).toHaveBeenCalledWith(
expect.not.objectContaining({
versions: expect.anything(),
__v: expect.anything(),
}),
);
expect(res.status).toHaveBeenCalledWith(201);
expect(res.json).toHaveBeenCalledWith({
agent: mockNewAgent,
actions: [],
});
});
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
versions: [
{
name: 'Test Agent',
versions: [{ name: 'Nested' }],
__v: 1,
},
],
__v: 2,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(mockNewAgent.versions).toHaveLength(1);
const firstVersion = mockNewAgent.versions[0];
expect(firstVersion).not.toHaveProperty('versions');
expect(firstVersion).not.toHaveProperty('__v');
expect(mockNewAgent).not.toHaveProperty('__v');
expect(res.status).toHaveBeenCalledWith(201);
});
it('should return 404 if agent not found', async () => {
getAgent.mockResolvedValue(null);
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Agent not found',
status: 'error',
});
});
it('should handle tool_resources.ocr correctly', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
tool_resources: {
ocr: { enabled: true, config: 'test' },
other: { should: 'not be copied' },
},
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue({ id: 'agent_new_123' });
await duplicateAgent(req, res);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
tool_resources: {
ocr: { enabled: true, config: 'test' },
},
}),
);
});
it('should handle errors gracefully', async () => {
getAgent.mockRejectedValue(new Error('Database error'));
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(500);
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
});
});

View File

@@ -1,18 +1,14 @@
require('events').EventEmitter.defaultMaxListeners = 100;
const { logger } = require('@librechat/data-schemas');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const {
sendEvent,
createRun,
Tokenizer,
checkAccess,
memoryInstructions,
createMemoryProcessor,
} = require('@librechat/api');
const {
Callback,
Providers,
GraphEvents,
formatMessage,
formatAgentMessages,
@@ -33,35 +29,20 @@ const {
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const {
findPluginAuthsByKeys,
getFormattedMemories,
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { setMemory, deleteMemory, getFormattedMemories } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const { checkAccess } = require('~/server/middleware/roles/access');
const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent');
const { getMCPManager } = require('~/config');
const omitTitleOptions = new Set([
'stream',
'thinking',
'streaming',
'clientOptions',
'thinkingConfig',
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
'additionalModelRequestFields',
]);
/**
* @param {ServerRequest} req
* @param {Agent} agent
@@ -408,12 +389,7 @@ class AgentClient extends BaseClient {
if (user.personalization?.memories === false) {
return;
}
const hasAccess = await checkAccess({
user,
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE],
getRoleByName,
});
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
if (!hasAccess) {
logger.debug(
@@ -531,10 +507,7 @@ class AgentClient extends BaseClient {
messagesToProcess = [...messages.slice(-messageWindowSize)];
}
}
const bufferString = getBufferString(messagesToProcess);
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
return await this.processMemory([bufferMessage]);
return await this.processMemory(messagesToProcess);
} catch (error) {
logger.error('Memory Agent failed to process memory', error);
}
@@ -700,7 +673,7 @@ class AgentClient extends BaseClient {
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
user: this.options.req.user,
},
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
recursionLimit: agentsEConfig?.recursionLimit,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',
@@ -825,21 +798,6 @@ class AgentClient extends BaseClient {
run.Graph.contentData = contentData;
}
try {
if (await hasCustomUserVars()) {
config.configurable.userMCPAuthMap = await getMCPAuthMap({
tools: agent.tools,
userId: this.options.req.user.id,
findPluginAuthsByKeys,
});
}
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
err,
);
}
await run.processStream({ messages }, config, {
keepContent: i !== 0,
tokenCounter: createTokenCounter(this.getEncoding()),
@@ -1005,26 +963,23 @@ class AgentClient extends BaseClient {
throw new Error('Run not initialized');
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const { req, res, agent } = this.options;
const endpoint = agent.endpoint;
const endpoint = this.options.agent.endpoint;
const { req, res } = this.options;
/** @type {import('@librechat/agents').ClientOptions} */
let clientOptions = {
maxTokens: 75,
model: agent.model_parameters.model,
};
const { getOptions, overrideProvider, customEndpointConfig } =
await getProviderConfig(endpoint);
/** @type {TEndpoint | undefined} */
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
let endpointConfig = req.app.locals[endpoint];
if (!endpointConfig) {
logger.warn(
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
);
try {
endpointConfig = await getCustomEndpointConfig(endpoint);
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
err,
);
}
}
if (
endpointConfig &&
endpointConfig.titleModel &&
@@ -1032,56 +987,30 @@ class AgentClient extends BaseClient {
) {
clientOptions.model = endpointConfig.titleModel;
}
const options = await getOptions({
req,
res,
optionsOnly: true,
overrideEndpoint: endpoint,
overrideModel: clientOptions.model,
endpointOption: { model_parameters: clientOptions },
});
let provider = options.provider ?? overrideProvider ?? agent.provider;
if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName == null
clientOptions.model &&
this.options.agent.model_parameters.model !== clientOptions.model
) {
provider = Providers.OPENAI;
} else if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName != null &&
provider !== Providers.AZURE
) {
provider = Providers.AZURE;
clientOptions =
(
await initOpenAI({
req,
res,
optionsOnly: true,
overrideModel: clientOptions.model,
overrideEndpoint: endpoint,
endpointOption: {
model_parameters: clientOptions,
},
})
)?.llmConfig ?? clientOptions;
}
/** @type {import('@librechat/agents').ClientOptions} */
clientOptions = { ...options.llmConfig };
if (options.configOptions) {
clientOptions.configuration = options.configOptions;
}
// Ensure maxTokens is set for non-o1 models
if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) {
clientOptions.maxTokens = 75;
} else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens;
}
clientOptions = Object.assign(
Object.fromEntries(
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
),
);
if (provider === Providers.GOOGLE) {
clientOptions.json = true;
}
try {
const titleResult = await this.run.generateTitle({
provider,
inputText: text,
contentParts: this.contentParts,
clientOptions,
@@ -1099,10 +1028,8 @@ class AgentClient extends BaseClient {
let input_tokens, output_tokens;
if (item.usage) {
input_tokens =
item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens;
output_tokens =
item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens;
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
} else if (item.tokenUsage) {
input_tokens = item.tokenUsage.promptTokens;
output_tokens = item.tokenUsage.completionTokens;

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { sendResponse } = require('~/server/middleware/error');
const { recordUsage } = require('~/server/services/Threads');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendResponse } = require('~/server/utils');
/**
* @typedef {Object} ErrorHandlerContext
@@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -0,0 +1,106 @@
const { HttpsProxyAgent } = require('https-proxy-agent');
const { resolveHeaders } = require('librechat-data-provider');
const { createLLM } = require('~/app/clients/llm');
/**
* Initializes and returns a Language Learning Model (LLM) instance.
*
* @param {Object} options - Configuration options for the LLM.
* @param {string} options.model - The model identifier.
* @param {string} options.modelName - The specific name of the model.
* @param {number} options.temperature - The temperature setting for the model.
* @param {number} options.presence_penalty - The presence penalty for the model.
* @param {number} options.frequency_penalty - The frequency penalty for the model.
* @param {number} options.max_tokens - The maximum number of tokens for the model output.
* @param {boolean} options.streaming - Whether to use streaming for the model output.
* @param {Object} options.context - The context for the conversation.
* @param {number} options.tokenBuffer - The token buffer size.
* @param {number} options.initialMessageCount - The initial message count.
* @param {string} options.conversationId - The ID of the conversation.
* @param {string} options.user - The user identifier.
* @param {string} options.langchainProxy - The langchain proxy URL.
* @param {boolean} options.useOpenRouter - Whether to use OpenRouter.
* @param {Object} options.options - Additional options.
* @param {Object} options.options.headers - Custom headers for the request.
* @param {string} options.options.proxy - Proxy URL.
* @param {Object} options.options.req - The request object.
* @param {Object} options.options.res - The response object.
* @param {boolean} options.options.debug - Whether to enable debug mode.
* @param {string} options.apiKey - The API key for authentication.
* @param {Object} options.azure - Azure-specific configuration.
* @param {Object} options.abortController - The AbortController instance.
* @returns {Object} The initialized LLM instance.
*/
function initializeLLM(options) {
const {
model,
modelName,
temperature,
presence_penalty,
frequency_penalty,
max_tokens,
streaming,
user,
langchainProxy,
useOpenRouter,
options: { headers, proxy },
apiKey,
azure,
} = options;
const modelOptions = {
modelName: modelName || model,
temperature,
presence_penalty,
frequency_penalty,
user,
};
if (max_tokens) {
modelOptions.max_tokens = max_tokens;
}
const configOptions = {};
if (langchainProxy) {
configOptions.basePath = langchainProxy;
}
if (useOpenRouter) {
configOptions.basePath = 'https://openrouter.ai/api/v1';
configOptions.baseOptions = {
headers: {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
};
}
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
configOptions.baseOptions = {
headers: resolveHeaders({
...headers,
...configOptions?.baseOptions?.headers,
}),
};
}
if (proxy) {
configOptions.httpAgent = new HttpsProxyAgent(proxy);
configOptions.httpsAgent = new HttpsProxyAgent(proxy);
}
const llm = createLLM({
modelOptions,
configOptions,
openAIApiKey: apiKey,
azure,
streaming,
});
return llm;
}
module.exports = {
initializeLLM,
};

View File

@@ -1,5 +1,3 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
handleAbortError,
@@ -7,18 +5,17 @@ const {
cleanupAbortController,
} = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
endpointOption,
conversationId,
isContinued = false,
editedContent = null,
parentMessageId = null,
overrideParentMessageId = null,
responseMessageId: editedResponseMessageId = null,
} = req.body;
let sender;
@@ -70,7 +67,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
handler();
}
} catch (e) {
logger.error('[AgentController] Error in cleanup handler', e);
// Ignore cleanup errors
}
}
}
@@ -158,7 +155,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
logger.error('[AgentController] Error removing close listener', e);
// Ignore
}
});
@@ -166,14 +163,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
user: userId,
onStart,
getReqData,
isContinued,
editedContent,
conversationId,
parentMessageId,
abortController,
overrideParentMessageId,
isEdited: !!editedContent,
responseMessageId: editedResponseMessageId,
progressOptions: {
res,
},
@@ -213,7 +206,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Create a new response object with minimal copies
const finalResponse = { ...response };
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -1,16 +1,13 @@
const { z } = require('zod');
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const {
Tools,
Constants,
FileContext,
FileSources,
SystemRoles,
EToolResources,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const {
getAgent,
@@ -19,21 +16,20 @@ const {
deleteAgent,
getListAgents,
} = require('~/models/Agent');
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
const { filterFile } = require('~/server/services/Files/process');
const { updateAction, getActions } = require('~/models/Action');
const { getCachedTools } = require('~/server/services/Config');
const { updateAgentProjects } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { revertAgentVersion } = require('~/models/Agent');
const { deleteFileByFilter } = require('~/models/File');
const { revertAgentVersion } = require('~/models/Agent');
const { logger } = require('~/config');
const systemTools = {
[Tools.execute_code]: true,
[Tools.file_search]: true,
[Tools.web_search]: true,
};
/**
@@ -46,18 +42,13 @@ const systemTools = {
*/
const createAgentHandler = async (req, res) => {
try {
const validatedData = agentCreateSchema.parse(req.body);
const { tools = [], ...agentData } = removeNullishValues(validatedData);
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
const { id: userId } = req.user;
agentData.id = `agent_${nanoid()}`;
agentData.author = userId;
agentData.tools = [];
const availableTools = await getCachedTools({ includeGlobal: true });
for (const tool of tools) {
if (availableTools[tool]) {
if (req.app.locals.availableTools[tool]) {
agentData.tools.push(tool);
}
@@ -66,13 +57,19 @@ const createAgentHandler = async (req, res) => {
}
}
Object.assign(agentData, {
author: userId,
name,
description,
instructions,
provider,
model,
});
agentData.id = `agent_${nanoid()}`;
const agent = await createAgent(agentData);
res.status(201).json(agent);
} catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents] Error creating agent', error);
res.status(500).json({ error: error.message });
}
@@ -156,16 +153,14 @@ const getAgentHandler = async (req, res) => {
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const validatedData = agentUpdateSchema.parse(req.body);
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
const { projectIds, removeProjectIds, ...updateData } = req.body;
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id });
const isAuthor = existingAgent.author.toString() === req.user.id;
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
@@ -204,11 +199,6 @@ const updateAgentHandler = async (req, res) => {
return res.json(updatedAgent);
} catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents/:id] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents/:id] Error updating Agent', error);
if (error.statusCode === 409) {
@@ -251,8 +241,6 @@ const duplicateAgentHandler = async (req, res) => {
createdAt: _createdAt,
updatedAt: _updatedAt,
tool_resources: _tool_resources = {},
versions: _versions,
__v: _v,
...cloneData
} = agent;
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
@@ -391,22 +379,6 @@ const uploadAgentAvatarHandler = async (req, res) => {
return res.status(400).json({ message: 'Agent ID is required' });
}
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id: agent_id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
return res.status(403).json({
error: 'You do not have permission to modify this non-collaborative agent',
});
}
const buffer = await fs.readFile(req.file.path);
const fileStrategy = req.app.locals.fileStrategy;
@@ -429,7 +401,14 @@ const uploadAgentAvatarHandler = async (req, res) => {
source: fileStrategy,
};
let _avatar = existingAgent.avatar;
let _avatar;
try {
const agent = await getAgent({ id: agent_id });
_avatar = agent.avatar;
} catch (error) {
logger.error('[/:agent_id/avatar] Error fetching agent', error);
_avatar = {};
}
if (_avatar && _avatar.source) {
const { deleteFile } = getStrategyFunctions(_avatar.source);
@@ -451,7 +430,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
};
promises.push(
await updateAgent({ id: agent_id }, data, {
await updateAgent({ id: agent_id, author: req.user.id }, data, {
updatingUserId: req.user.id,
}),
);
@@ -466,7 +445,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
try {
await fs.unlink(req.file.path);
logger.debug('[/:agent_id/avatar] Temp. image upload file deleted');
} catch {
} catch (error) {
logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted');
}
}

View File

@@ -1,659 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { agentSchema } = require('@librechat/data-schemas');
// Only mock the dependencies that are not database-related
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn().mockResolvedValue({
web_search: true,
execute_code: true,
file_search: true,
}),
}));
jest.mock('~/models/Project', () => ({
getProjectByName: jest.fn().mockResolvedValue(null),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/services/Files/images/avatar', () => ({
resizeAvatar: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3Url: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
filterFile: jest.fn(),
}));
jest.mock('~/models/Action', () => ({
updateAction: jest.fn(),
getActions: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/models/File', () => ({
deleteFileByFilter: jest.fn(),
}));
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
*/
let Agent;
describe('Agent Controllers - Mass Assignment Protection', () => {
let mongoServer;
let mockReq;
let mockRes;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
// Reset all mocks
jest.clearAllMocks();
// Setup mock request and response objects
mockReq = {
user: {
id: new mongoose.Types.ObjectId().toString(),
role: 'USER',
},
body: {},
params: {},
app: {
locals: {
fileStrategy: 'local',
},
},
};
mockRes = {
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
};
});
describe('createAgentHandler', () => {
test('should create agent with allowed fields only', async () => {
const validData = {
name: 'Test Agent',
description: 'A test agent',
instructions: 'Be helpful',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search'],
model_parameters: { temperature: 0.7 },
tool_resources: {
file_search: { file_ids: ['file1', 'file2'] },
},
};
mockReq.body = validData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
expect(mockRes.json).toHaveBeenCalled();
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.name).toBe('Test Agent');
expect(createdAgent.description).toBe('A test agent');
expect(createdAgent.provider).toBe('openai');
expect(createdAgent.model).toBe('gpt-4');
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
expect(createdAgent.tools).toContain('web_search');
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb).toBeDefined();
expect(agentInDb.name).toBe('Test Agent');
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
});
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
const maliciousData = {
// Required fields
provider: 'openai',
model: 'gpt-4',
name: 'Malicious Agent',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
authorName: 'Hacker', // Should be stripped
isCollaborative: true, // Should be stripped on creation
versions: [], // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
id: 'custom_agent_id', // Should be overridden
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
mockReq.body = maliciousData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not set
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
expect(createdAgent.authorName).toBeUndefined();
expect(createdAgent.isCollaborative).toBeFalsy();
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
// Verify timestamps are recent (not the malicious dates)
const createdTime = new Date(createdAgent.createdAt).getTime();
const now = Date.now();
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
expect(agentInDb.authorName).toBeUndefined();
});
test('should validate required fields', async () => {
const invalidData = {
name: 'Missing Required Fields',
// Missing provider and model
};
mockReq.body = invalidData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
// Verify nothing was created in database
const count = await Agent.countDocuments();
expect(count).toBe(0);
});
test('should handle tool_resources validation', async () => {
const dataWithInvalidToolResources = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Tool Resources',
tool_resources: {
// Valid resources
file_search: {
file_ids: ['file1', 'file2'],
vector_store_ids: ['vs1'],
},
execute_code: {
file_ids: ['file3'],
},
// Invalid resource (should be stripped by schema)
invalid_resource: {
file_ids: ['file4'],
},
},
};
mockReq.body = dataWithInvalidToolResources;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.tool_resources).toBeDefined();
expect(createdAgent.tool_resources.file_search).toBeDefined();
expect(createdAgent.tool_resources.execute_code).toBeDefined();
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
});
test('should handle avatar validation', async () => {
const dataWithAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Avatar',
avatar: {
filepath: 'https://example.com/avatar.png',
source: 's3',
},
};
mockReq.body = dataWithAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.avatar).toEqual({
filepath: 'https://example.com/avatar.png',
source: 's3',
});
});
test('should handle invalid avatar format', async () => {
const dataWithInvalidAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Invalid Avatar',
avatar: 'just-a-string', // Invalid format
};
mockReq.body = dataWithInvalidAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
}),
);
});
});
describe('updateAgentHandler', () => {
let existingAgentId;
let existingAgentAuthorId;
beforeEach(async () => {
// Create an existing agent for update tests
existingAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
author: existingAgentAuthorId,
description: 'Original description',
isCollaborative: false,
versions: [
{
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
description: 'Original description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
existingAgentId = agent.id;
});
test('should update agent with allowed fields only', async () => {
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Agent',
description: 'Updated description',
model: 'gpt-4',
isCollaborative: true, // This IS allowed in updates
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(400);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Updated Agent');
expect(updatedAgent.description).toBe('Updated description');
expect(updatedAgent.model).toBe('gpt-4');
expect(updatedAgent.isCollaborative).toBe(true);
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Updated Agent');
expect(agentInDb.isCollaborative).toBe(true);
});
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Name',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
authorName: 'Hacker', // Should be stripped
id: 'different_agent_id', // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
versions: [], // Should be stripped
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not changed
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
expect(updatedAgent.authorName).toBeUndefined();
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
expect(agentInDb.id).toBe(existingAgentId);
});
test('should reject update from non-author when not collaborative', async () => {
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Unauthorized Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalledWith({
error: 'You do not have permission to modify this non-collaborative agent',
});
// Verify agent was not modified in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Original Agent');
});
test('should allow update from non-author when collaborative', async () => {
// First make the agent collaborative
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Collaborative Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Collaborative Update');
// Author field should be removed for non-author
expect(updatedAgent.author).toBeUndefined();
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Collaborative Update');
});
test('should allow admin to update any agent', async () => {
const adminUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = adminUserId;
mockReq.user.role = 'ADMIN'; // Set as admin
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Admin Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Admin Update');
});
test('should handle projectIds updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
const projectId1 = new mongoose.Types.ObjectId().toString();
const projectId2 = new mongoose.Types.ObjectId().toString();
mockReq.body = {
projectIds: [projectId1, projectId2],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent).toBeDefined();
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
});
test('should validate tool_resources in updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tool_resources: {
ocr: {
file_ids: ['ocr1', 'ocr2'],
},
execute_code: {
file_ids: ['img1'],
},
// Invalid tool resource
invalid_tool: {
file_ids: ['invalid'],
},
},
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tool_resources).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeDefined();
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
});
test('should return 404 for non-existent agent', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
mockReq.body = {
name: 'Update Non-existent',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(404);
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
});
test('should handle validation errors properly', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
model_parameters: 'invalid-not-an-object', // Should be an object
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
});
});
describe('Mass Assignment Attack Scenarios', () => {
test('should prevent setting system fields during creation', async () => {
const systemFields = {
provider: 'openai',
model: 'gpt-4',
name: 'System Fields Test',
// System fields that should never be settable by users
__v: 99,
_id: new mongoose.Types.ObjectId(),
versions: [
{
name: 'Fake Version',
provider: 'fake',
model: 'fake-model',
},
],
};
mockReq.body = systemFields;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify system fields were not affected
expect(createdAgent.__v).not.toBe(99);
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.__v).not.toBe(99);
});
test('should prevent privilege escalation through isCollaborative', async () => {
// Create a non-collaborative agent
const authorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
isCollaborative: false,
versions: [
{
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
// Try to make it collaborative as a different user
const attackerId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = attackerId;
mockReq.params.id = agent.id;
mockReq.body = {
isCollaborative: true, // Trying to escalate privileges
};
await updateAgentHandler(mockReq, mockRes);
// Should be rejected
expect(mockRes.status).toHaveBeenCalledWith(403);
// Verify in database that it's still not collaborative
const agentInDb = await Agent.findOne({ id: agent.id });
expect(agentInDb.isCollaborative).toBe(false);
});
test('should prevent author hijacking', async () => {
const originalAuthorId = new mongoose.Types.ObjectId();
const attackerId = new mongoose.Types.ObjectId();
// Admin creates an agent
mockReq.user.id = originalAuthorId.toString();
mockReq.user.role = 'ADMIN';
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Admin Agent',
author: attackerId.toString(), // Trying to set different author
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Author should be the actual user, not the attempted value
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
});
test('should strip unknown fields to prevent future vulnerabilities', async () => {
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Future Proof Test',
// Unknown fields that might be added in future
superAdminAccess: true,
bypassAllChecks: true,
internalFlag: 'secret',
futureFeature: 'exploit',
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unknown fields were stripped
expect(createdAgent.superAdminAccess).toBeUndefined();
expect(createdAgent.bypassAllChecks).toBeUndefined();
expect(createdAgent.internalFlag).toBeUndefined();
expect(createdAgent.futureFeature).toBeUndefined();
// Also check in database
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
expect(agentInDb.superAdminAccess).toBeUndefined();
expect(agentInDb.bypassAllChecks).toBeUndefined();
expect(agentInDb.internalFlag).toBeUndefined();
expect(agentInDb.futureFeature).toBeUndefined();
});
});
});

View File

@@ -1,7 +1,4 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -22,20 +19,20 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { createRunBody } = require('~/server/services/createRunBody');
const { sendResponse } = require('~/server/middleware/error');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -474,7 +471,7 @@ const chatV1 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendEvent(res, {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -590,7 +587,7 @@ const chatV1 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,7 +1,4 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -25,14 +22,15 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { sendMessage, sleep, countTokens } = require('~/server/utils');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -311,7 +309,7 @@ const chatV2 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendEvent(res, {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -434,7 +432,7 @@ const chatV2 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
const { sendResponse } = require('~/server/middleware/error');
const { getConvo } = require('~/models/Conversation');
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
/**
* @typedef {Object} ErrorHandlerContext
@@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -1,5 +1,4 @@
const fs = require('fs').promises;
const { logger } = require('@librechat/data-schemas');
const { FileContext } = require('librechat-data-provider');
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
@@ -7,9 +6,9 @@ const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { deleteAssistantActions } = require('~/server/services/ActionService');
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
const { getOpenAIClient, fetchAssistants } = require('./helpers');
const { getCachedTools } = require('~/server/services/Config');
const { manifestToolMap } = require('~/app/clients/tools');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
/**
* Create an assistant.
@@ -31,20 +30,21 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
if (typeof tool !== 'string') {
return tool;
}
const toolDefinitions = req.app.locals.availableTools;
const toolDef = toolDefinitions[tool];
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
return Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
.map(([_, val]) => val);
return (
Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
// eslint-disable-next-line no-unused-vars
.map(([_, val]) => val)
);
}
return toolDef;
@@ -135,21 +135,21 @@ const patchAssistant = async (req, res) => {
append_current_datetime,
...updateData
} = req.body;
const toolDefinitions = await getCachedTools({ includeGlobal: true });
updateData.tools = (updateData.tools ?? [])
.map((tool) => {
if (typeof tool !== 'string') {
return tool;
}
const toolDefinitions = req.app.locals.availableTools;
const toolDef = toolDefinitions[tool];
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
return Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
.map(([_, val]) => val);
return (
Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
// eslint-disable-next-line no-unused-vars
.map(([_, val]) => val)
);
}
return toolDef;

View File

@@ -1,11 +1,10 @@
const { logger } = require('@librechat/data-schemas');
const { ToolCallTypes } = require('librechat-data-provider');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { validateAndUpdateTool } = require('~/server/services/ActionService');
const { getCachedTools } = require('~/server/services/Config');
const { updateAssistantDoc } = require('~/models/Assistant');
const { manifestToolMap } = require('~/app/clients/tools');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* Create an assistant.
@@ -28,20 +27,21 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
if (typeof tool !== 'string') {
return tool;
}
const toolDefinitions = req.app.locals.availableTools;
const toolDef = toolDefinitions[tool];
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
return Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
.map(([_, val]) => val);
return (
Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
// eslint-disable-next-line no-unused-vars
.map(([_, val]) => val)
);
}
return toolDef;
@@ -125,13 +125,13 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
let hasFileSearch = false;
for (const tool of updateData.tools ?? []) {
const toolDefinitions = await getCachedTools({ includeGlobal: true });
const toolDefinitions = req.app.locals.availableTools;
let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool;
if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
actualTool = Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
// eslint-disable-next-line no-unused-vars
.map(([_, val]) => val);
} else if (!actualTool) {
continue;

View File

@@ -1,7 +1,5 @@
const { nanoid } = require('nanoid');
const { EnvVar } = require('@librechat/agents');
const { checkAccess } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Tools,
AuthType,
@@ -15,8 +13,9 @@ const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { loadTools } = require('~/app/clients/tools/util');
const { getRoleByName } = require('~/models/Role');
const { checkAccess } = require('~/server/middleware');
const { getMessage } = require('~/models/Message');
const { logger } = require('~/config');
const fieldsMap = {
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
@@ -80,7 +79,6 @@ const verifyToolAuth = async (req, res) => {
throwError: false,
});
} catch (error) {
logger.error('Error loading auth values', error);
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
return;
}
@@ -134,12 +132,7 @@ const callTool = async (req, res) => {
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
let hasAccess = true;
if (toolAccessPermType[toolId]) {
hasAccess = await checkAccess({
user: req.user,
permissionType: toolAccessPermType[toolId],
permissions: [Permissions.USE],
getRoleByName,
});
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
}
if (!hasAccess) {
logger.warn(

View File

@@ -1,22 +1,22 @@
require('dotenv').config();
const fs = require('fs');
const path = require('path');
require('module-alias')({ base: path.resolve(__dirname, '..') });
const cors = require('cors');
const axios = require('axios');
const express = require('express');
const passport = require('passport');
const compression = require('compression');
const cookieParser = require('cookie-parser');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const passport = require('passport');
const mongoSanitize = require('express-mongo-sanitize');
const fs = require('fs');
const cookieParser = require('cookie-parser');
const { connectDb, indexSync } = require('~/db');
const { jwtLogin, passportLogin } = require('~/strategies');
const { isEnabled } = require('~/server/utils');
const { ldapLogin } = require('~/strategies');
const { logger } = require('~/config');
const validateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const errorController = require('./controllers/ErrorController');
const initializeMCP = require('./services/initializeMCP');
const configureSocialLogins = require('./socialLogins');
const AppService = require('./services/AppService');
const staticCache = require('./utils/staticCache');
@@ -39,9 +39,7 @@ const startServer = async () => {
await connectDb();
logger.info('Connected to MongoDB');
indexSync().catch((err) => {
logger.error('[indexSync] Background sync failed:', err);
});
await indexSync();
app.disable('x-powered-by');
app.set('trust proxy', trusted_proxy);
@@ -55,6 +53,7 @@ const startServer = async () => {
/* Middleware */
app.use(noIndex);
app.use(errorController);
app.use(express.json({ limit: '3mb' }));
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
app.use(mongoSanitize());
@@ -96,6 +95,7 @@ const startServer = async () => {
app.use('/api/actions', routes.actions);
app.use('/api/keys', routes.keys);
app.use('/api/user', routes.user);
app.use('/api/ask', routes.ask);
app.use('/api/search', routes.search);
app.use('/api/edit', routes.edit);
app.use('/api/messages', routes.messages);
@@ -116,12 +116,9 @@ const startServer = async () => {
app.use('/api/roles', routes.roles);
app.use('/api/agents', routes.agents);
app.use('/api/banner', routes.banner);
app.use('/api/bedrock', routes.bedrock);
app.use('/api/memories', routes.memories);
app.use('/api/tags', routes.tags);
app.use('/api/mcp', routes.mcp);
// Add the error controller one more time after all routes
app.use(errorController);
app.use((req, res) => {
res.set({
@@ -145,8 +142,6 @@ const startServer = async () => {
} else {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
initializeMCP(app);
});
};
@@ -189,5 +184,5 @@ process.on('uncaughtException', (err) => {
process.exit(1);
});
/** Export app for easier testing purposes */
// export app for easier testing purposes
module.exports = app;

View File

@@ -1,4 +1,5 @@
const fs = require('fs');
const path = require('path');
const request = require('supertest');
const { MongoMemoryServer } = require('mongodb-memory-server');
const mongoose = require('mongoose');
@@ -58,30 +59,6 @@ describe('Server Configuration', () => {
expect(response.headers['pragma']).toBe('no-cache');
expect(response.headers['expires']).toBe('0');
});
it('should return 500 for unknown errors via ErrorController', async () => {
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
// Mock MongoDB operations to fail
const originalFindOne = mongoose.models.User.findOne;
const mockError = new Error('MongoDB operation failed');
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
throw mockError;
});
try {
const response = await request(app).post('/api/auth/login').send({
email: 'test@example.com',
password: 'password123',
});
expect(response.status).toBe(500);
expect(response.text).toBe('An unknown error occurred.');
} finally {
// Restore original function
mongoose.models.User.findOne = originalFindOne;
}
});
});
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely

View File

@@ -1,13 +1,13 @@
const { logger } = require('@librechat/data-schemas');
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
// abortMiddleware.js
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const clearPendingReq = require('~/cache/clearPendingReq');
const { sendError } = require('~/server/middleware/error');
const { spendTokens } = require('~/models/spendTokens');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
const abortDataMap = new WeakMap();
@@ -101,7 +101,7 @@ async function abortMessage(req, res) {
cleanupAbortController(abortKey);
if (res.headersSent && finalEvent) {
return sendEvent(res, finalEvent);
return sendMessage(res, finalEvent);
}
res.setHeader('Content-Type', 'application/json');
@@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
* @param {string} responseMessageId
*/
const onStart = (userMessage, responseMessageId) => {
sendEvent(res, { message: userMessage, created: true });
sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
getReqData({ abortKey });

View File

@@ -1,11 +1,11 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { deleteMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils');
const { logger } = require('~/config');
const three_minutes = 1000 * 60 * 3;
@@ -34,7 +34,7 @@ async function abortRun(req, res) {
const [thread_id, run_id] = runValues.split(':');
if (!run_id) {
logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id });
logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id });
return res.status(204).send({ message: 'Run not found' });
} else if (run_id === 'cancelled') {
logger.warn('[abortRun] Run already cancelled', { thread_id });
@@ -93,7 +93,7 @@ async function abortRun(req, res) {
};
if (res.headersSent && finalEvent) {
return sendEvent(res, finalEvent);
return sendMessage(res, finalEvent);
}
res.json(finalEvent);

View File

@@ -1,13 +1,13 @@
const { handleError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
EndpointURLs,
parseCompactConvo,
EModelEndpoint,
isAgentsEndpoint,
parseCompactConvo,
EndpointURLs,
} = require('librechat-data-provider');
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const assistants = require('~/server/services/Endpoints/assistants');
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
const { processFiles } = require('~/server/services/Files/process');
const anthropic = require('~/server/services/Endpoints/anthropic');
const bedrock = require('~/server/services/Endpoints/bedrock');
@@ -15,6 +15,7 @@ const openAI = require('~/server/services/Endpoints/openAI');
const agents = require('~/server/services/Endpoints/agents');
const custom = require('~/server/services/Endpoints/custom');
const google = require('~/server/services/Endpoints/google');
const { handleError } = require('~/server/utils');
const buildFunction = {
[EModelEndpoint.openAI]: openAI.buildOptions,
@@ -24,6 +25,7 @@ const buildFunction = {
[EModelEndpoint.bedrock]: bedrock.buildOptions,
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
[EModelEndpoint.anthropic]: anthropic.buildOptions,
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
[EModelEndpoint.assistants]: assistants.buildOptions,
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
};
@@ -34,9 +36,6 @@ async function buildEndpointOption(req, res, next) {
try {
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
} catch (error) {
logger.warn(
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
);
return handleError(res, { text: 'Error parsing conversation' });
}
@@ -58,6 +57,15 @@ async function buildEndpointOption(req, res, next) {
return handleError(res, { text: 'Model spec mismatch' });
}
if (
currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins &&
currentModelSpec.preset.tools
) {
return handleError(res, {
text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`,
});
}
try {
currentModelSpec.preset.spec = spec;
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
@@ -69,7 +77,6 @@ async function buildEndpointOption(req, res, next) {
conversation: currentModelSpec.preset,
});
} catch (error) {
logger.error(`Error parsing model spec for endpoint ${endpoint}`, error);
return handleError(res, { text: 'Error parsing model spec' });
}
}
@@ -77,23 +84,20 @@ async function buildEndpointOption(req, res, next) {
try {
const isAgents =
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
const builder = isAgents
? (...args) => buildFunction[EModelEndpoint.agents](req, ...args)
: buildFunction[endpointType ?? endpoint];
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
// TODO: use object params
req.body.endpointOption = await builder(endpoint, parsedBody, endpointType);
// TODO: use `getModelsConfig` only when necessary
const modelsConfig = await getModelsConfig(req);
req.body.endpointOption.modelsConfig = modelsConfig;
if (req.body.files && !isAgents) {
req.body.endpointOption.attachments = processFiles(req.body.files);
}
next();
} catch (error) {
logger.error(
`Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`,
error,
);
return handleError(res, { text: 'Error building endpoint option' });
}
}

View File

@@ -18,6 +18,7 @@ const message = 'Your account has been temporarily banned due to violations of o
* @function
* @param {Object} req - Express Request object.
* @param {Object} res - Express Response object.
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
*
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
*/
@@ -134,7 +135,6 @@ const checkBan = async (req, res, next = () => {}) => {
return await banResponse(req, res);
} catch (error) {
logger.error('Error in checkBan middleware:', error);
return next(error);
}
};

View File

@@ -1,7 +1,6 @@
const crypto = require('crypto');
const { sendEvent } = require('@librechat/api');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { sendError } = require('~/server/middleware/error');
const { sendMessage, sendError } = require('~/server/utils');
const { saveMessage } = require('~/models');
/**
@@ -37,7 +36,7 @@ const denyRequest = async (req, res, errorMessage) => {
isCreatedByUser: true,
text,
};
sendEvent(res, { message: userMessage, created: true });
sendMessage(res, { message: userMessage, created: true });
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;

View File

@@ -1,95 +0,0 @@
const rateLimit = require('express-rate-limit');
const { isEnabled } = require('@librechat/api');
const { RedisStore } = require('rate-limit-redis');
const { logger } = require('@librechat/data-schemas');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
const forkIpMax = FORK_IP_MAX;
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
const forkUserMax = FORK_USER_MAX;
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
return {
forkIpWindowMs,
forkIpMax,
forkIpWindowInMinutes,
forkUserWindowMs,
forkUserMax,
forkUserWindowInMinutes,
forkViolationScore: FORK_VIOLATION_SCORE,
};
};
const createForkHandler = (ip = true) => {
const {
forkIpMax,
forkUserMax,
forkViolationScore,
forkIpWindowInMinutes,
forkUserWindowInMinutes,
} = getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
const errorMessage = {
type,
max: ip ? forkIpMax : forkUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, forkViolationScore);
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
};
};
const createForkLimiters = () => {
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
const ipLimiterOptions = {
windowMs: forkIpWindowMs,
max: forkIpMax,
handler: createForkHandler(),
};
const userLimiterOptions = {
windowMs: forkUserWindowMs,
max: forkUserMax,
handler: createForkHandler(false),
keyGenerator: function (req) {
return req.user?.id;
},
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for fork rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'fork_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'fork_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const forkIpLimiter = rateLimit(ipLimiterOptions);
const forkUserLimiter = rateLimit(userLimiterOptions);
return { forkIpLimiter, forkUserLimiter };
};
module.exports = { createForkLimiters };

View File

@@ -1,17 +1,16 @@
const rateLimit = require('express-rate-limit');
const { isEnabled } = require('@librechat/api');
const { RedisStore } = require('rate-limit-redis');
const { logger } = require('@librechat/data-schemas');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
const importIpMax = IMPORT_IP_MAX;
@@ -28,18 +27,12 @@ const getEnvironmentVariables = () => {
importUserWindowMs,
importUserMax,
importUserWindowInMinutes,
importViolationScore: IMPORT_VIOLATION_SCORE,
};
};
const createImportHandler = (ip = true) => {
const {
importIpMax,
importUserMax,
importViolationScore,
importIpWindowInMinutes,
importUserWindowInMinutes,
} = getEnvironmentVariables();
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
@@ -50,7 +43,7 @@ const createImportHandler = (ip = true) => {
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, importViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
};
};

View File

@@ -4,7 +4,6 @@ const createSTTLimiters = require('./sttLimiters');
const loginLimiter = require('./loginLimiter');
const importLimiters = require('./importLimiters');
const uploadLimiters = require('./uploadLimiters');
const forkLimiters = require('./forkLimiters');
const registerLimiter = require('./registerLimiter');
const toolCallLimiter = require('./toolCallLimiter');
const messageLimiters = require('./messageLimiters');
@@ -15,7 +14,6 @@ module.exports = {
...uploadLimiters,
...importLimiters,
...messageLimiters,
...forkLimiters,
loginLimiter,
registerLimiter,
toolCallLimiter,

View File

@@ -11,7 +11,6 @@ const {
MESSAGE_IP_WINDOW = 1,
MESSAGE_USER_MAX = 40,
MESSAGE_USER_WINDOW = 1,
MESSAGE_VIOLATION_SCORE: score,
} = process.env;
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
@@ -40,7 +39,7 @@ const createHandler = (ip = true) => {
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
await logViolation(req, res, type, errorMessage);
return await denyRequest(req, res, errorMessage);
};
};

View File

@@ -11,7 +11,6 @@ const getEnvironmentVariables = () => {
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
const sttIpMax = STT_IP_MAX;
@@ -28,12 +27,11 @@ const getEnvironmentVariables = () => {
sttUserWindowMs,
sttUserMax,
sttUserWindowInMinutes,
sttViolationScore: STT_VIOLATION_SCORE,
};
};
const createSTTHandler = (ip = true) => {
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
@@ -45,7 +43,7 @@ const createSTTHandler = (ip = true) => {
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, sttViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many STT requests. Try again later' });
};
};

View File

@@ -6,8 +6,6 @@ const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
const handler = async (req, res) => {
const type = ViolationTypes.TOOL_CALL_LIMIT;
const errorMessage = {
@@ -17,7 +15,7 @@ const handler = async (req, res) => {
windowInMinutes: 1,
};
await logViolation(req, res, type, errorMessage, score);
await logViolation(req, res, type, errorMessage, 0);
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
};

View File

@@ -11,7 +11,6 @@ const getEnvironmentVariables = () => {
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
const ttsIpMax = TTS_IP_MAX;
@@ -28,12 +27,11 @@ const getEnvironmentVariables = () => {
ttsUserWindowMs,
ttsUserMax,
ttsUserWindowInMinutes,
ttsViolationScore: TTS_VIOLATION_SCORE,
};
};
const createTTSHandler = (ip = true) => {
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
@@ -45,7 +43,7 @@ const createTTSHandler = (ip = true) => {
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, ttsViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
};
};

View File

@@ -11,7 +11,6 @@ const getEnvironmentVariables = () => {
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
@@ -28,7 +27,6 @@ const getEnvironmentVariables = () => {
fileUploadUserWindowMs,
fileUploadUserMax,
fileUploadUserWindowInMinutes,
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
};
};
@@ -38,7 +36,6 @@ const createFileUploadHandler = (ip = true) => {
fileUploadIpWindowInMinutes,
fileUploadUserMax,
fileUploadUserWindowInMinutes,
fileUploadViolationScore,
} = getEnvironmentVariables();
return async (req, res) => {
@@ -50,7 +47,7 @@ const createFileUploadHandler = (ip = true) => {
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
};
};

View File

@@ -0,0 +1,78 @@
const { getRoleByName } = require('~/models/Role');
const { logger } = require('~/config');
/**
* Core function to check if a user has one or more required permissions
*
* @param {object} user - The user object
* @param {PermissionTypes} permissionType - The type of permission to check
* @param {Permissions[]} permissions - The list of specific permissions to check
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
* @param {object} [checkObject] - The object to check properties against
* @returns {Promise<boolean>} Whether the user has the required permissions
*/
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
if (!user) {
return false;
}
const role = await getRoleByName(user.role);
if (role && role.permissions && role.permissions[permissionType]) {
const hasAnyPermission = permissions.some((permission) => {
if (role.permissions[permissionType][permission]) {
return true;
}
if (bodyProps[permission] && checkObject) {
return bodyProps[permission].some((prop) =>
Object.prototype.hasOwnProperty.call(checkObject, prop),
);
}
return false;
});
return hasAnyPermission;
}
return false;
};
/**
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
*
* @param {PermissionTypes} permissionType - The type of permission to check.
* @param {Permissions[]} permissions - The list of specific permissions to check.
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
*/
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
return async (req, res, next) => {
try {
const hasAccess = await checkAccess(
req.user,
permissionType,
permissions,
bodyProps,
req.body,
);
if (hasAccess) {
return next();
}
logger.warn(
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
);
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
} catch (error) {
logger.error(error);
return res.status(500).json({ message: `Server error: ${error.message}` });
}
};
};
module.exports = {
checkAccess,
generateCheckAccess,
};

View File

@@ -1,5 +1,8 @@
const checkAdmin = require('./admin');
const { checkAccess, generateCheckAccess } = require('./access');
module.exports = {
checkAdmin,
checkAccess,
generateCheckAccess,
};

View File

@@ -1,5 +1,5 @@
const uap = require('ua-parser-js');
const { handleError } = require('@librechat/api');
const { handleError } = require('../utils');
const { logViolation } = require('../../cache');
/**

View File

@@ -1,4 +1,4 @@
const { handleError } = require('@librechat/api');
const { handleError } = require('../utils');
function validateEndpoint(req, res, next) {
const { endpoint: _endpoint, endpointType } = req.body;

View File

@@ -1,6 +1,6 @@
const { handleError } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const { handleError } = require('~/server/utils');
const { logViolation } = require('~/cache');
/**
* Validates the model of the request.

View File

@@ -1,162 +0,0 @@
const fs = require('fs');
const path = require('path');
const express = require('express');
const request = require('supertest');
const zlib = require('zlib');
// Create test setup
const mockTestDir = path.join(__dirname, 'test-static-route');
// Mock the paths module to point to our test directory
jest.mock('~/config/paths', () => ({
imageOutput: mockTestDir,
}));
describe('Static Route Integration', () => {
let app;
let staticRoute;
let testDir;
let testImagePath;
beforeAll(() => {
// Create a test directory and files
testDir = mockTestDir;
testImagePath = path.join(testDir, 'test-image.jpg');
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir, { recursive: true });
}
// Create a test image file
fs.writeFileSync(testImagePath, 'fake-image-data');
// Create a gzipped version of the test image (for gzip scanning tests)
fs.writeFileSync(testImagePath + '.gz', zlib.gzipSync('fake-image-data'));
});
afterAll(() => {
// Clean up test files
if (fs.existsSync(testDir)) {
fs.rmSync(testDir, { recursive: true, force: true });
}
});
// Helper function to set up static route with specific config
const setupStaticRoute = (skipGzipScan = false) => {
if (skipGzipScan) {
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
} else {
process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN = 'true';
}
staticRoute = require('../static');
app.use('/images', staticRoute);
};
beforeEach(() => {
// Clear the module cache to get fresh imports
jest.resetModules();
app = express();
// Clear environment variables
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
delete process.env.NODE_ENV;
});
describe('route functionality', () => {
it('should serve static image files', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should return 404 for non-existent files', async () => {
setupStaticRoute();
const response = await request(app).get('/images/nonexistent.jpg');
expect(response.status).toBe(404);
});
});
describe('cache behavior', () => {
it('should set cache headers for images in production', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
});
it('should not set cache headers in development', async () => {
process.env.NODE_ENV = 'development';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
// Our middleware should not set the production cache-control header in development
expect(response.headers['cache-control']).not.toBe('public, max-age=172800, s-maxage=86400');
});
});
describe('gzip compression behavior', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve gzipped files when gzip scanning is enabled', async () => {
setupStaticRoute(false); // Enable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.body.toString()).toBe('fake-image-data');
});
it('should not serve gzipped files when gzip scanning is disabled', async () => {
setupStaticRoute(true); // Disable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.body.toString()).toBe('fake-image-data');
});
});
describe('path configuration', () => {
it('should use the configured imageOutput path', async () => {
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should serve from subdirectories', async () => {
// Create a subdirectory with a file
const subDir = path.join(testDir, 'thumbs');
fs.mkdirSync(subDir, { recursive: true });
const thumbPath = path.join(subDir, 'thumb.jpg');
fs.writeFileSync(thumbPath, 'thumbnail-data');
setupStaticRoute();
const response = await request(app).get('/images/thumbs/thumb.jpg').expect(200);
expect(response.body.toString()).toBe('thumbnail-data');
// Clean up
fs.rmSync(subDir, { recursive: true, force: true });
});
});
});

View File

@@ -1,10 +1,8 @@
const express = require('express');
const jwt = require('jsonwebtoken');
const { getAccessToken } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { findToken, updateToken, createToken } = require('~/models');
const { getFlowStateManager } = require('~/config');
const { getAccessToken } = require('~/server/services/TokenService');
const { logger, getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const router = express.Router();
@@ -30,19 +28,18 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
try {
decodedState = jwt.verify(state, JWT_SECRET);
} catch (err) {
logger.error('Error verifying state parameter:', err);
await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter');
return res.redirect('/oauth/error?error=invalid_state');
return res.status(400).send('Invalid or expired state parameter');
}
if (decodedState.action_id !== action_id) {
await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter');
return res.redirect('/oauth/error?error=invalid_state');
return res.status(400).send('Mismatched action ID in state parameter');
}
if (!decodedState.user) {
await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter');
return res.redirect('/oauth/error?error=invalid_state');
return res.status(400).send('Invalid user ID in state parameter');
}
identifier = `${decodedState.user}:${action_id}`;
const flowState = await flowManager.getFlowState(identifier, 'oauth');
@@ -50,34 +47,91 @@ router.get('/:action_id/oauth/callback', async (req, res) => {
throw new Error('OAuth flow not found');
}
const tokenData = await getAccessToken(
{
code,
userId: decodedState.user,
identifier,
client_url: flowState.metadata.client_url,
redirect_uri: flowState.metadata.redirect_uri,
token_exchange_method: flowState.metadata.token_exchange_method,
/** Encrypted values */
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
},
{
findToken,
updateToken,
createToken,
},
);
const tokenData = await getAccessToken({
code,
userId: decodedState.user,
identifier,
client_url: flowState.metadata.client_url,
redirect_uri: flowState.metadata.redirect_uri,
token_exchange_method: flowState.metadata.token_exchange_method,
/** Encrypted values */
encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id,
encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret,
});
await flowManager.completeFlow(identifier, 'oauth', tokenData);
/** Redirect to React success page */
const serverName = flowState.metadata?.action_name || `Action ${action_id}`;
const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`;
res.redirect(redirectUrl);
res.send(`
<!DOCTYPE html>
<html>
<head>
<title>Authentication Successful</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<style>
body {
font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont;
background-color: rgb(249, 250, 251);
margin: 0;
padding: 2rem;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}
.card {
background-color: white;
border-radius: 0.5rem;
padding: 2rem;
max-width: 28rem;
width: 100%;
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
text-align: center;
}
.heading {
color: rgb(17, 24, 39);
font-size: 1.875rem;
font-weight: 700;
margin: 0 0 1rem;
}
.description {
color: rgb(75, 85, 99);
font-size: 0.875rem;
margin: 0.5rem 0;
}
.countdown {
color: rgb(99, 102, 241);
font-weight: 500;
}
</style>
</head>
<body>
<div class="card">
<h1 class="heading">Authentication Successful</h1>
<p class="description">
Your authentication was successful. This window will close in
<span class="countdown" id="countdown">3</span> seconds.
</p>
</div>
<script>
let secondsLeft = 3;
const countdownElement = document.getElementById('countdown');
const countdown = setInterval(() => {
secondsLeft--;
countdownElement.textContent = secondsLeft;
if (secondsLeft <= 0) {
clearInterval(countdown);
window.close();
}
}, 1000);
</script>
</body>
</html>
`);
} catch (error) {
logger.error('Error in OAuth callback:', error);
await flowManager.failFlow(identifier, 'oauth', error);
res.redirect('/oauth/error?error=callback_failed');
res.status(500).send('Authentication failed. Please try again.');
}
});

View File

@@ -1,28 +1,14 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { generateCheckAccess } = require('@librechat/api');
const {
SystemRoles,
Permissions,
PermissionTypes,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { getAgent, updateAgent } = require('~/models/Agent');
const { getRoleByName } = require('~/models/Role');
const { logger } = require('~/config');
const router = express.Router();
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
// If the user has ADMIN role
// then action edition is possible even if not owner of the assistant
const isAdmin = (req) => {
@@ -55,7 +41,7 @@ router.get('/', async (req, res) => {
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
* @returns {Object} 200 - success response - application/json
*/
router.post('/:agent_id', checkAgentCreate, async (req, res) => {
router.post('/:agent_id', async (req, res) => {
try {
const { agent_id } = req.params;
@@ -163,7 +149,7 @@ router.post('/:agent_id', checkAgentCreate, async (req, res) => {
* @param {string} req.params.action_id - The ID of the action to delete.
* @returns {Object} 200 - success response - application/json
*/
router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => {
router.delete('/:agent_id/:action_id', async (req, res) => {
try {
const { agent_id, action_id } = req.params;
const admin = isAdmin(req);

View File

@@ -1,28 +1,22 @@
const express = require('express');
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
// validateModel,
generateCheckAccess,
validateConvoAccess,
buildEndpointOption,
} = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/agents');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
router.use(moderateText);
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
skipCheck: skipAgentCheck,
getRoleByName,
});
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
router.use(checkAgentAccess);
router.use(validateConvoAccess);

View File

@@ -1,36 +1,29 @@
const express = require('express');
const { generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { requireJwtAuth } = require('~/server/middleware');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const v1 = require('~/server/controllers/agents/v1');
const { getRoleByName } = require('~/models/Role');
const actions = require('./actions');
const tools = require('./tools');
const router = express.Router();
const avatar = express.Router();
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
Permissions.USE,
Permissions.CREATE,
]);
const checkGlobalAgentShare = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
bodyProps: {
const checkGlobalAgentShare = generateCheckAccess(
PermissionTypes.AGENTS,
[Permissions.USE, Permissions.CREATE],
{
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
},
getRoleByName,
});
);
router.use(requireJwtAuth);
router.use(checkAgentAccess);
/**
* Agent actions route.

View File

@@ -0,0 +1,63 @@
const { Keyv } = require('keyv');
const { KeyvFile } = require('keyv-file');
const { logger } = require('~/config');
const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => {
try {
const conversationsCache = new Keyv({
store: new KeyvFile({ filename: './data/cache.json' }),
namespace: 'chatgpt', // should be 'bing' for bing/sydney
});
const {
conversationId,
messageId: userMessageId,
parentMessageId: userParentMessageId,
text: userText,
} = userMessage;
const {
messageId: responseMessageId,
parentMessageId: responseParentMessageId,
text: responseText,
} = responseMessage;
let conversation = await conversationsCache.get(conversationId);
// used to generate a title for the conversation if none exists
// let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
// isNewConversation = true;
}
const roles = (options) => {
if (endpoint === 'openAI') {
return options?.chatGptLabel || 'ChatGPT';
}
};
let _userMessage = {
id: userMessageId,
parentMessageId: userParentMessageId,
role: 'User',
message: userText,
};
let _responseMessage = {
id: responseMessageId,
parentMessageId: responseParentMessageId,
role: roles(endpointOption),
message: responseText,
};
conversation.messages.push(_userMessage, _responseMessage);
await conversationsCache.set(conversationId, conversation);
} catch (error) {
logger.error('[addToCache] Error adding conversation to cache', error);
}
};
module.exports = addToCache;

View File

@@ -0,0 +1,25 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,25 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/custom');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const {
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,24 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const {
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const router = express.Router();
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,241 @@
const express = require('express');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage, updateMessage } = require('~/models');
const {
handleAbort,
createAbortController,
handleAbortError,
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
moderateText,
} = require('~/server/middleware');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
endpointOption,
conversationId,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const newConvo = !conversationId;
const user = req.user.id;
const plugins = [];
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
} else if (!conversationId && key === 'conversationId') {
conversationId = data[key];
}
}
};
let streaming = null;
let timer = null;
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
onProgress: () => {
if (timer) {
clearTimeout(timer);
}
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
resolve();
}, 250);
});
},
});
const pluginMap = new Map();
const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};
const onToolStart = async (tool, input, runId, parentRunId) => {
const pluginName = pluginMap.get(parentRunId);
const latestPlugin = {
runId,
loading: true,
inputs: [input],
latest: pluginName,
outputs: null,
};
if (streaming) {
await streaming;
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(
res,
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
extraTokens,
);
};
const onToolEnd = async (output, runId) => {
if (streaming) {
await streaming;
}
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
if (pluginIndex !== -1) {
plugins[pluginIndex].loading = false;
plugins[pluginIndex].outputs = output;
}
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugins: plugins.map((p) => ({ ...p, loading: false })),
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onChainEnd = () => {
if (!client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
}
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};
let response = await client.sendMessage(text, {
user,
conversationId,
parentMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
onStart,
getPartialText,
...endpointOption,
progressCallback,
progressOptions: {
res,
// parentMessageId: overrideParentMessageId || userMessageId,
plugins,
},
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
logger.debug('[/ask/gptPlugins]', response);
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
final: true,
conversation,
requestMessage: userMessage,
responseMessage: response,
});
res.end();
if (parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response,
client,
});
}
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
if (response.plugins?.length > 0) {
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/ask/gptPlugins.js - save plugins used' },
);
}
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
},
);
module.exports = router;

View File

@@ -0,0 +1,47 @@
const express = require('express');
const { EModelEndpoint } = require('librechat-data-provider');
const {
uaParser,
checkBan,
requireJwtAuth,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
validateConvoAccess,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const gptPlugins = require('./gptPlugins');
const anthropic = require('./anthropic');
const custom = require('./custom');
const google = require('./google');
const openAI = require('./openAI');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.custom}`, custom);
module.exports = router;

View File

@@ -0,0 +1,27 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
moderateText,
} = require('~/server/middleware');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,37 @@
const express = require('express');
const router = express.Router();
const {
setHeaders,
handleAbort,
moderateText,
// validateModel,
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/bedrock');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title');
router.use(moderateText);
/**
* @route POST /
* @desc Chat with an assistant
* @access Public
* @param {express.Request} req - The request object, containing the request data.
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post(
'/',
// validateModel,
// validateEndpoint,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AgentController(req, res, next, initializeClient, addTitle);
},
);
module.exports = router;

View File

@@ -0,0 +1,35 @@
const express = require('express');
const {
uaParser,
checkBan,
requireJwtAuth,
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
} = require('~/server/middleware');
const { isEnabled } = require('~/server/utils');
const chat = require('./chat');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
const router = express.Router();
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
router.use(concurrentLimiter);
}
if (isEnabled(LIMIT_MESSAGE_IP)) {
router.use(messageIpLimiter);
}
if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use('/chat', chat);
module.exports = router;

View File

@@ -1,11 +1,10 @@
const express = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider');
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
const { getLdapConfig } = require('~/server/services/Config/ldap');
const { getProjectByName } = require('~/models/Project');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const router = express.Router();
const emailLoginEnabled =
@@ -22,7 +21,6 @@ const publicSharedLinksEnabled =
router.get('/', async function (req, res) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
if (cachedStartupConfig) {
res.send(cachedStartupConfig);
@@ -98,18 +96,6 @@ router.get('/', async function (req, res) {
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL,
};
payload.mcpServers = {};
const config = await getCustomConfig();
if (config?.mcpServers != null) {
for (const serverName in config.mcpServers) {
const serverConfig = config.mcpServers[serverName];
payload.mcpServers[serverName] = {
customUserVars: serverConfig?.customUserVars || {},
};
}
}
/** @type {TCustomConfig['webSearch']} */
const webSearchConfig = req.app.locals.webSearch;
if (

View File

@@ -1,17 +1,16 @@
const multer = require('multer');
const express = require('express');
const { sleep } = require('@librechat/agents');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { importConversations } = require('~/server/utils/import');
const { createImportLimiters } = require('~/server/middleware');
const { deleteToolCalls } = require('~/models/ToolCall');
const { isEnabled, sleep } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
const { logger } = require('~/config');
const assistantClients = {
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
@@ -44,7 +43,6 @@ router.get('/', async (req, res) => {
});
res.status(200).json(result);
} catch (error) {
logger.error('Error fetching conversations', error);
res.status(500).json({ error: 'Error fetching conversations' });
}
});
@@ -158,7 +156,6 @@ router.post('/update', async (req, res) => {
});
const { importIpLimiter, importUserLimiter } = createImportLimiters();
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
const upload = multer({ storage: storage, fileFilter: importFileFilter });
/**
@@ -192,7 +189,7 @@ router.post(
* @param {express.Response<TForkConvoResponse>} res - Express response object.
* @returns {Promise<void>} - The response after forking the conversation.
*/
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
router.post('/fork', async (req, res) => {
try {
/** @type {TForkConvoRequest} */
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;

View File

@@ -0,0 +1,207 @@
const express = require('express');
const { getResponseSender } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
validateModel,
handleAbortError,
validateEndpoint,
buildEndpointOption,
createAbortController,
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, updateMessage } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
logger.debug('[/edit/gptPlugins]', {
text,
generation,
isContinued,
conversationId,
...endpointOption,
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: () => {
if (plugin.loading === true) {
plugin.loading = false;
}
},
});
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' },
);
}
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction);
};
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onStart,
...endpointOption,
progressCallback,
progressOptions: {
res,
plugin,
// parentMessageId: overrideParentMessageId || userMessageId,
},
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
final: true,
conversation,
requestMessage: userMessage,
responseMessage: response,
});
res.end();
response.plugin = { ...plugin, loading: false };
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/edit/gptPlugins.js' },
);
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
},
);
module.exports = router;

View File

@@ -3,6 +3,7 @@ const openAI = require('./openAI');
const custom = require('./custom');
const google = require('./google');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('librechat-data-provider');
const {
@@ -38,6 +39,7 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.custom}`, custom);

View File

@@ -1,282 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
// Mock dependencies
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: jest.fn().mockResolvedValue({}),
filterFile: jest.fn(),
processFileUpload: jest.fn(),
processAgentFileUpload: jest.fn(),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({})),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3FileUrls: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({
get: jest.fn(),
set: jest.fn(),
})),
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
},
}));
const { createFile } = require('~/models/File');
const { createAgent } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
// Import the router after mocks
const router = require('./files');
describe('File Routes - Agent Files Endpoint', () => {
let app;
let mongoServer;
let authorId;
let otherUserId;
let agentId;
let fileId1;
let fileId2;
let fileId3;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
// Initialize models
require('~/db/models');
app = express();
app.use(express.json());
// Mock authentication middleware
app.use((req, res, next) => {
req.user = { id: otherUserId || 'default-user' };
req.app = { locals: {} };
next();
});
app.use('/files', router);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
jest.clearAllMocks();
// Clear database
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
authorId = new mongoose.Types.ObjectId().toString();
otherUserId = new mongoose.Types.ObjectId().toString();
agentId = uuidv4();
fileId1 = uuidv4();
fileId2 = uuidv4();
fileId3 = uuidv4();
// Create files
await createFile({
user: authorId,
file_id: fileId1,
filename: 'agent-file1.txt',
filepath: `/uploads/${authorId}/${fileId1}`,
bytes: 1024,
type: 'text/plain',
});
await createFile({
user: authorId,
file_id: fileId2,
filename: 'agent-file2.txt',
filepath: `/uploads/${authorId}/${fileId2}`,
bytes: 2048,
type: 'text/plain',
});
await createFile({
user: otherUserId,
file_id: fileId3,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${fileId3}`,
bytes: 512,
type: 'text/plain',
});
// Create an agent with files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileId1, fileId2],
},
},
});
// Share the agent globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
}
});
describe('GET /files/agent/:agent_id', () => {
it('should return files accessible through the agent for non-author', async () => {
const response = await request(app).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(2); // Only agent files, not user-owned files
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).not.toContain(fileId3); // User's own file not included
});
it('should return 400 when agent_id is not provided', async () => {
const response = await request(app).get('/files/agent/');
expect(response.status).toBe(404); // Express returns 404 for missing route parameter
});
it('should return empty array for non-existent agent', async () => {
const response = await request(app).get('/files/agent/non-existent-agent');
expect(response.status).toBe(200);
expect(response.body).toEqual([]); // Empty array for non-existent agent
});
it('should return empty array when agent is not collaborative', async () => {
// Create a non-collaborative agent
const nonCollabAgentId = uuidv4();
await createAgent({
id: nonCollabAgentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: [fileId1],
},
},
});
// Share it globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: nonCollabAgentId }, { projectIds: [globalProject._id] });
}
const response = await request(app).get(`/files/agent/${nonCollabAgentId}`);
expect(response.status).toBe(200);
expect(response.body).toEqual([]); // Empty array when not collaborative
});
it('should return agent files for agent author', async () => {
// Create a new app instance with author authentication
const authorApp = express();
authorApp.use(express.json());
authorApp.use((req, res, next) => {
req.user = { id: authorId };
req.app = { locals: {} };
next();
});
authorApp.use('/files', router);
const response = await request(authorApp).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(2); // Agent files for author
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).not.toContain(fileId3); // User's own file not included
});
it('should return files uploaded by other users to shared agent for author', async () => {
// Create a file uploaded by another user
const otherUserFileId = uuidv4();
const anotherUserId = new mongoose.Types.ObjectId().toString();
await createFile({
user: anotherUserId,
file_id: otherUserFileId,
filename: 'other-user-file.txt',
filepath: `/uploads/${anotherUserId}/${otherUserFileId}`,
bytes: 4096,
type: 'text/plain',
});
// Update agent to include the file uploaded by another user
const { updateAgent } = require('~/models/Agent');
await updateAgent(
{ id: agentId },
{
tool_resources: {
file_search: {
file_ids: [fileId1, fileId2, otherUserFileId],
},
},
},
);
// Create app instance with author authentication
const authorApp = express();
authorApp.use(express.json());
authorApp.use((req, res, next) => {
req.user = { id: authorId };
req.app = { locals: {} };
next();
});
authorApp.use('/files', router);
const response = await request(authorApp).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(3); // Including file from another user
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).toContain(otherUserFileId); // File uploaded by another user
});
});
});

View File

@@ -5,7 +5,6 @@ const {
Time,
isUUID,
CacheKeys,
Constants,
FileSources,
EModelEndpoint,
isAgentsEndpoint,
@@ -17,12 +16,11 @@ const {
processDeleteRequest,
processAgentFileUpload,
} = require('~/server/services/Files/process');
const { getFiles, batchUpdateFiles, hasAccessToFilesViaAgent } = require('~/models/File');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
const { getProjectByName } = require('~/models/Project');
const { getFiles, batchUpdateFiles } = require('~/models/File');
const { getAssistant } = require('~/models/Assistant');
const { getAgent } = require('~/models/Agent');
const { getLogStores } = require('~/cache');
@@ -52,68 +50,6 @@ router.get('/', async (req, res) => {
}
});
/**
* Get files specific to an agent
* @route GET /files/agent/:agent_id
* @param {string} agent_id - The agent ID to get files for
* @returns {Promise<TFile[]>} Array of files attached to the agent
*/
router.get('/agent/:agent_id', async (req, res) => {
try {
const { agent_id } = req.params;
const userId = req.user.id;
if (!agent_id) {
return res.status(400).json({ error: 'Agent ID is required' });
}
// Get the agent to check ownership and attached files
const agent = await getAgent({ id: agent_id });
if (!agent) {
// No agent found, return empty array
return res.status(200).json([]);
}
// Check if user has access to the agent
if (agent.author.toString() !== userId) {
// Non-authors need the agent to be globally shared and collaborative
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString()) ||
!agent.isCollaborative
) {
return res.status(200).json([]);
}
}
// Collect all file IDs from agent's tool resources
const agentFileIds = [];
if (agent.tool_resources) {
for (const [, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
agentFileIds.push(...resource.file_ids);
}
}
}
// If no files attached to agent, return empty array
if (agentFileIds.length === 0) {
return res.status(200).json([]);
}
// Get only the files attached to this agent
const files = await getFiles({ file_id: { $in: agentFileIds } }, null, { text: 0 });
res.status(200).json(files);
} catch (error) {
logger.error('[/files/agent/:agent_id] Error fetching agent files:', error);
res.status(500).json({ error: 'Failed to fetch agent files' });
}
});
router.get('/config', async (req, res) => {
try {
res.status(200).json(req.app.locals.fileConfig);
@@ -150,62 +86,11 @@ router.delete('/', async (req, res) => {
const fileIds = files.map((file) => file.file_id);
const dbFiles = await getFiles({ file_id: { $in: fileIds } });
const ownedFiles = [];
const nonOwnedFiles = [];
const fileMap = new Map();
for (const file of dbFiles) {
fileMap.set(file.file_id, file);
if (file.user.toString() === req.user.id) {
ownedFiles.push(file);
} else {
nonOwnedFiles.push(file);
}
}
// If all files are owned by the user, no need for further checks
if (nonOwnedFiles.length === 0) {
await processDeleteRequest({ req, files: ownedFiles });
logger.debug(
`[/files] Files deleted successfully: ${ownedFiles
.filter((f) => f.file_id)
.map((f) => f.file_id)
.join(', ')}`,
);
res.status(200).json({ message: 'Files deleted successfully' });
return;
}
// Check access for non-owned files
let authorizedFiles = [...ownedFiles];
let unauthorizedFiles = [];
if (req.body.agent_id && nonOwnedFiles.length > 0) {
// Batch check access for all non-owned files
const nonOwnedFileIds = nonOwnedFiles.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
req.user.id,
nonOwnedFileIds,
req.body.agent_id,
);
// Separate authorized and unauthorized files
for (const file of nonOwnedFiles) {
if (accessMap.get(file.file_id)) {
authorizedFiles.push(file);
} else {
unauthorizedFiles.push(file);
}
}
} else {
// No agent context, all non-owned files are unauthorized
unauthorizedFiles = nonOwnedFiles;
}
const unauthorizedFiles = dbFiles.filter((file) => file.user.toString() !== req.user.id);
if (unauthorizedFiles.length > 0) {
return res.status(403).json({
message: 'You can only delete files you have access to',
message: 'You can only delete your own files',
unauthorizedFiles: unauthorizedFiles.map((f) => f.file_id),
});
}
@@ -246,10 +131,10 @@ router.delete('/', async (req, res) => {
.json({ message: 'File associations removed successfully from Azure Assistant' });
}
await processDeleteRequest({ req, files: authorizedFiles });
await processDeleteRequest({ req, files: dbFiles });
logger.debug(
`[/files] Files deleted successfully: ${authorizedFiles
`[/files] Files deleted successfully: ${files
.filter((f) => f.file_id)
.map((f) => f.file_id)
.join(', ')}`,
@@ -398,10 +283,7 @@ router.post('/', async (req, res) => {
message += ': ' + error.message;
}
if (
error.message?.includes('Invalid file format') ||
error.message?.includes('No OCR result')
) {
if (error.message?.includes('Invalid file format')) {
message = error.message;
}

Some files were not shown because too many files have changed in this diff Show More