Compare commits
62 Commits
feat/compo
...
feat/promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1af9d21f0 | ||
|
|
4918899c8d | ||
|
|
7e37211458 | ||
|
|
e57fc83d40 | ||
|
|
550610dba9 | ||
|
|
916cd46221 | ||
|
|
12b08183ff | ||
|
|
f4d97e1672 | ||
|
|
035fa081c1 | ||
|
|
aecf8f19a6 | ||
|
|
35f548a94d | ||
|
|
e60c0cf201 | ||
|
|
5b392f9cb0 | ||
|
|
e0f468da20 | ||
|
|
91a2df4759 | ||
|
|
97a99985fa | ||
|
|
3554625a06 | ||
|
|
a37bf6719c | ||
|
|
e513f50c08 | ||
|
|
f5511e4a4e | ||
|
|
a288ad1d9c | ||
|
|
458580ec87 | ||
|
|
4285d5841c | ||
|
|
5ee55cda4f | ||
|
|
404d40cbef | ||
|
|
f4680b016c | ||
|
|
077224b351 | ||
|
|
9c70d1db96 | ||
|
|
543281da6c | ||
|
|
24800bfbeb | ||
|
|
07e08143e4 | ||
|
|
8ba61a86f4 | ||
|
|
56ad92fb1c | ||
|
|
1ceb52d2b5 | ||
|
|
5d267aa8e2 | ||
|
|
59d00e99f3 | ||
|
|
738d04fac4 | ||
|
|
8a5dbac0f9 | ||
|
|
434289fe92 | ||
|
|
a648ad3d13 | ||
|
|
55d63caaf4 | ||
|
|
313539d1ed | ||
|
|
f869d772f7 | ||
|
|
20100e120b | ||
|
|
3f3cfefc52 | ||
|
|
3e1591d404 | ||
|
|
1060ae8040 | ||
|
|
dd67e463e4 | ||
|
|
d60ad61325 | ||
|
|
452151e408 | ||
|
|
33b4a97b42 | ||
|
|
9cdc62b655 | ||
|
|
799f0e5810 | ||
|
|
cbda3cb529 | ||
|
|
3ab1bd65e5 | ||
|
|
c551ba21f5 | ||
|
|
c87422a1e0 | ||
|
|
b169306096 | ||
|
|
3d261a969d | ||
|
|
84f62eb70c | ||
|
|
3e698338aa | ||
|
|
0e26df0390 |
15
.env.example
15
.env.example
@@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
|
||||
# GOOGLE_AUTH_HEADER=true
|
||||
|
||||
# Gemini API (AI Studio)
|
||||
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
|
||||
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
|
||||
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
|
||||
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
|
||||
|
||||
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
||||
|
||||
@@ -349,6 +349,11 @@ REGISTRATION_VIOLATION_SCORE=1
|
||||
CONCURRENT_VIOLATION_SCORE=1
|
||||
MESSAGE_VIOLATION_SCORE=1
|
||||
NON_BROWSER_VIOLATION_SCORE=20
|
||||
TTS_VIOLATION_SCORE=0
|
||||
STT_VIOLATION_SCORE=0
|
||||
FORK_VIOLATION_SCORE=0
|
||||
IMPORT_VIOLATION_SCORE=0
|
||||
FILE_UPLOAD_VIOLATION_SCORE=0
|
||||
|
||||
LOGIN_MAX=7
|
||||
LOGIN_WINDOW=5
|
||||
@@ -453,8 +458,8 @@ OPENID_REUSE_TOKENS=
|
||||
OPENID_JWKS_URL_CACHE_ENABLED=
|
||||
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
|
||||
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
|
||||
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
|
||||
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
|
||||
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
|
||||
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
|
||||
# Set to true to use the OpenID Connect end session endpoint for logout
|
||||
OPENID_USE_END_SESSION_ENDPOINT=
|
||||
|
||||
@@ -657,4 +662,4 @@ OPENWEATHER_API_KEY=
|
||||
# Reranker (Required)
|
||||
# JINA_API_KEY=your_jina_api_key
|
||||
# or
|
||||
# COHERE_API_KEY=your_cohere_api_key
|
||||
# COHERE_API_KEY=your_cohere_api_key
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.7.8
|
||||
# v0.7.9-rc1
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.7.8
|
||||
# v0.7.9-rc1
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
|
||||
|
||||
- 🤖 **AI Model Selection**:
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
|
||||
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
|
||||
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
||||
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
||||
@@ -66,10 +66,9 @@
|
||||
- 🔦 **Agents & Tools Integration**:
|
||||
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
||||
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
|
||||
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
|
||||
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
||||
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
|
||||
|
||||
- 🔍 **Web Search**:
|
||||
- Search the internet and retrieve relevant information to enhance your AI context
|
||||
|
||||
@@ -13,7 +13,6 @@ const {
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { addSpaceIfNeeded } = require('~/server/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
@@ -572,7 +571,7 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
const { generation = '' } = opts;
|
||||
const { editedContent } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
// depending on subclass implementation of handling messages
|
||||
@@ -587,11 +586,21 @@ class BaseClient {
|
||||
isCreatedByUser: false,
|
||||
model: this.modelOptions?.model ?? this.model,
|
||||
sender: this.sender,
|
||||
text: generation,
|
||||
};
|
||||
this.currentMessages.push(userMessage, latestMessage);
|
||||
} else {
|
||||
latestMessage.text = generation;
|
||||
} else if (editedContent != null) {
|
||||
// Handle editedContent for content parts
|
||||
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
|
||||
const { index, text, type } = editedContent;
|
||||
if (index >= 0 && index < latestMessage.content.length) {
|
||||
const contentPart = latestMessage.content[index];
|
||||
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
|
||||
contentPart[ContentTypes.THINK] = text;
|
||||
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
|
||||
contentPart[ContentTypes.TEXT] = text;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this.continued = true;
|
||||
} else {
|
||||
@@ -672,16 +681,32 @@ class BaseClient {
|
||||
};
|
||||
|
||||
if (typeof completion === 'string') {
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion;
|
||||
responseMessage.text = completion;
|
||||
} else if (
|
||||
Array.isArray(completion) &&
|
||||
(this.clientName === EModelEndpoint.agents ||
|
||||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
||||
) {
|
||||
responseMessage.text = '';
|
||||
responseMessage.content = completion;
|
||||
|
||||
if (!opts.editedContent || this.currentMessages.length === 0) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
|
||||
if (!latestMessage?.content) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const existingContent = [...latestMessage.content];
|
||||
const { type: editedType } = opts.editedContent;
|
||||
responseMessage.content = this.mergeEditedContent(
|
||||
existingContent,
|
||||
completion,
|
||||
editedType,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (Array.isArray(completion)) {
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
|
||||
responseMessage.text = completion.join('');
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -1095,6 +1120,50 @@ class BaseClient {
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges completion content with existing content when editing TEXT or THINK types
|
||||
* @param {Array} existingContent - The existing content array
|
||||
* @param {Array} newCompletion - The new completion content
|
||||
* @param {string} editedType - The type of content being edited
|
||||
* @returns {Array} The merged content array
|
||||
*/
|
||||
mergeEditedContent(existingContent, newCompletion, editedType) {
|
||||
if (!newCompletion.length) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const lastIndex = existingContent.length - 1;
|
||||
const lastExisting = existingContent[lastIndex];
|
||||
const firstNew = newCompletion[0];
|
||||
|
||||
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const mergedContent = [...existingContent];
|
||||
if (editedType === ContentTypes.TEXT) {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.TEXT]:
|
||||
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
|
||||
};
|
||||
} else {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.THINK]:
|
||||
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
|
||||
(firstNew[ContentTypes.THINK] || ''),
|
||||
};
|
||||
}
|
||||
|
||||
// Add remaining completion items
|
||||
return mergedContent.concat(newCompletion.slice(1));
|
||||
}
|
||||
|
||||
async sendPayload(payload, opts = {}) {
|
||||
if (opts && typeof opts === 'object') {
|
||||
this.setOptions(opts);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { google } = require('googleapis');
|
||||
const { Tokenizer } = require('@librechat/api');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { Tokenizer, getSafetySettings } = require('@librechat/api');
|
||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
||||
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
||||
@@ -12,13 +12,13 @@ const {
|
||||
endpointSettings,
|
||||
parseTextParts,
|
||||
EModelEndpoint,
|
||||
googleSettings,
|
||||
ContentTypes,
|
||||
VisionModes,
|
||||
ErrorTypes,
|
||||
Constants,
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
@@ -166,6 +166,16 @@ class GoogleClient extends BaseClient {
|
||||
);
|
||||
}
|
||||
|
||||
// Add thinking configuration
|
||||
this.modelOptions.thinkingConfig = {
|
||||
thinkingBudget:
|
||||
(this.modelOptions.thinking ?? googleSettings.thinking.default)
|
||||
? this.modelOptions.thinkingBudget
|
||||
: 0,
|
||||
};
|
||||
delete this.modelOptions.thinking;
|
||||
delete this.modelOptions.thinkingBudget;
|
||||
|
||||
this.sender =
|
||||
this.options.sender ??
|
||||
getResponseSender({
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const axios = require('axios');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
|
||||
const footer = `Use the context as your learned knowledge to better answer the user.
|
||||
|
||||
@@ -18,7 +19,7 @@ function createContextHandlers(req, userMessageContent) {
|
||||
const queryPromises = [];
|
||||
const processedFiles = [];
|
||||
const processedIds = new Set();
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
|
||||
|
||||
const query = async (file) => {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, EToolResources } = require('librechat-data-provider');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -59,7 +60,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
if (files.length === 0) {
|
||||
return 'No files to search. Instruct the user to add files for the search.';
|
||||
}
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
if (!jwtToken) {
|
||||
return 'There was an error authenticating the file search request.';
|
||||
}
|
||||
|
||||
3
api/cache/banViolation.js
vendored
3
api/cache/banViolation.js
vendored
@@ -1,7 +1,8 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, math } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, math, removePorts } = require('~/server/utils');
|
||||
const { deleteAllUserSessions } = require('~/models');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const getLogStores = require('./getLogStores');
|
||||
|
||||
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
||||
|
||||
2
api/cache/getLogStores.js
vendored
2
api/cache/getLogStores.js
vendored
@@ -1,7 +1,7 @@
|
||||
const { Keyv } = require('keyv');
|
||||
const { isEnabled, math } = require('@librechat/api');
|
||||
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const { isEnabled, math } = require('~/server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
|
||||
|
||||
2
api/cache/logViolation.js
vendored
2
api/cache/logViolation.js
vendored
@@ -9,7 +9,7 @@ const banViolation = require('./banViolation');
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {string} type - The type of violation.
|
||||
* @param {Object} errorMessage - The error message to log.
|
||||
* @param {number} [score=1] - The severity of the violation. Defaults to 1
|
||||
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
|
||||
*/
|
||||
const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||
const userId = req.user?.id ?? req.user?._id;
|
||||
|
||||
@@ -90,7 +90,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
return {
|
||||
const result = {
|
||||
id: agent_id,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
@@ -98,6 +98,11 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
@@ -98,10 +100,15 @@ module.exports = {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
if (req.body.isTemporary) {
|
||||
const expiredAt = new Date();
|
||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
||||
update.expiredAt = expiredAt;
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
@@ -54,9 +56,14 @@ async function saveMessage(req, params, metadata) {
|
||||
};
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
const expiredAt = new Date();
|
||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
||||
update.expiredAt = expiredAt;
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.7.8",
|
||||
"version": "v0.7.9-rc1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -48,7 +48,7 @@
|
||||
"@langchain/google-genai": "^0.2.13",
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.42",
|
||||
"@librechat/agents": "^2.4.56",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const openIdClient = require('openid-client');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
registerUser,
|
||||
resetPassword,
|
||||
setAuthTokens,
|
||||
requestPasswordReset,
|
||||
setOpenIDAuthTokens,
|
||||
resetPassword,
|
||||
setAuthTokens,
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getResponseSender } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbortError,
|
||||
@@ -10,9 +12,8 @@ const {
|
||||
clientRegistry,
|
||||
requestDataMap,
|
||||
} = require('~/server/cleanup');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const EditController = async (req, res, next, initializeClient) => {
|
||||
let {
|
||||
@@ -198,7 +199,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
const finalUserMessage = reqDataContext.userMessage;
|
||||
const finalResponseMessage = { ...response };
|
||||
|
||||
sendMessage(res, {
|
||||
sendEvent(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
|
||||
@@ -24,17 +24,23 @@ const handleValidationError = (err, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
module.exports = (err, req, res, next) => {
|
||||
module.exports = (err, _req, res, _next) => {
|
||||
try {
|
||||
if (err.name === 'ValidationError') {
|
||||
return (err = handleValidationError(err, res));
|
||||
return handleValidationError(err, res);
|
||||
}
|
||||
if (err.code && err.code == 11000) {
|
||||
return (err = handleDuplicateKeyError(err, res));
|
||||
return handleDuplicateKeyError(err, res);
|
||||
}
|
||||
} catch (err) {
|
||||
// Special handling for errors like SyntaxError
|
||||
if (err.statusCode && err.body) {
|
||||
return res.status(err.statusCode).send(err.body);
|
||||
}
|
||||
|
||||
logger.error('ErrorController => error', err);
|
||||
res.status(500).send('An unknown error occurred.');
|
||||
return res.status(500).send('An unknown error occurred.');
|
||||
} catch (err) {
|
||||
logger.error('ErrorController => processing error', err);
|
||||
return res.status(500).send('Processing error in ErrorController.');
|
||||
}
|
||||
};
|
||||
|
||||
241
api/server/controllers/ErrorController.spec.js
Normal file
241
api/server/controllers/ErrorController.spec.js
Normal file
@@ -0,0 +1,241 @@
|
||||
const errorController = require('./ErrorController');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// Mock the logger
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ErrorController', () => {
|
||||
let mockReq, mockRes, mockNext;
|
||||
|
||||
beforeEach(() => {
|
||||
mockReq = {};
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
mockNext = jest.fn();
|
||||
logger.error.mockClear();
|
||||
});
|
||||
|
||||
describe('ValidationError handling', () => {
|
||||
it('should handle ValidationError with single error', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '["Email is required"]',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with multiple errors', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
password: { message: 'Password is required', path: 'password' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '"Email is required Password is required"',
|
||||
fields: '["email","password"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with empty errors object', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '[]',
|
||||
fields: '[]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Duplicate key error handling', () => {
|
||||
it('should handle duplicate key error (code 11000)', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle duplicate key error with multiple fields', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com', username: 'testuser' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email","username"] already exists.',
|
||||
fields: '["email","username"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle error with code 11000 as string', () => {
|
||||
const duplicateKeyError = {
|
||||
code: '11000',
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('SyntaxError handling', () => {
|
||||
it('should handle errors with statusCode and body', () => {
|
||||
const syntaxError = {
|
||||
statusCode: 400,
|
||||
body: 'Invalid JSON syntax',
|
||||
};
|
||||
|
||||
errorController(syntaxError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
||||
});
|
||||
|
||||
it('should handle errors with different statusCode and body', () => {
|
||||
const customError = {
|
||||
statusCode: 422,
|
||||
body: { error: 'Unprocessable entity' },
|
||||
};
|
||||
|
||||
errorController(customError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(422);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
||||
});
|
||||
|
||||
it('should handle error with statusCode but no body', () => {
|
||||
const partialError = {
|
||||
statusCode: 400,
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
|
||||
it('should handle error with body but no statusCode', () => {
|
||||
const partialError = {
|
||||
body: 'Some error message',
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Unknown error handling', () => {
|
||||
it('should handle unknown errors', () => {
|
||||
const unknownError = new Error('Some unknown error');
|
||||
|
||||
errorController(unknownError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
|
||||
});
|
||||
|
||||
it('should handle errors with code other than 11000', () => {
|
||||
const mongoError = {
|
||||
code: 11100,
|
||||
message: 'Some MongoDB error',
|
||||
};
|
||||
|
||||
errorController(mongoError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
|
||||
});
|
||||
|
||||
it('should handle null/undefined errors', () => {
|
||||
errorController(null, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'ErrorController => processing error',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Catch block handling', () => {
|
||||
beforeEach(() => {
|
||||
// Restore logger mock to normal behavior for these tests
|
||||
logger.error.mockRestore();
|
||||
logger.error = jest.fn();
|
||||
});
|
||||
|
||||
it('should handle errors when logger.error throws', () => {
|
||||
// Create fresh mocks for this test
|
||||
const freshMockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
|
||||
// Mock logger to throw on the first call, succeed on the second
|
||||
logger.error
|
||||
.mockImplementationOnce(() => {
|
||||
throw new Error('Logger error');
|
||||
})
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const testError = new Error('Test error');
|
||||
|
||||
errorController(testError, mockReq, freshMockRes, mockNext);
|
||||
|
||||
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
195
api/server/controllers/agents/__tests__/v1.spec.js
Normal file
195
api/server/controllers/agents/__tests__/v1.spec.js
Normal file
@@ -0,0 +1,195 @@
|
||||
const { duplicateAgent } = require('../v1');
|
||||
const { getAgent, createAgent } = require('~/models/Agent');
|
||||
const { getActions } = require('~/models/Action');
|
||||
const { nanoid } = require('nanoid');
|
||||
|
||||
jest.mock('~/models/Agent');
|
||||
jest.mock('~/models/Action');
|
||||
jest.mock('nanoid');
|
||||
|
||||
describe('duplicateAgent', () => {
|
||||
let req, res;
|
||||
|
||||
beforeEach(() => {
|
||||
req = {
|
||||
params: { id: 'agent_123' },
|
||||
user: { id: 'user_456' },
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should duplicate an agent successfully', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
author: 'user_789',
|
||||
versions: [{ name: 'Test Agent', version: 1 }],
|
||||
__v: 0,
|
||||
};
|
||||
|
||||
const mockNewAgent = {
|
||||
id: 'agent_new_123',
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
author: 'user_456',
|
||||
versions: [
|
||||
{
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
getAgent.mockResolvedValue(mockAgent);
|
||||
getActions.mockResolvedValue([]);
|
||||
nanoid.mockReturnValue('new_123');
|
||||
createAgent.mockResolvedValue(mockNewAgent);
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
|
||||
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'agent_new_123',
|
||||
author: 'user_456',
|
||||
name: expect.stringContaining('Test Agent ('),
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
}),
|
||||
);
|
||||
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.not.objectContaining({
|
||||
versions: expect.anything(),
|
||||
__v: expect.anything(),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(201);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
agent: mockNewAgent,
|
||||
actions: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
description: 'Test Description',
|
||||
versions: [
|
||||
{
|
||||
name: 'Test Agent',
|
||||
versions: [{ name: 'Nested' }],
|
||||
__v: 1,
|
||||
},
|
||||
],
|
||||
__v: 2,
|
||||
};
|
||||
|
||||
const mockNewAgent = {
|
||||
id: 'agent_new_123',
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
versions: [
|
||||
{
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
getAgent.mockResolvedValue(mockAgent);
|
||||
getActions.mockResolvedValue([]);
|
||||
nanoid.mockReturnValue('new_123');
|
||||
createAgent.mockResolvedValue(mockNewAgent);
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(mockNewAgent.versions).toHaveLength(1);
|
||||
|
||||
const firstVersion = mockNewAgent.versions[0];
|
||||
expect(firstVersion).not.toHaveProperty('versions');
|
||||
expect(firstVersion).not.toHaveProperty('__v');
|
||||
|
||||
expect(mockNewAgent).not.toHaveProperty('__v');
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(201);
|
||||
});
|
||||
|
||||
it('should return 404 if agent not found', async () => {
|
||||
getAgent.mockResolvedValue(null);
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Agent not found',
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle tool_resources.ocr correctly', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
tool_resources: {
|
||||
ocr: { enabled: true, config: 'test' },
|
||||
other: { should: 'not be copied' },
|
||||
},
|
||||
};
|
||||
|
||||
getAgent.mockResolvedValue(mockAgent);
|
||||
getActions.mockResolvedValue([]);
|
||||
nanoid.mockReturnValue('new_123');
|
||||
createAgent.mockResolvedValue({ id: 'agent_new_123' });
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tool_resources: {
|
||||
ocr: { enabled: true, config: 'test' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
getAgent.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
|
||||
});
|
||||
});
|
||||
@@ -4,6 +4,7 @@ const {
|
||||
sendEvent,
|
||||
createRun,
|
||||
Tokenizer,
|
||||
checkAccess,
|
||||
memoryInstructions,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
@@ -39,11 +40,22 @@ const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { checkAccess } = require('~/server/middleware/roles/access');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
||||
const omitTitleOptions = new Set([
|
||||
'stream',
|
||||
'thinking',
|
||||
'streaming',
|
||||
'clientOptions',
|
||||
'thinkingConfig',
|
||||
'thinkingBudget',
|
||||
'includeThoughts',
|
||||
'maxOutputTokens',
|
||||
]);
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {Agent} agent
|
||||
@@ -390,7 +402,12 @@ class AgentClient extends BaseClient {
|
||||
if (user.personalization?.memories === false) {
|
||||
return;
|
||||
}
|
||||
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
|
||||
const hasAccess = await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
if (!hasAccess) {
|
||||
logger.debug(
|
||||
@@ -508,7 +525,10 @@ class AgentClient extends BaseClient {
|
||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||
}
|
||||
}
|
||||
return await this.processMemory(messagesToProcess);
|
||||
|
||||
const bufferString = getBufferString(messagesToProcess);
|
||||
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
|
||||
return await this.processMemory([bufferMessage]);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
@@ -1038,6 +1058,16 @@ class AgentClient extends BaseClient {
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
clientOptions = Object.assign(
|
||||
Object.fromEntries(
|
||||
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
|
||||
),
|
||||
);
|
||||
|
||||
if (provider === Providers.GOOGLE) {
|
||||
clientOptions.json = true;
|
||||
}
|
||||
|
||||
try {
|
||||
const titleResult = await this.run.generateTitle({
|
||||
provider,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// errorHandler.js
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { sendResponse } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
@@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === 'azureAssistants'
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbortError,
|
||||
@@ -5,17 +7,18 @@ const {
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
@@ -67,7 +70,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
handler();
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
logger.error('[AgentController] Error in cleanup handler', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -155,7 +158,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
try {
|
||||
res.removeListener('close', closeHandler);
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
logger.error('[AgentController] Error removing close listener', e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -163,10 +166,14 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res,
|
||||
},
|
||||
@@ -206,7 +213,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
// Create a new response object with minimal copies
|
||||
const finalResponse = { ...response };
|
||||
|
||||
sendMessage(res, {
|
||||
sendEvent(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
const { z } = require('zod');
|
||||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
@@ -8,6 +10,7 @@ const {
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getAgent,
|
||||
@@ -30,6 +33,7 @@ const { deleteFileByFilter } = require('~/models/File');
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
[Tools.file_search]: true,
|
||||
[Tools.web_search]: true,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -42,9 +46,13 @@ const systemTools = {
|
||||
*/
|
||||
const createAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
|
||||
const validatedData = agentCreateSchema.parse(req.body);
|
||||
const { tools = [], ...agentData } = removeNullishValues(validatedData);
|
||||
|
||||
const { id: userId } = req.user;
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
@@ -58,19 +66,13 @@ const createAgentHandler = async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
Object.assign(agentData, {
|
||||
author: userId,
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
const agent = await createAgent(agentData);
|
||||
res.status(201).json(agent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
logger.error('[/Agents] Error creating agent', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
@@ -154,14 +156,16 @@ const getAgentHandler = async (req, res) => {
|
||||
const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const { projectIds, removeProjectIds, ...updateData } = req.body;
|
||||
const validatedData = agentUpdateSchema.parse(req.body);
|
||||
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id });
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
@@ -200,6 +204,11 @@ const updateAgentHandler = async (req, res) => {
|
||||
|
||||
return res.json(updatedAgent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents/:id] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
|
||||
logger.error('[/Agents/:id] Error updating Agent', error);
|
||||
|
||||
if (error.statusCode === 409) {
|
||||
@@ -242,6 +251,8 @@ const duplicateAgentHandler = async (req, res) => {
|
||||
createdAt: _createdAt,
|
||||
updatedAt: _updatedAt,
|
||||
tool_resources: _tool_resources = {},
|
||||
versions: _versions,
|
||||
__v: _v,
|
||||
...cloneData
|
||||
} = agent;
|
||||
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
|
||||
|
||||
659
api/server/controllers/agents/v1.spec.js
Normal file
659
api/server/controllers/agents/v1.spec.js
Normal file
@@ -0,0 +1,659 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
|
||||
// Only mock the dependencies that are not database-related
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
web_search: true,
|
||||
execute_code: true,
|
||||
file_search: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Project', () => ({
|
||||
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||
resizeAvatar: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3Url: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
filterFile: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Action', () => ({
|
||||
updateAction: jest.fn(),
|
||||
getActions: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/File', () => ({
|
||||
deleteFileByFilter: jest.fn(),
|
||||
}));
|
||||
|
||||
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
*/
|
||||
let Agent;
|
||||
|
||||
describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
|
||||
// Reset all mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Setup mock request and response objects
|
||||
mockReq = {
|
||||
user: {
|
||||
id: new mongoose.Types.ObjectId().toString(),
|
||||
role: 'USER',
|
||||
},
|
||||
body: {},
|
||||
params: {},
|
||||
app: {
|
||||
locals: {
|
||||
fileStrategy: 'local',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('createAgentHandler', () => {
|
||||
test('should create agent with allowed fields only', async () => {
|
||||
const validData = {
|
||||
name: 'Test Agent',
|
||||
description: 'A test agent',
|
||||
instructions: 'Be helpful',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search'],
|
||||
model_parameters: { temperature: 0.7 },
|
||||
tool_resources: {
|
||||
file_search: { file_ids: ['file1', 'file2'] },
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = validData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.name).toBe('Test Agent');
|
||||
expect(createdAgent.description).toBe('A test agent');
|
||||
expect(createdAgent.provider).toBe('openai');
|
||||
expect(createdAgent.model).toBe('gpt-4');
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
|
||||
expect(createdAgent.tools).toContain('web_search');
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb).toBeDefined();
|
||||
expect(agentInDb.name).toBe('Test Agent');
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
});
|
||||
|
||||
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
|
||||
const maliciousData = {
|
||||
// Required fields
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Malicious Agent',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
isCollaborative: true, // Should be stripped on creation
|
||||
versions: [], // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
id: 'custom_agent_id', // Should be overridden
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
mockReq.body = maliciousData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not set
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
|
||||
expect(createdAgent.authorName).toBeUndefined();
|
||||
expect(createdAgent.isCollaborative).toBeFalsy();
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
|
||||
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
|
||||
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
|
||||
|
||||
// Verify timestamps are recent (not the malicious dates)
|
||||
const createdTime = new Date(createdAgent.createdAt).getTime();
|
||||
const now = Date.now();
|
||||
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
expect(agentInDb.authorName).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should validate required fields', async () => {
|
||||
const invalidData = {
|
||||
name: 'Missing Required Fields',
|
||||
// Missing provider and model
|
||||
};
|
||||
|
||||
mockReq.body = invalidData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify nothing was created in database
|
||||
const count = await Agent.countDocuments();
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
test('should handle tool_resources validation', async () => {
|
||||
const dataWithInvalidToolResources = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Tool Resources',
|
||||
tool_resources: {
|
||||
// Valid resources
|
||||
file_search: {
|
||||
file_ids: ['file1', 'file2'],
|
||||
vector_store_ids: ['vs1'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['file3'],
|
||||
},
|
||||
// Invalid resource (should be stripped by schema)
|
||||
invalid_resource: {
|
||||
file_ids: ['file4'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidToolResources;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.tool_resources).toBeDefined();
|
||||
expect(createdAgent.tool_resources.file_search).toBeDefined();
|
||||
expect(createdAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle avatar validation', async () => {
|
||||
const dataWithAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Avatar',
|
||||
avatar: {
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.avatar).toEqual({
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle invalid avatar format', async () => {
|
||||
const dataWithInvalidAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Invalid Avatar',
|
||||
avatar: 'just-a-string', // Invalid format
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateAgentHandler', () => {
|
||||
let existingAgentId;
|
||||
let existingAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create an existing agent for update tests
|
||||
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
author: existingAgentAuthorId,
|
||||
description: 'Original description',
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
description: 'Original description',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
existingAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should update agent with allowed fields only', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Agent',
|
||||
description: 'Updated description',
|
||||
model: 'gpt-4',
|
||||
isCollaborative: true, // This IS allowed in updates
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(400);
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Updated Agent');
|
||||
expect(updatedAgent.description).toBe('Updated description');
|
||||
expect(updatedAgent.model).toBe('gpt-4');
|
||||
expect(updatedAgent.isCollaborative).toBe(true);
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Updated Agent');
|
||||
expect(agentInDb.isCollaborative).toBe(true);
|
||||
});
|
||||
|
||||
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Name',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
id: 'different_agent_id', // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
versions: [], // Should be stripped
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not changed
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
|
||||
expect(updatedAgent.authorName).toBeUndefined();
|
||||
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
|
||||
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
|
||||
expect(agentInDb.id).toBe(existingAgentId);
|
||||
});
|
||||
|
||||
test('should reject update from non-author when not collaborative', async () => {
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Unauthorized Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
|
||||
// Verify agent was not modified in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Original Agent');
|
||||
});
|
||||
|
||||
test('should allow update from non-author when collaborative', async () => {
|
||||
// First make the agent collaborative
|
||||
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
|
||||
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Collaborative Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Collaborative Update');
|
||||
// Author field should be removed for non-author
|
||||
expect(updatedAgent.author).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Collaborative Update');
|
||||
});
|
||||
|
||||
test('should allow admin to update any agent', async () => {
|
||||
const adminUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = adminUserId;
|
||||
mockReq.user.role = 'ADMIN'; // Set as admin
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Admin Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Admin Update');
|
||||
});
|
||||
|
||||
test('should handle projectIds updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
||||
const projectId1 = new mongoose.Types.ObjectId().toString();
|
||||
const projectId2 = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
mockReq.body = {
|
||||
projectIds: [projectId1, projectId2],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent).toBeDefined();
|
||||
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
|
||||
});
|
||||
|
||||
test('should validate tool_resources in updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tool_resources: {
|
||||
ocr: {
|
||||
file_ids: ['ocr1', 'ocr2'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['img1'],
|
||||
},
|
||||
// Invalid tool resource
|
||||
invalid_tool: {
|
||||
file_ids: ['invalid'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tool_resources).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.ocr).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return 404 for non-existent agent', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
|
||||
mockReq.body = {
|
||||
name: 'Update Non-existent',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(404);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
|
||||
});
|
||||
|
||||
test('should handle validation errors properly', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
model_parameters: 'invalid-not-an-object', // Should be an object
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mass Assignment Attack Scenarios', () => {
|
||||
test('should prevent setting system fields during creation', async () => {
|
||||
const systemFields = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'System Fields Test',
|
||||
|
||||
// System fields that should never be settable by users
|
||||
__v: 99,
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
versions: [
|
||||
{
|
||||
name: 'Fake Version',
|
||||
provider: 'fake',
|
||||
model: 'fake-model',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockReq.body = systemFields;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify system fields were not affected
|
||||
expect(createdAgent.__v).not.toBe(99);
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
|
||||
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
|
||||
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.__v).not.toBe(99);
|
||||
});
|
||||
|
||||
test('should prevent privilege escalation through isCollaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Try to make it collaborative as a different user
|
||||
const attackerId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = attackerId;
|
||||
mockReq.params.id = agent.id;
|
||||
mockReq.body = {
|
||||
isCollaborative: true, // Trying to escalate privileges
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
// Should be rejected
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
|
||||
// Verify in database that it's still not collaborative
|
||||
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||
expect(agentInDb.isCollaborative).toBe(false);
|
||||
});
|
||||
|
||||
test('should prevent author hijacking', async () => {
|
||||
const originalAuthorId = new mongoose.Types.ObjectId();
|
||||
const attackerId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Admin creates an agent
|
||||
mockReq.user.id = originalAuthorId.toString();
|
||||
mockReq.user.role = 'ADMIN';
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Admin Agent',
|
||||
author: attackerId.toString(), // Trying to set different author
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Author should be the actual user, not the attempted value
|
||||
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
|
||||
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
|
||||
});
|
||||
|
||||
test('should strip unknown fields to prevent future vulnerabilities', async () => {
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Future Proof Test',
|
||||
|
||||
// Unknown fields that might be added in future
|
||||
superAdminAccess: true,
|
||||
bypassAllChecks: true,
|
||||
internalFlag: 'secret',
|
||||
futureFeature: 'exploit',
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unknown fields were stripped
|
||||
expect(createdAgent.superAdminAccess).toBeUndefined();
|
||||
expect(createdAgent.bypassAllChecks).toBeUndefined();
|
||||
expect(createdAgent.internalFlag).toBeUndefined();
|
||||
expect(createdAgent.futureFeature).toBeUndefined();
|
||||
|
||||
// Also check in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
|
||||
expect(agentInDb.superAdminAccess).toBeUndefined();
|
||||
expect(agentInDb.bypassAllChecks).toBeUndefined();
|
||||
expect(agentInDb.internalFlag).toBeUndefined();
|
||||
expect(agentInDb.futureFeature).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -19,20 +22,20 @@ const {
|
||||
addThreadMetadata,
|
||||
saveAssistantMessage,
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
@@ -471,7 +474,7 @@ const chatV1 = async (req, res) => {
|
||||
await Promise.all(promises);
|
||||
|
||||
const sendInitialResponse = () => {
|
||||
sendMessage(res, {
|
||||
sendEvent(res, {
|
||||
sync: true,
|
||||
conversationId,
|
||||
// messages: previousMessages,
|
||||
@@ -587,7 +590,7 @@ const chatV1 = async (req, res) => {
|
||||
iconURL: endpointOption.iconURL,
|
||||
};
|
||||
|
||||
sendMessage(res, {
|
||||
sendEvent(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -22,15 +25,14 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { sendMessage, sleep, countTokens } = require('~/server/utils');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
@@ -309,7 +311,7 @@ const chatV2 = async (req, res) => {
|
||||
await Promise.all(promises);
|
||||
|
||||
const sendInitialResponse = () => {
|
||||
sendMessage(res, {
|
||||
sendEvent(res, {
|
||||
sync: true,
|
||||
conversationId,
|
||||
// messages: previousMessages,
|
||||
@@ -432,7 +434,7 @@ const chatV2 = async (req, res) => {
|
||||
iconURL: endpointOption.iconURL,
|
||||
};
|
||||
|
||||
sendMessage(res, {
|
||||
sendEvent(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// errorHandler.js
|
||||
const { sendResponse } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
@@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === 'azureAssistants'
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const { nanoid } = require('nanoid');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const { checkAccess } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Tools,
|
||||
AuthType,
|
||||
@@ -13,9 +15,8 @@ const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { checkAccess } = require('~/server/middleware');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const fieldsMap = {
|
||||
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
|
||||
@@ -79,6 +80,7 @@ const verifyToolAuth = async (req, res) => {
|
||||
throwError: false,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading auth values', error);
|
||||
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
|
||||
return;
|
||||
}
|
||||
@@ -132,7 +134,12 @@ const callTool = async (req, res) => {
|
||||
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
|
||||
let hasAccess = true;
|
||||
if (toolAccessPermType[toolId]) {
|
||||
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
|
||||
hasAccess = await checkAccess({
|
||||
user: req.user,
|
||||
permissionType: toolAccessPermType[toolId],
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
}
|
||||
if (!hasAccess) {
|
||||
logger.warn(
|
||||
|
||||
@@ -55,7 +55,6 @@ const startServer = async () => {
|
||||
|
||||
/* Middleware */
|
||||
app.use(noIndex);
|
||||
app.use(errorController);
|
||||
app.use(express.json({ limit: '3mb' }));
|
||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||
app.use(mongoSanitize());
|
||||
@@ -121,6 +120,9 @@ const startServer = async () => {
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
// Add the error controller one more time after all routes
|
||||
app.use(errorController);
|
||||
|
||||
app.use((req, res) => {
|
||||
res.set({
|
||||
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const request = require('supertest');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const mongoose = require('mongoose');
|
||||
@@ -59,6 +58,30 @@ describe('Server Configuration', () => {
|
||||
expect(response.headers['pragma']).toBe('no-cache');
|
||||
expect(response.headers['expires']).toBe('0');
|
||||
});
|
||||
|
||||
it('should return 500 for unknown errors via ErrorController', async () => {
|
||||
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
||||
|
||||
// Mock MongoDB operations to fail
|
||||
const originalFindOne = mongoose.models.User.findOne;
|
||||
const mockError = new Error('MongoDB operation failed');
|
||||
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
|
||||
throw mockError;
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await request(app).post('/api/auth/login').send({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.text).toBe('An unknown error occurred.');
|
||||
} finally {
|
||||
// Restore original function
|
||||
mongoose.models.User.findOne = originalFindOne;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// abortMiddleware.js
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { abortRun } = require('./abortRun');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
@@ -101,7 +101,7 @@ async function abortMessage(req, res) {
|
||||
cleanupAbortController(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
return sendMessage(res, finalEvent);
|
||||
return sendEvent(res, finalEvent);
|
||||
}
|
||||
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
@@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
* @param {string} responseMessageId
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
getReqData({ abortKey });
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||
const { deleteMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const three_minutes = 1000 * 60 * 3;
|
||||
|
||||
@@ -34,7 +34,7 @@ async function abortRun(req, res) {
|
||||
const [thread_id, run_id] = runValues.split(':');
|
||||
|
||||
if (!run_id) {
|
||||
logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id });
|
||||
logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id });
|
||||
return res.status(204).send({ message: 'Run not found' });
|
||||
} else if (run_id === 'cancelled') {
|
||||
logger.warn('[abortRun] Run already cancelled', { thread_id });
|
||||
@@ -93,7 +93,7 @@ async function abortRun(req, res) {
|
||||
};
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
return sendMessage(res, finalEvent);
|
||||
return sendEvent(res, finalEvent);
|
||||
}
|
||||
|
||||
res.json(finalEvent);
|
||||
|
||||
@@ -18,7 +18,6 @@ const message = 'Your account has been temporarily banned due to violations of o
|
||||
* @function
|
||||
* @param {Object} req - Express Request object.
|
||||
* @param {Object} res - Express Response object.
|
||||
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
|
||||
*
|
||||
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
|
||||
*/
|
||||
@@ -135,6 +134,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
return await banResponse(req, res);
|
||||
} catch (error) {
|
||||
logger.error('Error in checkBan middleware:', error);
|
||||
return next(error);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const crypto = require('crypto');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError } = require('~/server/utils');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
/**
|
||||
@@ -36,7 +37,7 @@ const denyRequest = async (req, res, errorMessage) => {
|
||||
isCreatedByUser: true,
|
||||
text,
|
||||
};
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
||||
|
||||
|
||||
@@ -1,31 +1,9 @@
|
||||
const crypto = require('crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { parseConvo } = require('librechat-data-provider');
|
||||
const { sendEvent, handleError } = require('@librechat/api');
|
||||
const { saveMessage, getMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Sends error data in Server Sent Events format and ends the response.
|
||||
* @param {object} res - The server response.
|
||||
* @param {string} message - The error message.
|
||||
*/
|
||||
const handleError = (res, message) => {
|
||||
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
|
||||
res.end();
|
||||
};
|
||||
|
||||
/**
|
||||
* Sends message data in Server Sent Events format.
|
||||
* @param {Express.Response} res - - The server response.
|
||||
* @param {string | Object} message - The message to be sent.
|
||||
* @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'.
|
||||
*/
|
||||
const sendMessage = (res, message, event = 'message') => {
|
||||
if (typeof message === 'string' && message.length === 0) {
|
||||
return;
|
||||
}
|
||||
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
|
||||
@@ -91,7 +69,7 @@ const sendError = async (req, res, options, callback) => {
|
||||
convo = parseConvo(errorMessage);
|
||||
}
|
||||
|
||||
return sendMessage(res, {
|
||||
return sendEvent(res, {
|
||||
final: true,
|
||||
requestMessage: query?.[0] ? query[0] : requestMessage,
|
||||
responseMessage: errorMessage,
|
||||
@@ -120,12 +98,10 @@ const sendResponse = (req, res, data, errorMessage) => {
|
||||
if (errorMessage) {
|
||||
return sendError(req, res, { ...data, text: errorMessage });
|
||||
}
|
||||
return sendMessage(res, data);
|
||||
return sendEvent(res, data);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
sendResponse,
|
||||
handleError,
|
||||
sendMessage,
|
||||
sendError,
|
||||
sendResponse,
|
||||
};
|
||||
95
api/server/middleware/limiters/forkLimiters.js
Normal file
95
api/server/middleware/limiters/forkLimiters.js
Normal file
@@ -0,0 +1,95 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
|
||||
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
|
||||
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
|
||||
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
|
||||
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
|
||||
|
||||
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
|
||||
const forkIpMax = FORK_IP_MAX;
|
||||
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
|
||||
|
||||
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
|
||||
const forkUserMax = FORK_USER_MAX;
|
||||
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
forkIpWindowMs,
|
||||
forkIpMax,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowMs,
|
||||
forkUserMax,
|
||||
forkUserWindowInMinutes,
|
||||
forkViolationScore: FORK_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createForkHandler = (ip = true) => {
|
||||
const {
|
||||
forkIpMax,
|
||||
forkUserMax,
|
||||
forkViolationScore,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? forkIpMax : forkUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createForkLimiters = () => {
|
||||
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ipLimiterOptions = {
|
||||
windowMs: forkIpWindowMs,
|
||||
max: forkIpMax,
|
||||
handler: createForkHandler(),
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: forkUserWindowMs,
|
||||
max: forkUserMax,
|
||||
handler: createForkHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for fork rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'fork_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'fork_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const forkIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const forkUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { forkIpLimiter, forkUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = { createForkLimiters };
|
||||
@@ -1,16 +1,17 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
||||
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
|
||||
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
|
||||
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
|
||||
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
|
||||
|
||||
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
|
||||
const importIpMax = IMPORT_IP_MAX;
|
||||
@@ -27,12 +28,18 @@ const getEnvironmentVariables = () => {
|
||||
importUserWindowMs,
|
||||
importUserMax,
|
||||
importUserWindowInMinutes,
|
||||
importViolationScore: IMPORT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createImportHandler = (ip = true) => {
|
||||
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
const {
|
||||
importIpMax,
|
||||
importUserMax,
|
||||
importViolationScore,
|
||||
importIpWindowInMinutes,
|
||||
importUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
@@ -43,7 +50,7 @@ const createImportHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, importViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ const createSTTLimiters = require('./sttLimiters');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const importLimiters = require('./importLimiters');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const forkLimiters = require('./forkLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const toolCallLimiter = require('./toolCallLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
@@ -14,6 +15,7 @@ module.exports = {
|
||||
...uploadLimiters,
|
||||
...importLimiters,
|
||||
...messageLimiters,
|
||||
...forkLimiters,
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
toolCallLimiter,
|
||||
|
||||
@@ -11,6 +11,7 @@ const {
|
||||
MESSAGE_IP_WINDOW = 1,
|
||||
MESSAGE_USER_MAX = 40,
|
||||
MESSAGE_USER_WINDOW = 1,
|
||||
MESSAGE_VIOLATION_SCORE: score,
|
||||
} = process.env;
|
||||
|
||||
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
|
||||
@@ -39,7 +40,7 @@ const createHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
@@ -27,11 +28,12 @@ const getEnvironmentVariables = () => {
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
sttViolationScore: STT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -43,7 +45,7 @@ const createSTTHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, sttViolationScore);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -6,6 +6,8 @@ const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||
const errorMessage = {
|
||||
@@ -15,7 +17,7 @@ const handler = async (req, res) => {
|
||||
windowInMinutes: 1,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||
};
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
@@ -27,11 +28,12 @@ const getEnvironmentVariables = () => {
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
ttsViolationScore: TTS_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -43,7 +45,7 @@ const createTTSHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, ttsViolationScore);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
|
||||
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
|
||||
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
|
||||
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
|
||||
|
||||
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
|
||||
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
|
||||
@@ -27,6 +28,7 @@ const getEnvironmentVariables = () => {
|
||||
fileUploadUserWindowMs,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -36,6 +38,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
fileUploadIpWindowInMinutes,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -47,7 +50,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
|
||||
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Core function to check if a user has one or more required permissions
|
||||
*
|
||||
* @param {object} user - The user object
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
|
||||
* @param {object} [checkObject] - The object to check properties against
|
||||
* @returns {Promise<boolean>} Whether the user has the required permissions
|
||||
*/
|
||||
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
|
||||
if (!user) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role.permissions && role.permissions[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role.permissions[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && checkObject) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(checkObject, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
return hasAnyPermission;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
*
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
|
||||
*/
|
||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const hasAccess = await checkAccess(
|
||||
req.user,
|
||||
permissionType,
|
||||
permissions,
|
||||
bodyProps,
|
||||
req.body,
|
||||
);
|
||||
|
||||
if (hasAccess) {
|
||||
return next();
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
|
||||
);
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
checkAccess,
|
||||
generateCheckAccess,
|
||||
};
|
||||
@@ -1,8 +1,5 @@
|
||||
const checkAdmin = require('./admin');
|
||||
const { checkAccess, generateCheckAccess } = require('./access');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
checkAccess,
|
||||
generateCheckAccess,
|
||||
};
|
||||
|
||||
@@ -1,14 +1,28 @@
|
||||
const express = require('express');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const {
|
||||
SystemRoles,
|
||||
Permissions,
|
||||
PermissionTypes,
|
||||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { getAgent, updateAgent } = require('~/models/Agent');
|
||||
const { logger } = require('~/config');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const checkAgentCreate = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
// If the user has ADMIN role
|
||||
// then action edition is possible even if not owner of the assistant
|
||||
const isAdmin = (req) => {
|
||||
@@ -41,7 +55,7 @@ router.get('/', async (req, res) => {
|
||||
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/:agent_id', async (req, res) => {
|
||||
router.post('/:agent_id', checkAgentCreate, async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
|
||||
@@ -149,7 +163,7 @@ router.post('/:agent_id', async (req, res) => {
|
||||
* @param {string} req.params.action_id - The ID of the action to delete.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:agent_id/:action_id', async (req, res) => {
|
||||
router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => {
|
||||
try {
|
||||
const { agent_id, action_id } = req.params;
|
||||
const admin = isAdmin(req);
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
const express = require('express');
|
||||
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const {
|
||||
setHeaders,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
generateCheckAccess,
|
||||
validateConvoAccess,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/agents');
|
||||
const AgentController = require('~/server/controllers/agents/request');
|
||||
const addTitle = require('~/server/services/Endpoints/agents/title');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
|
||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
const checkAgentAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
skipCheck: skipAgentCheck,
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
router.use(checkAgentAccess);
|
||||
router.use(validateConvoAccess);
|
||||
|
||||
@@ -1,29 +1,36 @@
|
||||
const express = require('express');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const v1 = require('~/server/controllers/agents/v1');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const actions = require('./actions');
|
||||
const tools = require('./tools');
|
||||
|
||||
const router = express.Router();
|
||||
const avatar = express.Router();
|
||||
|
||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkAgentAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
const checkAgentCreate = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
const checkGlobalAgentShare = generateCheckAccess(
|
||||
PermissionTypes.AGENTS,
|
||||
[Permissions.USE, Permissions.CREATE],
|
||||
{
|
||||
const checkGlobalAgentShare = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
bodyProps: {
|
||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||
},
|
||||
);
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkAgentAccess);
|
||||
|
||||
/**
|
||||
* Agent actions route.
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
||||
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
|
||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { importConversations } = require('~/server/utils/import');
|
||||
const { createImportLimiters } = require('~/server/middleware');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const assistantClients = {
|
||||
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
|
||||
@@ -43,6 +44,7 @@ router.get('/', async (req, res) => {
|
||||
});
|
||||
res.status(200).json(result);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching conversations', error);
|
||||
res.status(500).json({ error: 'Error fetching conversations' });
|
||||
}
|
||||
});
|
||||
@@ -156,6 +158,7 @@ router.post('/update', async (req, res) => {
|
||||
});
|
||||
|
||||
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
||||
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
|
||||
const upload = multer({ storage: storage, fileFilter: importFileFilter });
|
||||
|
||||
/**
|
||||
@@ -189,7 +192,7 @@ router.post(
|
||||
* @param {express.Response<TForkConvoResponse>} res - Express response object.
|
||||
* @returns {Promise<void>} - The response after forking the conversation.
|
||||
*/
|
||||
router.post('/fork', async (req, res) => {
|
||||
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
|
||||
try {
|
||||
/** @type {TForkConvoRequest} */
|
||||
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;
|
||||
|
||||
@@ -477,7 +477,9 @@ describe('Multer Configuration', () => {
|
||||
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
||||
} catch (error) {
|
||||
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
||||
expect(error.code).toBe('EACCES');
|
||||
// On Linux, this typically returns EACCES (permission denied)
|
||||
// On macOS/Darwin, this returns ENOENT (no such file or directory)
|
||||
expect(['EACCES', 'ENOENT']).toContain(error.code);
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,37 +1,43 @@
|
||||
const express = require('express');
|
||||
const { Tokenizer } = require('@librechat/api');
|
||||
const { Tokenizer, generateCheckAccess } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const {
|
||||
getAllUserMemories,
|
||||
toggleUserMemories,
|
||||
createMemory,
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
setMemory,
|
||||
} = require('~/models');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.READ,
|
||||
]);
|
||||
const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.UPDATE,
|
||||
]);
|
||||
const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.UPDATE,
|
||||
]);
|
||||
const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
|
||||
Permissions.USE,
|
||||
Permissions.OPT_OUT,
|
||||
]);
|
||||
const checkMemoryRead = generateCheckAccess({
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE, Permissions.READ],
|
||||
getRoleByName,
|
||||
});
|
||||
const checkMemoryCreate = generateCheckAccess({
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
const checkMemoryUpdate = generateCheckAccess({
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE, Permissions.UPDATE],
|
||||
getRoleByName,
|
||||
});
|
||||
const checkMemoryDelete = generateCheckAccess({
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE, Permissions.UPDATE],
|
||||
getRoleByName,
|
||||
});
|
||||
const checkMemoryOptOut = generateCheckAccess({
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE, Permissions.OPT_OUT],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
@@ -166,40 +172,68 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
|
||||
/**
|
||||
* PATCH /memories/:key
|
||||
* Updates the value of an existing memory entry for the authenticated user.
|
||||
* Body: { value: string }
|
||||
* Body: { key?: string, value: string }
|
||||
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
||||
*/
|
||||
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
const { value } = req.body || {};
|
||||
const { key: urlKey } = req.params;
|
||||
const { key: bodyKey, value } = req.body || {};
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
// Use the key from the body if provided, otherwise use the key from the URL
|
||||
const newKey = bodyKey || urlKey;
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
const existingMemory = memories.find((m) => m.key === key);
|
||||
const existingMemory = memories.find((m) => m.key === urlKey);
|
||||
|
||||
if (!existingMemory) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
// If the key is changing, we need to handle it specially
|
||||
if (newKey !== urlKey) {
|
||||
const keyExists = memories.find((m) => m.key === newKey);
|
||||
if (keyExists) {
|
||||
return res.status(409).json({ error: 'Memory with this key already exists.' });
|
||||
}
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
const createResult = await createMemory({
|
||||
userId: req.user.id,
|
||||
key: newKey,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!createResult.ok) {
|
||||
return res.status(500).json({ error: 'Failed to create new memory.' });
|
||||
}
|
||||
|
||||
const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey });
|
||||
if (!deleteResult.ok) {
|
||||
return res.status(500).json({ error: 'Failed to delete old memory.' });
|
||||
}
|
||||
} else {
|
||||
// Key is not changing, just update the value
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key: newKey,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
}
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === key);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === newKey);
|
||||
|
||||
res.json({ updated: true, memory: updatedMemory });
|
||||
} catch (error) {
|
||||
|
||||
@@ -235,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) =
|
||||
return res.status(400).json({ error: 'Content part not found' });
|
||||
}
|
||||
|
||||
if (updatedContent[index].type !== ContentTypes.TEXT) {
|
||||
const currentPartType = updatedContent[index].type;
|
||||
if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) {
|
||||
return res.status(400).json({ error: 'Cannot update non-text content' });
|
||||
}
|
||||
|
||||
const oldText = updatedContent[index].text;
|
||||
updatedContent[index] = { type: ContentTypes.TEXT, text };
|
||||
const oldText = updatedContent[index][currentPartType];
|
||||
updatedContent[index] = { type: currentPartType, [currentPartType]: text };
|
||||
|
||||
let tokenCount = message.tokenCount;
|
||||
if (tokenCount !== undefined) {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const express = require('express');
|
||||
const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
getPrompt,
|
||||
getPrompts,
|
||||
@@ -15,23 +17,31 @@ const {
|
||||
makePromptProduction,
|
||||
} = require('~/models/Prompt');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const { getUserById, updateUser } = require('~/models');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
|
||||
const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkPromptAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
const checkPromptCreate = generateCheckAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
const checkGlobalPromptShare = generateCheckAccess(
|
||||
PermissionTypes.PROMPTS,
|
||||
[Permissions.USE, Permissions.CREATE],
|
||||
{
|
||||
const checkGlobalPromptShare = generateCheckAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE],
|
||||
bodyProps: {
|
||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||
},
|
||||
);
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkPromptAccess);
|
||||
@@ -173,6 +183,28 @@ router.patch('/:promptId/tags/production', checkPromptCreate, async (req, res) =
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to get user's prompt preferences (favorites and rankings)
|
||||
* GET /preferences
|
||||
*/
|
||||
router.get('/preferences', async (req, res) => {
|
||||
try {
|
||||
const user = await getUserById(req.user.id);
|
||||
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
||||
res.json({
|
||||
favorites: user.promptFavorites || [],
|
||||
rankings: user.promptRanking || [],
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error getting user preferences', error);
|
||||
res.status(500).json({ message: 'Error getting user preferences' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/:promptId', async (req, res) => {
|
||||
const { promptId } = req.params;
|
||||
const author = req.user.id;
|
||||
@@ -243,4 +275,79 @@ const deletePromptGroupController = async (req, res) => {
|
||||
router.delete('/:promptId', checkPromptCreate, deletePromptController);
|
||||
router.delete('/groups/:groupId', checkPromptCreate, deletePromptGroupController);
|
||||
|
||||
/**
|
||||
* Route to toggle favorite status for a prompt group
|
||||
* POST /favorites/:groupId
|
||||
*/
|
||||
router.post('/favorites/:groupId', async (req, res) => {
|
||||
try {
|
||||
const { groupId } = req.params;
|
||||
|
||||
const user = await getUserById(req.user.id);
|
||||
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
||||
const favorites = user.promptFavorites || [];
|
||||
const isFavorite = favorites.some((id) => id.toString() === groupId.toString());
|
||||
|
||||
let updatedFavorites;
|
||||
if (isFavorite) {
|
||||
updatedFavorites = favorites.filter((id) => id.toString() !== groupId.toString());
|
||||
} else {
|
||||
updatedFavorites = [...favorites, groupId];
|
||||
}
|
||||
|
||||
await updateUser(req.user.id, { promptFavorites: updatedFavorites });
|
||||
|
||||
const response = {
|
||||
promptGroupId: groupId,
|
||||
isFavorite: !isFavorite,
|
||||
};
|
||||
|
||||
res.json(response);
|
||||
} catch (error) {
|
||||
logger.error('Error toggling favorite status', error);
|
||||
res.status(500).json({ message: 'Error updating favorite status' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to update prompt group rankings
|
||||
* PUT /rankings
|
||||
*/
|
||||
router.put('/rankings', async (req, res) => {
|
||||
try {
|
||||
const { rankings } = req.body;
|
||||
|
||||
if (!Array.isArray(rankings)) {
|
||||
return res.status(400).json({ message: 'Rankings must be an array' });
|
||||
}
|
||||
|
||||
const user = await getUserById(req.user.id);
|
||||
|
||||
if (!user) {
|
||||
return res.status(404).json({ message: 'User not found' });
|
||||
}
|
||||
|
||||
const promptRanking = rankings
|
||||
.filter(({ promptGroupId, order }) => promptGroupId && !isNaN(parseInt(order, 10)))
|
||||
.map(({ promptGroupId, order }) => ({
|
||||
promptGroupId,
|
||||
order: parseInt(order, 10),
|
||||
}));
|
||||
|
||||
const updatedUser = await updateUser(req.user.id, { promptRanking });
|
||||
|
||||
res.json({
|
||||
message: 'Rankings updated successfully',
|
||||
rankings: updatedUser?.promptRanking || [],
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error updating rankings', error);
|
||||
res.status(500).json({ message: 'Error updating rankings' });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const {
|
||||
getConversationTags,
|
||||
updateTagsForConversation,
|
||||
updateConversationTag,
|
||||
createConversationTag,
|
||||
deleteConversationTag,
|
||||
updateTagsForConversation,
|
||||
getConversationTags,
|
||||
} = require('~/models/ConversationTag');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const { logger } = require('~/config');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const checkBookmarkAccess = generateCheckAccess(PermissionTypes.BOOKMARKS, [Permissions.USE]);
|
||||
const checkBookmarkAccess = generateCheckAccess({
|
||||
permissionType: PermissionTypes.BOOKMARKS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBookmarkAccess);
|
||||
|
||||
@@ -152,12 +152,14 @@ describe('AppService', () => {
|
||||
filteredTools: undefined,
|
||||
includedTools: undefined,
|
||||
webSearch: {
|
||||
safeSearch: 1,
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
cohereApiKey: '${COHERE_API_KEY}',
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngApiKey: '${SEARXNG_API_KEY}',
|
||||
firecrawlApiKey: '${FIRECRAWL_API_KEY}',
|
||||
firecrawlApiUrl: '${FIRECRAWL_API_URL}',
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
safeSearch: 1,
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}',
|
||||
},
|
||||
memory: undefined,
|
||||
agents: {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
const { klona } = require('klona');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
StepTypes,
|
||||
RunStatus,
|
||||
@@ -11,11 +14,10 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||
const { processRequiredActions } = require('~/server/services/ToolService');
|
||||
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
|
||||
const { RunManager, waitForRun } = require('~/server/services/Runs');
|
||||
const { processMessages } = require('~/server/services/Threads');
|
||||
const { createOnProgress } = require('~/server/utils');
|
||||
const { TextStream } = require('~/app/clients');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Sorts, processes, and flattens messages to a single string.
|
||||
@@ -64,7 +66,7 @@ async function createOnTextProgress({
|
||||
};
|
||||
|
||||
logger.debug('Content data:', contentData);
|
||||
sendMessage(openai.res, contentData);
|
||||
sendEvent(openai.res, contentData);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const bcrypt = require('bcryptjs');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
@@ -499,6 +500,18 @@ const resendVerificationEmail = async (req) => {
|
||||
};
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Generate a short-lived JWT token
|
||||
* @param {String} userId - The ID of the user
|
||||
* @param {String} [expireIn='5m'] - The expiration time for the token (default is 5 minutes)
|
||||
* @returns {String} - The generated JWT token
|
||||
*/
|
||||
const generateShortLivedToken = (userId, expireIn = '5m') => {
|
||||
return jwt.sign({ id: userId }, process.env.JWT_SECRET, {
|
||||
expiresIn: expireIn,
|
||||
algorithm: 'HS256',
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
logoutUser,
|
||||
@@ -506,7 +519,8 @@ module.exports = {
|
||||
registerUser,
|
||||
setAuthTokens,
|
||||
resetPassword,
|
||||
setOpenIDAuthTokens,
|
||||
requestPasswordReset,
|
||||
resendVerificationEmail,
|
||||
setOpenIDAuthTokens,
|
||||
generateShortLivedToken,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const { isUserProvided } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { isUserProvided, generateConfig } = require('~/server/utils');
|
||||
const { generateConfig } = require('~/server/utils/handleText');
|
||||
|
||||
const {
|
||||
OPENAI_API_KEY: openAIApiKey,
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
|
||||
const {
|
||||
CacheKeys,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
orderEndpointsConfig,
|
||||
defaultAgentCapabilities,
|
||||
} = require('librechat-data-provider');
|
||||
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
|
||||
const loadConfigEndpoints = require('./loadConfigEndpoints');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
@@ -80,8 +86,12 @@ async function getEndpointsConfig(req) {
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint);
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
const capabilities =
|
||||
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
|
||||
? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [])
|
||||
: defaultAgentCapabilities;
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const path = require('path');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { loadServiceKey, isUserProvided } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { config } = require('./EndpointService');
|
||||
|
||||
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config;
|
||||
@@ -9,37 +11,41 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
|
||||
* @param {Express.Request} req - The request object
|
||||
*/
|
||||
async function loadAsyncEndpoints(req) {
|
||||
let i = 0;
|
||||
let serviceKey, googleUserProvides;
|
||||
try {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
} catch (e) {
|
||||
if (i === 0) {
|
||||
i++;
|
||||
|
||||
/** Check if GOOGLE_KEY is provided at all(including 'user_provided') */
|
||||
const isGoogleKeyProvided = googleKey && googleKey.trim() !== '';
|
||||
|
||||
if (isGoogleKeyProvided) {
|
||||
/** If GOOGLE_KEY is provided, check if it's user_provided */
|
||||
googleUserProvides = isUserProvided(googleKey);
|
||||
} else {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json');
|
||||
|
||||
try {
|
||||
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||
} catch (error) {
|
||||
logger.error('Error loading service key', error);
|
||||
serviceKey = null;
|
||||
}
|
||||
}
|
||||
|
||||
if (isUserProvided(googleKey)) {
|
||||
googleUserProvides = true;
|
||||
if (i <= 1) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
|
||||
const google = serviceKey || isGoogleKeyProvided ? { userProvide: googleUserProvides } : false;
|
||||
|
||||
const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
|
||||
const gptPlugins =
|
||||
useAzure || openAIApiKey || azureOpenAIApiKey
|
||||
? {
|
||||
availableAgents: ['classic', 'functions'],
|
||||
userProvide: useAzure ? false : userProvidedOpenAI,
|
||||
userProvideURL: useAzure
|
||||
? false
|
||||
: config[EModelEndpoint.openAI]?.userProvideURL ||
|
||||
availableAgents: ['classic', 'functions'],
|
||||
userProvide: useAzure ? false : userProvidedOpenAI,
|
||||
userProvideURL: useAzure
|
||||
? false
|
||||
: config[EModelEndpoint.openAI]?.userProvideURL ||
|
||||
config[EModelEndpoint.azureOpenAI]?.userProvideURL,
|
||||
azure: useAzurePlugins || useAzure,
|
||||
}
|
||||
azure: useAzurePlugins || useAzure,
|
||||
}
|
||||
: false;
|
||||
|
||||
return { google, gptPlugins };
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
const path = require('path');
|
||||
const {
|
||||
CacheKeys,
|
||||
configSchema,
|
||||
EImageOutputType,
|
||||
validateSettingDefinitions,
|
||||
agentParamSettings,
|
||||
paramSettings,
|
||||
} = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const loadYaml = require('~/utils/loadYaml');
|
||||
const { logger } = require('~/config');
|
||||
const axios = require('axios');
|
||||
const yaml = require('js-yaml');
|
||||
const keyBy = require('lodash/keyBy');
|
||||
const { loadYaml } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
CacheKeys,
|
||||
configSchema,
|
||||
paramSettings,
|
||||
EImageOutputType,
|
||||
agentParamSettings,
|
||||
validateSettingDefinitions,
|
||||
} = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
const projectRoot = path.resolve(__dirname, '..', '..', '..', '..');
|
||||
const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml');
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
jest.mock('axios');
|
||||
jest.mock('~/cache/getLogStores');
|
||||
jest.mock('~/utils/loadYaml');
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
loadYaml: jest.fn(),
|
||||
}));
|
||||
jest.mock('librechat-data-provider', () => {
|
||||
const actual = jest.requireActual('librechat-data-provider');
|
||||
return {
|
||||
@@ -30,11 +33,22 @@ jest.mock('librechat-data-provider', () => {
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
return {
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const axios = require('axios');
|
||||
const { loadYaml } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const loadYaml = require('~/utils/loadYaml');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
describe('loadCustomConfig', () => {
|
||||
const mockSet = jest.fn();
|
||||
|
||||
@@ -85,7 +85,7 @@ const initializeAgent = async ({
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } =
|
||||
const { tools: structuredTools, toolContextMap } =
|
||||
(await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
@@ -140,6 +140,24 @@ const initializeAgent = async ({
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').GenericTool[]} */
|
||||
let tools = options.tools?.length ? options.tools : structuredTools;
|
||||
if (
|
||||
(agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) &&
|
||||
options.tools?.length &&
|
||||
structuredTools?.length
|
||||
) {
|
||||
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
|
||||
} else if (
|
||||
(agent.provider === Providers.OPENAI ||
|
||||
agent.provider === Providers.AZURE ||
|
||||
agent.provider === Providers.ANTHROPIC) &&
|
||||
options.tools?.length &&
|
||||
structuredTools?.length
|
||||
) {
|
||||
tools = structuredTools.concat(options.tools);
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = { ...options.llmConfig };
|
||||
if (options.configOptions) {
|
||||
@@ -162,10 +180,10 @@ const initializeAgent = async ({
|
||||
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
resendFiles,
|
||||
toolContextMap,
|
||||
tools,
|
||||
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -78,7 +78,17 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
const tools = [];
|
||||
|
||||
if (mergedOptions.web_search) {
|
||||
tools.push({
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
tools,
|
||||
/** @type {AnthropicClientOptions} */
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
};
|
||||
|
||||
93
api/server/services/Endpoints/custom/initialize.spec.js
Normal file
93
api/server/services/Endpoints/custom/initialize.spec.js
Normal file
@@ -0,0 +1,93 @@
|
||||
const initializeClient = require('./initialize');
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
resolveHeaders: jest.fn(),
|
||||
getOpenAIConfig: jest.fn(),
|
||||
createHandleLLMNewToken: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
CacheKeys: { TOKEN_CONFIG: 'token_config' },
|
||||
ErrorTypes: { NO_USER_KEY: 'NO_USER_KEY', NO_BASE_URL: 'NO_BASE_URL' },
|
||||
envVarRegex: /\$\{([^}]+)\}/,
|
||||
FetchTokenConfig: {},
|
||||
extractEnvVariable: jest.fn((value) => value),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
Providers: { OLLAMA: 'ollama' },
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/UserService', () => ({
|
||||
getUserKeyValues: jest.fn(),
|
||||
checkUserKeyExpiry: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomEndpointConfig: jest.fn().mockResolvedValue({
|
||||
apiKey: 'test-key',
|
||||
baseURL: 'https://test.com',
|
||||
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
models: { default: ['test-model'] },
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ModelService', () => ({
|
||||
fetchModels: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/app/clients/OpenAIClient', () => {
|
||||
return jest.fn().mockImplementation(() => ({
|
||||
options: {},
|
||||
}));
|
||||
});
|
||||
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isUserProvided: jest.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache/getLogStores', () =>
|
||||
jest.fn().mockReturnValue({
|
||||
get: jest.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
describe('custom/initializeClient', () => {
|
||||
const mockRequest = {
|
||||
body: { endpoint: 'test-endpoint' },
|
||||
user: { id: 'user-123', email: 'test@example.com' },
|
||||
app: { locals: {} },
|
||||
};
|
||||
const mockResponse = {};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('calls resolveHeaders with headers and user', async () => {
|
||||
const { resolveHeaders } = require('@librechat/api');
|
||||
await initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true });
|
||||
expect(resolveHeaders).toHaveBeenCalledWith(
|
||||
{ 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
{ id: 'user-123', email: 'test@example.com' },
|
||||
);
|
||||
});
|
||||
|
||||
it('throws if endpoint config is missing', async () => {
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
getCustomEndpointConfig.mockResolvedValueOnce(null);
|
||||
await expect(
|
||||
initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true }),
|
||||
).rejects.toThrow('Config not found for the test-endpoint custom endpoint.');
|
||||
});
|
||||
|
||||
it('throws if user is missing', async () => {
|
||||
await expect(
|
||||
initializeClient({
|
||||
req: { ...mockRequest, user: undefined },
|
||||
res: mockResponse,
|
||||
optionsOnly: true,
|
||||
}),
|
||||
).rejects.toThrow("Cannot read properties of undefined (reading 'id')");
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
const path = require('path');
|
||||
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
|
||||
const { getGoogleConfig, isEnabled, loadServiceKey } = require('@librechat/api');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/google/llm');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { GoogleClient } = require('~/app');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
|
||||
@@ -16,10 +16,25 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
}
|
||||
|
||||
let serviceKey = {};
|
||||
try {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
} catch (_e) {
|
||||
// Do nothing
|
||||
|
||||
/** Check if GOOGLE_KEY is provided at all (including 'user_provided') */
|
||||
const isGoogleKeyProvided =
|
||||
(GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null);
|
||||
|
||||
if (!isGoogleKeyProvided) {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
try {
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE ||
|
||||
path.join(__dirname, '../../../..', 'data', 'auth.json');
|
||||
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||
if (!serviceKey) {
|
||||
serviceKey = {};
|
||||
}
|
||||
} catch (_e) {
|
||||
// Service key loading failed, but that's okay if not required
|
||||
serviceKey = {};
|
||||
}
|
||||
}
|
||||
|
||||
const credentials = isUserProvided
|
||||
@@ -65,7 +80,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
if (overrideModel) {
|
||||
clientOptions.modelOptions.model = overrideModel;
|
||||
}
|
||||
return getLLMConfig(credentials, clientOptions);
|
||||
return getGoogleConfig(credentials, clientOptions);
|
||||
}
|
||||
|
||||
const client = new GoogleClient(credentials, clientOptions);
|
||||
|
||||
@@ -7,6 +7,16 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
|
||||
/** Check if the provider is a known custom provider
|
||||
* @param {string | undefined} [provider] - The provider string
|
||||
* @returns {boolean} - True if the provider is a known custom provider, false otherwise
|
||||
*/
|
||||
function isKnownCustomProvider(provider) {
|
||||
return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes(
|
||||
provider?.toLowerCase() || '',
|
||||
);
|
||||
}
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
@@ -46,6 +56,13 @@ async function getProviderConfig(provider) {
|
||||
overrideProvider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (isKnownCustomProvider(overrideProvider || provider) && !customEndpointConfig) {
|
||||
customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
getOptions,
|
||||
overrideProvider,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const axios = require('axios');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { getBufferMetadata } = require('~/server/utils');
|
||||
const paths = require('~/config/paths');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Saves a file to a specified output path with a new filename.
|
||||
@@ -206,7 +207,7 @@ const deleteLocalFile = async (req, file) => {
|
||||
const cleanFilepath = file.filepath.split('?')[0];
|
||||
|
||||
if (file.embedded && process.env.RAG_API_URL) {
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
axios.delete(`${process.env.RAG_API_URL}/documents`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${jwtToken}`,
|
||||
|
||||
@@ -4,6 +4,7 @@ const FormData = require('form-data');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
|
||||
/**
|
||||
* Deletes a file from the vector database. This function takes a file object, constructs the full path, and
|
||||
@@ -23,7 +24,8 @@ const deleteVectors = async (req, file) => {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
|
||||
return await axios.delete(`${process.env.RAG_API_URL}/documents`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${jwtToken}`,
|
||||
@@ -70,7 +72,7 @@ async function uploadVectors({ req, file, file_id, entity_id }) {
|
||||
}
|
||||
|
||||
try {
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
const formData = new FormData();
|
||||
formData.append('file_id', file_id);
|
||||
formData.append('file', fs.createReadStream(file.path));
|
||||
|
||||
@@ -55,7 +55,9 @@ const processFiles = async (files, fileIds) => {
|
||||
}
|
||||
|
||||
if (!fileIds) {
|
||||
return await Promise.all(promises);
|
||||
const results = await Promise.all(promises);
|
||||
// Filter out null results from failed updateFileUsage calls
|
||||
return results.filter((result) => result != null);
|
||||
}
|
||||
|
||||
for (let file_id of fileIds) {
|
||||
@@ -67,7 +69,9 @@ const processFiles = async (files, fileIds) => {
|
||||
}
|
||||
|
||||
// TODO: calculate token cost when image is first uploaded
|
||||
return await Promise.all(promises);
|
||||
const results = await Promise.all(promises);
|
||||
// Filter out null results from failed updateFileUsage calls
|
||||
return results.filter((result) => result != null);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
208
api/server/services/Files/processFiles.test.js
Normal file
208
api/server/services/Files/processFiles.test.js
Normal file
@@ -0,0 +1,208 @@
|
||||
// Mock the updateFileUsage function before importing the actual processFiles
|
||||
jest.mock('~/models/File', () => ({
|
||||
updateFileUsage: jest.fn(),
|
||||
}));
|
||||
|
||||
// Mock winston and logger configuration to avoid dependency issues
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock all other dependencies that might cause issues
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
isUUID: { parse: jest.fn() },
|
||||
megabyte: 1024 * 1024,
|
||||
FileContext: { message_attachment: 'message_attachment' },
|
||||
FileSources: { local: 'local' },
|
||||
EModelEndpoint: { assistants: 'assistants' },
|
||||
EToolResources: { file_search: 'file_search' },
|
||||
mergeFileConfig: jest.fn(),
|
||||
removeNullishValues: jest.fn((obj) => obj),
|
||||
isAssistantsEndpoint: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
convertImage: jest.fn(),
|
||||
resizeAndConvert: jest.fn(),
|
||||
resizeImageBuffer: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/assistants/v2', () => ({
|
||||
addResourceFileId: jest.fn(),
|
||||
deleteResourceFileId: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Agent', () => ({
|
||||
addAgentResourceFile: jest.fn(),
|
||||
removeAgentResourceFiles: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/assistants/helpers', () => ({
|
||||
getOpenAIClient: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
checkCapability: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils/queue', () => ({
|
||||
LB_QueueAsyncCall: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('./strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/utils', () => ({
|
||||
determineFileType: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import the actual processFiles function after all mocks are set up
|
||||
const { processFiles } = require('./process');
|
||||
const { updateFileUsage } = require('~/models/File');
|
||||
|
||||
describe('processFiles', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('null filtering functionality', () => {
|
||||
it('should filter out null results from updateFileUsage when files do not exist', async () => {
|
||||
const mockFiles = [
|
||||
{ file_id: 'existing-file-1' },
|
||||
{ file_id: 'non-existent-file' },
|
||||
{ file_id: 'existing-file-2' },
|
||||
];
|
||||
|
||||
// Mock updateFileUsage to return null for non-existent files
|
||||
updateFileUsage.mockImplementation(({ file_id }) => {
|
||||
if (file_id === 'non-existent-file') {
|
||||
return Promise.resolve(null); // Simulate file not found in the database
|
||||
}
|
||||
return Promise.resolve({ file_id, usage: 1 });
|
||||
});
|
||||
|
||||
const result = await processFiles(mockFiles);
|
||||
|
||||
expect(updateFileUsage).toHaveBeenCalledTimes(3);
|
||||
expect(result).toEqual([
|
||||
{ file_id: 'existing-file-1', usage: 1 },
|
||||
{ file_id: 'existing-file-2', usage: 1 },
|
||||
]);
|
||||
|
||||
// Critical test - ensure no null values in result
|
||||
expect(result).not.toContain(null);
|
||||
expect(result).not.toContain(undefined);
|
||||
expect(result.length).toBe(2); // Only valid files should be returned
|
||||
});
|
||||
|
||||
it('should return empty array when all updateFileUsage calls return null', async () => {
|
||||
const mockFiles = [{ file_id: 'non-existent-1' }, { file_id: 'non-existent-2' }];
|
||||
|
||||
// All updateFileUsage calls return null
|
||||
updateFileUsage.mockResolvedValue(null);
|
||||
|
||||
const result = await processFiles(mockFiles);
|
||||
|
||||
expect(updateFileUsage).toHaveBeenCalledTimes(2);
|
||||
expect(result).toEqual([]);
|
||||
expect(result).not.toContain(null);
|
||||
expect(result.length).toBe(0);
|
||||
});
|
||||
|
||||
it('should work correctly when all files exist', async () => {
|
||||
const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }];
|
||||
|
||||
updateFileUsage.mockImplementation(({ file_id }) => {
|
||||
return Promise.resolve({ file_id, usage: 1 });
|
||||
});
|
||||
|
||||
const result = await processFiles(mockFiles);
|
||||
|
||||
expect(result).toEqual([
|
||||
{ file_id: 'file-1', usage: 1 },
|
||||
{ file_id: 'file-2', usage: 1 },
|
||||
]);
|
||||
expect(result).not.toContain(null);
|
||||
expect(result.length).toBe(2);
|
||||
});
|
||||
|
||||
it('should handle fileIds parameter and filter nulls correctly', async () => {
|
||||
const mockFiles = [{ file_id: 'file-1' }];
|
||||
const mockFileIds = ['file-2', 'non-existent-file'];
|
||||
|
||||
updateFileUsage.mockImplementation(({ file_id }) => {
|
||||
if (file_id === 'non-existent-file') {
|
||||
return Promise.resolve(null);
|
||||
}
|
||||
return Promise.resolve({ file_id, usage: 1 });
|
||||
});
|
||||
|
||||
const result = await processFiles(mockFiles, mockFileIds);
|
||||
|
||||
expect(result).toEqual([
|
||||
{ file_id: 'file-1', usage: 1 },
|
||||
{ file_id: 'file-2', usage: 1 },
|
||||
]);
|
||||
expect(result).not.toContain(null);
|
||||
expect(result).not.toContain(undefined);
|
||||
expect(result.length).toBe(2);
|
||||
});
|
||||
|
||||
it('should handle duplicate file_ids correctly', async () => {
|
||||
const mockFiles = [
|
||||
{ file_id: 'duplicate-file' },
|
||||
{ file_id: 'duplicate-file' }, // Duplicate should be ignored
|
||||
{ file_id: 'unique-file' },
|
||||
];
|
||||
|
||||
updateFileUsage.mockImplementation(({ file_id }) => {
|
||||
return Promise.resolve({ file_id, usage: 1 });
|
||||
});
|
||||
|
||||
const result = await processFiles(mockFiles);
|
||||
|
||||
// Should only call updateFileUsage twice (duplicate ignored)
|
||||
expect(updateFileUsage).toHaveBeenCalledTimes(2);
|
||||
expect(result).toEqual([
|
||||
{ file_id: 'duplicate-file', usage: 1 },
|
||||
{ file_id: 'unique-file', usage: 1 },
|
||||
]);
|
||||
expect(result.length).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty files array', async () => {
|
||||
const result = await processFiles([]);
|
||||
expect(result).toEqual([]);
|
||||
expect(updateFileUsage).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mixed null and undefined returns from updateFileUsage', async () => {
|
||||
const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }, { file_id: 'file-3' }];
|
||||
|
||||
updateFileUsage.mockImplementation(({ file_id }) => {
|
||||
if (file_id === 'file-1') return Promise.resolve(null);
|
||||
if (file_id === 'file-2') return Promise.resolve(undefined);
|
||||
return Promise.resolve({ file_id, usage: 1 });
|
||||
});
|
||||
|
||||
const result = await processFiles(mockFiles);
|
||||
|
||||
expect(result).toEqual([{ file_id: 'file-3', usage: 1 }]);
|
||||
expect(result).not.toContain(null);
|
||||
expect(result).not.toContain(undefined);
|
||||
expect(result.length).toBe(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,9 @@
|
||||
const { FileSources } = require('librechat-data-provider');
|
||||
const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api');
|
||||
const {
|
||||
uploadMistralOCR,
|
||||
uploadAzureMistralOCR,
|
||||
uploadGoogleVertexMistralOCR,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
getFirebaseURL,
|
||||
prepareImageURL,
|
||||
@@ -222,6 +226,26 @@ const azureMistralOCRStrategy = () => ({
|
||||
handleFileUpload: uploadAzureMistralOCR,
|
||||
});
|
||||
|
||||
const vertexMistralOCRStrategy = () => ({
|
||||
/** @type {typeof saveFileFromURL | null} */
|
||||
saveURL: null,
|
||||
/** @type {typeof getLocalFileURL | null} */
|
||||
getFileURL: null,
|
||||
/** @type {typeof saveLocalBuffer | null} */
|
||||
saveBuffer: null,
|
||||
/** @type {typeof processLocalAvatar | null} */
|
||||
processAvatar: null,
|
||||
/** @type {typeof uploadLocalImage | null} */
|
||||
handleImageUpload: null,
|
||||
/** @type {typeof prepareImagesLocal | null} */
|
||||
prepareImagePayload: null,
|
||||
/** @type {typeof deleteLocalFile | null} */
|
||||
deleteFile: null,
|
||||
/** @type {typeof getLocalFileStream | null} */
|
||||
getDownloadStream: null,
|
||||
handleFileUpload: uploadGoogleVertexMistralOCR,
|
||||
});
|
||||
|
||||
// Strategy Selector
|
||||
const getStrategyFunctions = (fileSource) => {
|
||||
if (fileSource === FileSources.firebase) {
|
||||
@@ -244,6 +268,8 @@ const getStrategyFunctions = (fileSource) => {
|
||||
return mistralOCRStrategy();
|
||||
} else if (fileSource === FileSources.azure_mistral_ocr) {
|
||||
return azureMistralOCRStrategy();
|
||||
} else if (fileSource === FileSources.vertexai_mistral_ocr) {
|
||||
return vertexMistralOCRStrategy();
|
||||
} else {
|
||||
throw new Error('Invalid file source');
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Constants,
|
||||
StepTypes,
|
||||
@@ -8,9 +11,8 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||
const { processRequiredActions } = require('~/server/services/ToolService');
|
||||
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
|
||||
const { processMessages } = require('~/server/services/Threads');
|
||||
const { logger } = require('~/config');
|
||||
const { createOnProgress } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Implements the StreamRunManager functionality for managing the streaming
|
||||
@@ -126,7 +128,7 @@ class StreamRunManager {
|
||||
conversationId: this.finalMessage.conversationId,
|
||||
};
|
||||
|
||||
sendMessage(this.res, contentData);
|
||||
sendEvent(this.res, contentData);
|
||||
}
|
||||
|
||||
/* <------------------ Misc. Helpers ------------------> */
|
||||
@@ -302,7 +304,7 @@ class StreamRunManager {
|
||||
|
||||
for (const d of delta[key]) {
|
||||
if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) {
|
||||
logger.warn('Expected an object with an \'index\' for array updates but got:', d);
|
||||
logger.warn("Expected an object with an 'index' for array updates but got:", d);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ const {
|
||||
defaultAssistantsVersion,
|
||||
defaultAgentCapabilities,
|
||||
} = require('librechat-data-provider');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const partialRight = require('lodash/partialRight');
|
||||
const { sendMessage } = require('./streamResponse');
|
||||
|
||||
/** Helper function to escape special characters in regex
|
||||
* @param {string} string - The string to escape.
|
||||
@@ -37,7 +37,7 @@ const createOnProgress = (
|
||||
basePayload.text = basePayload.text + chunk;
|
||||
|
||||
const payload = Object.assign({}, basePayload, rest);
|
||||
sendMessage(res, payload);
|
||||
sendEvent(res, payload);
|
||||
if (_onProgress) {
|
||||
_onProgress(payload);
|
||||
}
|
||||
@@ -50,7 +50,7 @@ const createOnProgress = (
|
||||
const sendIntermediateMessage = (res, payload, extraTokens = '') => {
|
||||
basePayload.text = basePayload.text + extraTokens;
|
||||
const message = Object.assign({}, basePayload, payload);
|
||||
sendMessage(res, message);
|
||||
sendEvent(res, message);
|
||||
if (i === 0) {
|
||||
basePayload.initial = false;
|
||||
}
|
||||
|
||||
280
api/server/utils/import/importers-timestamp.spec.js
Normal file
280
api/server/utils/import/importers-timestamp.spec.js
Normal file
@@ -0,0 +1,280 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { ImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { getImporter } = require('./importers');
|
||||
|
||||
// Mock the database methods
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
bulkSaveConvos: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Message', () => ({
|
||||
bulkSaveMessages: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const mockedCacheGet = jest.fn();
|
||||
getLogStores.mockImplementation(() => ({
|
||||
get: mockedCacheGet,
|
||||
}));
|
||||
|
||||
describe('Import Timestamp Ordering', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockedCacheGet.mockResolvedValue(null);
|
||||
});
|
||||
|
||||
describe('LibreChat Import - Timestamp Issues', () => {
|
||||
test('should maintain proper timestamp order between parent and child messages', async () => {
|
||||
// Create a LibreChat export with out-of-order timestamps
|
||||
const jsonData = {
|
||||
conversationId: 'test-convo-123',
|
||||
title: 'Test Conversation',
|
||||
messages: [
|
||||
{
|
||||
messageId: 'parent-1',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Parent Message',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child
|
||||
},
|
||||
{
|
||||
messageId: 'child-1',
|
||||
parentMessageId: 'parent-1',
|
||||
text: 'Child Message',
|
||||
sender: 'assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent
|
||||
},
|
||||
{
|
||||
messageId: 'grandchild-1',
|
||||
parentMessageId: 'child-1',
|
||||
text: 'Grandchild Message',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:00:30Z', // Even earlier
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Check the actual messages stored in the builder
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent Message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child Message');
|
||||
const grandchild = savedMessages.find((msg) => msg.text === 'Grandchild Message');
|
||||
|
||||
// Verify all messages were found
|
||||
expect(parent).toBeDefined();
|
||||
expect(child).toBeDefined();
|
||||
expect(grandchild).toBeDefined();
|
||||
|
||||
// FIXED behavior: timestamps ARE corrected
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
expect(new Date(grandchild.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(child.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle complex multi-branch scenario with out-of-order timestamps', async () => {
|
||||
const jsonData = {
|
||||
conversationId: 'complex-test-123',
|
||||
title: 'Complex Test',
|
||||
messages: [
|
||||
// Branch 1: Root -> A -> B with reversed timestamps
|
||||
{
|
||||
messageId: 'root-1',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Root 1',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:03:00Z',
|
||||
},
|
||||
{
|
||||
messageId: 'a-1',
|
||||
parentMessageId: 'root-1',
|
||||
text: 'A1',
|
||||
sender: 'assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:02:00Z', // Before parent
|
||||
},
|
||||
{
|
||||
messageId: 'b-1',
|
||||
parentMessageId: 'a-1',
|
||||
text: 'B1',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:01:00Z', // Before grandparent
|
||||
},
|
||||
// Branch 2: Root -> C -> D with mixed timestamps
|
||||
{
|
||||
messageId: 'root-2',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Root 2',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:00:30Z', // Earlier than branch 1
|
||||
},
|
||||
{
|
||||
messageId: 'c-2',
|
||||
parentMessageId: 'root-2',
|
||||
text: 'C2',
|
||||
sender: 'assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:04:00Z', // Much later
|
||||
},
|
||||
{
|
||||
messageId: 'd-2',
|
||||
parentMessageId: 'c-2',
|
||||
text: 'D2',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:02:30Z', // Between root and parent
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
|
||||
// Verify that timestamps are preserved as-is (not corrected)
|
||||
const root1 = savedMessages.find((msg) => msg.text === 'Root 1');
|
||||
const a1 = savedMessages.find((msg) => msg.text === 'A1');
|
||||
const b1 = savedMessages.find((msg) => msg.text === 'B1');
|
||||
const root2 = savedMessages.find((msg) => msg.text === 'Root 2');
|
||||
const c2 = savedMessages.find((msg) => msg.text === 'C2');
|
||||
const d2 = savedMessages.find((msg) => msg.text === 'D2');
|
||||
|
||||
// Branch 1: timestamps should now be in correct order
|
||||
expect(new Date(a1.createdAt).getTime()).toBeGreaterThan(new Date(root1.createdAt).getTime());
|
||||
expect(new Date(b1.createdAt).getTime()).toBeGreaterThan(new Date(a1.createdAt).getTime());
|
||||
|
||||
// Branch 2: all timestamps should be properly ordered
|
||||
expect(new Date(c2.createdAt).getTime()).toBeGreaterThan(new Date(root2.createdAt).getTime());
|
||||
expect(new Date(d2.createdAt).getTime()).toBeGreaterThan(new Date(c2.createdAt).getTime());
|
||||
});
|
||||
|
||||
test('recursive format should NOW have timestamp protection', async () => {
|
||||
// Create a recursive LibreChat export with out-of-order timestamps
|
||||
const jsonData = {
|
||||
conversationId: 'recursive-test-123',
|
||||
title: 'Recursive Test',
|
||||
recursive: true,
|
||||
messages: [
|
||||
{
|
||||
messageId: 'parent-1',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Parent Message',
|
||||
sender: 'User',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child
|
||||
children: [
|
||||
{
|
||||
messageId: 'child-1',
|
||||
parentMessageId: 'parent-1',
|
||||
text: 'Child Message',
|
||||
sender: 'Assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent
|
||||
children: [
|
||||
{
|
||||
messageId: 'grandchild-1',
|
||||
parentMessageId: 'child-1',
|
||||
text: 'Grandchild Message',
|
||||
sender: 'User',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:00:30Z', // Even earlier
|
||||
children: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
|
||||
// Messages should be saved
|
||||
expect(savedMessages).toHaveLength(3);
|
||||
|
||||
// In recursive format, timestamps are NOT included in the saved messages
|
||||
// The saveMessage method doesn't receive createdAt for recursive imports
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent Message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child Message');
|
||||
const grandchild = savedMessages.find((msg) => msg.text === 'Grandchild Message');
|
||||
|
||||
expect(parent).toBeDefined();
|
||||
expect(child).toBeDefined();
|
||||
expect(grandchild).toBeDefined();
|
||||
|
||||
// Recursive imports NOW preserve and correct timestamps
|
||||
expect(parent.createdAt).toBeDefined();
|
||||
expect(child.createdAt).toBeDefined();
|
||||
expect(grandchild.createdAt).toBeDefined();
|
||||
|
||||
// Timestamps should be corrected to maintain proper order
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
expect(new Date(grandchild.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(child.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Comparison with Fork Functionality', () => {
|
||||
test('fork functionality correctly handles timestamp issues (for comparison)', async () => {
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
|
||||
const messagesToClone = [
|
||||
{
|
||||
messageId: 'parent',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Parent Message',
|
||||
createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child
|
||||
},
|
||||
{
|
||||
messageId: 'child',
|
||||
parentMessageId: 'parent',
|
||||
text: 'Child Message',
|
||||
createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent
|
||||
},
|
||||
];
|
||||
|
||||
const importBatchBuilder = new ImportBatchBuilder('user-123');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent Message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child Message');
|
||||
|
||||
// Fork functionality DOES correct the timestamps
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,7 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
@@ -107,67 +108,47 @@ async function importLibreChatConvo(
|
||||
|
||||
if (jsonData.recursive) {
|
||||
/**
|
||||
* Recursively traverse the messages tree and save each message to the database.
|
||||
* Flatten the recursive message tree into a flat array
|
||||
* @param {TMessage[]} messages
|
||||
* @param {string} parentMessageId
|
||||
* @param {TMessage[]} flatMessages
|
||||
*/
|
||||
const traverseMessages = async (messages, parentMessageId = null) => {
|
||||
const flattenMessages = (
|
||||
messages,
|
||||
parentMessageId = Constants.NO_PARENT,
|
||||
flatMessages = [],
|
||||
) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text && !message.content) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let savedMessage;
|
||||
if (message.sender?.toLowerCase() === 'user' || message.isCreatedByUser) {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
content: message.content,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
} else {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
content: message.content,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: options.model,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
}
|
||||
const flatMessage = {
|
||||
...message,
|
||||
parentMessageId: parentMessageId,
|
||||
children: undefined, // Remove children from flat structure
|
||||
};
|
||||
flatMessages.push(flatMessage);
|
||||
|
||||
if (!firstMessageDate && message.createdAt) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
|
||||
if (message.children && message.children.length > 0) {
|
||||
await traverseMessages(message.children, savedMessage.messageId);
|
||||
flattenMessages(message.children, message.messageId, flatMessages);
|
||||
}
|
||||
}
|
||||
return flatMessages;
|
||||
};
|
||||
|
||||
await traverseMessages(messagesToImport);
|
||||
const flatMessages = flattenMessages(messagesToImport);
|
||||
cloneMessagesWithTimestamps(flatMessages, importBatchBuilder);
|
||||
} else if (messagesToImport) {
|
||||
const idMapping = new Map();
|
||||
|
||||
cloneMessagesWithTimestamps(messagesToImport, importBatchBuilder);
|
||||
for (const message of messagesToImport) {
|
||||
if (!firstMessageDate && message.createdAt) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
const newMessageId = uuidv4();
|
||||
idMapping.set(message.messageId, newMessageId);
|
||||
|
||||
const clonedMessage = {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT
|
||||
? idMapping.get(message.parentMessageId) || Constants.NO_PARENT
|
||||
: Constants.NO_PARENT,
|
||||
};
|
||||
|
||||
importBatchBuilder.saveMessage(clonedMessage);
|
||||
}
|
||||
} else {
|
||||
throw new Error('Invalid LibreChat file format');
|
||||
|
||||
@@ -175,36 +175,60 @@ describe('importLibreChatConvo', () => {
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Create a map to track original message IDs to new UUIDs
|
||||
const idToUUIDMap = new Map();
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
const message = call[0];
|
||||
idToUUIDMap.set(message.originalMessageId, message.messageId);
|
||||
// Get the imported messages
|
||||
const messages = importBatchBuilder.messages;
|
||||
expect(messages.length).toBeGreaterThan(0);
|
||||
|
||||
// Build maps for verification
|
||||
const textToMessageMap = new Map();
|
||||
const messageIdToMessage = new Map();
|
||||
messages.forEach((msg) => {
|
||||
if (msg.text) {
|
||||
// For recursive imports, text might be very long, so just use the first 100 chars as key
|
||||
const textKey = msg.text.substring(0, 100);
|
||||
textToMessageMap.set(textKey, msg);
|
||||
}
|
||||
messageIdToMessage.set(msg.messageId, msg);
|
||||
});
|
||||
|
||||
const checkChildren = (children, parentId) => {
|
||||
children.forEach((child) => {
|
||||
const childUUID = idToUUIDMap.get(child.messageId);
|
||||
const expectedParentId = idToUUIDMap.get(parentId) ?? null;
|
||||
const messageCall = importBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === childUUID,
|
||||
);
|
||||
|
||||
const actualParentId = messageCall[0].parentMessageId;
|
||||
expect(actualParentId).toBe(expectedParentId);
|
||||
|
||||
if (child.children && child.children.length > 0) {
|
||||
checkChildren(child.children, child.messageId);
|
||||
// Count expected messages from the tree
|
||||
const countMessagesInTree = (nodes) => {
|
||||
let count = 0;
|
||||
nodes.forEach((node) => {
|
||||
if (node.text || node.content) {
|
||||
count++;
|
||||
}
|
||||
if (node.children && node.children.length > 0) {
|
||||
count += countMessagesInTree(node.children);
|
||||
}
|
||||
});
|
||||
return count;
|
||||
};
|
||||
|
||||
// Start hierarchy validation from root messages
|
||||
checkChildren(jsonData.messages, null);
|
||||
const expectedMessageCount = countMessagesInTree(jsonData.messages);
|
||||
expect(messages.length).toBe(expectedMessageCount);
|
||||
|
||||
// Verify all messages have valid parent relationships
|
||||
messages.forEach((msg) => {
|
||||
if (msg.parentMessageId !== Constants.NO_PARENT) {
|
||||
const parent = messageIdToMessage.get(msg.parentMessageId);
|
||||
expect(parent).toBeDefined();
|
||||
|
||||
// Verify timestamp ordering
|
||||
if (msg.createdAt && parent.createdAt) {
|
||||
expect(new Date(msg.createdAt).getTime()).toBeGreaterThanOrEqual(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Verify at least one root message exists
|
||||
const rootMessages = messages.filter((msg) => msg.parentMessageId === Constants.NO_PARENT);
|
||||
expect(rootMessages.length).toBeGreaterThan(0);
|
||||
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
const streamResponse = require('./streamResponse');
|
||||
const removePorts = require('./removePorts');
|
||||
const countTokens = require('./countTokens');
|
||||
const handleText = require('./handleText');
|
||||
const sendEmail = require('./sendEmail');
|
||||
const queue = require('./queue');
|
||||
const files = require('./files');
|
||||
const math = require('./math');
|
||||
|
||||
/**
|
||||
* Check if email configuration is set
|
||||
@@ -28,7 +26,6 @@ function checkEmailConfig() {
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
...streamResponse,
|
||||
checkEmailConfig,
|
||||
...handleText,
|
||||
countTokens,
|
||||
@@ -36,5 +33,4 @@ module.exports = {
|
||||
sendEmail,
|
||||
...files,
|
||||
...queue,
|
||||
math,
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { updateUser, findUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
@@ -13,17 +14,23 @@ const { isEnabled } = require('~/server/utils');
|
||||
* The strategy extracts the JWT from the Authorization header as a Bearer token.
|
||||
* The JWT is then verified using the signing key, and the user is retrieved from the database.
|
||||
*/
|
||||
const openIdJwtLogin = (openIdConfig) =>
|
||||
new JwtStrategy(
|
||||
const openIdJwtLogin = (openIdConfig) => {
|
||||
let jwksRsaOptions = {
|
||||
cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true,
|
||||
cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME
|
||||
? eval(process.env.OPENID_JWKS_URL_CACHE_TIME)
|
||||
: 60000,
|
||||
jwksUri: openIdConfig.serverMetadata().jwks_uri,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
jwksRsaOptions.requestAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return new JwtStrategy(
|
||||
{
|
||||
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
|
||||
secretOrKeyProvider: jwksRsa.passportJwtSecret({
|
||||
cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true,
|
||||
cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME
|
||||
? eval(process.env.OPENID_JWKS_URL_CACHE_TIME)
|
||||
: 60000,
|
||||
jwksUri: openIdConfig.serverMetadata().jwks_uri,
|
||||
}),
|
||||
secretOrKeyProvider: jwksRsa.passportJwtSecret(jwksRsaOptions),
|
||||
},
|
||||
async (payload, done) => {
|
||||
try {
|
||||
@@ -48,5 +55,6 @@ const openIdJwtLogin = (openIdConfig) =>
|
||||
}
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = openIdJwtLogin;
|
||||
|
||||
@@ -49,7 +49,7 @@ async function customFetch(url, options) {
|
||||
logger.info(`[openidStrategy] proxy agent configured: ${process.env.PROXY}`);
|
||||
fetchOptions = {
|
||||
...options,
|
||||
dispatcher: new HttpsProxyAgent(process.env.PROXY),
|
||||
dispatcher: new undici.ProxyAgent(process.env.PROXY),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ class CustomOpenIDStrategy extends OpenIDStrategy {
|
||||
*/
|
||||
const exchangeAccessTokenIfNeeded = async (config, accessToken, sub, fromCache = false) => {
|
||||
const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS);
|
||||
const onBehalfFlowRequired = isEnabled(process.env.OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED);
|
||||
const onBehalfFlowRequired = isEnabled(process.env.OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED);
|
||||
if (onBehalfFlowRequired) {
|
||||
if (fromCache) {
|
||||
const cachedToken = await tokensCache.get(sub);
|
||||
@@ -130,7 +130,7 @@ const exchangeAccessTokenIfNeeded = async (config, accessToken, sub, fromCache =
|
||||
config,
|
||||
'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
{
|
||||
scope: process.env.OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE || 'user.read',
|
||||
scope: process.env.OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE || 'user.read',
|
||||
assertion: accessToken,
|
||||
requested_token_use: 'on_behalf_of',
|
||||
},
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
const loadYaml = require('./loadYaml');
|
||||
const tokenHelpers = require('./tokens');
|
||||
const deriveBaseURL = require('./deriveBaseURL');
|
||||
const extractBaseURL = require('./extractBaseURL');
|
||||
const findMessageContent = require('./findMessageContent');
|
||||
|
||||
module.exports = {
|
||||
loadYaml,
|
||||
deriveBaseURL,
|
||||
extractBaseURL,
|
||||
...tokenHelpers,
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
const fs = require('fs');
|
||||
const yaml = require('js-yaml');
|
||||
|
||||
function loadYaml(filepath) {
|
||||
try {
|
||||
let fileContents = fs.readFileSync(filepath, 'utf8');
|
||||
return yaml.load(fileContents);
|
||||
} catch (e) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = loadYaml;
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "v0.7.8",
|
||||
"version": "v0.7.9-rc1",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
37
client/src/Providers/ActivePanelContext.tsx
Normal file
37
client/src/Providers/ActivePanelContext.tsx
Normal file
@@ -0,0 +1,37 @@
|
||||
import { createContext, useContext, useState, ReactNode } from 'react';
|
||||
|
||||
interface ActivePanelContextType {
|
||||
active: string | undefined;
|
||||
setActive: (id: string) => void;
|
||||
}
|
||||
|
||||
const ActivePanelContext = createContext<ActivePanelContextType | undefined>(undefined);
|
||||
|
||||
export function ActivePanelProvider({
|
||||
children,
|
||||
defaultActive,
|
||||
}: {
|
||||
children: ReactNode;
|
||||
defaultActive?: string;
|
||||
}) {
|
||||
const [active, _setActive] = useState<string | undefined>(defaultActive);
|
||||
|
||||
const setActive = (id: string) => {
|
||||
localStorage.setItem('side:active-panel', id);
|
||||
_setActive(id);
|
||||
};
|
||||
|
||||
return (
|
||||
<ActivePanelContext.Provider value={{ active, setActive }}>
|
||||
{children}
|
||||
</ActivePanelContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useActivePanel() {
|
||||
const context = useContext(ActivePanelContext);
|
||||
if (context === undefined) {
|
||||
throw new Error('useActivePanel must be used within an ActivePanelProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
import React, { createContext, useContext, useState } from 'react';
|
||||
import { Constants, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TPlugin, AgentToolType, Action, MCP } from 'librechat-data-provider';
|
||||
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
|
||||
import type { AgentPanelContextType } from '~/common';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { useLocalize, useGetAgentsConfig } from '~/hooks';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
|
||||
@@ -40,57 +40,60 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
agent_id: agent_id || '',
|
||||
})) || [];
|
||||
|
||||
const groupedTools =
|
||||
tools?.reduce(
|
||||
(acc, tool) => {
|
||||
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
|
||||
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
|
||||
const groupKey = `${serverName.toLowerCase()}`;
|
||||
if (!acc[groupKey]) {
|
||||
acc[groupKey] = {
|
||||
tool_id: groupKey,
|
||||
metadata: {
|
||||
name: `${serverName}`,
|
||||
pluginKey: groupKey,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: tool.metadata.icon || '',
|
||||
} as TPlugin,
|
||||
agent_id: agent_id || '',
|
||||
tools: [],
|
||||
};
|
||||
}
|
||||
acc[groupKey].tools?.push({
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
});
|
||||
} else {
|
||||
acc[tool.tool_id] = {
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
const groupedTools = tools?.reduce(
|
||||
(acc, tool) => {
|
||||
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
|
||||
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
|
||||
const groupKey = `${serverName.toLowerCase()}`;
|
||||
if (!acc[groupKey]) {
|
||||
acc[groupKey] = {
|
||||
tool_id: groupKey,
|
||||
metadata: {
|
||||
name: `${serverName}`,
|
||||
pluginKey: groupKey,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: tool.metadata.icon || '',
|
||||
} as TPlugin,
|
||||
agent_id: agent_id || '',
|
||||
tools: [],
|
||||
};
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, AgentToolType & { tools?: AgentToolType[] }>,
|
||||
) || {};
|
||||
acc[groupKey].tools?.push({
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
});
|
||||
} else {
|
||||
acc[tool.tool_id] = {
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
};
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, AgentToolType & { tools?: AgentToolType[] }>,
|
||||
);
|
||||
|
||||
const value = {
|
||||
action,
|
||||
setAction,
|
||||
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
|
||||
|
||||
const value: AgentPanelContextType = {
|
||||
mcp,
|
||||
setMcp,
|
||||
mcps,
|
||||
setMcps,
|
||||
activePanel,
|
||||
setActivePanel,
|
||||
setCurrentAgentId,
|
||||
agent_id,
|
||||
groupedTools,
|
||||
/** Query data for actions and tools */
|
||||
actions,
|
||||
tools,
|
||||
action,
|
||||
setMcp,
|
||||
actions,
|
||||
setMcps,
|
||||
agent_id,
|
||||
setAction,
|
||||
activePanel,
|
||||
groupedTools,
|
||||
agentsConfig,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
setCurrentAgentId,
|
||||
};
|
||||
|
||||
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;
|
||||
|
||||
@@ -1,14 +1,25 @@
|
||||
import React, { createContext, useContext } from 'react';
|
||||
import { Tools, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { useMCPSelect, useToolToggle, useCodeApiKeyForm, useSearchApiKeyForm } from '~/hooks';
|
||||
import React, { createContext, useContext, useEffect, useRef } from 'react';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { TAgentsEndpoint } from 'librechat-data-provider';
|
||||
import {
|
||||
useSearchApiKeyForm,
|
||||
useGetAgentsConfig,
|
||||
useCodeApiKeyForm,
|
||||
useToolToggle,
|
||||
useMCPSelect,
|
||||
} from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
|
||||
interface BadgeRowContextType {
|
||||
conversationId?: string | null;
|
||||
agentsConfig?: TAgentsEndpoint | null;
|
||||
mcpSelect: ReturnType<typeof useMCPSelect>;
|
||||
webSearch: ReturnType<typeof useToolToggle>;
|
||||
codeInterpreter: ReturnType<typeof useToolToggle>;
|
||||
artifacts: ReturnType<typeof useToolToggle>;
|
||||
fileSearch: ReturnType<typeof useToolToggle>;
|
||||
codeInterpreter: ReturnType<typeof useToolToggle>;
|
||||
codeApiKeyForm: ReturnType<typeof useCodeApiKeyForm>;
|
||||
searchApiKeyForm: ReturnType<typeof useSearchApiKeyForm>;
|
||||
startupConfig: ReturnType<typeof useGetStartupConfig>['data'];
|
||||
@@ -26,10 +37,88 @@ export function useBadgeRowContext() {
|
||||
|
||||
interface BadgeRowProviderProps {
|
||||
children: React.ReactNode;
|
||||
isSubmitting?: boolean;
|
||||
conversationId?: string | null;
|
||||
}
|
||||
|
||||
export default function BadgeRowProvider({ children, conversationId }: BadgeRowProviderProps) {
|
||||
export default function BadgeRowProvider({
|
||||
children,
|
||||
isSubmitting,
|
||||
conversationId,
|
||||
}: BadgeRowProviderProps) {
|
||||
const hasInitializedRef = useRef(false);
|
||||
const lastKeyRef = useRef<string>('');
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key));
|
||||
|
||||
/** Initialize ephemeralAgent from localStorage on mount and when conversation changes */
|
||||
useEffect(() => {
|
||||
if (isSubmitting) {
|
||||
return;
|
||||
}
|
||||
// Check if this is a new conversation or the first load
|
||||
if (!hasInitializedRef.current || lastKeyRef.current !== key) {
|
||||
hasInitializedRef.current = true;
|
||||
lastKeyRef.current = key;
|
||||
|
||||
// Load all localStorage values
|
||||
const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`;
|
||||
const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`;
|
||||
const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`;
|
||||
const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`;
|
||||
|
||||
const codeToggleValue = localStorage.getItem(codeToggleKey);
|
||||
const webSearchToggleValue = localStorage.getItem(webSearchToggleKey);
|
||||
const fileSearchToggleValue = localStorage.getItem(fileSearchToggleKey);
|
||||
const artifactsToggleValue = localStorage.getItem(artifactsToggleKey);
|
||||
|
||||
const initialValues: Record<string, any> = {};
|
||||
|
||||
if (codeToggleValue !== null) {
|
||||
try {
|
||||
initialValues[Tools.execute_code] = JSON.parse(codeToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse code toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
if (webSearchToggleValue !== null) {
|
||||
try {
|
||||
initialValues[Tools.web_search] = JSON.parse(webSearchToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse web search toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
if (fileSearchToggleValue !== null) {
|
||||
try {
|
||||
initialValues[Tools.file_search] = JSON.parse(fileSearchToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse file search toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
if (artifactsToggleValue !== null) {
|
||||
try {
|
||||
initialValues[AgentCapabilities.artifacts] = JSON.parse(artifactsToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse artifacts toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
// Always set values for all tools (use defaults if not in localStorage)
|
||||
// If ephemeralAgent is null, create a new object with just our tool values
|
||||
setEphemeralAgent((prev) => ({
|
||||
...(prev || {}),
|
||||
[Tools.execute_code]: initialValues[Tools.execute_code] ?? false,
|
||||
[Tools.web_search]: initialValues[Tools.web_search] ?? false,
|
||||
[Tools.file_search]: initialValues[Tools.file_search] ?? false,
|
||||
[AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false,
|
||||
}));
|
||||
}
|
||||
}, [key, isSubmitting, setEphemeralAgent]);
|
||||
|
||||
/** Startup config */
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
|
||||
@@ -74,10 +163,20 @@ export default function BadgeRowProvider({ children, conversationId }: BadgeRowP
|
||||
isAuthenticated: true,
|
||||
});
|
||||
|
||||
/** Artifacts hook - using a custom key since it's not a Tool but a capability */
|
||||
const artifacts = useToolToggle({
|
||||
conversationId,
|
||||
toolKey: AgentCapabilities.artifacts,
|
||||
localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_,
|
||||
isAuthenticated: true,
|
||||
});
|
||||
|
||||
const value: BadgeRowContextType = {
|
||||
mcpSelect,
|
||||
webSearch,
|
||||
artifacts,
|
||||
fileSearch,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
conversationId,
|
||||
codeApiKeyForm,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export { default as AssistantsProvider } from './AssistantsContext';
|
||||
export { default as AgentsProvider } from './AgentsContext';
|
||||
export { default as ToastProvider } from './ToastContext';
|
||||
export * from './ActivePanelContext';
|
||||
export * from './AgentPanelContext';
|
||||
export * from './ChatContext';
|
||||
export * from './ShareContext';
|
||||
|
||||
@@ -206,9 +206,7 @@ export type AgentPanelProps = {
|
||||
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
|
||||
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
|
||||
setAction: React.Dispatch<React.SetStateAction<t.Action | undefined>>;
|
||||
endpointsConfig?: t.TEndpointsConfig;
|
||||
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
agentsConfig?: t.TAgentsEndpoint | null;
|
||||
};
|
||||
|
||||
export type AgentPanelContextType = {
|
||||
@@ -219,12 +217,14 @@ export type AgentPanelContextType = {
|
||||
mcps?: t.MCP[];
|
||||
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
|
||||
setMcps: React.Dispatch<React.SetStateAction<t.MCP[] | undefined>>;
|
||||
groupedTools: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
|
||||
tools: t.AgentToolType[];
|
||||
activePanel?: string;
|
||||
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
|
||||
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
groupedTools?: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
|
||||
agent_id?: string;
|
||||
agentsConfig?: t.TAgentsEndpoint | null;
|
||||
endpointsConfig?: t.TEndpointsConfig | null;
|
||||
};
|
||||
|
||||
export type AgentModelPanelProps = {
|
||||
@@ -336,6 +336,11 @@ export type TAskProps = {
|
||||
export type TOptions = {
|
||||
editedMessageId?: string | null;
|
||||
editedText?: string | null;
|
||||
editedContent?: {
|
||||
index: number;
|
||||
text: string;
|
||||
type: 'text' | 'think';
|
||||
};
|
||||
isRegenerate?: boolean;
|
||||
isContinued?: boolean;
|
||||
isEdited?: boolean;
|
||||
|
||||
152
client/src/components/Chat/Input/Artifacts.tsx
Normal file
152
client/src/components/Chat/Input/Artifacts.tsx
Normal file
@@ -0,0 +1,152 @@
|
||||
import React, { memo, useState, useCallback, useMemo } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ArtifactModes } from 'librechat-data-provider';
|
||||
import { WandSparkles, ChevronDown } from 'lucide-react';
|
||||
import CheckboxButton from '~/components/ui/CheckboxButton';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface ArtifactsToggleState {
|
||||
enabled: boolean;
|
||||
mode: string;
|
||||
}
|
||||
|
||||
function Artifacts() {
|
||||
const localize = useLocalize();
|
||||
const { artifacts } = useBadgeRowContext();
|
||||
const { toggleState, debouncedChange, isPinned } = artifacts;
|
||||
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
|
||||
const currentState = useMemo<ArtifactsToggleState>(() => {
|
||||
if (typeof toggleState === 'string' && toggleState) {
|
||||
return { enabled: true, mode: toggleState };
|
||||
}
|
||||
return { enabled: false, mode: '' };
|
||||
}, [toggleState]);
|
||||
|
||||
const isEnabled = currentState.enabled;
|
||||
const isShadcnEnabled = currentState.mode === ArtifactModes.SHADCNUI;
|
||||
const isCustomEnabled = currentState.mode === ArtifactModes.CUSTOM;
|
||||
|
||||
const handleToggle = useCallback(() => {
|
||||
if (isEnabled) {
|
||||
debouncedChange({ value: '' });
|
||||
} else {
|
||||
debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
}
|
||||
}, [isEnabled, debouncedChange]);
|
||||
|
||||
const handleShadcnToggle = useCallback(() => {
|
||||
if (isShadcnEnabled) {
|
||||
debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
debouncedChange({ value: ArtifactModes.SHADCNUI });
|
||||
}
|
||||
}, [isShadcnEnabled, debouncedChange]);
|
||||
|
||||
const handleCustomToggle = useCallback(() => {
|
||||
if (isCustomEnabled) {
|
||||
debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
debouncedChange({ value: ArtifactModes.CUSTOM });
|
||||
}
|
||||
}, [isCustomEnabled, debouncedChange]);
|
||||
|
||||
if (!isEnabled && !isPinned) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<CheckboxButton
|
||||
className={cn('max-w-fit', isEnabled && 'rounded-r-none border-r-0')}
|
||||
checked={isEnabled}
|
||||
setValue={handleToggle}
|
||||
label={localize('com_ui_artifacts')}
|
||||
isCheckedClassName="border-amber-600/40 bg-amber-500/10 hover:bg-amber-700/10"
|
||||
icon={<WandSparkles className="icon-md" />}
|
||||
/>
|
||||
|
||||
{isEnabled && (
|
||||
<Ariakit.MenuProvider open={isPopoverOpen} setOpen={setIsPopoverOpen}>
|
||||
<Ariakit.MenuButton
|
||||
className={cn(
|
||||
'w-7 rounded-l-none rounded-r-full border-b border-l-0 border-r border-t border-border-light md:w-6',
|
||||
'border-amber-600/40 bg-amber-500/10 hover:bg-amber-700/10',
|
||||
'transition-colors',
|
||||
)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<ChevronDown className="ml-1 h-4 w-4 text-text-secondary md:ml-0" />
|
||||
</Ariakit.MenuButton>
|
||||
|
||||
<Ariakit.Menu
|
||||
gutter={8}
|
||||
className={cn(
|
||||
'animate-popover z-50 flex max-h-[300px]',
|
||||
'flex-col overflow-auto overscroll-contain rounded-xl',
|
||||
'bg-surface-secondary px-1.5 py-1 text-text-primary shadow-lg',
|
||||
'border border-border-light',
|
||||
'min-w-[250px] outline-none',
|
||||
)}
|
||||
portal
|
||||
>
|
||||
<div className="px-2 py-1.5">
|
||||
<div className="mb-2 text-xs font-medium text-text-secondary">
|
||||
{localize('com_ui_artifacts_options')}
|
||||
</div>
|
||||
|
||||
{/* Include shadcn/ui Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleShadcnToggle();
|
||||
}}
|
||||
disabled={isCustomEnabled}
|
||||
className={cn(
|
||||
'mb-1 flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
isCustomEnabled && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isShadcnEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_include_shadcnui' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{/* Custom Prompt Mode Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleCustomToggle();
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isCustomEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_custom_prompt_mode' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(Artifacts);
|
||||
147
client/src/components/Chat/Input/ArtifactsSubMenu.tsx
Normal file
147
client/src/components/Chat/Input/ArtifactsSubMenu.tsx
Normal file
@@ -0,0 +1,147 @@
|
||||
import React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronRight, WandSparkles } from 'lucide-react';
|
||||
import { ArtifactModes } from 'librechat-data-provider';
|
||||
import { PinIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface ArtifactsSubMenuProps {
|
||||
isArtifactsPinned: boolean;
|
||||
setIsArtifactsPinned: (value: boolean) => void;
|
||||
artifactsMode: string;
|
||||
handleArtifactsToggle: () => void;
|
||||
handleShadcnToggle: () => void;
|
||||
handleCustomToggle: () => void;
|
||||
}
|
||||
|
||||
const ArtifactsSubMenu = ({
|
||||
isArtifactsPinned,
|
||||
setIsArtifactsPinned,
|
||||
artifactsMode,
|
||||
handleArtifactsToggle,
|
||||
handleShadcnToggle,
|
||||
handleCustomToggle,
|
||||
...props
|
||||
}: ArtifactsSubMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
focusLoop: true,
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
|
||||
const isEnabled = artifactsMode !== '' && artifactsMode !== undefined;
|
||||
const isShadcnEnabled = artifactsMode === ArtifactModes.SHADCNUI;
|
||||
const isCustomEnabled = artifactsMode === ArtifactModes.CUSTOM;
|
||||
|
||||
return (
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
hideOnClick={false}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
handleArtifactsToggle();
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
if (isEnabled) {
|
||||
menuStore.show();
|
||||
}
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<WandSparkles className="icon-md" />
|
||||
<span>{localize('com_ui_artifacts')}</span>
|
||||
{isEnabled && <ChevronRight className="ml-auto h-3 w-3" />}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsArtifactsPinned(!isArtifactsPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isArtifactsPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isArtifactsPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isArtifactsPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{isEnabled && (
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[250px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary px-1.5 py-1 shadow-lg',
|
||||
)}
|
||||
>
|
||||
<div className="px-2 py-1.5">
|
||||
<div className="mb-2 text-xs font-medium text-text-secondary">
|
||||
{localize('com_ui_artifacts_options')}
|
||||
</div>
|
||||
|
||||
{/* Include shadcn/ui Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleShadcnToggle();
|
||||
}}
|
||||
disabled={isCustomEnabled}
|
||||
className={cn(
|
||||
'mb-1 flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
isCustomEnabled && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isShadcnEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_include_shadcnui' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{/* Custom Prompt Mode Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleCustomToggle();
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isCustomEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_custom_prompt_mode' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
)}
|
||||
</Ariakit.MenuProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(ArtifactsSubMenu);
|
||||
@@ -18,6 +18,7 @@ import { useChatBadges } from '~/hooks';
|
||||
import { Badge } from '~/components/ui';
|
||||
import ToolDialogs from './ToolDialogs';
|
||||
import FileSearch from './FileSearch';
|
||||
import Artifacts from './Artifacts';
|
||||
import MCPSelect from './MCPSelect';
|
||||
import WebSearch from './WebSearch';
|
||||
import store from '~/store';
|
||||
@@ -27,6 +28,7 @@ interface BadgeRowProps {
|
||||
onChange: (badges: Pick<BadgeItem, 'id'>[]) => void;
|
||||
onToggle?: (badgeId: string, currentActive: boolean) => void;
|
||||
conversationId?: string | null;
|
||||
isSubmitting?: boolean;
|
||||
isInChat: boolean;
|
||||
}
|
||||
|
||||
@@ -140,6 +142,7 @@ const dragReducer = (state: DragState, action: DragAction): DragState => {
|
||||
function BadgeRow({
|
||||
showEphemeralBadges,
|
||||
conversationId,
|
||||
isSubmitting,
|
||||
onChange,
|
||||
onToggle,
|
||||
isInChat,
|
||||
@@ -317,7 +320,7 @@ function BadgeRow({
|
||||
}, [dragState.draggedBadge, handleMouseMove, handleMouseUp]);
|
||||
|
||||
return (
|
||||
<BadgeRowProvider conversationId={conversationId}>
|
||||
<BadgeRowProvider conversationId={conversationId} isSubmitting={isSubmitting}>
|
||||
<div ref={containerRef} className="relative flex flex-wrap items-center gap-2">
|
||||
{showEphemeralBadges === true && <ToolsDropdown />}
|
||||
{tempBadges.map((badge, index) => (
|
||||
@@ -364,6 +367,7 @@ function BadgeRow({
|
||||
<WebSearch />
|
||||
<CodeInterpreter />
|
||||
<FileSearch />
|
||||
<Artifacts />
|
||||
<MCPSelect />
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -305,6 +305,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
||||
</div>
|
||||
<BadgeRow
|
||||
showEphemeralBadges={!isAgentsEndpoint(endpoint) && !isAssistantsEndpoint(endpoint)}
|
||||
isSubmitting={isSubmitting || isSubmittingAdded}
|
||||
conversationId={conversationId}
|
||||
onChange={setBadges}
|
||||
isInChat={
|
||||
|
||||
@@ -4,34 +4,41 @@ import {
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAgentsEndpoint,
|
||||
EndpointFileConfig,
|
||||
isAssistantsEndpoint,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig } from 'librechat-data-provider';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import AttachFile from './AttachFile';
|
||||
|
||||
function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const { conversation } = useChatContext();
|
||||
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]);
|
||||
const { endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(endpoint), [endpoint]);
|
||||
const isAssistants = useMemo(() => isAssistantsEndpoint(endpoint), [endpoint]);
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
});
|
||||
|
||||
const endpointFileConfig = fileConfig.endpoints[_endpoint ?? ''] as
|
||||
| EndpointFileConfig
|
||||
| undefined;
|
||||
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? _endpoint ?? ''] ?? false;
|
||||
const endpointFileConfig = fileConfig.endpoints[endpoint ?? ''] as EndpointFileConfig | undefined;
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled = (disableInputs || endpointFileConfig?.disabled) ?? false;
|
||||
|
||||
if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) {
|
||||
return <AttachFileMenu disabled={disableInputs} conversationId={conversationId} />;
|
||||
if (isAssistants && endpointSupportsFiles && !isUploadDisabled) {
|
||||
return <AttachFile disabled={disableInputs} />;
|
||||
} else if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) {
|
||||
return (
|
||||
<AttachFileMenu
|
||||
disabled={disableInputs}
|
||||
conversationId={conversationId}
|
||||
endpointFileConfig={endpointFileConfig}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,38 +2,37 @@ import { useSetRecoilState } from 'recoil';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import React, { useRef, useState, useMemo } from 'react';
|
||||
import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
|
||||
import { EToolResources, EModelEndpoint, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig } from 'librechat-data-provider';
|
||||
import { useLocalize, useGetAgentsConfig, useFileHandling, useAgentCapabilities } from '~/hooks';
|
||||
import { FileUpload, TooltipAnchor, DropdownPopup, AttachmentIcon } from '~/components';
|
||||
import { EToolResources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useLocalize, useFileHandling } from '~/hooks';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface AttachFileMenuProps {
|
||||
conversationId: string;
|
||||
disabled?: boolean | null;
|
||||
endpointFileConfig?: EndpointFileConfig;
|
||||
}
|
||||
|
||||
const AttachFileMenu = ({ disabled, conversationId }: AttachFileMenuProps) => {
|
||||
const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: AttachFileMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { handleFileChange } = useFileHandling({
|
||||
overrideEndpoint: EModelEndpoint.agents,
|
||||
overrideEndpointFileConfig: endpointFileConfig,
|
||||
});
|
||||
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
/** TODO: Ephemeral Agent Capabilities
|
||||
* Allow defining agent capabilities on a per-endpoint basis
|
||||
* Use definition for agents endpoint for ephemeral agents
|
||||
* */
|
||||
const capabilities = useMemo(
|
||||
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
|
||||
[endpointsConfig],
|
||||
);
|
||||
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
|
||||
const handleUploadClick = (isImage?: boolean) => {
|
||||
if (!inputRef.current) {
|
||||
@@ -57,7 +56,7 @@ const AttachFileMenu = ({ disabled, conversationId }: AttachFileMenuProps) => {
|
||||
},
|
||||
];
|
||||
|
||||
if (capabilities.includes(EToolResources.ocr)) {
|
||||
if (capabilities.ocrEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_ocr_text'),
|
||||
onClick: () => {
|
||||
@@ -68,7 +67,7 @@ const AttachFileMenu = ({ disabled, conversationId }: AttachFileMenuProps) => {
|
||||
});
|
||||
}
|
||||
|
||||
if (capabilities.includes(EToolResources.file_search)) {
|
||||
if (capabilities.fileSearchEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
onClick: () => {
|
||||
@@ -80,7 +79,7 @@ const AttachFileMenu = ({ disabled, conversationId }: AttachFileMenuProps) => {
|
||||
});
|
||||
}
|
||||
|
||||
if (capabilities.includes(EToolResources.execute_code)) {
|
||||
if (capabilities.codeEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_code_files'),
|
||||
onClick: () => {
|
||||
|
||||
@@ -2,11 +2,18 @@ import React, { useState, useMemo, useCallback } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { Globe, Settings, Settings2, TerminalSquareIcon } from 'lucide-react';
|
||||
import type { MenuItemProps } from '~/common';
|
||||
import { Permissions, PermissionTypes, AuthType } from 'librechat-data-provider';
|
||||
import {
|
||||
AuthType,
|
||||
Permissions,
|
||||
ArtifactModes,
|
||||
PermissionTypes,
|
||||
defaultAgentCapabilities,
|
||||
} from 'librechat-data-provider';
|
||||
import { TooltipAnchor, DropdownPopup } from '~/components';
|
||||
import { useLocalize, useHasAccess, useAgentCapabilities } from '~/hooks';
|
||||
import ArtifactsSubMenu from '~/components/Chat/Input/ArtifactsSubMenu';
|
||||
import MCPSubMenu from '~/components/Chat/Input/MCPSubMenu';
|
||||
import { PinIcon, VectorIcon } from '~/components/svg';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
@@ -21,12 +28,17 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
const {
|
||||
webSearch,
|
||||
mcpSelect,
|
||||
artifacts,
|
||||
fileSearch,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
codeApiKeyForm,
|
||||
codeInterpreter,
|
||||
searchApiKeyForm,
|
||||
} = useBadgeRowContext();
|
||||
const { codeEnabled, webSearchEnabled, artifactsEnabled, fileSearchEnabled } =
|
||||
useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
|
||||
const { setIsDialogOpen: setIsCodeDialogOpen, menuTriggerRef: codeMenuTriggerRef } =
|
||||
codeApiKeyForm;
|
||||
const { setIsDialogOpen: setIsSearchDialogOpen, menuTriggerRef: searchMenuTriggerRef } =
|
||||
@@ -42,6 +54,7 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
authData: codeAuthData,
|
||||
} = codeInterpreter;
|
||||
const { isPinned: isFileSearchPinned, setIsPinned: setIsFileSearchPinned } = fileSearch;
|
||||
const { isPinned: isArtifactsPinned, setIsPinned: setIsArtifactsPinned } = artifacts;
|
||||
const {
|
||||
mcpValues,
|
||||
mcpServerNames,
|
||||
@@ -72,19 +85,46 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
|
||||
const handleWebSearchToggle = useCallback(() => {
|
||||
const newValue = !webSearch.toggleState;
|
||||
webSearch.debouncedChange({ isChecked: newValue });
|
||||
webSearch.debouncedChange({ value: newValue });
|
||||
}, [webSearch]);
|
||||
|
||||
const handleCodeInterpreterToggle = useCallback(() => {
|
||||
const newValue = !codeInterpreter.toggleState;
|
||||
codeInterpreter.debouncedChange({ isChecked: newValue });
|
||||
codeInterpreter.debouncedChange({ value: newValue });
|
||||
}, [codeInterpreter]);
|
||||
|
||||
const handleFileSearchToggle = useCallback(() => {
|
||||
const newValue = !fileSearch.toggleState;
|
||||
fileSearch.debouncedChange({ isChecked: newValue });
|
||||
fileSearch.debouncedChange({ value: newValue });
|
||||
}, [fileSearch]);
|
||||
|
||||
const handleArtifactsToggle = useCallback(() => {
|
||||
const currentState = artifacts.toggleState;
|
||||
if (!currentState || currentState === '') {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
artifacts.debouncedChange({ value: '' });
|
||||
}
|
||||
}, [artifacts]);
|
||||
|
||||
const handleShadcnToggle = useCallback(() => {
|
||||
const currentState = artifacts.toggleState;
|
||||
if (currentState === ArtifactModes.SHADCNUI) {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.SHADCNUI });
|
||||
}
|
||||
}, [artifacts]);
|
||||
|
||||
const handleCustomToggle = useCallback(() => {
|
||||
const currentState = artifacts.toggleState;
|
||||
if (currentState === ArtifactModes.CUSTOM) {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.CUSTOM });
|
||||
}
|
||||
}, [artifacts]);
|
||||
|
||||
const handleMCPToggle = useCallback(
|
||||
(serverName: string) => {
|
||||
const currentValues = mcpSelect.mcpValues ?? [];
|
||||
@@ -98,9 +138,10 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
|
||||
const mcpPlaceholder = startupConfig?.interface?.mcpServers?.placeholder;
|
||||
|
||||
const dropdownItems = useMemo(() => {
|
||||
const items: MenuItemProps[] = [];
|
||||
items.push({
|
||||
const dropdownItems: MenuItemProps[] = [];
|
||||
|
||||
if (fileSearchEnabled) {
|
||||
dropdownItems.push({
|
||||
onClick: handleFileSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
@@ -129,159 +170,149 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canUseWebSearch) {
|
||||
items.push({
|
||||
onClick: handleWebSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Globe className="icon-md" />
|
||||
<span>{localize('com_ui_web_search')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showWebSearchSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchDialogOpen(true);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label="Configure web search"
|
||||
ref={searchMenuTriggerRef}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
if (canUseWebSearch && webSearchEnabled) {
|
||||
dropdownItems.push({
|
||||
onClick: handleWebSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Globe className="icon-md" />
|
||||
<span>{localize('com_ui_web_search')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showWebSearchSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchPinned(!isSearchPinned);
|
||||
setIsSearchDialogOpen(true);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isSearchPinned && 'text-text-secondary hover:text-text-primary',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isSearchPinned ? 'Unpin' : 'Pin'}
|
||||
aria-label="Configure web search"
|
||||
ref={searchMenuTriggerRef}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isSearchPinned} />
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canRunCode) {
|
||||
items.push({
|
||||
onClick: handleCodeInterpreterToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<TerminalSquareIcon className="icon-md" />
|
||||
<span>{localize('com_assistants_code_interpreter')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showCodeSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodeDialogOpen(true);
|
||||
}}
|
||||
ref={codeMenuTriggerRef}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label="Configure code interpreter"
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchPinned(!isSearchPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isSearchPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isSearchPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isSearchPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canRunCode && codeEnabled) {
|
||||
dropdownItems.push({
|
||||
onClick: handleCodeInterpreterToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<TerminalSquareIcon className="icon-md" />
|
||||
<span>{localize('com_assistants_code_interpreter')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showCodeSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodePinned(!isCodePinned);
|
||||
setIsCodeDialogOpen(true);
|
||||
}}
|
||||
ref={codeMenuTriggerRef}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isCodePinned && 'text-text-primary hover:text-text-primary',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isCodePinned ? 'Unpin' : 'Pin'}
|
||||
aria-label="Configure code interpreter"
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isCodePinned} />
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodePinned(!isCodePinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isCodePinned && 'text-text-primary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isCodePinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isCodePinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (mcpServerNames && mcpServerNames.length > 0) {
|
||||
items.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<MCPSubMenu
|
||||
{...props}
|
||||
mcpValues={mcpValues}
|
||||
isMCPPinned={isMCPPinned}
|
||||
placeholder={mcpPlaceholder}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsMCPPinned={setIsMCPPinned}
|
||||
handleMCPToggle={handleMCPToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
if (artifactsEnabled) {
|
||||
dropdownItems.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<ArtifactsSubMenu
|
||||
{...props}
|
||||
isArtifactsPinned={isArtifactsPinned}
|
||||
setIsArtifactsPinned={setIsArtifactsPinned}
|
||||
artifactsMode={artifacts.toggleState as string}
|
||||
handleArtifactsToggle={handleArtifactsToggle}
|
||||
handleShadcnToggle={handleShadcnToggle}
|
||||
handleCustomToggle={handleCustomToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
return items;
|
||||
}, [
|
||||
localize,
|
||||
mcpValues,
|
||||
canRunCode,
|
||||
isMCPPinned,
|
||||
isCodePinned,
|
||||
mcpPlaceholder,
|
||||
mcpServerNames,
|
||||
isSearchPinned,
|
||||
setIsMCPPinned,
|
||||
canUseWebSearch,
|
||||
setIsCodePinned,
|
||||
handleMCPToggle,
|
||||
showCodeSettings,
|
||||
setIsSearchPinned,
|
||||
isFileSearchPinned,
|
||||
codeMenuTriggerRef,
|
||||
setIsCodeDialogOpen,
|
||||
searchMenuTriggerRef,
|
||||
showWebSearchSettings,
|
||||
setIsFileSearchPinned,
|
||||
handleWebSearchToggle,
|
||||
setIsSearchDialogOpen,
|
||||
handleFileSearchToggle,
|
||||
handleCodeInterpreterToggle,
|
||||
]);
|
||||
if (mcpServerNames && mcpServerNames.length > 0) {
|
||||
dropdownItems.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<MCPSubMenu
|
||||
{...props}
|
||||
mcpValues={mcpValues}
|
||||
isMCPPinned={isMCPPinned}
|
||||
placeholder={mcpPlaceholder}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsMCPPinned={setIsMCPPinned}
|
||||
handleMCPToggle={handleMCPToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
const menuTrigger = (
|
||||
<TooltipAnchor
|
||||
|
||||
@@ -8,7 +8,7 @@ import { useBadgeRowContext } from '~/Providers';
|
||||
function WebSearch() {
|
||||
const localize = useLocalize();
|
||||
const { webSearch: webSearchData, searchApiKeyForm } = useBadgeRowContext();
|
||||
const { toggleState: webSearch, debouncedChange, isPinned } = webSearchData;
|
||||
const { toggleState: webSearch, debouncedChange, isPinned, authData } = webSearchData;
|
||||
const { badgeTriggerRef } = searchApiKeyForm;
|
||||
|
||||
const canUseWebSearch = useHasAccess({
|
||||
@@ -21,7 +21,7 @@ function WebSearch() {
|
||||
}
|
||||
|
||||
return (
|
||||
(webSearch || isPinned) && (
|
||||
(isPinned || (webSearch && authData?.authenticated)) && (
|
||||
<CheckboxButton
|
||||
ref={badgeTriggerRef}
|
||||
className="max-w-fit"
|
||||
|
||||
@@ -81,14 +81,23 @@ const ContentParts = memo(
|
||||
return (
|
||||
<>
|
||||
{content.map((part, idx) => {
|
||||
if (part?.type !== ContentTypes.TEXT || typeof part.text !== 'string') {
|
||||
if (!part) {
|
||||
return null;
|
||||
}
|
||||
const isTextPart =
|
||||
part?.type === ContentTypes.TEXT ||
|
||||
typeof (part as unknown as Agents.MessageContentText)?.text !== 'string';
|
||||
const isThinkPart =
|
||||
part?.type === ContentTypes.THINK ||
|
||||
typeof (part as unknown as Agents.ReasoningDeltaUpdate)?.think !== 'string';
|
||||
if (!isTextPart && !isThinkPart) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<EditTextPart
|
||||
index={idx}
|
||||
text={part.text}
|
||||
part={part as Agents.MessageContentText | Agents.ReasoningDeltaUpdate}
|
||||
messageId={messageId}
|
||||
isSubmitting={isSubmitting}
|
||||
enterEdit={enterEdit}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { ContentTypes } from 'librechat-data-provider';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
|
||||
import type { Agents } from 'librechat-data-provider';
|
||||
import type { TEditProps } from '~/common';
|
||||
import Container from '~/components/Chat/Messages/Content/Container';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
@@ -12,18 +13,19 @@ import { useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
|
||||
const EditTextPart = ({
|
||||
text,
|
||||
part,
|
||||
index,
|
||||
messageId,
|
||||
isSubmitting,
|
||||
enterEdit,
|
||||
}: Omit<TEditProps, 'message' | 'ask'> & {
|
||||
}: Omit<TEditProps, 'message' | 'ask' | 'text'> & {
|
||||
index: number;
|
||||
messageId: string;
|
||||
part: Agents.MessageContentText | Agents.ReasoningDeltaUpdate;
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const { getMessages, setMessages, conversation } = useChatContext();
|
||||
const { ask, getMessages, setMessages, conversation } = useChatContext();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
@@ -34,15 +36,16 @@ const EditTextPart = ({
|
||||
[getMessages, messageId],
|
||||
);
|
||||
|
||||
const chatDirection = useRecoilValue(store.chatDirection);
|
||||
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const updateMessageContentMutation = useUpdateMessageContentMutation(conversationId ?? '');
|
||||
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const isRTL = chatDirection === 'rtl';
|
||||
const isRTL = chatDirection?.toLowerCase() === 'rtl';
|
||||
|
||||
const { register, handleSubmit, setValue } = useForm({
|
||||
defaultValues: {
|
||||
text: text ?? '',
|
||||
text: (ContentTypes.THINK in part ? part.think : part.text) || '',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -55,15 +58,7 @@ const EditTextPart = ({
|
||||
}
|
||||
}, []);
|
||||
|
||||
/*
|
||||
const resubmitMessage = () => {
|
||||
showToast({
|
||||
status: 'warning',
|
||||
message: localize('com_warning_resubmit_unsupported'),
|
||||
});
|
||||
|
||||
// const resubmitMessage = (data: { text: string }) => {
|
||||
// Not supported by AWS Bedrock
|
||||
const resubmitMessage = (data: { text: string }) => {
|
||||
const messages = getMessages();
|
||||
const parentMessage = messages?.find((msg) => msg.messageId === message?.parentMessageId);
|
||||
|
||||
@@ -73,17 +68,19 @@ const EditTextPart = ({
|
||||
ask(
|
||||
{ ...parentMessage },
|
||||
{
|
||||
editedText: data.text,
|
||||
editedContent: {
|
||||
index,
|
||||
text: data.text,
|
||||
type: part.type,
|
||||
},
|
||||
editedMessageId: messageId,
|
||||
isRegenerate: true,
|
||||
isEdited: true,
|
||||
},
|
||||
);
|
||||
|
||||
setSiblingIdx((siblingIdx ?? 0) - 1);
|
||||
enterEdit(true);
|
||||
};
|
||||
*/
|
||||
|
||||
const updateMessage = (data: { text: string }) => {
|
||||
const messages = getMessages();
|
||||
@@ -167,13 +164,13 @@ const EditTextPart = ({
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-2 flex w-full justify-center text-center">
|
||||
{/* <button
|
||||
<button
|
||||
className="btn btn-primary relative mr-2"
|
||||
disabled={isSubmitting}
|
||||
onClick={handleSubmit(resubmitMessage)}
|
||||
>
|
||||
{localize('com_ui_save_submit')}
|
||||
</button> */}
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-secondary relative mr-2"
|
||||
disabled={isSubmitting}
|
||||
|
||||
@@ -233,9 +233,17 @@ export default function Fork({
|
||||
status: 'info',
|
||||
});
|
||||
},
|
||||
onError: () => {
|
||||
onError: (error) => {
|
||||
/** Rate limit error (429 status code) */
|
||||
const isRateLimitError =
|
||||
(error as any)?.response?.status === 429 ||
|
||||
(error as any)?.status === 429 ||
|
||||
(error as any)?.statusCode === 429;
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_fork_error'),
|
||||
message: isRateLimitError
|
||||
? localize('com_ui_fork_error_rate_limit')
|
||||
: localize('com_ui_fork_error'),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user