Compare commits
63 Commits
feat/mcp-p
...
feat/merma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f67dd1b1b7 | ||
|
|
4136dda7c7 | ||
|
|
c53bdc1fef | ||
|
|
170cc340d8 | ||
|
|
f1b29ffb45 | ||
|
|
6aa4bb5a4a | ||
|
|
9f44187351 | ||
|
|
d2e1ca4c4a | ||
|
|
8e869f2274 | ||
|
|
2e1874e596 | ||
|
|
929b433662 | ||
|
|
1e4f1f780c | ||
|
|
4733f10e41 | ||
|
|
110984b48f | ||
|
|
19320f2296 | ||
|
|
8523074e87 | ||
|
|
e4531d682d | ||
|
|
4bbdc4c402 | ||
|
|
8ca4cf3d2f | ||
|
|
13a9bcdd48 | ||
|
|
4b32ec42c6 | ||
|
|
4918899c8d | ||
|
|
7e37211458 | ||
|
|
e57fc83d40 | ||
|
|
550610dba9 | ||
|
|
916cd46221 | ||
|
|
12b08183ff | ||
|
|
f4d97e1672 | ||
|
|
035fa081c1 | ||
|
|
aecf8f19a6 | ||
|
|
35f548a94d | ||
|
|
e60c0cf201 | ||
|
|
5b392f9cb0 | ||
|
|
e0f468da20 | ||
|
|
91a2df4759 | ||
|
|
97a99985fa | ||
|
|
3554625a06 | ||
|
|
a37bf6719c | ||
|
|
e513f50c08 | ||
|
|
f5511e4a4e | ||
|
|
a288ad1d9c | ||
|
|
458580ec87 | ||
|
|
4285d5841c | ||
|
|
5ee55cda4f | ||
|
|
404d40cbef | ||
|
|
f4680b016c | ||
|
|
077224b351 | ||
|
|
9c70d1db96 | ||
|
|
543281da6c | ||
|
|
24800bfbeb | ||
|
|
07e08143e4 | ||
|
|
8ba61a86f4 | ||
|
|
56ad92fb1c | ||
|
|
1ceb52d2b5 | ||
|
|
5d267aa8e2 | ||
|
|
59d00e99f3 | ||
|
|
738d04fac4 | ||
|
|
8a5dbac0f9 | ||
|
|
434289fe92 | ||
|
|
a648ad3d13 | ||
|
|
55d63caaf4 | ||
|
|
313539d1ed | ||
|
|
f869d772f7 |
@@ -349,6 +349,11 @@ REGISTRATION_VIOLATION_SCORE=1
|
||||
CONCURRENT_VIOLATION_SCORE=1
|
||||
MESSAGE_VIOLATION_SCORE=1
|
||||
NON_BROWSER_VIOLATION_SCORE=20
|
||||
TTS_VIOLATION_SCORE=0
|
||||
STT_VIOLATION_SCORE=0
|
||||
FORK_VIOLATION_SCORE=0
|
||||
IMPORT_VIOLATION_SCORE=0
|
||||
FILE_UPLOAD_VIOLATION_SCORE=0
|
||||
|
||||
LOGIN_MAX=7
|
||||
LOGIN_WINDOW=5
|
||||
@@ -575,6 +580,10 @@ ALLOW_SHARED_LINKS_PUBLIC=true
|
||||
# If you have another service in front of your LibreChat doing compression, disable express based compression here
|
||||
# DISABLE_COMPRESSION=true
|
||||
|
||||
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
|
||||
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
|
||||
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
|
||||
|
||||
#===================================================#
|
||||
# UI #
|
||||
#===================================================#
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -56,6 +56,7 @@ bower_components/
|
||||
.clineignore
|
||||
.cursor
|
||||
.aider*
|
||||
CLAUDE.md
|
||||
|
||||
# Floobits
|
||||
.floo
|
||||
@@ -124,4 +125,4 @@ helm/**/.values.yaml
|
||||
!/client/src/@types/i18next.d.ts
|
||||
|
||||
# SAML Idp cert
|
||||
*.cert
|
||||
*.cert
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.7.8
|
||||
# v0.7.9-rc1
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.7.8
|
||||
# v0.7.9-rc1
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
|
||||
|
||||
- 🤖 **AI Model Selection**:
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
|
||||
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
|
||||
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
||||
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
||||
@@ -66,10 +66,9 @@
|
||||
- 🔦 **Agents & Tools Integration**:
|
||||
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
||||
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
|
||||
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
|
||||
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
||||
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
|
||||
|
||||
- 🔍 **Web Search**:
|
||||
- Search the internet and retrieve relevant information to enhance your AI context
|
||||
|
||||
@@ -13,7 +13,6 @@ const {
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { addSpaceIfNeeded } = require('~/server/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
@@ -572,7 +571,7 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
const { generation = '' } = opts;
|
||||
const { editedContent } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
// depending on subclass implementation of handling messages
|
||||
@@ -587,11 +586,21 @@ class BaseClient {
|
||||
isCreatedByUser: false,
|
||||
model: this.modelOptions?.model ?? this.model,
|
||||
sender: this.sender,
|
||||
text: generation,
|
||||
};
|
||||
this.currentMessages.push(userMessage, latestMessage);
|
||||
} else {
|
||||
latestMessage.text = generation;
|
||||
} else if (editedContent != null) {
|
||||
// Handle editedContent for content parts
|
||||
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
|
||||
const { index, text, type } = editedContent;
|
||||
if (index >= 0 && index < latestMessage.content.length) {
|
||||
const contentPart = latestMessage.content[index];
|
||||
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
|
||||
contentPart[ContentTypes.THINK] = text;
|
||||
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
|
||||
contentPart[ContentTypes.TEXT] = text;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this.continued = true;
|
||||
} else {
|
||||
@@ -672,16 +681,32 @@ class BaseClient {
|
||||
};
|
||||
|
||||
if (typeof completion === 'string') {
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion;
|
||||
responseMessage.text = completion;
|
||||
} else if (
|
||||
Array.isArray(completion) &&
|
||||
(this.clientName === EModelEndpoint.agents ||
|
||||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
||||
) {
|
||||
responseMessage.text = '';
|
||||
responseMessage.content = completion;
|
||||
|
||||
if (!opts.editedContent || this.currentMessages.length === 0) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
|
||||
if (!latestMessage?.content) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const existingContent = [...latestMessage.content];
|
||||
const { type: editedType } = opts.editedContent;
|
||||
responseMessage.content = this.mergeEditedContent(
|
||||
existingContent,
|
||||
completion,
|
||||
editedType,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (Array.isArray(completion)) {
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
|
||||
responseMessage.text = completion.join('');
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -1095,6 +1120,50 @@ class BaseClient {
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges completion content with existing content when editing TEXT or THINK types
|
||||
* @param {Array} existingContent - The existing content array
|
||||
* @param {Array} newCompletion - The new completion content
|
||||
* @param {string} editedType - The type of content being edited
|
||||
* @returns {Array} The merged content array
|
||||
*/
|
||||
mergeEditedContent(existingContent, newCompletion, editedType) {
|
||||
if (!newCompletion.length) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const lastIndex = existingContent.length - 1;
|
||||
const lastExisting = existingContent[lastIndex];
|
||||
const firstNew = newCompletion[0];
|
||||
|
||||
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const mergedContent = [...existingContent];
|
||||
if (editedType === ContentTypes.TEXT) {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.TEXT]:
|
||||
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
|
||||
};
|
||||
} else {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.THINK]:
|
||||
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
|
||||
(firstNew[ContentTypes.THINK] || ''),
|
||||
};
|
||||
}
|
||||
|
||||
// Add remaining completion items
|
||||
return mergedContent.concat(newCompletion.slice(1));
|
||||
}
|
||||
|
||||
async sendPayload(payload, opts = {}) {
|
||||
if (opts && typeof opts === 'object') {
|
||||
this.setOptions(opts);
|
||||
|
||||
@@ -11,17 +11,25 @@ const { getFiles } = require('~/models/File');
|
||||
* @param {Object} options
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Agent['tool_resources']} options.tool_resources
|
||||
* @param {string} [options.agentId] - The agent ID for file access control
|
||||
* @returns {Promise<{
|
||||
* files: Array<{ file_id: string; filename: string }>,
|
||||
* toolContext: string
|
||||
* }>}
|
||||
*/
|
||||
const primeFiles = async (options) => {
|
||||
const { tool_resources } = options;
|
||||
const { tool_resources, req, agentId } = options;
|
||||
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
|
||||
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
|
||||
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
|
||||
|
||||
|
||||
@@ -245,7 +245,13 @@ const loadTools = async ({
|
||||
authFields: [EnvVar.CODE_API_KEY],
|
||||
});
|
||||
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
|
||||
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
|
||||
const { files, toolContext } = await primeCodeFiles(
|
||||
{
|
||||
...options,
|
||||
agentId: agent?.id,
|
||||
},
|
||||
codeApiKey,
|
||||
);
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
@@ -260,7 +266,10 @@ const loadTools = async ({
|
||||
continue;
|
||||
} else if (tool === Tools.file_search) {
|
||||
requestedTools[tool] = async () => {
|
||||
const { files, toolContext } = await primeSearchFiles(options);
|
||||
const { files, toolContext } = await primeSearchFiles({
|
||||
...options,
|
||||
agentId: agent?.id,
|
||||
});
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
|
||||
2
api/cache/logViolation.js
vendored
2
api/cache/logViolation.js
vendored
@@ -9,7 +9,7 @@ const banViolation = require('./banViolation');
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {string} type - The type of violation.
|
||||
* @param {Object} errorMessage - The error message to log.
|
||||
* @param {number} [score=1] - The severity of the violation. Defaults to 1
|
||||
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
|
||||
*/
|
||||
const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||
const userId = req.user?.id ?? req.user?._id;
|
||||
|
||||
@@ -90,7 +90,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
return {
|
||||
const result = {
|
||||
id: agent_id,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
@@ -98,6 +98,11 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EToolResources, FileContext } = require('librechat-data-provider');
|
||||
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { getAgent } = require('./Agent');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
/**
|
||||
@@ -12,17 +14,119 @@ const findFileById = async (file_id, options = {}) => {
|
||||
return await File.findOne({ file_id, ...options }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a user has access to multiple files through a shared agent (batch operation)
|
||||
* @param {string} userId - The user ID to check access for
|
||||
* @param {string[]} fileIds - Array of file IDs to check
|
||||
* @param {string} agentId - The agent ID that might grant access
|
||||
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
|
||||
*/
|
||||
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId) => {
|
||||
const accessMap = new Map();
|
||||
|
||||
// Initialize all files as no access
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, false));
|
||||
|
||||
try {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
|
||||
if (!agent) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if user is the author - if so, grant access to all files
|
||||
if (agent.author.toString() === userId) {
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, true));
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is shared with the user via projects
|
||||
if (!agent.projectIds || agent.projectIds.length === 0) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is in global project
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
|
||||
) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Agent is globally shared - check if it's collaborative
|
||||
if (!agent.isCollaborative) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Agent is globally shared and collaborative - check which files are actually attached
|
||||
const attachedFileIds = new Set();
|
||||
if (agent.tool_resources) {
|
||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grant access only to files that are attached to this agent
|
||||
fileIds.forEach((fileId) => {
|
||||
if (attachedFileIds.has(fileId)) {
|
||||
accessMap.set(fileId, true);
|
||||
}
|
||||
});
|
||||
|
||||
return accessMap;
|
||||
} catch (error) {
|
||||
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
|
||||
return accessMap;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves files matching a given filter, sorted by the most recently updated.
|
||||
* @param {Object} filter - The filter criteria to apply.
|
||||
* @param {Object} [_sortOptions] - Optional sort parameters.
|
||||
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
|
||||
* Default excludes the 'text' field.
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {string} [options.userId] - User ID for access control
|
||||
* @param {string} [options.agentId] - Agent ID that might grant access to files
|
||||
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
|
||||
*/
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
|
||||
// If userId and agentId are provided, filter files based on access
|
||||
if (options.userId && options.agentId) {
|
||||
// Collect file IDs that need access check
|
||||
const filesToCheck = [];
|
||||
const ownedFiles = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (file.user && file.user.toString() === options.userId) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
filesToCheck.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
if (filesToCheck.length === 0) {
|
||||
return ownedFiles;
|
||||
}
|
||||
|
||||
// Batch check access for all non-owned files
|
||||
const fileIds = filesToCheck.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(options.userId, fileIds, options.agentId);
|
||||
|
||||
// Filter files based on access
|
||||
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
|
||||
|
||||
return [...ownedFiles, ...accessibleFiles];
|
||||
}
|
||||
|
||||
return files;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -176,4 +280,5 @@ module.exports = {
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
hasAccessToFilesViaAgent,
|
||||
};
|
||||
|
||||
264
api/models/File.spec.js
Normal file
264
api/models/File.spec.js
Normal file
@@ -0,0 +1,264 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { fileSchema } = require('@librechat/data-schemas');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { projectSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { getFiles, createFile } = require('./File');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { createAgent } = require('./Agent');
|
||||
|
||||
let File;
|
||||
let Agent;
|
||||
let Project;
|
||||
|
||||
describe('File Access Control', () => {
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
File = mongoose.models.File || mongoose.model('File', fileSchema);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await Project.deleteMany({});
|
||||
});
|
||||
|
||||
describe('hasAccessToFilesViaAgent', () => {
|
||||
it('should efficiently check access for multiple files at once', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: `file-${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with only first two files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0], fileIds[1]],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
|
||||
// Check access for all files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
|
||||
// Should have access only to the first two files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(false);
|
||||
expect(accessMap.get(fileIds[3])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access to all files when user is the agent author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0]], // Only one file attached
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Check access as the author
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
|
||||
|
||||
// Author should have access to all files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle non-existent agent gracefully', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
|
||||
|
||||
// Should have no access to any files
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when agent is not collaborative', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent with files but isCollaborative: false
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
|
||||
// Should have no access to any files when isCollaborative is false
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFiles with agent access control', () => {
|
||||
test('should return files owned by user and files accessible through agent', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const ownedFileId = `file_${uuidv4()}`;
|
||||
const sharedFileId = `file_${uuidv4()}`;
|
||||
const inaccessibleFileId = `file_${uuidv4()}`;
|
||||
|
||||
// Create/get global project using getProjectByName which will upsert
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
|
||||
|
||||
// Create agent with shared file
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Shared Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
projectIds: [globalProject._id],
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [sharedFileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
file_id: ownedFileId,
|
||||
user: userId,
|
||||
filename: 'owned.txt',
|
||||
filepath: '/uploads/owned.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: sharedFileId,
|
||||
user: authorId,
|
||||
filename: 'shared.txt',
|
||||
filepath: '/uploads/shared.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
embedded: true,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: inaccessibleFileId,
|
||||
user: authorId,
|
||||
filename: 'inaccessible.txt',
|
||||
filepath: '/uploads/inaccessible.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 300,
|
||||
});
|
||||
|
||||
// Get files with access control
|
||||
const files = await getFiles(
|
||||
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: userId.toString(), agentId },
|
||||
);
|
||||
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
|
||||
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
|
||||
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
|
||||
});
|
||||
|
||||
test('should return all files when no userId/agentId provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileId1 = `file_${uuidv4()}`;
|
||||
const fileId2 = `file_${uuidv4()}`;
|
||||
|
||||
await createFile({
|
||||
file_id: fileId1,
|
||||
user: userId,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: fileId2,
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
});
|
||||
|
||||
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
@@ -135,10 +135,11 @@ const tokenValues = Object.assign(
|
||||
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
|
||||
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
|
||||
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
|
||||
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
|
||||
'grok-3': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-4': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'mistral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'pixtral-large': { prompt: 2.0, completion: 6.0 },
|
||||
|
||||
@@ -636,6 +636,15 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 model', () => {
|
||||
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-3'].prompt,
|
||||
@@ -662,6 +671,15 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
tokenValues['grok-3-mini-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.7.8",
|
||||
"version": "v0.7.9-rc1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -48,7 +48,7 @@
|
||||
"@langchain/google-genai": "^0.2.13",
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.46",
|
||||
"@librechat/agents": "^2.4.59",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
|
||||
@@ -24,17 +24,23 @@ const handleValidationError = (err, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
module.exports = (err, req, res, next) => {
|
||||
module.exports = (err, _req, res, _next) => {
|
||||
try {
|
||||
if (err.name === 'ValidationError') {
|
||||
return (err = handleValidationError(err, res));
|
||||
return handleValidationError(err, res);
|
||||
}
|
||||
if (err.code && err.code == 11000) {
|
||||
return (err = handleDuplicateKeyError(err, res));
|
||||
return handleDuplicateKeyError(err, res);
|
||||
}
|
||||
} catch (err) {
|
||||
// Special handling for errors like SyntaxError
|
||||
if (err.statusCode && err.body) {
|
||||
return res.status(err.statusCode).send(err.body);
|
||||
}
|
||||
|
||||
logger.error('ErrorController => error', err);
|
||||
res.status(500).send('An unknown error occurred.');
|
||||
return res.status(500).send('An unknown error occurred.');
|
||||
} catch (err) {
|
||||
logger.error('ErrorController => processing error', err);
|
||||
return res.status(500).send('Processing error in ErrorController.');
|
||||
}
|
||||
};
|
||||
|
||||
241
api/server/controllers/ErrorController.spec.js
Normal file
241
api/server/controllers/ErrorController.spec.js
Normal file
@@ -0,0 +1,241 @@
|
||||
const errorController = require('./ErrorController');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// Mock the logger
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ErrorController', () => {
|
||||
let mockReq, mockRes, mockNext;
|
||||
|
||||
beforeEach(() => {
|
||||
mockReq = {};
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
mockNext = jest.fn();
|
||||
logger.error.mockClear();
|
||||
});
|
||||
|
||||
describe('ValidationError handling', () => {
|
||||
it('should handle ValidationError with single error', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '["Email is required"]',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with multiple errors', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
password: { message: 'Password is required', path: 'password' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '"Email is required Password is required"',
|
||||
fields: '["email","password"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with empty errors object', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '[]',
|
||||
fields: '[]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Duplicate key error handling', () => {
|
||||
it('should handle duplicate key error (code 11000)', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle duplicate key error with multiple fields', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com', username: 'testuser' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email","username"] already exists.',
|
||||
fields: '["email","username"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle error with code 11000 as string', () => {
|
||||
const duplicateKeyError = {
|
||||
code: '11000',
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('SyntaxError handling', () => {
|
||||
it('should handle errors with statusCode and body', () => {
|
||||
const syntaxError = {
|
||||
statusCode: 400,
|
||||
body: 'Invalid JSON syntax',
|
||||
};
|
||||
|
||||
errorController(syntaxError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
||||
});
|
||||
|
||||
it('should handle errors with different statusCode and body', () => {
|
||||
const customError = {
|
||||
statusCode: 422,
|
||||
body: { error: 'Unprocessable entity' },
|
||||
};
|
||||
|
||||
errorController(customError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(422);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
||||
});
|
||||
|
||||
it('should handle error with statusCode but no body', () => {
|
||||
const partialError = {
|
||||
statusCode: 400,
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
|
||||
it('should handle error with body but no statusCode', () => {
|
||||
const partialError = {
|
||||
body: 'Some error message',
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Unknown error handling', () => {
|
||||
it('should handle unknown errors', () => {
|
||||
const unknownError = new Error('Some unknown error');
|
||||
|
||||
errorController(unknownError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
|
||||
});
|
||||
|
||||
it('should handle errors with code other than 11000', () => {
|
||||
const mongoError = {
|
||||
code: 11100,
|
||||
message: 'Some MongoDB error',
|
||||
};
|
||||
|
||||
errorController(mongoError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
|
||||
});
|
||||
|
||||
it('should handle null/undefined errors', () => {
|
||||
errorController(null, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'ErrorController => processing error',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Catch block handling', () => {
|
||||
beforeEach(() => {
|
||||
// Restore logger mock to normal behavior for these tests
|
||||
logger.error.mockRestore();
|
||||
logger.error = jest.fn();
|
||||
});
|
||||
|
||||
it('should handle errors when logger.error throws', () => {
|
||||
// Create fresh mocks for this test
|
||||
const freshMockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
|
||||
// Mock logger to throw on the first call, succeed on the second
|
||||
logger.error
|
||||
.mockImplementationOnce(() => {
|
||||
throw new Error('Logger error');
|
||||
})
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const testError = new Error('Test error');
|
||||
|
||||
errorController(testError, mockReq, freshMockRes, mockNext);
|
||||
|
||||
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,10 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, AuthType } = require('librechat-data-provider');
|
||||
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getToolkitKey } = require('~/server/services/ToolService');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { availableTools } = require('~/app/clients/tools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Filters out duplicate plugins from the list of plugins.
|
||||
@@ -140,9 +139,9 @@ function createGetServerTools() {
|
||||
const getAvailableTools = async (req, res) => {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedTools = await cache.get(CacheKeys.TOOLS);
|
||||
if (cachedTools) {
|
||||
res.status(200).json(cachedTools);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
if (cachedToolsArray) {
|
||||
res.status(200).json(cachedToolsArray);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -173,7 +172,7 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
|
||||
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
sendEvent,
|
||||
createRun,
|
||||
@@ -31,13 +33,16 @@ const {
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config');
|
||||
const {
|
||||
findPluginAuthsByKeys,
|
||||
getFormattedMemories,
|
||||
deleteMemory,
|
||||
setMemory,
|
||||
} = require('~/models');
|
||||
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
@@ -54,6 +59,7 @@ const omitTitleOptions = new Set([
|
||||
'thinkingBudget',
|
||||
'includeThoughts',
|
||||
'maxOutputTokens',
|
||||
'additionalModelRequestFields',
|
||||
]);
|
||||
|
||||
/**
|
||||
@@ -525,7 +531,10 @@ class AgentClient extends BaseClient {
|
||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||
}
|
||||
}
|
||||
return await this.processMemory(messagesToProcess);
|
||||
|
||||
const bufferString = getBufferString(messagesToProcess);
|
||||
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
|
||||
return await this.processMemory([bufferMessage]);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
@@ -697,8 +706,6 @@ class AgentClient extends BaseClient {
|
||||
version: 'v2',
|
||||
};
|
||||
|
||||
const getUserMCPAuthMap = await createGetMCPAuthMap();
|
||||
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
@@ -819,10 +826,11 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
try {
|
||||
if (getUserMCPAuthMap) {
|
||||
config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
|
||||
if (await hasCustomUserVars()) {
|
||||
config.configurable.userMCPAuthMap = await getMCPAuthMap({
|
||||
tools: agent.tools,
|
||||
userId: this.options.req.user.id,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -1040,6 +1048,12 @@ class AgentClient extends BaseClient {
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
) {
|
||||
provider = Providers.OPENAI;
|
||||
} else if (
|
||||
endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName != null &&
|
||||
provider !== Providers.AZURE
|
||||
) {
|
||||
provider = Providers.AZURE;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
|
||||
@@ -14,8 +14,11 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
@@ -67,7 +70,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
handler();
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
logger.error('[AgentController] Error in cleanup handler', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -155,7 +158,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
try {
|
||||
res.removeListener('close', closeHandler);
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
logger.error('[AgentController] Error removing close listener', e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -163,10 +166,14 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res,
|
||||
},
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
const { z } = require('zod');
|
||||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
@@ -8,6 +10,7 @@ const {
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getAgent,
|
||||
@@ -30,6 +33,7 @@ const { deleteFileByFilter } = require('~/models/File');
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
[Tools.file_search]: true,
|
||||
[Tools.web_search]: true,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -42,9 +46,13 @@ const systemTools = {
|
||||
*/
|
||||
const createAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
|
||||
const validatedData = agentCreateSchema.parse(req.body);
|
||||
const { tools = [], ...agentData } = removeNullishValues(validatedData);
|
||||
|
||||
const { id: userId } = req.user;
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
@@ -58,19 +66,13 @@ const createAgentHandler = async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
Object.assign(agentData, {
|
||||
author: userId,
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
const agent = await createAgent(agentData);
|
||||
res.status(201).json(agent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
logger.error('[/Agents] Error creating agent', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
@@ -154,14 +156,16 @@ const getAgentHandler = async (req, res) => {
|
||||
const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const { projectIds, removeProjectIds, ...updateData } = req.body;
|
||||
const validatedData = agentUpdateSchema.parse(req.body);
|
||||
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id });
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
@@ -200,6 +204,11 @@ const updateAgentHandler = async (req, res) => {
|
||||
|
||||
return res.json(updatedAgent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents/:id] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
|
||||
logger.error('[/Agents/:id] Error updating Agent', error);
|
||||
|
||||
if (error.statusCode === 409) {
|
||||
@@ -382,6 +391,22 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
return res.status(400).json({ message: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id: agent_id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
}
|
||||
|
||||
const buffer = await fs.readFile(req.file.path);
|
||||
|
||||
const fileStrategy = req.app.locals.fileStrategy;
|
||||
@@ -404,14 +429,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
source: fileStrategy,
|
||||
};
|
||||
|
||||
let _avatar;
|
||||
try {
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
_avatar = agent.avatar;
|
||||
} catch (error) {
|
||||
logger.error('[/:agent_id/avatar] Error fetching agent', error);
|
||||
_avatar = {};
|
||||
}
|
||||
let _avatar = existingAgent.avatar;
|
||||
|
||||
if (_avatar && _avatar.source) {
|
||||
const { deleteFile } = getStrategyFunctions(_avatar.source);
|
||||
@@ -433,7 +451,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
};
|
||||
|
||||
promises.push(
|
||||
await updateAgent({ id: agent_id, author: req.user.id }, data, {
|
||||
await updateAgent({ id: agent_id }, data, {
|
||||
updatingUserId: req.user.id,
|
||||
}),
|
||||
);
|
||||
|
||||
659
api/server/controllers/agents/v1.spec.js
Normal file
659
api/server/controllers/agents/v1.spec.js
Normal file
@@ -0,0 +1,659 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
|
||||
// Only mock the dependencies that are not database-related
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
web_search: true,
|
||||
execute_code: true,
|
||||
file_search: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Project', () => ({
|
||||
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||
resizeAvatar: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3Url: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
filterFile: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Action', () => ({
|
||||
updateAction: jest.fn(),
|
||||
getActions: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/File', () => ({
|
||||
deleteFileByFilter: jest.fn(),
|
||||
}));
|
||||
|
||||
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
*/
|
||||
let Agent;
|
||||
|
||||
describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
|
||||
// Reset all mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Setup mock request and response objects
|
||||
mockReq = {
|
||||
user: {
|
||||
id: new mongoose.Types.ObjectId().toString(),
|
||||
role: 'USER',
|
||||
},
|
||||
body: {},
|
||||
params: {},
|
||||
app: {
|
||||
locals: {
|
||||
fileStrategy: 'local',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('createAgentHandler', () => {
|
||||
test('should create agent with allowed fields only', async () => {
|
||||
const validData = {
|
||||
name: 'Test Agent',
|
||||
description: 'A test agent',
|
||||
instructions: 'Be helpful',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search'],
|
||||
model_parameters: { temperature: 0.7 },
|
||||
tool_resources: {
|
||||
file_search: { file_ids: ['file1', 'file2'] },
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = validData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.name).toBe('Test Agent');
|
||||
expect(createdAgent.description).toBe('A test agent');
|
||||
expect(createdAgent.provider).toBe('openai');
|
||||
expect(createdAgent.model).toBe('gpt-4');
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
|
||||
expect(createdAgent.tools).toContain('web_search');
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb).toBeDefined();
|
||||
expect(agentInDb.name).toBe('Test Agent');
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
});
|
||||
|
||||
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
|
||||
const maliciousData = {
|
||||
// Required fields
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Malicious Agent',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
isCollaborative: true, // Should be stripped on creation
|
||||
versions: [], // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
id: 'custom_agent_id', // Should be overridden
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
mockReq.body = maliciousData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not set
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
|
||||
expect(createdAgent.authorName).toBeUndefined();
|
||||
expect(createdAgent.isCollaborative).toBeFalsy();
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
|
||||
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
|
||||
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
|
||||
|
||||
// Verify timestamps are recent (not the malicious dates)
|
||||
const createdTime = new Date(createdAgent.createdAt).getTime();
|
||||
const now = Date.now();
|
||||
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
expect(agentInDb.authorName).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should validate required fields', async () => {
|
||||
const invalidData = {
|
||||
name: 'Missing Required Fields',
|
||||
// Missing provider and model
|
||||
};
|
||||
|
||||
mockReq.body = invalidData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify nothing was created in database
|
||||
const count = await Agent.countDocuments();
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
test('should handle tool_resources validation', async () => {
|
||||
const dataWithInvalidToolResources = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Tool Resources',
|
||||
tool_resources: {
|
||||
// Valid resources
|
||||
file_search: {
|
||||
file_ids: ['file1', 'file2'],
|
||||
vector_store_ids: ['vs1'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['file3'],
|
||||
},
|
||||
// Invalid resource (should be stripped by schema)
|
||||
invalid_resource: {
|
||||
file_ids: ['file4'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidToolResources;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.tool_resources).toBeDefined();
|
||||
expect(createdAgent.tool_resources.file_search).toBeDefined();
|
||||
expect(createdAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle avatar validation', async () => {
|
||||
const dataWithAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Avatar',
|
||||
avatar: {
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.avatar).toEqual({
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle invalid avatar format', async () => {
|
||||
const dataWithInvalidAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Invalid Avatar',
|
||||
avatar: 'just-a-string', // Invalid format
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateAgentHandler', () => {
|
||||
let existingAgentId;
|
||||
let existingAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create an existing agent for update tests
|
||||
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
author: existingAgentAuthorId,
|
||||
description: 'Original description',
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
description: 'Original description',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
existingAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should update agent with allowed fields only', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Agent',
|
||||
description: 'Updated description',
|
||||
model: 'gpt-4',
|
||||
isCollaborative: true, // This IS allowed in updates
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(400);
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Updated Agent');
|
||||
expect(updatedAgent.description).toBe('Updated description');
|
||||
expect(updatedAgent.model).toBe('gpt-4');
|
||||
expect(updatedAgent.isCollaborative).toBe(true);
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Updated Agent');
|
||||
expect(agentInDb.isCollaborative).toBe(true);
|
||||
});
|
||||
|
||||
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Name',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
id: 'different_agent_id', // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
versions: [], // Should be stripped
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not changed
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
|
||||
expect(updatedAgent.authorName).toBeUndefined();
|
||||
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
|
||||
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
|
||||
expect(agentInDb.id).toBe(existingAgentId);
|
||||
});
|
||||
|
||||
test('should reject update from non-author when not collaborative', async () => {
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Unauthorized Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
|
||||
// Verify agent was not modified in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Original Agent');
|
||||
});
|
||||
|
||||
test('should allow update from non-author when collaborative', async () => {
|
||||
// First make the agent collaborative
|
||||
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
|
||||
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Collaborative Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Collaborative Update');
|
||||
// Author field should be removed for non-author
|
||||
expect(updatedAgent.author).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Collaborative Update');
|
||||
});
|
||||
|
||||
test('should allow admin to update any agent', async () => {
|
||||
const adminUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = adminUserId;
|
||||
mockReq.user.role = 'ADMIN'; // Set as admin
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Admin Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Admin Update');
|
||||
});
|
||||
|
||||
test('should handle projectIds updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
||||
const projectId1 = new mongoose.Types.ObjectId().toString();
|
||||
const projectId2 = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
mockReq.body = {
|
||||
projectIds: [projectId1, projectId2],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent).toBeDefined();
|
||||
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
|
||||
});
|
||||
|
||||
test('should validate tool_resources in updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tool_resources: {
|
||||
ocr: {
|
||||
file_ids: ['ocr1', 'ocr2'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['img1'],
|
||||
},
|
||||
// Invalid tool resource
|
||||
invalid_tool: {
|
||||
file_ids: ['invalid'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tool_resources).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.ocr).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return 404 for non-existent agent', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
|
||||
mockReq.body = {
|
||||
name: 'Update Non-existent',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(404);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
|
||||
});
|
||||
|
||||
test('should handle validation errors properly', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
model_parameters: 'invalid-not-an-object', // Should be an object
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mass Assignment Attack Scenarios', () => {
|
||||
test('should prevent setting system fields during creation', async () => {
|
||||
const systemFields = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'System Fields Test',
|
||||
|
||||
// System fields that should never be settable by users
|
||||
__v: 99,
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
versions: [
|
||||
{
|
||||
name: 'Fake Version',
|
||||
provider: 'fake',
|
||||
model: 'fake-model',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockReq.body = systemFields;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify system fields were not affected
|
||||
expect(createdAgent.__v).not.toBe(99);
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
|
||||
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
|
||||
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.__v).not.toBe(99);
|
||||
});
|
||||
|
||||
test('should prevent privilege escalation through isCollaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Try to make it collaborative as a different user
|
||||
const attackerId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = attackerId;
|
||||
mockReq.params.id = agent.id;
|
||||
mockReq.body = {
|
||||
isCollaborative: true, // Trying to escalate privileges
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
// Should be rejected
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
|
||||
// Verify in database that it's still not collaborative
|
||||
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||
expect(agentInDb.isCollaborative).toBe(false);
|
||||
});
|
||||
|
||||
test('should prevent author hijacking', async () => {
|
||||
const originalAuthorId = new mongoose.Types.ObjectId();
|
||||
const attackerId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Admin creates an agent
|
||||
mockReq.user.id = originalAuthorId.toString();
|
||||
mockReq.user.role = 'ADMIN';
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Admin Agent',
|
||||
author: attackerId.toString(), // Trying to set different author
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Author should be the actual user, not the attempted value
|
||||
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
|
||||
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
|
||||
});
|
||||
|
||||
test('should strip unknown fields to prevent future vulnerabilities', async () => {
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Future Proof Test',
|
||||
|
||||
// Unknown fields that might be added in future
|
||||
superAdminAccess: true,
|
||||
bypassAllChecks: true,
|
||||
internalFlag: 'secret',
|
||||
futureFeature: 'exploit',
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unknown fields were stripped
|
||||
expect(createdAgent.superAdminAccess).toBeUndefined();
|
||||
expect(createdAgent.bypassAllChecks).toBeUndefined();
|
||||
expect(createdAgent.internalFlag).toBeUndefined();
|
||||
expect(createdAgent.futureFeature).toBeUndefined();
|
||||
|
||||
// Also check in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
|
||||
expect(agentInDb.superAdminAccess).toBeUndefined();
|
||||
expect(agentInDb.bypassAllChecks).toBeUndefined();
|
||||
expect(agentInDb.internalFlag).toBeUndefined();
|
||||
expect(agentInDb.futureFeature).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -55,7 +55,6 @@ const startServer = async () => {
|
||||
|
||||
/* Middleware */
|
||||
app.use(noIndex);
|
||||
app.use(errorController);
|
||||
app.use(express.json({ limit: '3mb' }));
|
||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||
app.use(mongoSanitize());
|
||||
@@ -121,6 +120,9 @@ const startServer = async () => {
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
// Add the error controller one more time after all routes
|
||||
app.use(errorController);
|
||||
|
||||
app.use((req, res) => {
|
||||
res.set({
|
||||
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const request = require('supertest');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const mongoose = require('mongoose');
|
||||
@@ -59,6 +58,30 @@ describe('Server Configuration', () => {
|
||||
expect(response.headers['pragma']).toBe('no-cache');
|
||||
expect(response.headers['expires']).toBe('0');
|
||||
});
|
||||
|
||||
it('should return 500 for unknown errors via ErrorController', async () => {
|
||||
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
||||
|
||||
// Mock MongoDB operations to fail
|
||||
const originalFindOne = mongoose.models.User.findOne;
|
||||
const mockError = new Error('MongoDB operation failed');
|
||||
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
|
||||
throw mockError;
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await request(app).post('/api/auth/login').send({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.text).toBe('An unknown error occurred.');
|
||||
} finally {
|
||||
// Restore original function
|
||||
mongoose.models.User.findOne = originalFindOne;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
EndpointURLs,
|
||||
@@ -14,7 +15,6 @@ const openAI = require('~/server/services/Endpoints/openAI');
|
||||
const agents = require('~/server/services/Endpoints/agents');
|
||||
const custom = require('~/server/services/Endpoints/custom');
|
||||
const google = require('~/server/services/Endpoints/google');
|
||||
const { handleError } = require('~/server/utils');
|
||||
|
||||
const buildFunction = {
|
||||
[EModelEndpoint.openAI]: openAI.buildOptions,
|
||||
|
||||
@@ -18,7 +18,6 @@ const message = 'Your account has been temporarily banned due to violations of o
|
||||
* @function
|
||||
* @param {Object} req - Express Request object.
|
||||
* @param {Object} res - Express Response object.
|
||||
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
|
||||
*
|
||||
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
|
||||
*/
|
||||
@@ -135,6 +134,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
return await banResponse(req, res);
|
||||
} catch (error) {
|
||||
logger.error('Error in checkBan middleware:', error);
|
||||
return next(error);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
95
api/server/middleware/limiters/forkLimiters.js
Normal file
95
api/server/middleware/limiters/forkLimiters.js
Normal file
@@ -0,0 +1,95 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
|
||||
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
|
||||
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
|
||||
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
|
||||
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
|
||||
|
||||
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
|
||||
const forkIpMax = FORK_IP_MAX;
|
||||
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
|
||||
|
||||
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
|
||||
const forkUserMax = FORK_USER_MAX;
|
||||
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
forkIpWindowMs,
|
||||
forkIpMax,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowMs,
|
||||
forkUserMax,
|
||||
forkUserWindowInMinutes,
|
||||
forkViolationScore: FORK_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createForkHandler = (ip = true) => {
|
||||
const {
|
||||
forkIpMax,
|
||||
forkUserMax,
|
||||
forkViolationScore,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? forkIpMax : forkUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createForkLimiters = () => {
|
||||
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ipLimiterOptions = {
|
||||
windowMs: forkIpWindowMs,
|
||||
max: forkIpMax,
|
||||
handler: createForkHandler(),
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: forkUserWindowMs,
|
||||
max: forkUserMax,
|
||||
handler: createForkHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for fork rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'fork_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'fork_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const forkIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const forkUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { forkIpLimiter, forkUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = { createForkLimiters };
|
||||
@@ -1,16 +1,17 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
||||
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
|
||||
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
|
||||
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
|
||||
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
|
||||
|
||||
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
|
||||
const importIpMax = IMPORT_IP_MAX;
|
||||
@@ -27,12 +28,18 @@ const getEnvironmentVariables = () => {
|
||||
importUserWindowMs,
|
||||
importUserMax,
|
||||
importUserWindowInMinutes,
|
||||
importViolationScore: IMPORT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createImportHandler = (ip = true) => {
|
||||
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
const {
|
||||
importIpMax,
|
||||
importUserMax,
|
||||
importViolationScore,
|
||||
importIpWindowInMinutes,
|
||||
importUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
@@ -43,7 +50,7 @@ const createImportHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, importViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ const createSTTLimiters = require('./sttLimiters');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const importLimiters = require('./importLimiters');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const forkLimiters = require('./forkLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const toolCallLimiter = require('./toolCallLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
@@ -14,6 +15,7 @@ module.exports = {
|
||||
...uploadLimiters,
|
||||
...importLimiters,
|
||||
...messageLimiters,
|
||||
...forkLimiters,
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
toolCallLimiter,
|
||||
|
||||
@@ -11,6 +11,7 @@ const {
|
||||
MESSAGE_IP_WINDOW = 1,
|
||||
MESSAGE_USER_MAX = 40,
|
||||
MESSAGE_USER_WINDOW = 1,
|
||||
MESSAGE_VIOLATION_SCORE: score,
|
||||
} = process.env;
|
||||
|
||||
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
|
||||
@@ -39,7 +40,7 @@ const createHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
@@ -27,11 +28,12 @@ const getEnvironmentVariables = () => {
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
sttViolationScore: STT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -43,7 +45,7 @@ const createSTTHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, sttViolationScore);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -6,6 +6,8 @@ const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||
const errorMessage = {
|
||||
@@ -15,7 +17,7 @@ const handler = async (req, res) => {
|
||||
windowInMinutes: 1,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||
};
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
@@ -27,11 +28,12 @@ const getEnvironmentVariables = () => {
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
ttsViolationScore: TTS_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -43,7 +45,7 @@ const createTTSHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, ttsViolationScore);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ const getEnvironmentVariables = () => {
|
||||
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
|
||||
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
|
||||
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
|
||||
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
|
||||
|
||||
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
|
||||
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
|
||||
@@ -27,6 +28,7 @@ const getEnvironmentVariables = () => {
|
||||
fileUploadUserWindowMs,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -36,6 +38,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
fileUploadIpWindowInMinutes,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -47,7 +50,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
|
||||
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const uap = require('ua-parser-js');
|
||||
const { handleError } = require('../utils');
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { logViolation } = require('../../cache');
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { handleError } = require('../utils');
|
||||
const { handleError } = require('@librechat/api');
|
||||
|
||||
function validateEndpoint(req, res, next) {
|
||||
const { endpoint: _endpoint, endpointType } = req.body;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const { handleError } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
/**
|
||||
* Validates the model of the request.
|
||||
|
||||
162
api/server/routes/__tests__/static.spec.js
Normal file
162
api/server/routes/__tests__/static.spec.js
Normal file
@@ -0,0 +1,162 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const zlib = require('zlib');
|
||||
|
||||
// Create test setup
|
||||
const mockTestDir = path.join(__dirname, 'test-static-route');
|
||||
|
||||
// Mock the paths module to point to our test directory
|
||||
jest.mock('~/config/paths', () => ({
|
||||
imageOutput: mockTestDir,
|
||||
}));
|
||||
|
||||
describe('Static Route Integration', () => {
|
||||
let app;
|
||||
let staticRoute;
|
||||
let testDir;
|
||||
let testImagePath;
|
||||
|
||||
beforeAll(() => {
|
||||
// Create a test directory and files
|
||||
testDir = mockTestDir;
|
||||
testImagePath = path.join(testDir, 'test-image.jpg');
|
||||
|
||||
if (!fs.existsSync(testDir)) {
|
||||
fs.mkdirSync(testDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Create a test image file
|
||||
fs.writeFileSync(testImagePath, 'fake-image-data');
|
||||
|
||||
// Create a gzipped version of the test image (for gzip scanning tests)
|
||||
fs.writeFileSync(testImagePath + '.gz', zlib.gzipSync('fake-image-data'));
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Clean up test files
|
||||
if (fs.existsSync(testDir)) {
|
||||
fs.rmSync(testDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
// Helper function to set up static route with specific config
|
||||
const setupStaticRoute = (skipGzipScan = false) => {
|
||||
if (skipGzipScan) {
|
||||
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
|
||||
} else {
|
||||
process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN = 'true';
|
||||
}
|
||||
|
||||
staticRoute = require('../static');
|
||||
app.use('/images', staticRoute);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear the module cache to get fresh imports
|
||||
jest.resetModules();
|
||||
|
||||
app = express();
|
||||
|
||||
// Clear environment variables
|
||||
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
|
||||
delete process.env.NODE_ENV;
|
||||
});
|
||||
|
||||
describe('route functionality', () => {
|
||||
it('should serve static image files', async () => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
|
||||
it('should return 404 for non-existent files', async () => {
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/nonexistent.jpg');
|
||||
expect(response.status).toBe(404);
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache behavior', () => {
|
||||
it('should set cache headers for images in production', async () => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
});
|
||||
|
||||
it('should not set cache headers in development', async () => {
|
||||
process.env.NODE_ENV = 'development';
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
// Our middleware should not set the production cache-control header in development
|
||||
expect(response.headers['cache-control']).not.toBe('public, max-age=172800, s-maxage=86400');
|
||||
});
|
||||
});
|
||||
|
||||
describe('gzip compression behavior', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should serve gzipped files when gzip scanning is enabled', async () => {
|
||||
setupStaticRoute(false); // Enable gzip scanning
|
||||
|
||||
const response = await request(app)
|
||||
.get('/images/test-image.jpg')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
|
||||
it('should not serve gzipped files when gzip scanning is disabled', async () => {
|
||||
setupStaticRoute(true); // Disable gzip scanning
|
||||
|
||||
const response = await request(app)
|
||||
.get('/images/test-image.jpg')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBeUndefined();
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
});
|
||||
|
||||
describe('path configuration', () => {
|
||||
it('should use the configured imageOutput path', async () => {
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
|
||||
it('should serve from subdirectories', async () => {
|
||||
// Create a subdirectory with a file
|
||||
const subDir = path.join(testDir, 'thumbs');
|
||||
fs.mkdirSync(subDir, { recursive: true });
|
||||
const thumbPath = path.join(subDir, 'thumb.jpg');
|
||||
fs.writeFileSync(thumbPath, 'thumbnail-data');
|
||||
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/thumbs/thumb.jpg').expect(200);
|
||||
|
||||
expect(response.body.toString()).toBe('thumbnail-data');
|
||||
|
||||
// Clean up
|
||||
fs.rmSync(subDir, { recursive: true, force: true });
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,16 +1,17 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
||||
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
|
||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { importConversations } = require('~/server/utils/import');
|
||||
const { createImportLimiters } = require('~/server/middleware');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const assistantClients = {
|
||||
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
|
||||
@@ -43,6 +44,7 @@ router.get('/', async (req, res) => {
|
||||
});
|
||||
res.status(200).json(result);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching conversations', error);
|
||||
res.status(500).json({ error: 'Error fetching conversations' });
|
||||
}
|
||||
});
|
||||
@@ -156,6 +158,7 @@ router.post('/update', async (req, res) => {
|
||||
});
|
||||
|
||||
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
||||
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
|
||||
const upload = multer({ storage: storage, fileFilter: importFileFilter });
|
||||
|
||||
/**
|
||||
@@ -189,7 +192,7 @@ router.post(
|
||||
* @param {express.Response<TForkConvoResponse>} res - Express response object.
|
||||
* @returns {Promise<void>} - The response after forking the conversation.
|
||||
*/
|
||||
router.post('/fork', async (req, res) => {
|
||||
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
|
||||
try {
|
||||
/** @type {TForkConvoRequest} */
|
||||
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;
|
||||
|
||||
282
api/server/routes/files/files.agents.test.js
Normal file
282
api/server/routes/files/files.agents.test.js
Normal file
@@ -0,0 +1,282 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
processFileUpload: jest.fn(),
|
||||
processAgentFileUpload: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/assistants/helpers', () => ({
|
||||
getOpenAIClient: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3FileUrls: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const { createFile } = require('~/models/File');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
|
||||
// Import the router after mocks
|
||||
const router = require('./files');
|
||||
|
||||
describe('File Routes - Agent Files Endpoint', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let authorId;
|
||||
let otherUserId;
|
||||
let agentId;
|
||||
let fileId1;
|
||||
let fileId2;
|
||||
let fileId3;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
// Initialize models
|
||||
require('~/db/models');
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: otherUserId || 'default-user' };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
|
||||
app.use('/files', router);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Clear database
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
otherUserId = new mongoose.Types.ObjectId().toString();
|
||||
agentId = uuidv4();
|
||||
fileId1 = uuidv4();
|
||||
fileId2 = uuidv4();
|
||||
fileId3 = uuidv4();
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId1,
|
||||
filename: 'agent-file1.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId1}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId2,
|
||||
filename: 'agent-file2.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId2}`,
|
||||
bytes: 2048,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: fileId3,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${fileId3}`,
|
||||
bytes: 512,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Share the agent globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
});
|
||||
|
||||
describe('GET /files/agent/:agent_id', () => {
|
||||
it('should return files accessible through the agent for non-author', async () => {
|
||||
const response = await request(app).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(2); // Only agent files, not user-owned files
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).not.toContain(fileId3); // User's own file not included
|
||||
});
|
||||
|
||||
it('should return 400 when agent_id is not provided', async () => {
|
||||
const response = await request(app).get('/files/agent/');
|
||||
|
||||
expect(response.status).toBe(404); // Express returns 404 for missing route parameter
|
||||
});
|
||||
|
||||
it('should return empty array for non-existent agent', async () => {
|
||||
const response = await request(app).get('/files/agent/non-existent-agent');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual([]); // Empty array for non-existent agent
|
||||
});
|
||||
|
||||
it('should return empty array when agent is not collaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const nonCollabAgentId = uuidv4();
|
||||
await createAgent({
|
||||
id: nonCollabAgentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Share it globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: nonCollabAgentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
|
||||
const response = await request(app).get(`/files/agent/${nonCollabAgentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual([]); // Empty array when not collaborative
|
||||
});
|
||||
|
||||
it('should return agent files for agent author', async () => {
|
||||
// Create a new app instance with author authentication
|
||||
const authorApp = express();
|
||||
authorApp.use(express.json());
|
||||
authorApp.use((req, res, next) => {
|
||||
req.user = { id: authorId };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
authorApp.use('/files', router);
|
||||
|
||||
const response = await request(authorApp).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(2); // Agent files for author
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).not.toContain(fileId3); // User's own file not included
|
||||
});
|
||||
|
||||
it('should return files uploaded by other users to shared agent for author', async () => {
|
||||
// Create a file uploaded by another user
|
||||
const otherUserFileId = uuidv4();
|
||||
const anotherUserId = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
await createFile({
|
||||
user: anotherUserId,
|
||||
file_id: otherUserFileId,
|
||||
filename: 'other-user-file.txt',
|
||||
filepath: `/uploads/${anotherUserId}/${otherUserFileId}`,
|
||||
bytes: 4096,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Update agent to include the file uploaded by another user
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2, otherUserFileId],
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Create app instance with author authentication
|
||||
const authorApp = express();
|
||||
authorApp.use(express.json());
|
||||
authorApp.use((req, res, next) => {
|
||||
req.user = { id: authorId };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
authorApp.use('/files', router);
|
||||
|
||||
const response = await request(authorApp).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(3); // Including file from another user
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).toContain(otherUserFileId); // File uploaded by another user
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,6 +5,7 @@ const {
|
||||
Time,
|
||||
isUUID,
|
||||
CacheKeys,
|
||||
Constants,
|
||||
FileSources,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
@@ -16,11 +17,12 @@ const {
|
||||
processDeleteRequest,
|
||||
processAgentFileUpload,
|
||||
} = require('~/server/services/Files/process');
|
||||
const { getFiles, batchUpdateFiles, hasAccessToFilesViaAgent } = require('~/models/File');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
|
||||
const { getFiles, batchUpdateFiles } = require('~/models/File');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -50,6 +52,68 @@ router.get('/', async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Get files specific to an agent
|
||||
* @route GET /files/agent/:agent_id
|
||||
* @param {string} agent_id - The agent ID to get files for
|
||||
* @returns {Promise<TFile[]>} Array of files attached to the agent
|
||||
*/
|
||||
router.get('/agent/:agent_id', async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!agent_id) {
|
||||
return res.status(400).json({ error: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
// Get the agent to check ownership and attached files
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
|
||||
if (!agent) {
|
||||
// No agent found, return empty array
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
|
||||
// Check if user has access to the agent
|
||||
if (agent.author.toString() !== userId) {
|
||||
// Non-authors need the agent to be globally shared and collaborative
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString()) ||
|
||||
!agent.isCollaborative
|
||||
) {
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all file IDs from agent's tool resources
|
||||
const agentFileIds = [];
|
||||
if (agent.tool_resources) {
|
||||
for (const [, resource] of Object.entries(agent.tool_resources)) {
|
||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||
agentFileIds.push(...resource.file_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no files attached to agent, return empty array
|
||||
if (agentFileIds.length === 0) {
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
|
||||
// Get only the files attached to this agent
|
||||
const files = await getFiles({ file_id: { $in: agentFileIds } }, null, { text: 0 });
|
||||
|
||||
res.status(200).json(files);
|
||||
} catch (error) {
|
||||
logger.error('[/files/agent/:agent_id] Error fetching agent files:', error);
|
||||
res.status(500).json({ error: 'Failed to fetch agent files' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/config', async (req, res) => {
|
||||
try {
|
||||
res.status(200).json(req.app.locals.fileConfig);
|
||||
@@ -86,11 +150,62 @@ router.delete('/', async (req, res) => {
|
||||
|
||||
const fileIds = files.map((file) => file.file_id);
|
||||
const dbFiles = await getFiles({ file_id: { $in: fileIds } });
|
||||
const unauthorizedFiles = dbFiles.filter((file) => file.user.toString() !== req.user.id);
|
||||
|
||||
const ownedFiles = [];
|
||||
const nonOwnedFiles = [];
|
||||
const fileMap = new Map();
|
||||
|
||||
for (const file of dbFiles) {
|
||||
fileMap.set(file.file_id, file);
|
||||
if (file.user.toString() === req.user.id) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
nonOwnedFiles.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
// If all files are owned by the user, no need for further checks
|
||||
if (nonOwnedFiles.length === 0) {
|
||||
await processDeleteRequest({ req, files: ownedFiles });
|
||||
logger.debug(
|
||||
`[/files] Files deleted successfully: ${ownedFiles
|
||||
.filter((f) => f.file_id)
|
||||
.map((f) => f.file_id)
|
||||
.join(', ')}`,
|
||||
);
|
||||
res.status(200).json({ message: 'Files deleted successfully' });
|
||||
return;
|
||||
}
|
||||
|
||||
// Check access for non-owned files
|
||||
let authorizedFiles = [...ownedFiles];
|
||||
let unauthorizedFiles = [];
|
||||
|
||||
if (req.body.agent_id && nonOwnedFiles.length > 0) {
|
||||
// Batch check access for all non-owned files
|
||||
const nonOwnedFileIds = nonOwnedFiles.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(
|
||||
req.user.id,
|
||||
nonOwnedFileIds,
|
||||
req.body.agent_id,
|
||||
);
|
||||
|
||||
// Separate authorized and unauthorized files
|
||||
for (const file of nonOwnedFiles) {
|
||||
if (accessMap.get(file.file_id)) {
|
||||
authorizedFiles.push(file);
|
||||
} else {
|
||||
unauthorizedFiles.push(file);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No agent context, all non-owned files are unauthorized
|
||||
unauthorizedFiles = nonOwnedFiles;
|
||||
}
|
||||
|
||||
if (unauthorizedFiles.length > 0) {
|
||||
return res.status(403).json({
|
||||
message: 'You can only delete your own files',
|
||||
message: 'You can only delete files you have access to',
|
||||
unauthorizedFiles: unauthorizedFiles.map((f) => f.file_id),
|
||||
});
|
||||
}
|
||||
@@ -131,10 +246,10 @@ router.delete('/', async (req, res) => {
|
||||
.json({ message: 'File associations removed successfully from Azure Assistant' });
|
||||
}
|
||||
|
||||
await processDeleteRequest({ req, files: dbFiles });
|
||||
await processDeleteRequest({ req, files: authorizedFiles });
|
||||
|
||||
logger.debug(
|
||||
`[/files] Files deleted successfully: ${files
|
||||
`[/files] Files deleted successfully: ${authorizedFiles
|
||||
.filter((f) => f.file_id)
|
||||
.map((f) => f.file_id)
|
||||
.join(', ')}`,
|
||||
|
||||
302
api/server/routes/files/files.test.js
Normal file
302
api/server/routes/files/files.test.js
Normal file
@@ -0,0 +1,302 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
processFileUpload: jest.fn(),
|
||||
processAgentFileUpload: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/assistants/helpers', () => ({
|
||||
getOpenAIClient: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3FileUrls: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const { createFile } = require('~/models/File');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
|
||||
// Import the router after mocks
|
||||
const router = require('./files');
|
||||
|
||||
describe('File Routes - Delete with Agent Access', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let authorId;
|
||||
let otherUserId;
|
||||
let agentId;
|
||||
let fileId;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
// Initialize models
|
||||
require('~/db/models');
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: otherUserId || 'default-user' };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
|
||||
app.use('/files', router);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Clear database
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
otherUserId = new mongoose.Types.ObjectId().toString();
|
||||
fileId = uuidv4();
|
||||
|
||||
// Create a file owned by the author
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: 'test.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with the file attached
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
agentId = agent.id;
|
||||
|
||||
// Share the agent globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
});
|
||||
|
||||
describe('DELETE /files', () => {
|
||||
it('should allow deleting files owned by the user', async () => {
|
||||
// Create a file owned by the current user
|
||||
const userFileId = uuidv4();
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: userFileId,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
files: [
|
||||
{
|
||||
file_id: userFileId,
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.message).toBe('Files deleted successfully');
|
||||
expect(processDeleteRequest).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should prevent deleting files not owned by user without agent context', async () => {
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(fileId);
|
||||
expect(processDeleteRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow deleting files accessible through shared agent', async () => {
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.message).toBe('Files deleted successfully');
|
||||
expect(processDeleteRequest).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should prevent deleting files not attached to the specified agent', async () => {
|
||||
// Create another file not attached to the agent
|
||||
const unattachedFileId = uuidv4();
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: unattachedFileId,
|
||||
filename: 'unattached.txt',
|
||||
filepath: `/uploads/${authorId}/${unattachedFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: unattachedFileId,
|
||||
filepath: `/uploads/${authorId}/${unattachedFileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(unattachedFileId);
|
||||
});
|
||||
|
||||
it('should handle mixed authorized and unauthorized files', async () => {
|
||||
// Create a file owned by the current user
|
||||
const userFileId = uuidv4();
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: userFileId,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an unauthorized file
|
||||
const unauthorizedFileId = uuidv4();
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: unauthorizedFileId,
|
||||
filename: 'unauthorized.txt',
|
||||
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId, // Authorized through agent
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
{
|
||||
file_id: userFileId, // Owned by user
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
},
|
||||
{
|
||||
file_id: unauthorizedFileId, // Not authorized
|
||||
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(unauthorizedFileId);
|
||||
expect(response.body.unauthorizedFiles).not.toContain(fileId);
|
||||
expect(response.body.unauthorizedFiles).not.toContain(userFileId);
|
||||
});
|
||||
|
||||
it('should prevent deleting files when agent is not collaborative', async () => {
|
||||
// Update the agent to be non-collaborative
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { isCollaborative: false });
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(fileId);
|
||||
expect(processDeleteRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -477,7 +477,9 @@ describe('Multer Configuration', () => {
|
||||
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
||||
} catch (error) {
|
||||
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
||||
expect(error.code).toBe('EACCES');
|
||||
// On Linux, this typically returns EACCES (permission denied)
|
||||
// On macOS/Darwin, this returns ENOENT (no such file or directory)
|
||||
expect(['EACCES', 'ENOENT']).toContain(error.code);
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -172,40 +172,68 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
|
||||
/**
|
||||
* PATCH /memories/:key
|
||||
* Updates the value of an existing memory entry for the authenticated user.
|
||||
* Body: { value: string }
|
||||
* Body: { key?: string, value: string }
|
||||
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
||||
*/
|
||||
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
const { value } = req.body || {};
|
||||
const { key: urlKey } = req.params;
|
||||
const { key: bodyKey, value } = req.body || {};
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
// Use the key from the body if provided, otherwise use the key from the URL
|
||||
const newKey = bodyKey || urlKey;
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
const existingMemory = memories.find((m) => m.key === key);
|
||||
const existingMemory = memories.find((m) => m.key === urlKey);
|
||||
|
||||
if (!existingMemory) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
// If the key is changing, we need to handle it specially
|
||||
if (newKey !== urlKey) {
|
||||
const keyExists = memories.find((m) => m.key === newKey);
|
||||
if (keyExists) {
|
||||
return res.status(409).json({ error: 'Memory with this key already exists.' });
|
||||
}
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
const createResult = await createMemory({
|
||||
userId: req.user.id,
|
||||
key: newKey,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!createResult.ok) {
|
||||
return res.status(500).json({ error: 'Failed to create new memory.' });
|
||||
}
|
||||
|
||||
const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey });
|
||||
if (!deleteResult.ok) {
|
||||
return res.status(500).json({ error: 'Failed to delete old memory.' });
|
||||
}
|
||||
} else {
|
||||
// Key is not changing, just update the value
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key: newKey,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
}
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === key);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === newKey);
|
||||
|
||||
res.json({ updated: true, memory: updatedMemory });
|
||||
} catch (error) {
|
||||
|
||||
@@ -235,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) =
|
||||
return res.status(400).json({ error: 'Content part not found' });
|
||||
}
|
||||
|
||||
if (updatedContent[index].type !== ContentTypes.TEXT) {
|
||||
const currentPartType = updatedContent[index].type;
|
||||
if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) {
|
||||
return res.status(400).json({ error: 'Cannot update non-text content' });
|
||||
}
|
||||
|
||||
const oldText = updatedContent[index].text;
|
||||
updatedContent[index] = { type: ContentTypes.TEXT, text };
|
||||
const oldText = updatedContent[index][currentPartType];
|
||||
updatedContent[index] = { type: currentPartType, [currentPartType]: text };
|
||||
|
||||
let tokenCount = message.tokenCount;
|
||||
if (tokenCount !== undefined) {
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
const express = require('express');
|
||||
const staticCache = require('../utils/staticCache');
|
||||
const paths = require('~/config/paths');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
const skipGzipScan = !isEnabled(process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN);
|
||||
|
||||
const router = express.Router();
|
||||
router.use(staticCache(paths.imageOutput));
|
||||
router.use(staticCache(paths.imageOutput, { skipGzipScan }));
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -152,12 +152,14 @@ describe('AppService', () => {
|
||||
filteredTools: undefined,
|
||||
includedTools: undefined,
|
||||
webSearch: {
|
||||
safeSearch: 1,
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
cohereApiKey: '${COHERE_API_KEY}',
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngApiKey: '${SEARXNG_API_KEY}',
|
||||
firecrawlApiKey: '${FIRECRAWL_API_KEY}',
|
||||
firecrawlApiUrl: '${FIRECRAWL_API_URL}',
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
safeSearch: 1,
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}',
|
||||
},
|
||||
memory: undefined,
|
||||
agents: {
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
|
||||
const { normalizeEndpointName } = require('~/server/utils');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const { getCachedTools } = require('./getCachedTools');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
@@ -55,46 +54,48 @@ const getCustomEndpointConfig = async (endpoint) => {
|
||||
);
|
||||
};
|
||||
|
||||
async function createGetMCPAuthMap() {
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId
|
||||
* @param {GenericTool[]} [params.tools]
|
||||
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
|
||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
||||
*/
|
||||
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
|
||||
try {
|
||||
if (!tools || tools.length === 0) {
|
||||
return;
|
||||
}
|
||||
const appTools = await getCachedTools({
|
||||
userId,
|
||||
});
|
||||
return await getUserMCPAuthMap({
|
||||
tools,
|
||||
userId,
|
||||
appTools,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
async function hasCustomUserVars() {
|
||||
const customConfig = await getCustomConfig();
|
||||
const mcpServers = customConfig?.mcpServers;
|
||||
const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
|
||||
if (!hasCustomUserVars) {
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {GenericTool[]} [params.tools]
|
||||
* @param {string} params.userId
|
||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
||||
*/
|
||||
return async function ({ tools, userId }) {
|
||||
try {
|
||||
if (!tools || tools.length === 0) {
|
||||
return;
|
||||
}
|
||||
const appTools = await getCachedTools({
|
||||
userId,
|
||||
});
|
||||
return await getUserMCPAuthMap({
|
||||
tools,
|
||||
userId,
|
||||
appTools,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
};
|
||||
return Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getMCPAuthMap,
|
||||
getCustomConfig,
|
||||
getBalanceConfig,
|
||||
createGetMCPAuthMap,
|
||||
hasCustomUserVars,
|
||||
getCustomEndpointConfig,
|
||||
};
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
|
||||
const {
|
||||
CacheKeys,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
orderEndpointsConfig,
|
||||
defaultAgentCapabilities,
|
||||
} = require('librechat-data-provider');
|
||||
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
|
||||
const loadConfigEndpoints = require('./loadConfigEndpoints');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
@@ -80,8 +86,12 @@ async function getEndpointsConfig(req) {
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint);
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
const capabilities =
|
||||
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
|
||||
? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [])
|
||||
: defaultAgentCapabilities;
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { loadServiceKey, isUserProvided } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { config } = require('./EndpointService');
|
||||
|
||||
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config;
|
||||
@@ -11,36 +11,28 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
|
||||
* @param {Express.Request} req - The request object
|
||||
*/
|
||||
async function loadAsyncEndpoints(req) {
|
||||
let i = 0;
|
||||
let serviceKey, googleUserProvides;
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE_PATH ||
|
||||
path.join(__dirname, '../../..', 'data', 'auth.json');
|
||||
|
||||
try {
|
||||
if (process.env.GOOGLE_SERVICE_KEY_FILE_PATH) {
|
||||
const absolutePath = path.isAbsolute(serviceKeyPath)
|
||||
? serviceKeyPath
|
||||
: path.resolve(serviceKeyPath);
|
||||
const fileContent = fs.readFileSync(absolutePath, 'utf8');
|
||||
serviceKey = JSON.parse(fileContent);
|
||||
} else {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
}
|
||||
} catch {
|
||||
if (i === 0) {
|
||||
i++;
|
||||
/** Check if GOOGLE_KEY is provided at all(including 'user_provided') */
|
||||
const isGoogleKeyProvided = googleKey && googleKey.trim() !== '';
|
||||
|
||||
if (isGoogleKeyProvided) {
|
||||
/** If GOOGLE_KEY is provided, check if it's user_provided */
|
||||
googleUserProvides = isUserProvided(googleKey);
|
||||
} else {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json');
|
||||
|
||||
try {
|
||||
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||
} catch (error) {
|
||||
logger.error('Error loading service key', error);
|
||||
serviceKey = null;
|
||||
}
|
||||
}
|
||||
|
||||
if (isUserProvided(googleKey)) {
|
||||
googleUserProvides = true;
|
||||
if (i <= 1) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
|
||||
const google = serviceKey || isGoogleKeyProvided ? { userProvide: googleUserProvides } : false;
|
||||
|
||||
const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
|
||||
const gptPlugins =
|
||||
|
||||
@@ -11,8 +11,8 @@ const {
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getFiles, getToolFilesByIds } = require('~/models/File');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
@@ -82,10 +82,11 @@ const initializeAgent = async ({
|
||||
attachments: currentFiles,
|
||||
tool_resources: agent.tool_resources,
|
||||
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
|
||||
agentId: agent.id,
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } =
|
||||
const { tools: structuredTools, toolContextMap } =
|
||||
(await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
@@ -140,6 +141,24 @@ const initializeAgent = async ({
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').GenericTool[]} */
|
||||
let tools = options.tools?.length ? options.tools : structuredTools;
|
||||
if (
|
||||
(agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) &&
|
||||
options.tools?.length &&
|
||||
structuredTools?.length
|
||||
) {
|
||||
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
|
||||
} else if (
|
||||
(agent.provider === Providers.OPENAI ||
|
||||
agent.provider === Providers.AZURE ||
|
||||
agent.provider === Providers.ANTHROPIC) &&
|
||||
options.tools?.length &&
|
||||
structuredTools?.length
|
||||
) {
|
||||
tools = structuredTools.concat(options.tools);
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = { ...options.llmConfig };
|
||||
if (options.configOptions) {
|
||||
@@ -162,10 +181,10 @@ const initializeAgent = async ({
|
||||
|
||||
return {
|
||||
...agent,
|
||||
tools,
|
||||
attachments,
|
||||
resendFiles,
|
||||
toolContextMap,
|
||||
tools,
|
||||
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -78,7 +78,17 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
const tools = [];
|
||||
|
||||
if (mergedOptions.web_search) {
|
||||
tools.push({
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
tools,
|
||||
/** @type {AnthropicClientOptions} */
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { getGoogleConfig, isEnabled } = require('@librechat/api');
|
||||
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
|
||||
const { getGoogleConfig, isEnabled, loadServiceKey } = require('@librechat/api');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { GoogleClient } = require('~/app');
|
||||
|
||||
@@ -18,21 +17,24 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
|
||||
let serviceKey = {};
|
||||
|
||||
try {
|
||||
if (process.env.GOOGLE_SERVICE_KEY_FILE_PATH) {
|
||||
/** Check if GOOGLE_KEY is provided at all (including 'user_provided') */
|
||||
const isGoogleKeyProvided =
|
||||
(GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null);
|
||||
|
||||
if (!isGoogleKeyProvided) {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
try {
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE_PATH ||
|
||||
path.join(__dirname, '../../../../..', 'data', 'auth.json');
|
||||
const absolutePath = path.isAbsolute(serviceKeyPath)
|
||||
? serviceKeyPath
|
||||
: path.resolve(serviceKeyPath);
|
||||
const fileContent = fs.readFileSync(absolutePath, 'utf8');
|
||||
serviceKey = JSON.parse(fileContent);
|
||||
} else {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE ||
|
||||
path.join(__dirname, '../../../..', 'data', 'auth.json');
|
||||
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||
if (!serviceKey) {
|
||||
serviceKey = {};
|
||||
}
|
||||
} catch (_e) {
|
||||
// Service key loading failed, but that's okay if not required
|
||||
serviceKey = {};
|
||||
}
|
||||
} catch (_e) {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
const credentials = isUserProvided
|
||||
|
||||
@@ -7,6 +7,16 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
|
||||
/** Check if the provider is a known custom provider
|
||||
* @param {string | undefined} [provider] - The provider string
|
||||
* @returns {boolean} - True if the provider is a known custom provider, false otherwise
|
||||
*/
|
||||
function isKnownCustomProvider(provider) {
|
||||
return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes(
|
||||
provider?.toLowerCase() || '',
|
||||
);
|
||||
}
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
@@ -46,6 +56,13 @@ async function getProviderConfig(provider) {
|
||||
overrideProvider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (isKnownCustomProvider(overrideProvider || provider) && !customEndpointConfig) {
|
||||
customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
getOptions,
|
||||
overrideProvider,
|
||||
|
||||
@@ -152,6 +152,7 @@ async function getSessionInfo(fileIdentifier, apiKey) {
|
||||
* @param {Object} options
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Agent['tool_resources']} options.tool_resources
|
||||
* @param {string} [options.agentId] - The agent ID for file access control
|
||||
* @param {string} apiKey
|
||||
* @returns {Promise<{
|
||||
* files: Array<{ id: string; session_id: string; name: string }>,
|
||||
@@ -159,11 +160,18 @@ async function getSessionInfo(fileIdentifier, apiKey) {
|
||||
* }>}
|
||||
*/
|
||||
const primeFiles = async (options, apiKey) => {
|
||||
const { tool_resources } = options;
|
||||
const { tool_resources, req, agentId } = options;
|
||||
const file_ids = tool_resources?.[EToolResources.execute_code]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.execute_code]?.files ?? [];
|
||||
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
|
||||
const files = [];
|
||||
const sessions = new Map();
|
||||
|
||||
@@ -2,16 +2,16 @@ const { z } = require('zod');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Time, CacheKeys, StepTypes } = require('librechat-data-provider');
|
||||
const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api');
|
||||
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
|
||||
const { Constants, ContentTypes, isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
Constants,
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
convertJsonSchemaToZod,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
normalizeServerName,
|
||||
convertWithResolvedRefs,
|
||||
} = require('@librechat/api');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
@@ -113,7 +113,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
let schema = convertJsonSchemaToZod(parameters, {
|
||||
let schema = convertWithResolvedRefs(parameters, {
|
||||
allowEmptyObject: !isGoogle,
|
||||
transformOneOfAnyOf: true,
|
||||
});
|
||||
|
||||
@@ -44,6 +44,9 @@ async function initializeMCP(app) {
|
||||
await mcpManager.mapAvailableTools(toolsCopy, flowManager);
|
||||
await setCachedTools(toolsCopy, { isGlobal: true });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug('Cleared tools array cache after MCP initialization');
|
||||
logger.info('MCP servers initialized successfully');
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize MCP servers:', error);
|
||||
|
||||
407
api/server/utils/__tests__/staticCache.spec.js
Normal file
407
api/server/utils/__tests__/staticCache.spec.js
Normal file
@@ -0,0 +1,407 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const zlib = require('zlib');
|
||||
const staticCache = require('../staticCache');
|
||||
|
||||
describe('staticCache', () => {
|
||||
let app;
|
||||
let testDir;
|
||||
let testFile;
|
||||
let indexFile;
|
||||
let manifestFile;
|
||||
let swFile;
|
||||
|
||||
beforeAll(() => {
|
||||
// Create a test directory and files
|
||||
testDir = path.join(__dirname, 'test-static');
|
||||
if (!fs.existsSync(testDir)) {
|
||||
fs.mkdirSync(testDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Create test files
|
||||
testFile = path.join(testDir, 'test.js');
|
||||
indexFile = path.join(testDir, 'index.html');
|
||||
manifestFile = path.join(testDir, 'manifest.json');
|
||||
swFile = path.join(testDir, 'sw.js');
|
||||
|
||||
const jsContent = 'console.log("test");';
|
||||
const htmlContent = '<html><body>Test</body></html>';
|
||||
const jsonContent = '{"name": "test"}';
|
||||
const swContent = 'self.addEventListener("install", () => {});';
|
||||
|
||||
fs.writeFileSync(testFile, jsContent);
|
||||
fs.writeFileSync(indexFile, htmlContent);
|
||||
fs.writeFileSync(manifestFile, jsonContent);
|
||||
fs.writeFileSync(swFile, swContent);
|
||||
|
||||
// Create gzipped versions of some files
|
||||
fs.writeFileSync(testFile + '.gz', zlib.gzipSync(jsContent));
|
||||
fs.writeFileSync(path.join(testDir, 'test.css'), 'body { color: red; }');
|
||||
fs.writeFileSync(path.join(testDir, 'test.css.gz'), zlib.gzipSync('body { color: red; }'));
|
||||
|
||||
// Create a file that only exists in gzipped form
|
||||
fs.writeFileSync(
|
||||
path.join(testDir, 'only-gzipped.js.gz'),
|
||||
zlib.gzipSync('console.log("only gzipped");'),
|
||||
);
|
||||
|
||||
// Create a subdirectory for dist/images testing
|
||||
const distImagesDir = path.join(testDir, 'dist', 'images');
|
||||
fs.mkdirSync(distImagesDir, { recursive: true });
|
||||
fs.writeFileSync(path.join(distImagesDir, 'logo.png'), 'fake-png-data');
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Clean up test files
|
||||
if (fs.existsSync(testDir)) {
|
||||
fs.rmSync(testDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
app = express();
|
||||
|
||||
// Clear environment variables
|
||||
delete process.env.NODE_ENV;
|
||||
delete process.env.STATIC_CACHE_S_MAX_AGE;
|
||||
delete process.env.STATIC_CACHE_MAX_AGE;
|
||||
});
|
||||
describe('cache headers in production', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should set standard cache headers for regular files', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
});
|
||||
|
||||
it('should set no-cache headers for index.html', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/index.html').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
|
||||
});
|
||||
|
||||
it('should set no-cache headers for manifest.json', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/manifest.json').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
|
||||
});
|
||||
|
||||
it('should set no-cache headers for sw.js', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/sw.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
|
||||
});
|
||||
|
||||
it('should not set cache headers for /dist/images/ files', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/dist/images/logo.png').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=0');
|
||||
});
|
||||
|
||||
it('should set no-cache headers when noCache option is true', async () => {
|
||||
app.use(staticCache(testDir, { noCache: true }));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache headers in non-production', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'development';
|
||||
});
|
||||
|
||||
it('should not set cache headers in development', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
// Our middleware should not set cache-control in non-production
|
||||
// Express static might set its own default headers
|
||||
const cacheControl = response.headers['cache-control'];
|
||||
expect(cacheControl).toBe('public, max-age=0');
|
||||
});
|
||||
});
|
||||
|
||||
describe('environment variable configuration', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should use custom s-maxage from environment', async () => {
|
||||
process.env.STATIC_CACHE_S_MAX_AGE = '3600';
|
||||
|
||||
// Need to re-require to pick up new env vars
|
||||
jest.resetModules();
|
||||
const freshStaticCache = require('../staticCache');
|
||||
|
||||
app.use(freshStaticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=3600');
|
||||
});
|
||||
|
||||
it('should use custom max-age from environment', async () => {
|
||||
process.env.STATIC_CACHE_MAX_AGE = '7200';
|
||||
|
||||
// Need to re-require to pick up new env vars
|
||||
jest.resetModules();
|
||||
const freshStaticCache = require('../staticCache');
|
||||
|
||||
app.use(freshStaticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=7200, s-maxage=86400');
|
||||
});
|
||||
|
||||
it('should use both custom values from environment', async () => {
|
||||
process.env.STATIC_CACHE_S_MAX_AGE = '1800';
|
||||
process.env.STATIC_CACHE_MAX_AGE = '3600';
|
||||
|
||||
// Need to re-require to pick up new env vars
|
||||
jest.resetModules();
|
||||
const freshStaticCache = require('../staticCache');
|
||||
|
||||
app.use(freshStaticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=3600, s-maxage=1800');
|
||||
});
|
||||
});
|
||||
|
||||
describe('express-static-gzip behavior', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should serve gzipped files when client accepts gzip encoding', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/test.js')
|
||||
.set('Accept-Encoding', 'gzip, deflate')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.headers['content-type']).toMatch(/javascript/);
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
// Content should be decompressed by supertest
|
||||
expect(response.text).toBe('console.log("test");');
|
||||
});
|
||||
|
||||
it('should fall back to uncompressed files when client does not accept gzip', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/test.js')
|
||||
.set('Accept-Encoding', 'identity')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBeUndefined();
|
||||
expect(response.headers['content-type']).toMatch(/javascript/);
|
||||
expect(response.text).toBe('console.log("test");');
|
||||
});
|
||||
|
||||
it('should serve gzipped CSS files with correct content-type', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/test.css')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.headers['content-type']).toMatch(/css/);
|
||||
expect(response.text).toBe('body { color: red; }');
|
||||
});
|
||||
|
||||
it('should serve uncompressed files when no gzipped version exists', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/manifest.json')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBeUndefined();
|
||||
expect(response.headers['content-type']).toMatch(/json/);
|
||||
expect(response.text).toBe('{"name": "test"}');
|
||||
});
|
||||
|
||||
it('should handle files that only exist in gzipped form', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/only-gzipped.js')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.headers['content-type']).toMatch(/javascript/);
|
||||
expect(response.text).toBe('console.log("only gzipped");');
|
||||
});
|
||||
|
||||
it('should return 404 for gzip-only files when client does not accept gzip', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/only-gzipped.js')
|
||||
.set('Accept-Encoding', 'identity');
|
||||
expect(response.status).toBe(404);
|
||||
});
|
||||
|
||||
it('should handle cache headers correctly for gzipped content', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/test.js')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
expect(response.headers['content-type']).toMatch(/javascript/);
|
||||
});
|
||||
|
||||
it('should preserve original MIME types for gzipped files', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: false }));
|
||||
|
||||
const jsResponse = await request(app)
|
||||
.get('/test.js')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
const cssResponse = await request(app)
|
||||
.get('/test.css')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(jsResponse.headers['content-type']).toMatch(/javascript/);
|
||||
expect(cssResponse.headers['content-type']).toMatch(/css/);
|
||||
expect(jsResponse.headers['content-encoding']).toBe('gzip');
|
||||
expect(cssResponse.headers['content-encoding']).toBe('gzip');
|
||||
});
|
||||
});
|
||||
|
||||
describe('skipGzipScan option comparison', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should use express.static (no gzip) when skipGzipScan is true', async () => {
|
||||
app.use(staticCache(testDir, { skipGzipScan: true }));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/test.js')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
// Should NOT serve gzipped version even though client accepts it
|
||||
expect(response.headers['content-encoding']).toBeUndefined();
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
expect(response.text).toBe('console.log("test");');
|
||||
});
|
||||
|
||||
it('should use expressStaticGzip when skipGzipScan is false', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app)
|
||||
.get('/test.js')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
// Should serve gzipped version when client accepts it
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
expect(response.text).toBe('console.log("test");');
|
||||
});
|
||||
});
|
||||
|
||||
describe('file serving', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should serve files correctly', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/test.js').expect(200);
|
||||
|
||||
expect(response.text).toBe('console.log("test");');
|
||||
expect(response.headers['content-type']).toMatch(/javascript|text/);
|
||||
});
|
||||
|
||||
it('should return 404 for non-existent files', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/nonexistent.js');
|
||||
expect(response.status).toBe(404);
|
||||
});
|
||||
|
||||
it('should serve HTML files', async () => {
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/index.html').expect(200);
|
||||
|
||||
expect(response.text).toBe('<html><body>Test</body></html>');
|
||||
expect(response.headers['content-type']).toMatch(/html/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should handle webmanifest files', async () => {
|
||||
// Create a webmanifest file
|
||||
const webmanifestFile = path.join(testDir, 'site.webmanifest');
|
||||
fs.writeFileSync(webmanifestFile, '{"name": "test app"}');
|
||||
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/site.webmanifest').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
|
||||
|
||||
// Clean up
|
||||
fs.unlinkSync(webmanifestFile);
|
||||
});
|
||||
|
||||
it('should handle files in subdirectories', async () => {
|
||||
const subDir = path.join(testDir, 'subdir');
|
||||
fs.mkdirSync(subDir, { recursive: true });
|
||||
const subFile = path.join(subDir, 'nested.js');
|
||||
fs.writeFileSync(subFile, 'console.log("nested");');
|
||||
|
||||
app.use(staticCache(testDir));
|
||||
|
||||
const response = await request(app).get('/subdir/nested.js').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
expect(response.text).toBe('console.log("nested");');
|
||||
|
||||
// Clean up
|
||||
fs.rmSync(subDir, { recursive: true, force: true });
|
||||
});
|
||||
});
|
||||
});
|
||||
280
api/server/utils/import/importers-timestamp.spec.js
Normal file
280
api/server/utils/import/importers-timestamp.spec.js
Normal file
@@ -0,0 +1,280 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { ImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { getImporter } = require('./importers');
|
||||
|
||||
// Mock the database methods
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
bulkSaveConvos: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Message', () => ({
|
||||
bulkSaveMessages: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const mockedCacheGet = jest.fn();
|
||||
getLogStores.mockImplementation(() => ({
|
||||
get: mockedCacheGet,
|
||||
}));
|
||||
|
||||
describe('Import Timestamp Ordering', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockedCacheGet.mockResolvedValue(null);
|
||||
});
|
||||
|
||||
describe('LibreChat Import - Timestamp Issues', () => {
|
||||
test('should maintain proper timestamp order between parent and child messages', async () => {
|
||||
// Create a LibreChat export with out-of-order timestamps
|
||||
const jsonData = {
|
||||
conversationId: 'test-convo-123',
|
||||
title: 'Test Conversation',
|
||||
messages: [
|
||||
{
|
||||
messageId: 'parent-1',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Parent Message',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child
|
||||
},
|
||||
{
|
||||
messageId: 'child-1',
|
||||
parentMessageId: 'parent-1',
|
||||
text: 'Child Message',
|
||||
sender: 'assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent
|
||||
},
|
||||
{
|
||||
messageId: 'grandchild-1',
|
||||
parentMessageId: 'child-1',
|
||||
text: 'Grandchild Message',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:00:30Z', // Even earlier
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Check the actual messages stored in the builder
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent Message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child Message');
|
||||
const grandchild = savedMessages.find((msg) => msg.text === 'Grandchild Message');
|
||||
|
||||
// Verify all messages were found
|
||||
expect(parent).toBeDefined();
|
||||
expect(child).toBeDefined();
|
||||
expect(grandchild).toBeDefined();
|
||||
|
||||
// FIXED behavior: timestamps ARE corrected
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
expect(new Date(grandchild.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(child.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle complex multi-branch scenario with out-of-order timestamps', async () => {
|
||||
const jsonData = {
|
||||
conversationId: 'complex-test-123',
|
||||
title: 'Complex Test',
|
||||
messages: [
|
||||
// Branch 1: Root -> A -> B with reversed timestamps
|
||||
{
|
||||
messageId: 'root-1',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Root 1',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:03:00Z',
|
||||
},
|
||||
{
|
||||
messageId: 'a-1',
|
||||
parentMessageId: 'root-1',
|
||||
text: 'A1',
|
||||
sender: 'assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:02:00Z', // Before parent
|
||||
},
|
||||
{
|
||||
messageId: 'b-1',
|
||||
parentMessageId: 'a-1',
|
||||
text: 'B1',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:01:00Z', // Before grandparent
|
||||
},
|
||||
// Branch 2: Root -> C -> D with mixed timestamps
|
||||
{
|
||||
messageId: 'root-2',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Root 2',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:00:30Z', // Earlier than branch 1
|
||||
},
|
||||
{
|
||||
messageId: 'c-2',
|
||||
parentMessageId: 'root-2',
|
||||
text: 'C2',
|
||||
sender: 'assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:04:00Z', // Much later
|
||||
},
|
||||
{
|
||||
messageId: 'd-2',
|
||||
parentMessageId: 'c-2',
|
||||
text: 'D2',
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:02:30Z', // Between root and parent
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
|
||||
// Verify that timestamps are preserved as-is (not corrected)
|
||||
const root1 = savedMessages.find((msg) => msg.text === 'Root 1');
|
||||
const a1 = savedMessages.find((msg) => msg.text === 'A1');
|
||||
const b1 = savedMessages.find((msg) => msg.text === 'B1');
|
||||
const root2 = savedMessages.find((msg) => msg.text === 'Root 2');
|
||||
const c2 = savedMessages.find((msg) => msg.text === 'C2');
|
||||
const d2 = savedMessages.find((msg) => msg.text === 'D2');
|
||||
|
||||
// Branch 1: timestamps should now be in correct order
|
||||
expect(new Date(a1.createdAt).getTime()).toBeGreaterThan(new Date(root1.createdAt).getTime());
|
||||
expect(new Date(b1.createdAt).getTime()).toBeGreaterThan(new Date(a1.createdAt).getTime());
|
||||
|
||||
// Branch 2: all timestamps should be properly ordered
|
||||
expect(new Date(c2.createdAt).getTime()).toBeGreaterThan(new Date(root2.createdAt).getTime());
|
||||
expect(new Date(d2.createdAt).getTime()).toBeGreaterThan(new Date(c2.createdAt).getTime());
|
||||
});
|
||||
|
||||
test('recursive format should NOW have timestamp protection', async () => {
|
||||
// Create a recursive LibreChat export with out-of-order timestamps
|
||||
const jsonData = {
|
||||
conversationId: 'recursive-test-123',
|
||||
title: 'Recursive Test',
|
||||
recursive: true,
|
||||
messages: [
|
||||
{
|
||||
messageId: 'parent-1',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Parent Message',
|
||||
sender: 'User',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child
|
||||
children: [
|
||||
{
|
||||
messageId: 'child-1',
|
||||
parentMessageId: 'parent-1',
|
||||
text: 'Child Message',
|
||||
sender: 'Assistant',
|
||||
isCreatedByUser: false,
|
||||
createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent
|
||||
children: [
|
||||
{
|
||||
messageId: 'grandchild-1',
|
||||
parentMessageId: 'child-1',
|
||||
text: 'Grandchild Message',
|
||||
sender: 'User',
|
||||
isCreatedByUser: true,
|
||||
createdAt: '2023-01-01T00:00:30Z', // Even earlier
|
||||
children: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
|
||||
// Messages should be saved
|
||||
expect(savedMessages).toHaveLength(3);
|
||||
|
||||
// In recursive format, timestamps are NOT included in the saved messages
|
||||
// The saveMessage method doesn't receive createdAt for recursive imports
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent Message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child Message');
|
||||
const grandchild = savedMessages.find((msg) => msg.text === 'Grandchild Message');
|
||||
|
||||
expect(parent).toBeDefined();
|
||||
expect(child).toBeDefined();
|
||||
expect(grandchild).toBeDefined();
|
||||
|
||||
// Recursive imports NOW preserve and correct timestamps
|
||||
expect(parent.createdAt).toBeDefined();
|
||||
expect(child.createdAt).toBeDefined();
|
||||
expect(grandchild.createdAt).toBeDefined();
|
||||
|
||||
// Timestamps should be corrected to maintain proper order
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
expect(new Date(grandchild.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(child.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Comparison with Fork Functionality', () => {
|
||||
test('fork functionality correctly handles timestamp issues (for comparison)', async () => {
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
|
||||
const messagesToClone = [
|
||||
{
|
||||
messageId: 'parent',
|
||||
parentMessageId: Constants.NO_PARENT,
|
||||
text: 'Parent Message',
|
||||
createdAt: '2023-01-01T00:02:00Z', // Parent created AFTER child
|
||||
},
|
||||
{
|
||||
messageId: 'child',
|
||||
parentMessageId: 'parent',
|
||||
text: 'Child Message',
|
||||
createdAt: '2023-01-01T00:01:00Z', // Child created BEFORE parent
|
||||
},
|
||||
];
|
||||
|
||||
const importBatchBuilder = new ImportBatchBuilder('user-123');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
cloneMessagesWithTimestamps(messagesToClone, importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.messages;
|
||||
const parent = savedMessages.find((msg) => msg.text === 'Parent Message');
|
||||
const child = savedMessages.find((msg) => msg.text === 'Child Message');
|
||||
|
||||
// Fork functionality DOES correct the timestamps
|
||||
expect(new Date(child.createdAt).getTime()).toBeGreaterThan(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,7 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
@@ -107,67 +108,47 @@ async function importLibreChatConvo(
|
||||
|
||||
if (jsonData.recursive) {
|
||||
/**
|
||||
* Recursively traverse the messages tree and save each message to the database.
|
||||
* Flatten the recursive message tree into a flat array
|
||||
* @param {TMessage[]} messages
|
||||
* @param {string} parentMessageId
|
||||
* @param {TMessage[]} flatMessages
|
||||
*/
|
||||
const traverseMessages = async (messages, parentMessageId = null) => {
|
||||
const flattenMessages = (
|
||||
messages,
|
||||
parentMessageId = Constants.NO_PARENT,
|
||||
flatMessages = [],
|
||||
) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text && !message.content) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let savedMessage;
|
||||
if (message.sender?.toLowerCase() === 'user' || message.isCreatedByUser) {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
content: message.content,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
} else {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
content: message.content,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: options.model,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
}
|
||||
const flatMessage = {
|
||||
...message,
|
||||
parentMessageId: parentMessageId,
|
||||
children: undefined, // Remove children from flat structure
|
||||
};
|
||||
flatMessages.push(flatMessage);
|
||||
|
||||
if (!firstMessageDate && message.createdAt) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
|
||||
if (message.children && message.children.length > 0) {
|
||||
await traverseMessages(message.children, savedMessage.messageId);
|
||||
flattenMessages(message.children, message.messageId, flatMessages);
|
||||
}
|
||||
}
|
||||
return flatMessages;
|
||||
};
|
||||
|
||||
await traverseMessages(messagesToImport);
|
||||
const flatMessages = flattenMessages(messagesToImport);
|
||||
cloneMessagesWithTimestamps(flatMessages, importBatchBuilder);
|
||||
} else if (messagesToImport) {
|
||||
const idMapping = new Map();
|
||||
|
||||
cloneMessagesWithTimestamps(messagesToImport, importBatchBuilder);
|
||||
for (const message of messagesToImport) {
|
||||
if (!firstMessageDate && message.createdAt) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
const newMessageId = uuidv4();
|
||||
idMapping.set(message.messageId, newMessageId);
|
||||
|
||||
const clonedMessage = {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT
|
||||
? idMapping.get(message.parentMessageId) || Constants.NO_PARENT
|
||||
: Constants.NO_PARENT,
|
||||
};
|
||||
|
||||
importBatchBuilder.saveMessage(clonedMessage);
|
||||
}
|
||||
} else {
|
||||
throw new Error('Invalid LibreChat file format');
|
||||
|
||||
@@ -175,36 +175,60 @@ describe('importLibreChatConvo', () => {
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Create a map to track original message IDs to new UUIDs
|
||||
const idToUUIDMap = new Map();
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
const message = call[0];
|
||||
idToUUIDMap.set(message.originalMessageId, message.messageId);
|
||||
// Get the imported messages
|
||||
const messages = importBatchBuilder.messages;
|
||||
expect(messages.length).toBeGreaterThan(0);
|
||||
|
||||
// Build maps for verification
|
||||
const textToMessageMap = new Map();
|
||||
const messageIdToMessage = new Map();
|
||||
messages.forEach((msg) => {
|
||||
if (msg.text) {
|
||||
// For recursive imports, text might be very long, so just use the first 100 chars as key
|
||||
const textKey = msg.text.substring(0, 100);
|
||||
textToMessageMap.set(textKey, msg);
|
||||
}
|
||||
messageIdToMessage.set(msg.messageId, msg);
|
||||
});
|
||||
|
||||
const checkChildren = (children, parentId) => {
|
||||
children.forEach((child) => {
|
||||
const childUUID = idToUUIDMap.get(child.messageId);
|
||||
const expectedParentId = idToUUIDMap.get(parentId) ?? null;
|
||||
const messageCall = importBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === childUUID,
|
||||
);
|
||||
|
||||
const actualParentId = messageCall[0].parentMessageId;
|
||||
expect(actualParentId).toBe(expectedParentId);
|
||||
|
||||
if (child.children && child.children.length > 0) {
|
||||
checkChildren(child.children, child.messageId);
|
||||
// Count expected messages from the tree
|
||||
const countMessagesInTree = (nodes) => {
|
||||
let count = 0;
|
||||
nodes.forEach((node) => {
|
||||
if (node.text || node.content) {
|
||||
count++;
|
||||
}
|
||||
if (node.children && node.children.length > 0) {
|
||||
count += countMessagesInTree(node.children);
|
||||
}
|
||||
});
|
||||
return count;
|
||||
};
|
||||
|
||||
// Start hierarchy validation from root messages
|
||||
checkChildren(jsonData.messages, null);
|
||||
const expectedMessageCount = countMessagesInTree(jsonData.messages);
|
||||
expect(messages.length).toBe(expectedMessageCount);
|
||||
|
||||
// Verify all messages have valid parent relationships
|
||||
messages.forEach((msg) => {
|
||||
if (msg.parentMessageId !== Constants.NO_PARENT) {
|
||||
const parent = messageIdToMessage.get(msg.parentMessageId);
|
||||
expect(parent).toBeDefined();
|
||||
|
||||
// Verify timestamp ordering
|
||||
if (msg.createdAt && parent.createdAt) {
|
||||
expect(new Date(msg.createdAt).getTime()).toBeGreaterThanOrEqual(
|
||||
new Date(parent.createdAt).getTime(),
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Verify at least one root message exists
|
||||
const rootMessages = messages.filter((msg) => msg.parentMessageId === Constants.NO_PARENT);
|
||||
expect(rootMessages.length).toBeGreaterThan(0);
|
||||
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const path = require('path');
|
||||
const express = require('express');
|
||||
const expressStaticGzip = require('express-static-gzip');
|
||||
|
||||
const oneDayInSeconds = 24 * 60 * 60;
|
||||
@@ -7,44 +8,55 @@ const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds;
|
||||
const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2;
|
||||
|
||||
/**
|
||||
* Creates an Express static middleware with gzip compression and configurable caching
|
||||
* Creates an Express static middleware with optional gzip compression and configurable caching
|
||||
*
|
||||
* @param {string} staticPath - The file system path to serve static files from
|
||||
* @param {Object} [options={}] - Configuration options
|
||||
* @param {boolean} [options.noCache=false] - If true, disables caching entirely for all files
|
||||
* @returns {ReturnType<expressStaticGzip>} Express middleware function for serving static files
|
||||
* @param {boolean} [options.skipGzipScan=false] - If true, skips expressStaticGzip middleware
|
||||
* @returns {ReturnType<expressStaticGzip>|ReturnType<express.static>} Express middleware function for serving static files
|
||||
*/
|
||||
function staticCache(staticPath, options = {}) {
|
||||
const { noCache = false } = options;
|
||||
return expressStaticGzip(staticPath, {
|
||||
enableBrotli: false,
|
||||
orderPreference: ['gz'],
|
||||
setHeaders: (res, filePath) => {
|
||||
if (process.env.NODE_ENV?.toLowerCase() !== 'production') {
|
||||
return;
|
||||
}
|
||||
if (noCache) {
|
||||
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
|
||||
return;
|
||||
}
|
||||
if (filePath.includes('/dist/images/')) {
|
||||
return;
|
||||
}
|
||||
const fileName = path.basename(filePath);
|
||||
const { noCache = false, skipGzipScan = false } = options;
|
||||
|
||||
if (
|
||||
fileName === 'index.html' ||
|
||||
fileName.endsWith('.webmanifest') ||
|
||||
fileName === 'manifest.json' ||
|
||||
fileName === 'sw.js'
|
||||
) {
|
||||
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
|
||||
} else {
|
||||
res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`);
|
||||
}
|
||||
},
|
||||
index: false,
|
||||
});
|
||||
const setHeaders = (res, filePath) => {
|
||||
if (process.env.NODE_ENV?.toLowerCase() !== 'production') {
|
||||
return;
|
||||
}
|
||||
if (noCache) {
|
||||
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
|
||||
return;
|
||||
}
|
||||
if (filePath && filePath.includes('/dist/images/')) {
|
||||
return;
|
||||
}
|
||||
const fileName = filePath ? path.basename(filePath) : '';
|
||||
|
||||
if (
|
||||
fileName === 'index.html' ||
|
||||
fileName.endsWith('.webmanifest') ||
|
||||
fileName === 'manifest.json' ||
|
||||
fileName === 'sw.js'
|
||||
) {
|
||||
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
|
||||
} else {
|
||||
res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`);
|
||||
}
|
||||
};
|
||||
|
||||
if (skipGzipScan) {
|
||||
return express.static(staticPath, {
|
||||
setHeaders,
|
||||
index: false,
|
||||
});
|
||||
} else {
|
||||
return expressStaticGzip(staticPath, {
|
||||
enableBrotli: false,
|
||||
orderPreference: ['gz'],
|
||||
setHeaders,
|
||||
index: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = staticCache;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { updateUser, findUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
@@ -13,17 +14,23 @@ const { isEnabled } = require('~/server/utils');
|
||||
* The strategy extracts the JWT from the Authorization header as a Bearer token.
|
||||
* The JWT is then verified using the signing key, and the user is retrieved from the database.
|
||||
*/
|
||||
const openIdJwtLogin = (openIdConfig) =>
|
||||
new JwtStrategy(
|
||||
const openIdJwtLogin = (openIdConfig) => {
|
||||
let jwksRsaOptions = {
|
||||
cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true,
|
||||
cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME
|
||||
? eval(process.env.OPENID_JWKS_URL_CACHE_TIME)
|
||||
: 60000,
|
||||
jwksUri: openIdConfig.serverMetadata().jwks_uri,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
jwksRsaOptions.requestAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return new JwtStrategy(
|
||||
{
|
||||
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
|
||||
secretOrKeyProvider: jwksRsa.passportJwtSecret({
|
||||
cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true,
|
||||
cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME
|
||||
? eval(process.env.OPENID_JWKS_URL_CACHE_TIME)
|
||||
: 60000,
|
||||
jwksUri: openIdConfig.serverMetadata().jwks_uri,
|
||||
}),
|
||||
secretOrKeyProvider: jwksRsa.passportJwtSecret(jwksRsaOptions),
|
||||
},
|
||||
async (payload, done) => {
|
||||
try {
|
||||
@@ -48,5 +55,6 @@ const openIdJwtLogin = (openIdConfig) =>
|
||||
}
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = openIdJwtLogin;
|
||||
|
||||
@@ -49,7 +49,7 @@ async function customFetch(url, options) {
|
||||
logger.info(`[openidStrategy] proxy agent configured: ${process.env.PROXY}`);
|
||||
fetchOptions = {
|
||||
...options,
|
||||
dispatcher: new HttpsProxyAgent(process.env.PROXY),
|
||||
dispatcher: new undici.ProxyAgent(process.env.PROXY),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1074,7 +1074,7 @@
|
||||
|
||||
/**
|
||||
* @exports JsonSchemaType
|
||||
* @typedef {import('librechat-data-provider').JsonSchemaType} JsonSchemaType
|
||||
* @typedef {import('@librechat/api').JsonSchemaType} JsonSchemaType
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
|
||||
@@ -223,6 +223,7 @@ const xAIModels = {
|
||||
'grok-3-fast': 131072,
|
||||
'grok-3-mini': 131072,
|
||||
'grok-3-mini-fast': 131072,
|
||||
'grok-4': 256000, // 256K context
|
||||
};
|
||||
|
||||
const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels };
|
||||
|
||||
@@ -386,7 +386,7 @@ describe('matchModelName', () => {
|
||||
});
|
||||
|
||||
it('should return the closest matching key for gpt-4-1106 partial matches', () => {
|
||||
expect(matchModelName('something/gpt-4-1106')).toBe('gpt-4-1106');
|
||||
expect(matchModelName('gpt-4-1106/something')).toBe('gpt-4-1106');
|
||||
expect(matchModelName('gpt-4-1106-preview')).toBe('gpt-4-1106');
|
||||
expect(matchModelName('gpt-4-1106-vision-preview')).toBe('gpt-4-1106');
|
||||
});
|
||||
@@ -589,6 +589,10 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(getModelMaxTokens('grok-3-mini-fast')).toBe(131072);
|
||||
});
|
||||
|
||||
test('should return correct tokens for Grok 4 model', () => {
|
||||
expect(getModelMaxTokens('grok-4-0709')).toBe(256000);
|
||||
});
|
||||
|
||||
test('should handle partial matches for Grok models with prefixes', () => {
|
||||
// Vision models should match before general models
|
||||
expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(32768);
|
||||
@@ -606,6 +610,8 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(getModelMaxTokens('xai/grok-3-fast')).toBe(131072);
|
||||
expect(getModelMaxTokens('xai/grok-3-mini')).toBe(131072);
|
||||
expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe(131072);
|
||||
// Grok 4 model
|
||||
expect(getModelMaxTokens('xai/grok-4-0709')).toBe(256000);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -627,6 +633,8 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(matchModelName('grok-3-fast')).toBe('grok-3-fast');
|
||||
expect(matchModelName('grok-3-mini')).toBe('grok-3-mini');
|
||||
expect(matchModelName('grok-3-mini-fast')).toBe('grok-3-mini-fast');
|
||||
// Grok 4 model
|
||||
expect(matchModelName('grok-4-0709')).toBe('grok-4');
|
||||
});
|
||||
|
||||
test('should match Grok model variations with prefixes', () => {
|
||||
@@ -646,6 +654,8 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(matchModelName('xai/grok-3-fast')).toBe('grok-3-fast');
|
||||
expect(matchModelName('xai/grok-3-mini')).toBe('grok-3-mini');
|
||||
expect(matchModelName('xai/grok-3-mini-fast')).toBe('grok-3-mini-fast');
|
||||
// Grok 4 model
|
||||
expect(matchModelName('xai/grok-4-0709')).toBe('grok-4');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "v0.7.8",
|
||||
"version": "v0.7.9-rc1",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
@@ -72,7 +72,7 @@
|
||||
"input-otp": "^1.4.2",
|
||||
"js-cookie": "^3.0.5",
|
||||
"librechat-data-provider": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lucide-react": "^0.394.0",
|
||||
"match-sorter": "^6.3.4",
|
||||
"micromark-extension-llm-math": "^3.1.0",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import React, { createContext, useContext, useState } from 'react';
|
||||
import { Constants, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TPlugin, AgentToolType, Action, MCP } from 'librechat-data-provider';
|
||||
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
|
||||
import type { AgentPanelContextType } from '~/common';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { useLocalize, useGetAgentsConfig } from '~/hooks';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
|
||||
@@ -75,21 +75,25 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
{} as Record<string, AgentToolType & { tools?: AgentToolType[] }>,
|
||||
);
|
||||
|
||||
const value = {
|
||||
action,
|
||||
setAction,
|
||||
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
|
||||
|
||||
const value: AgentPanelContextType = {
|
||||
mcp,
|
||||
setMcp,
|
||||
mcps,
|
||||
setMcps,
|
||||
activePanel,
|
||||
setActivePanel,
|
||||
setCurrentAgentId,
|
||||
agent_id,
|
||||
groupedTools,
|
||||
/** Query data for actions and tools */
|
||||
actions,
|
||||
tools,
|
||||
action,
|
||||
setMcp,
|
||||
actions,
|
||||
setMcps,
|
||||
agent_id,
|
||||
setAction,
|
||||
activePanel,
|
||||
groupedTools,
|
||||
agentsConfig,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
setCurrentAgentId,
|
||||
};
|
||||
|
||||
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;
|
||||
|
||||
@@ -1,14 +1,25 @@
|
||||
import React, { createContext, useContext } from 'react';
|
||||
import { Tools, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { useMCPSelect, useToolToggle, useCodeApiKeyForm, useSearchApiKeyForm } from '~/hooks';
|
||||
import React, { createContext, useContext, useEffect, useRef } from 'react';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { TAgentsEndpoint } from 'librechat-data-provider';
|
||||
import {
|
||||
useSearchApiKeyForm,
|
||||
useGetAgentsConfig,
|
||||
useCodeApiKeyForm,
|
||||
useToolToggle,
|
||||
useMCPSelect,
|
||||
} from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
|
||||
interface BadgeRowContextType {
|
||||
conversationId?: string | null;
|
||||
agentsConfig?: TAgentsEndpoint | null;
|
||||
mcpSelect: ReturnType<typeof useMCPSelect>;
|
||||
webSearch: ReturnType<typeof useToolToggle>;
|
||||
codeInterpreter: ReturnType<typeof useToolToggle>;
|
||||
artifacts: ReturnType<typeof useToolToggle>;
|
||||
fileSearch: ReturnType<typeof useToolToggle>;
|
||||
codeInterpreter: ReturnType<typeof useToolToggle>;
|
||||
codeApiKeyForm: ReturnType<typeof useCodeApiKeyForm>;
|
||||
searchApiKeyForm: ReturnType<typeof useSearchApiKeyForm>;
|
||||
startupConfig: ReturnType<typeof useGetStartupConfig>['data'];
|
||||
@@ -26,10 +37,88 @@ export function useBadgeRowContext() {
|
||||
|
||||
interface BadgeRowProviderProps {
|
||||
children: React.ReactNode;
|
||||
isSubmitting?: boolean;
|
||||
conversationId?: string | null;
|
||||
}
|
||||
|
||||
export default function BadgeRowProvider({ children, conversationId }: BadgeRowProviderProps) {
|
||||
export default function BadgeRowProvider({
|
||||
children,
|
||||
isSubmitting,
|
||||
conversationId,
|
||||
}: BadgeRowProviderProps) {
|
||||
const hasInitializedRef = useRef(false);
|
||||
const lastKeyRef = useRef<string>('');
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key));
|
||||
|
||||
/** Initialize ephemeralAgent from localStorage on mount and when conversation changes */
|
||||
useEffect(() => {
|
||||
if (isSubmitting) {
|
||||
return;
|
||||
}
|
||||
// Check if this is a new conversation or the first load
|
||||
if (!hasInitializedRef.current || lastKeyRef.current !== key) {
|
||||
hasInitializedRef.current = true;
|
||||
lastKeyRef.current = key;
|
||||
|
||||
// Load all localStorage values
|
||||
const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`;
|
||||
const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`;
|
||||
const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`;
|
||||
const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`;
|
||||
|
||||
const codeToggleValue = localStorage.getItem(codeToggleKey);
|
||||
const webSearchToggleValue = localStorage.getItem(webSearchToggleKey);
|
||||
const fileSearchToggleValue = localStorage.getItem(fileSearchToggleKey);
|
||||
const artifactsToggleValue = localStorage.getItem(artifactsToggleKey);
|
||||
|
||||
const initialValues: Record<string, any> = {};
|
||||
|
||||
if (codeToggleValue !== null) {
|
||||
try {
|
||||
initialValues[Tools.execute_code] = JSON.parse(codeToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse code toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
if (webSearchToggleValue !== null) {
|
||||
try {
|
||||
initialValues[Tools.web_search] = JSON.parse(webSearchToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse web search toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
if (fileSearchToggleValue !== null) {
|
||||
try {
|
||||
initialValues[Tools.file_search] = JSON.parse(fileSearchToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse file search toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
if (artifactsToggleValue !== null) {
|
||||
try {
|
||||
initialValues[AgentCapabilities.artifacts] = JSON.parse(artifactsToggleValue);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse artifacts toggle value:', e);
|
||||
}
|
||||
}
|
||||
|
||||
// Always set values for all tools (use defaults if not in localStorage)
|
||||
// If ephemeralAgent is null, create a new object with just our tool values
|
||||
setEphemeralAgent((prev) => ({
|
||||
...(prev || {}),
|
||||
[Tools.execute_code]: initialValues[Tools.execute_code] ?? false,
|
||||
[Tools.web_search]: initialValues[Tools.web_search] ?? false,
|
||||
[Tools.file_search]: initialValues[Tools.file_search] ?? false,
|
||||
[AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false,
|
||||
}));
|
||||
}
|
||||
}, [key, isSubmitting, setEphemeralAgent]);
|
||||
|
||||
/** Startup config */
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
|
||||
@@ -74,10 +163,20 @@ export default function BadgeRowProvider({ children, conversationId }: BadgeRowP
|
||||
isAuthenticated: true,
|
||||
});
|
||||
|
||||
/** Artifacts hook - using a custom key since it's not a Tool but a capability */
|
||||
const artifacts = useToolToggle({
|
||||
conversationId,
|
||||
toolKey: AgentCapabilities.artifacts,
|
||||
localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_,
|
||||
isAuthenticated: true,
|
||||
});
|
||||
|
||||
const value: BadgeRowContextType = {
|
||||
mcpSelect,
|
||||
webSearch,
|
||||
artifacts,
|
||||
fileSearch,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
conversationId,
|
||||
codeApiKeyForm,
|
||||
|
||||
@@ -206,9 +206,7 @@ export type AgentPanelProps = {
|
||||
setActivePanel: React.Dispatch<React.SetStateAction<Panel>>;
|
||||
setMcp: React.Dispatch<React.SetStateAction<t.MCP | undefined>>;
|
||||
setAction: React.Dispatch<React.SetStateAction<t.Action | undefined>>;
|
||||
endpointsConfig?: t.TEndpointsConfig;
|
||||
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
agentsConfig?: t.TAgentsEndpoint | null;
|
||||
};
|
||||
|
||||
export type AgentPanelContextType = {
|
||||
@@ -225,6 +223,8 @@ export type AgentPanelContextType = {
|
||||
setCurrentAgentId: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
groupedTools?: Record<string, t.AgentToolType & { tools?: t.AgentToolType[] }>;
|
||||
agent_id?: string;
|
||||
agentsConfig?: t.TAgentsEndpoint | null;
|
||||
endpointsConfig?: t.TEndpointsConfig | null;
|
||||
};
|
||||
|
||||
export type AgentModelPanelProps = {
|
||||
@@ -336,6 +336,11 @@ export type TAskProps = {
|
||||
export type TOptions = {
|
||||
editedMessageId?: string | null;
|
||||
editedText?: string | null;
|
||||
editedContent?: {
|
||||
index: number;
|
||||
text: string;
|
||||
type: 'text' | 'think';
|
||||
};
|
||||
isRegenerate?: boolean;
|
||||
isContinued?: boolean;
|
||||
isEdited?: boolean;
|
||||
|
||||
152
client/src/components/Chat/Input/Artifacts.tsx
Normal file
152
client/src/components/Chat/Input/Artifacts.tsx
Normal file
@@ -0,0 +1,152 @@
|
||||
import React, { memo, useState, useCallback, useMemo } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ArtifactModes } from 'librechat-data-provider';
|
||||
import { WandSparkles, ChevronDown } from 'lucide-react';
|
||||
import CheckboxButton from '~/components/ui/CheckboxButton';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface ArtifactsToggleState {
|
||||
enabled: boolean;
|
||||
mode: string;
|
||||
}
|
||||
|
||||
function Artifacts() {
|
||||
const localize = useLocalize();
|
||||
const { artifacts } = useBadgeRowContext();
|
||||
const { toggleState, debouncedChange, isPinned } = artifacts;
|
||||
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
|
||||
const currentState = useMemo<ArtifactsToggleState>(() => {
|
||||
if (typeof toggleState === 'string' && toggleState) {
|
||||
return { enabled: true, mode: toggleState };
|
||||
}
|
||||
return { enabled: false, mode: '' };
|
||||
}, [toggleState]);
|
||||
|
||||
const isEnabled = currentState.enabled;
|
||||
const isShadcnEnabled = currentState.mode === ArtifactModes.SHADCNUI;
|
||||
const isCustomEnabled = currentState.mode === ArtifactModes.CUSTOM;
|
||||
|
||||
const handleToggle = useCallback(() => {
|
||||
if (isEnabled) {
|
||||
debouncedChange({ value: '' });
|
||||
} else {
|
||||
debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
}
|
||||
}, [isEnabled, debouncedChange]);
|
||||
|
||||
const handleShadcnToggle = useCallback(() => {
|
||||
if (isShadcnEnabled) {
|
||||
debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
debouncedChange({ value: ArtifactModes.SHADCNUI });
|
||||
}
|
||||
}, [isShadcnEnabled, debouncedChange]);
|
||||
|
||||
const handleCustomToggle = useCallback(() => {
|
||||
if (isCustomEnabled) {
|
||||
debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
debouncedChange({ value: ArtifactModes.CUSTOM });
|
||||
}
|
||||
}, [isCustomEnabled, debouncedChange]);
|
||||
|
||||
if (!isEnabled && !isPinned) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<CheckboxButton
|
||||
className={cn('max-w-fit', isEnabled && 'rounded-r-none border-r-0')}
|
||||
checked={isEnabled}
|
||||
setValue={handleToggle}
|
||||
label={localize('com_ui_artifacts')}
|
||||
isCheckedClassName="border-amber-600/40 bg-amber-500/10 hover:bg-amber-700/10"
|
||||
icon={<WandSparkles className="icon-md" />}
|
||||
/>
|
||||
|
||||
{isEnabled && (
|
||||
<Ariakit.MenuProvider open={isPopoverOpen} setOpen={setIsPopoverOpen}>
|
||||
<Ariakit.MenuButton
|
||||
className={cn(
|
||||
'w-7 rounded-l-none rounded-r-full border-b border-l-0 border-r border-t border-border-light md:w-6',
|
||||
'border-amber-600/40 bg-amber-500/10 hover:bg-amber-700/10',
|
||||
'transition-colors',
|
||||
)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<ChevronDown className="ml-1 h-4 w-4 text-text-secondary md:ml-0" />
|
||||
</Ariakit.MenuButton>
|
||||
|
||||
<Ariakit.Menu
|
||||
gutter={8}
|
||||
className={cn(
|
||||
'animate-popover z-50 flex max-h-[300px]',
|
||||
'flex-col overflow-auto overscroll-contain rounded-xl',
|
||||
'bg-surface-secondary px-1.5 py-1 text-text-primary shadow-lg',
|
||||
'border border-border-light',
|
||||
'min-w-[250px] outline-none',
|
||||
)}
|
||||
portal
|
||||
>
|
||||
<div className="px-2 py-1.5">
|
||||
<div className="mb-2 text-xs font-medium text-text-secondary">
|
||||
{localize('com_ui_artifacts_options')}
|
||||
</div>
|
||||
|
||||
{/* Include shadcn/ui Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleShadcnToggle();
|
||||
}}
|
||||
disabled={isCustomEnabled}
|
||||
className={cn(
|
||||
'mb-1 flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
isCustomEnabled && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isShadcnEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_include_shadcnui' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{/* Custom Prompt Mode Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleCustomToggle();
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isCustomEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_custom_prompt_mode' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(Artifacts);
|
||||
147
client/src/components/Chat/Input/ArtifactsSubMenu.tsx
Normal file
147
client/src/components/Chat/Input/ArtifactsSubMenu.tsx
Normal file
@@ -0,0 +1,147 @@
|
||||
import React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronRight, WandSparkles } from 'lucide-react';
|
||||
import { ArtifactModes } from 'librechat-data-provider';
|
||||
import { PinIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface ArtifactsSubMenuProps {
|
||||
isArtifactsPinned: boolean;
|
||||
setIsArtifactsPinned: (value: boolean) => void;
|
||||
artifactsMode: string;
|
||||
handleArtifactsToggle: () => void;
|
||||
handleShadcnToggle: () => void;
|
||||
handleCustomToggle: () => void;
|
||||
}
|
||||
|
||||
const ArtifactsSubMenu = ({
|
||||
isArtifactsPinned,
|
||||
setIsArtifactsPinned,
|
||||
artifactsMode,
|
||||
handleArtifactsToggle,
|
||||
handleShadcnToggle,
|
||||
handleCustomToggle,
|
||||
...props
|
||||
}: ArtifactsSubMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
focusLoop: true,
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
|
||||
const isEnabled = artifactsMode !== '' && artifactsMode !== undefined;
|
||||
const isShadcnEnabled = artifactsMode === ArtifactModes.SHADCNUI;
|
||||
const isCustomEnabled = artifactsMode === ArtifactModes.CUSTOM;
|
||||
|
||||
return (
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
hideOnClick={false}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
handleArtifactsToggle();
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
if (isEnabled) {
|
||||
menuStore.show();
|
||||
}
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<WandSparkles className="icon-md" />
|
||||
<span>{localize('com_ui_artifacts')}</span>
|
||||
{isEnabled && <ChevronRight className="ml-auto h-3 w-3" />}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsArtifactsPinned(!isArtifactsPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isArtifactsPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isArtifactsPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isArtifactsPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{isEnabled && (
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[250px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary px-1.5 py-1 shadow-lg',
|
||||
)}
|
||||
>
|
||||
<div className="px-2 py-1.5">
|
||||
<div className="mb-2 text-xs font-medium text-text-secondary">
|
||||
{localize('com_ui_artifacts_options')}
|
||||
</div>
|
||||
|
||||
{/* Include shadcn/ui Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleShadcnToggle();
|
||||
}}
|
||||
disabled={isCustomEnabled}
|
||||
className={cn(
|
||||
'mb-1 flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
isCustomEnabled && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isShadcnEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_include_shadcnui' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{/* Custom Prompt Mode Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleCustomToggle();
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isCustomEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_custom_prompt_mode' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
)}
|
||||
</Ariakit.MenuProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(ArtifactsSubMenu);
|
||||
@@ -18,6 +18,7 @@ import { useChatBadges } from '~/hooks';
|
||||
import { Badge } from '~/components/ui';
|
||||
import ToolDialogs from './ToolDialogs';
|
||||
import FileSearch from './FileSearch';
|
||||
import Artifacts from './Artifacts';
|
||||
import MCPSelect from './MCPSelect';
|
||||
import WebSearch from './WebSearch';
|
||||
import store from '~/store';
|
||||
@@ -27,6 +28,7 @@ interface BadgeRowProps {
|
||||
onChange: (badges: Pick<BadgeItem, 'id'>[]) => void;
|
||||
onToggle?: (badgeId: string, currentActive: boolean) => void;
|
||||
conversationId?: string | null;
|
||||
isSubmitting?: boolean;
|
||||
isInChat: boolean;
|
||||
}
|
||||
|
||||
@@ -140,6 +142,7 @@ const dragReducer = (state: DragState, action: DragAction): DragState => {
|
||||
function BadgeRow({
|
||||
showEphemeralBadges,
|
||||
conversationId,
|
||||
isSubmitting,
|
||||
onChange,
|
||||
onToggle,
|
||||
isInChat,
|
||||
@@ -317,7 +320,7 @@ function BadgeRow({
|
||||
}, [dragState.draggedBadge, handleMouseMove, handleMouseUp]);
|
||||
|
||||
return (
|
||||
<BadgeRowProvider conversationId={conversationId}>
|
||||
<BadgeRowProvider conversationId={conversationId} isSubmitting={isSubmitting}>
|
||||
<div ref={containerRef} className="relative flex flex-wrap items-center gap-2">
|
||||
{showEphemeralBadges === true && <ToolsDropdown />}
|
||||
{tempBadges.map((badge, index) => (
|
||||
@@ -364,6 +367,7 @@ function BadgeRow({
|
||||
<WebSearch />
|
||||
<CodeInterpreter />
|
||||
<FileSearch />
|
||||
<Artifacts />
|
||||
<MCPSelect />
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -305,6 +305,7 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => {
|
||||
</div>
|
||||
<BadgeRow
|
||||
showEphemeralBadges={!isAgentsEndpoint(endpoint) && !isAssistantsEndpoint(endpoint)}
|
||||
isSubmitting={isSubmitting || isSubmittingAdded}
|
||||
conversationId={conversationId}
|
||||
onChange={setBadges}
|
||||
isInChat={
|
||||
|
||||
@@ -4,18 +4,21 @@ import {
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAgentsEndpoint,
|
||||
isAssistantsEndpoint,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig } from 'librechat-data-provider';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import AttachFile from './AttachFile';
|
||||
|
||||
function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const { conversation } = useChatContext();
|
||||
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const { endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(endpoint), [endpoint]);
|
||||
const isAssistants = useMemo(() => isAssistantsEndpoint(endpoint), [endpoint]);
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
@@ -25,7 +28,9 @@ function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled = (disableInputs || endpointFileConfig?.disabled) ?? false;
|
||||
|
||||
if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) {
|
||||
if (isAssistants && endpointSupportsFiles && !isUploadDisabled) {
|
||||
return <AttachFile disabled={disableInputs} />;
|
||||
} else if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) {
|
||||
return (
|
||||
<AttachFileMenu
|
||||
disabled={disableInputs}
|
||||
@@ -34,7 +39,6 @@ function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,10 @@ import { useSetRecoilState } from 'recoil';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import React, { useRef, useState, useMemo } from 'react';
|
||||
import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
|
||||
import { EToolResources, EModelEndpoint, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import type { EndpointFileConfig } from 'librechat-data-provider';
|
||||
import { useLocalize, useGetAgentsConfig, useFileHandling, useAgentCapabilities } from '~/hooks';
|
||||
import { FileUpload, TooltipAnchor, DropdownPopup, AttachmentIcon } from '~/components';
|
||||
import { EToolResources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useLocalize, useFileHandling } from '~/hooks';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
@@ -23,20 +22,17 @@ const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: Attach
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { handleFileChange } = useFileHandling({
|
||||
overrideEndpoint: EModelEndpoint.agents,
|
||||
overrideEndpointFileConfig: endpointFileConfig,
|
||||
});
|
||||
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
/** TODO: Ephemeral Agent Capabilities
|
||||
* Allow defining agent capabilities on a per-endpoint basis
|
||||
* Use definition for agents endpoint for ephemeral agents
|
||||
* */
|
||||
const capabilities = useMemo(
|
||||
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
|
||||
[endpointsConfig],
|
||||
);
|
||||
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
|
||||
const handleUploadClick = (isImage?: boolean) => {
|
||||
if (!inputRef.current) {
|
||||
@@ -60,7 +56,7 @@ const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: Attach
|
||||
},
|
||||
];
|
||||
|
||||
if (capabilities.includes(EToolResources.ocr)) {
|
||||
if (capabilities.ocrEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_ocr_text'),
|
||||
onClick: () => {
|
||||
@@ -71,7 +67,7 @@ const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: Attach
|
||||
});
|
||||
}
|
||||
|
||||
if (capabilities.includes(EToolResources.file_search)) {
|
||||
if (capabilities.fileSearchEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
onClick: () => {
|
||||
@@ -83,7 +79,7 @@ const AttachFileMenu = ({ disabled, conversationId, endpointFileConfig }: Attach
|
||||
});
|
||||
}
|
||||
|
||||
if (capabilities.includes(EToolResources.execute_code)) {
|
||||
if (capabilities.codeEnabled) {
|
||||
items.push({
|
||||
label: localize('com_ui_upload_code_files'),
|
||||
onClick: () => {
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { EModelEndpoint, EToolResources } from 'librechat-data-provider';
|
||||
import { EToolResources, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import { FileSearch, ImageUpIcon, FileType2Icon, TerminalSquareIcon } from 'lucide-react';
|
||||
import OGDialogTemplate from '~/components/ui/OGDialogTemplate';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import useLocalize from '~/hooks/useLocalize';
|
||||
import { OGDialog } from '~/components/ui';
|
||||
import { useLocalize, useGetAgentsConfig, useAgentCapabilities } from '~/hooks';
|
||||
import { OGDialog, OGDialogTemplate } from '~/components/ui';
|
||||
|
||||
interface DragDropModalProps {
|
||||
onOptionSelect: (option: EToolResources | undefined) => void;
|
||||
@@ -22,12 +20,12 @@ interface FileOption {
|
||||
|
||||
const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragDropModalProps) => {
|
||||
const localize = useLocalize();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const capabilities = useMemo(
|
||||
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
|
||||
[endpointsConfig],
|
||||
);
|
||||
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
/** TODO: Ephemeral Agent Capabilities
|
||||
* Allow defining agent capabilities on a per-endpoint basis
|
||||
* Use definition for agents endpoint for ephemeral agents
|
||||
* */
|
||||
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
const options = useMemo(() => {
|
||||
const _options: FileOption[] = [
|
||||
{
|
||||
@@ -37,26 +35,26 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
||||
condition: files.every((file) => file.type?.startsWith('image/')),
|
||||
},
|
||||
];
|
||||
for (const capability of capabilities) {
|
||||
if (capability === EToolResources.file_search) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
value: EToolResources.file_search,
|
||||
icon: <FileSearch className="icon-md" />,
|
||||
});
|
||||
} else if (capability === EToolResources.execute_code) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_code_files'),
|
||||
value: EToolResources.execute_code,
|
||||
icon: <TerminalSquareIcon className="icon-md" />,
|
||||
});
|
||||
} else if (capability === EToolResources.ocr) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_ocr_text'),
|
||||
value: EToolResources.ocr,
|
||||
icon: <FileType2Icon className="icon-md" />,
|
||||
});
|
||||
}
|
||||
if (capabilities.fileSearchEnabled) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
value: EToolResources.file_search,
|
||||
icon: <FileSearch className="icon-md" />,
|
||||
});
|
||||
}
|
||||
if (capabilities.codeEnabled) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_code_files'),
|
||||
value: EToolResources.execute_code,
|
||||
icon: <TerminalSquareIcon className="icon-md" />,
|
||||
});
|
||||
}
|
||||
if (capabilities.ocrEnabled) {
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_ocr_text'),
|
||||
value: EToolResources.ocr,
|
||||
icon: <FileType2Icon className="icon-md" />,
|
||||
});
|
||||
}
|
||||
|
||||
return _options;
|
||||
|
||||
@@ -2,6 +2,8 @@ import { useEffect } from 'react';
|
||||
import { EToolResources } from 'librechat-data-provider';
|
||||
import type { ExtendedFile } from '~/common';
|
||||
import { useDeleteFilesMutation } from '~/data-provider';
|
||||
import { useToastContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { useFileDeletion } from '~/hooks/Files';
|
||||
import FileContainer from './FileContainer';
|
||||
import { logger } from '~/utils';
|
||||
@@ -30,6 +32,8 @@ export default function FileRow({
|
||||
isRTL?: boolean;
|
||||
Wrapper?: React.FC<{ children: React.ReactNode }>;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const files = Array.from(_files?.values() ?? []).filter((file) =>
|
||||
fileFilter ? fileFilter(file) : true,
|
||||
);
|
||||
@@ -105,6 +109,10 @@ export default function FileRow({
|
||||
)
|
||||
.uniqueFiles.map((file: ExtendedFile, index: number) => {
|
||||
const handleDelete = () => {
|
||||
showToast({
|
||||
message: localize('com_ui_deleting_file'),
|
||||
status: 'info',
|
||||
});
|
||||
if (abortUpload && file.progress < 1) {
|
||||
abortUpload();
|
||||
}
|
||||
|
||||
@@ -2,11 +2,18 @@ import React, { useState, useMemo, useCallback } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { Globe, Settings, Settings2, TerminalSquareIcon } from 'lucide-react';
|
||||
import type { MenuItemProps } from '~/common';
|
||||
import { Permissions, PermissionTypes, AuthType } from 'librechat-data-provider';
|
||||
import {
|
||||
AuthType,
|
||||
Permissions,
|
||||
ArtifactModes,
|
||||
PermissionTypes,
|
||||
defaultAgentCapabilities,
|
||||
} from 'librechat-data-provider';
|
||||
import { TooltipAnchor, DropdownPopup } from '~/components';
|
||||
import { useLocalize, useHasAccess, useAgentCapabilities } from '~/hooks';
|
||||
import ArtifactsSubMenu from '~/components/Chat/Input/ArtifactsSubMenu';
|
||||
import MCPSubMenu from '~/components/Chat/Input/MCPSubMenu';
|
||||
import { PinIcon, VectorIcon } from '~/components/svg';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
@@ -21,12 +28,17 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
const {
|
||||
webSearch,
|
||||
mcpSelect,
|
||||
artifacts,
|
||||
fileSearch,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
codeApiKeyForm,
|
||||
codeInterpreter,
|
||||
searchApiKeyForm,
|
||||
} = useBadgeRowContext();
|
||||
const { codeEnabled, webSearchEnabled, artifactsEnabled, fileSearchEnabled } =
|
||||
useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
|
||||
|
||||
const { setIsDialogOpen: setIsCodeDialogOpen, menuTriggerRef: codeMenuTriggerRef } =
|
||||
codeApiKeyForm;
|
||||
const { setIsDialogOpen: setIsSearchDialogOpen, menuTriggerRef: searchMenuTriggerRef } =
|
||||
@@ -42,6 +54,7 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
authData: codeAuthData,
|
||||
} = codeInterpreter;
|
||||
const { isPinned: isFileSearchPinned, setIsPinned: setIsFileSearchPinned } = fileSearch;
|
||||
const { isPinned: isArtifactsPinned, setIsPinned: setIsArtifactsPinned } = artifacts;
|
||||
const {
|
||||
mcpValues,
|
||||
mcpServerNames,
|
||||
@@ -72,19 +85,46 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
|
||||
const handleWebSearchToggle = useCallback(() => {
|
||||
const newValue = !webSearch.toggleState;
|
||||
webSearch.debouncedChange({ isChecked: newValue });
|
||||
webSearch.debouncedChange({ value: newValue });
|
||||
}, [webSearch]);
|
||||
|
||||
const handleCodeInterpreterToggle = useCallback(() => {
|
||||
const newValue = !codeInterpreter.toggleState;
|
||||
codeInterpreter.debouncedChange({ isChecked: newValue });
|
||||
codeInterpreter.debouncedChange({ value: newValue });
|
||||
}, [codeInterpreter]);
|
||||
|
||||
const handleFileSearchToggle = useCallback(() => {
|
||||
const newValue = !fileSearch.toggleState;
|
||||
fileSearch.debouncedChange({ isChecked: newValue });
|
||||
fileSearch.debouncedChange({ value: newValue });
|
||||
}, [fileSearch]);
|
||||
|
||||
const handleArtifactsToggle = useCallback(() => {
|
||||
const currentState = artifacts.toggleState;
|
||||
if (!currentState || currentState === '') {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
artifacts.debouncedChange({ value: '' });
|
||||
}
|
||||
}, [artifacts]);
|
||||
|
||||
const handleShadcnToggle = useCallback(() => {
|
||||
const currentState = artifacts.toggleState;
|
||||
if (currentState === ArtifactModes.SHADCNUI) {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.SHADCNUI });
|
||||
}
|
||||
}, [artifacts]);
|
||||
|
||||
const handleCustomToggle = useCallback(() => {
|
||||
const currentState = artifacts.toggleState;
|
||||
if (currentState === ArtifactModes.CUSTOM) {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.DEFAULT });
|
||||
} else {
|
||||
artifacts.debouncedChange({ value: ArtifactModes.CUSTOM });
|
||||
}
|
||||
}, [artifacts]);
|
||||
|
||||
const handleMCPToggle = useCallback(
|
||||
(serverName: string) => {
|
||||
const currentValues = mcpSelect.mcpValues ?? [];
|
||||
@@ -98,9 +138,10 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
|
||||
const mcpPlaceholder = startupConfig?.interface?.mcpServers?.placeholder;
|
||||
|
||||
const dropdownItems = useMemo(() => {
|
||||
const items: MenuItemProps[] = [];
|
||||
items.push({
|
||||
const dropdownItems: MenuItemProps[] = [];
|
||||
|
||||
if (fileSearchEnabled) {
|
||||
dropdownItems.push({
|
||||
onClick: handleFileSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
@@ -129,159 +170,149 @@ const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canUseWebSearch) {
|
||||
items.push({
|
||||
onClick: handleWebSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Globe className="icon-md" />
|
||||
<span>{localize('com_ui_web_search')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showWebSearchSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchDialogOpen(true);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label="Configure web search"
|
||||
ref={searchMenuTriggerRef}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
if (canUseWebSearch && webSearchEnabled) {
|
||||
dropdownItems.push({
|
||||
onClick: handleWebSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Globe className="icon-md" />
|
||||
<span>{localize('com_ui_web_search')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showWebSearchSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchPinned(!isSearchPinned);
|
||||
setIsSearchDialogOpen(true);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isSearchPinned && 'text-text-secondary hover:text-text-primary',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isSearchPinned ? 'Unpin' : 'Pin'}
|
||||
aria-label="Configure web search"
|
||||
ref={searchMenuTriggerRef}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isSearchPinned} />
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canRunCode) {
|
||||
items.push({
|
||||
onClick: handleCodeInterpreterToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<TerminalSquareIcon className="icon-md" />
|
||||
<span>{localize('com_assistants_code_interpreter')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showCodeSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodeDialogOpen(true);
|
||||
}}
|
||||
ref={codeMenuTriggerRef}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label="Configure code interpreter"
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchPinned(!isSearchPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isSearchPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isSearchPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isSearchPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canRunCode && codeEnabled) {
|
||||
dropdownItems.push({
|
||||
onClick: handleCodeInterpreterToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<TerminalSquareIcon className="icon-md" />
|
||||
<span>{localize('com_assistants_code_interpreter')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showCodeSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodePinned(!isCodePinned);
|
||||
setIsCodeDialogOpen(true);
|
||||
}}
|
||||
ref={codeMenuTriggerRef}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isCodePinned && 'text-text-primary hover:text-text-primary',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isCodePinned ? 'Unpin' : 'Pin'}
|
||||
aria-label="Configure code interpreter"
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isCodePinned} />
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodePinned(!isCodePinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isCodePinned && 'text-text-primary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isCodePinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isCodePinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (mcpServerNames && mcpServerNames.length > 0) {
|
||||
items.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<MCPSubMenu
|
||||
{...props}
|
||||
mcpValues={mcpValues}
|
||||
isMCPPinned={isMCPPinned}
|
||||
placeholder={mcpPlaceholder}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsMCPPinned={setIsMCPPinned}
|
||||
handleMCPToggle={handleMCPToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
if (artifactsEnabled) {
|
||||
dropdownItems.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<ArtifactsSubMenu
|
||||
{...props}
|
||||
isArtifactsPinned={isArtifactsPinned}
|
||||
setIsArtifactsPinned={setIsArtifactsPinned}
|
||||
artifactsMode={artifacts.toggleState as string}
|
||||
handleArtifactsToggle={handleArtifactsToggle}
|
||||
handleShadcnToggle={handleShadcnToggle}
|
||||
handleCustomToggle={handleCustomToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
return items;
|
||||
}, [
|
||||
localize,
|
||||
mcpValues,
|
||||
canRunCode,
|
||||
isMCPPinned,
|
||||
isCodePinned,
|
||||
mcpPlaceholder,
|
||||
mcpServerNames,
|
||||
isSearchPinned,
|
||||
setIsMCPPinned,
|
||||
canUseWebSearch,
|
||||
setIsCodePinned,
|
||||
handleMCPToggle,
|
||||
showCodeSettings,
|
||||
setIsSearchPinned,
|
||||
isFileSearchPinned,
|
||||
codeMenuTriggerRef,
|
||||
setIsCodeDialogOpen,
|
||||
searchMenuTriggerRef,
|
||||
showWebSearchSettings,
|
||||
setIsFileSearchPinned,
|
||||
handleWebSearchToggle,
|
||||
setIsSearchDialogOpen,
|
||||
handleFileSearchToggle,
|
||||
handleCodeInterpreterToggle,
|
||||
]);
|
||||
if (mcpServerNames && mcpServerNames.length > 0) {
|
||||
dropdownItems.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<MCPSubMenu
|
||||
{...props}
|
||||
mcpValues={mcpValues}
|
||||
isMCPPinned={isMCPPinned}
|
||||
placeholder={mcpPlaceholder}
|
||||
mcpServerNames={mcpServerNames}
|
||||
setIsMCPPinned={setIsMCPPinned}
|
||||
handleMCPToggle={handleMCPToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
const menuTrigger = (
|
||||
<TooltipAnchor
|
||||
|
||||
@@ -8,7 +8,7 @@ import { useBadgeRowContext } from '~/Providers';
|
||||
function WebSearch() {
|
||||
const localize = useLocalize();
|
||||
const { webSearch: webSearchData, searchApiKeyForm } = useBadgeRowContext();
|
||||
const { toggleState: webSearch, debouncedChange, isPinned } = webSearchData;
|
||||
const { toggleState: webSearch, debouncedChange, isPinned, authData } = webSearchData;
|
||||
const { badgeTriggerRef } = searchApiKeyForm;
|
||||
|
||||
const canUseWebSearch = useHasAccess({
|
||||
@@ -21,7 +21,7 @@ function WebSearch() {
|
||||
}
|
||||
|
||||
return (
|
||||
(webSearch || isPinned) && (
|
||||
(isPinned || (webSearch && authData?.authenticated)) && (
|
||||
<CheckboxButton
|
||||
ref={badgeTriggerRef}
|
||||
className="max-w-fit"
|
||||
|
||||
@@ -81,14 +81,23 @@ const ContentParts = memo(
|
||||
return (
|
||||
<>
|
||||
{content.map((part, idx) => {
|
||||
if (part?.type !== ContentTypes.TEXT || typeof part.text !== 'string') {
|
||||
if (!part) {
|
||||
return null;
|
||||
}
|
||||
const isTextPart =
|
||||
part?.type === ContentTypes.TEXT ||
|
||||
typeof (part as unknown as Agents.MessageContentText)?.text !== 'string';
|
||||
const isThinkPart =
|
||||
part?.type === ContentTypes.THINK ||
|
||||
typeof (part as unknown as Agents.ReasoningDeltaUpdate)?.think !== 'string';
|
||||
if (!isTextPart && !isThinkPart) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<EditTextPart
|
||||
index={idx}
|
||||
text={part.text}
|
||||
part={part as Agents.MessageContentText | Agents.ReasoningDeltaUpdate}
|
||||
messageId={messageId}
|
||||
isSubmitting={isSubmitting}
|
||||
enterEdit={enterEdit}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { memo, useMemo, useRef, useEffect } from 'react';
|
||||
import React, { memo, useMemo } from 'react';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import remarkMath from 'remark-math';
|
||||
import supersub from 'remark-supersub';
|
||||
@@ -7,167 +7,16 @@ import { useRecoilValue } from 'recoil';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import remarkDirective from 'remark-directive';
|
||||
import { PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import type { Pluggable } from 'unified';
|
||||
import {
|
||||
useToastContext,
|
||||
ArtifactProvider,
|
||||
CodeBlockProvider,
|
||||
useCodeBlockContext,
|
||||
} from '~/Providers';
|
||||
import { Citation, CompositeCitation, HighlightedText } from '~/components/Web/Citation';
|
||||
import { Artifact, artifactPlugin } from '~/components/Artifacts/Artifact';
|
||||
import { langSubset, preprocessLaTeX, handleDoubleClick } from '~/utils';
|
||||
import CodeBlock from '~/components/Messages/Content/CodeBlock';
|
||||
import useHasAccess from '~/hooks/Roles/useHasAccess';
|
||||
import { ArtifactProvider, CodeBlockProvider } from '~/Providers';
|
||||
import MarkdownErrorBoundary from './MarkdownErrorBoundary';
|
||||
import { langSubset, preprocessLaTeX } from '~/utils';
|
||||
import { unicodeCitation } from '~/components/Web';
|
||||
import { useFileDownload } from '~/data-provider';
|
||||
import useLocalize from '~/hooks/useLocalize';
|
||||
import { code, a, p } from './MarkdownComponents';
|
||||
import store from '~/store';
|
||||
|
||||
type TCodeProps = {
|
||||
inline?: boolean;
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
export const code: React.ElementType = memo(({ className, children }: TCodeProps) => {
|
||||
const canRunCode = useHasAccess({
|
||||
permissionType: PermissionTypes.RUN_CODE,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
const match = /language-(\w+)/.exec(className ?? '');
|
||||
const lang = match && match[1];
|
||||
const isMath = lang === 'math';
|
||||
const isSingleLine = typeof children === 'string' && children.split('\n').length === 1;
|
||||
|
||||
const { getNextIndex, resetCounter } = useCodeBlockContext();
|
||||
const blockIndex = useRef(getNextIndex(isMath || isSingleLine)).current;
|
||||
|
||||
useEffect(() => {
|
||||
resetCounter();
|
||||
}, [children, resetCounter]);
|
||||
|
||||
if (isMath) {
|
||||
return <>{children}</>;
|
||||
} else if (isSingleLine) {
|
||||
return (
|
||||
<code onDoubleClick={handleDoubleClick} className={className}>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<CodeBlock
|
||||
lang={lang ?? 'text'}
|
||||
codeChildren={children}
|
||||
blockIndex={blockIndex}
|
||||
allowExecution={canRunCode}
|
||||
/>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
export const codeNoExecution: React.ElementType = memo(({ className, children }: TCodeProps) => {
|
||||
const match = /language-(\w+)/.exec(className ?? '');
|
||||
const lang = match && match[1];
|
||||
|
||||
if (lang === 'math') {
|
||||
return children;
|
||||
} else if (typeof children === 'string' && children.split('\n').length === 1) {
|
||||
return (
|
||||
<code onDoubleClick={handleDoubleClick} className={className}>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
} else {
|
||||
return <CodeBlock lang={lang ?? 'text'} codeChildren={children} allowExecution={false} />;
|
||||
}
|
||||
});
|
||||
|
||||
type TAnchorProps = {
|
||||
href: string;
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
export const a: React.ElementType = memo(({ href, children }: TAnchorProps) => {
|
||||
const user = useRecoilValue(store.user);
|
||||
const { showToast } = useToastContext();
|
||||
const localize = useLocalize();
|
||||
|
||||
const {
|
||||
file_id = '',
|
||||
filename = '',
|
||||
filepath,
|
||||
} = useMemo(() => {
|
||||
const pattern = new RegExp(`(?:files|outputs)/${user?.id}/([^\\s]+)`);
|
||||
const match = href.match(pattern);
|
||||
if (match && match[0]) {
|
||||
const path = match[0];
|
||||
const parts = path.split('/');
|
||||
const name = parts.pop();
|
||||
const file_id = parts.pop();
|
||||
return { file_id, filename: name, filepath: path };
|
||||
}
|
||||
return { file_id: '', filename: '', filepath: '' };
|
||||
}, [user?.id, href]);
|
||||
|
||||
const { refetch: downloadFile } = useFileDownload(user?.id ?? '', file_id);
|
||||
const props: { target?: string; onClick?: React.MouseEventHandler } = { target: '_new' };
|
||||
|
||||
if (!file_id || !filename) {
|
||||
return (
|
||||
<a href={href} {...props}>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
const handleDownload = async (event: React.MouseEvent<HTMLAnchorElement>) => {
|
||||
event.preventDefault();
|
||||
try {
|
||||
const stream = await downloadFile();
|
||||
if (stream.data == null || stream.data === '') {
|
||||
console.error('Error downloading file: No data found');
|
||||
showToast({
|
||||
status: 'error',
|
||||
message: localize('com_ui_download_error'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
const link = document.createElement('a');
|
||||
link.href = stream.data;
|
||||
link.setAttribute('download', filename);
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(stream.data);
|
||||
} catch (error) {
|
||||
console.error('Error downloading file:', error);
|
||||
}
|
||||
};
|
||||
|
||||
props.onClick = handleDownload;
|
||||
props.target = '_blank';
|
||||
|
||||
return (
|
||||
<a
|
||||
href={filepath?.startsWith('files/') ? `/api/${filepath}` : `/api/files/${filepath}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
});
|
||||
|
||||
type TParagraphProps = {
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
export const p: React.ElementType = memo(({ children }: TParagraphProps) => {
|
||||
return <p className="mb-2 whitespace-pre-wrap">{children}</p>;
|
||||
});
|
||||
|
||||
type TContentProps = {
|
||||
content: string;
|
||||
isLatestMessage: boolean;
|
||||
@@ -219,31 +68,33 @@ const Markdown = memo(({ content = '', isLatestMessage }: TContentProps) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<ArtifactProvider>
|
||||
<CodeBlockProvider>
|
||||
<ReactMarkdown
|
||||
/** @ts-ignore */
|
||||
remarkPlugins={remarkPlugins}
|
||||
/* @ts-ignore */
|
||||
rehypePlugins={rehypePlugins}
|
||||
components={
|
||||
{
|
||||
code,
|
||||
a,
|
||||
p,
|
||||
artifact: Artifact,
|
||||
citation: Citation,
|
||||
'highlighted-text': HighlightedText,
|
||||
'composite-citation': CompositeCitation,
|
||||
} as {
|
||||
[nodeType: string]: React.ElementType;
|
||||
<MarkdownErrorBoundary content={content} codeExecution={true}>
|
||||
<ArtifactProvider>
|
||||
<CodeBlockProvider>
|
||||
<ReactMarkdown
|
||||
/** @ts-ignore */
|
||||
remarkPlugins={remarkPlugins}
|
||||
/* @ts-ignore */
|
||||
rehypePlugins={rehypePlugins}
|
||||
components={
|
||||
{
|
||||
code,
|
||||
a,
|
||||
p,
|
||||
artifact: Artifact,
|
||||
citation: Citation,
|
||||
'highlighted-text': HighlightedText,
|
||||
'composite-citation': CompositeCitation,
|
||||
} as {
|
||||
[nodeType: string]: React.ElementType;
|
||||
}
|
||||
}
|
||||
}
|
||||
>
|
||||
{currentContent}
|
||||
</ReactMarkdown>
|
||||
</CodeBlockProvider>
|
||||
</ArtifactProvider>
|
||||
>
|
||||
{currentContent}
|
||||
</ReactMarkdown>
|
||||
</CodeBlockProvider>
|
||||
</ArtifactProvider>
|
||||
</MarkdownErrorBoundary>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
import React, { memo, useMemo, useRef, useEffect, lazy, Suspense } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import { useToastContext, useCodeBlockContext } from '~/Providers';
|
||||
import CodeBlock from '~/components/Messages/Content/CodeBlock';
|
||||
import useHasAccess from '~/hooks/Roles/useHasAccess';
|
||||
import { useFileDownload } from '~/data-provider';
|
||||
import useLocalize from '~/hooks/useLocalize';
|
||||
import { handleDoubleClick } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
// Loading fallback component for lazy-loaded Mermaid diagrams
|
||||
const MermaidLoadingFallback = memo(() => {
|
||||
const localize = useLocalize();
|
||||
return (
|
||||
<div className="my-4 rounded-lg border border-border-light bg-surface-primary p-4 text-center text-text-secondary dark:border-border-heavy dark:bg-surface-primary-alt">
|
||||
{localize('com_ui_loading_diagram')}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
type TCodeProps = {
|
||||
inline?: boolean;
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
export const code: React.ElementType = memo(({ className, children }: TCodeProps) => {
|
||||
const canRunCode = useHasAccess({
|
||||
permissionType: PermissionTypes.RUN_CODE,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
const match = /language-(\w+)/.exec(className ?? '');
|
||||
const lang = match && match[1];
|
||||
const isMath = lang === 'math';
|
||||
const isMermaid = lang === 'mermaid';
|
||||
const isSingleLine = typeof children === 'string' && children.split('\n').length === 1;
|
||||
|
||||
const { getNextIndex, resetCounter } = useCodeBlockContext();
|
||||
const blockIndex = useRef(getNextIndex(isMath || isSingleLine)).current;
|
||||
|
||||
useEffect(() => {
|
||||
resetCounter();
|
||||
}, [children, resetCounter]);
|
||||
|
||||
if (isMath) {
|
||||
return <>{children}</>;
|
||||
} else if (isMermaid && typeof children === 'string') {
|
||||
const SandpackMermaidDiagram = lazy(() => import('./SandpackMermaidDiagram'));
|
||||
return (
|
||||
<Suspense fallback={<MermaidLoadingFallback />}>
|
||||
<SandpackMermaidDiagram content={children} />
|
||||
</Suspense>
|
||||
);
|
||||
} else if (isSingleLine) {
|
||||
return (
|
||||
<code onDoubleClick={handleDoubleClick} className={className}>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<CodeBlock
|
||||
lang={lang ?? 'text'}
|
||||
codeChildren={children}
|
||||
blockIndex={blockIndex}
|
||||
allowExecution={canRunCode}
|
||||
/>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
export const codeNoExecution: React.ElementType = memo(({ className, children }: TCodeProps) => {
|
||||
const match = /language-(\w+)/.exec(className ?? '');
|
||||
const lang = match && match[1];
|
||||
const isMermaid = lang === 'mermaid';
|
||||
|
||||
if (lang === 'math') {
|
||||
return children;
|
||||
} else if (isMermaid && typeof children === 'string') {
|
||||
const SandpackMermaidDiagram = lazy(() => import('./SandpackMermaidDiagram'));
|
||||
return (
|
||||
<Suspense fallback={<MermaidLoadingFallback />}>
|
||||
<SandpackMermaidDiagram content={children} />
|
||||
</Suspense>
|
||||
);
|
||||
} else if (typeof children === 'string' && children.split('\n').length === 1) {
|
||||
return (
|
||||
<code onDoubleClick={handleDoubleClick} className={className}>
|
||||
{children}
|
||||
</code>
|
||||
);
|
||||
} else {
|
||||
return <CodeBlock lang={lang ?? 'text'} codeChildren={children} allowExecution={false} />;
|
||||
}
|
||||
});
|
||||
|
||||
type TAnchorProps = {
|
||||
href: string;
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
export const a: React.ElementType = memo(({ href, children }: TAnchorProps) => {
|
||||
const user = useRecoilValue(store.user);
|
||||
const { showToast } = useToastContext();
|
||||
const localize = useLocalize();
|
||||
|
||||
const {
|
||||
file_id = '',
|
||||
filename = '',
|
||||
filepath,
|
||||
} = useMemo(() => {
|
||||
const pattern = new RegExp(`(?:files|outputs)/${user?.id}/([^\\s]+)`);
|
||||
const match = href.match(pattern);
|
||||
if (match && match[0]) {
|
||||
const path = match[0];
|
||||
const parts = path.split('/');
|
||||
const name = parts.pop();
|
||||
const file_id = parts.pop();
|
||||
return { file_id, filename: name, filepath: path };
|
||||
}
|
||||
return { file_id: '', filename: '', filepath: '' };
|
||||
}, [user?.id, href]);
|
||||
|
||||
const { refetch: downloadFile } = useFileDownload(user?.id ?? '', file_id);
|
||||
const props: { target?: string; onClick?: React.MouseEventHandler } = { target: '_new' };
|
||||
|
||||
if (!file_id || !filename) {
|
||||
return (
|
||||
<a href={href} {...props}>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
const handleDownload = async (event: React.MouseEvent<HTMLAnchorElement>) => {
|
||||
event.preventDefault();
|
||||
try {
|
||||
const stream = await downloadFile();
|
||||
if (stream.data == null || stream.data === '') {
|
||||
console.error('Error downloading file: No data found');
|
||||
showToast({
|
||||
status: 'error',
|
||||
message: localize('com_ui_download_error'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
const link = document.createElement('a');
|
||||
link.href = stream.data;
|
||||
link.setAttribute('download', filename);
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(stream.data);
|
||||
} catch (error) {
|
||||
console.error('Error downloading file:', error);
|
||||
}
|
||||
};
|
||||
|
||||
props.onClick = handleDownload;
|
||||
props.target = '_blank';
|
||||
|
||||
return (
|
||||
<a
|
||||
href={filepath?.startsWith('files/') ? `/api/${filepath}` : `/api/files/${filepath}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</a>
|
||||
);
|
||||
});
|
||||
|
||||
type TParagraphProps = {
|
||||
children: React.ReactNode;
|
||||
};
|
||||
|
||||
export const p: React.ElementType = memo(({ children }: TParagraphProps) => {
|
||||
return <p className="mb-2 whitespace-pre-wrap">{children}</p>;
|
||||
});
|
||||
@@ -0,0 +1,90 @@
|
||||
import React from 'react';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import supersub from 'remark-supersub';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import type { PluggableList } from 'unified';
|
||||
import { code, codeNoExecution, a, p } from './MarkdownComponents';
|
||||
import { CodeBlockProvider } from '~/Providers';
|
||||
import { langSubset } from '~/utils';
|
||||
|
||||
interface ErrorBoundaryState {
|
||||
hasError: boolean;
|
||||
error?: Error;
|
||||
}
|
||||
|
||||
interface MarkdownErrorBoundaryProps {
|
||||
children: React.ReactNode;
|
||||
content: string;
|
||||
codeExecution?: boolean;
|
||||
}
|
||||
|
||||
class MarkdownErrorBoundary extends React.Component<
|
||||
MarkdownErrorBoundaryProps,
|
||||
ErrorBoundaryState
|
||||
> {
|
||||
constructor(props: MarkdownErrorBoundaryProps) {
|
||||
super(props);
|
||||
this.state = { hasError: false };
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
|
||||
return { hasError: true, error };
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
|
||||
console.error('Markdown rendering error:', error, errorInfo);
|
||||
}
|
||||
|
||||
componentDidUpdate(prevProps: MarkdownErrorBoundaryProps) {
|
||||
if (prevProps.content !== this.props.content && this.state.hasError) {
|
||||
this.setState({ hasError: false, error: undefined });
|
||||
}
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.state.hasError) {
|
||||
const { content, codeExecution = true } = this.props;
|
||||
|
||||
const rehypePlugins: PluggableList = [
|
||||
[
|
||||
rehypeHighlight,
|
||||
{
|
||||
detect: true,
|
||||
ignoreMissing: true,
|
||||
subset: langSubset,
|
||||
},
|
||||
],
|
||||
];
|
||||
|
||||
return (
|
||||
<CodeBlockProvider>
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[
|
||||
/** @ts-ignore */
|
||||
supersub,
|
||||
remarkGfm,
|
||||
]}
|
||||
/** @ts-ignore */
|
||||
rehypePlugins={rehypePlugins}
|
||||
components={
|
||||
{
|
||||
code: codeExecution ? code : codeNoExecution,
|
||||
a,
|
||||
p,
|
||||
} as {
|
||||
[nodeType: string]: React.ElementType;
|
||||
}
|
||||
}
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
</CodeBlockProvider>
|
||||
);
|
||||
}
|
||||
|
||||
return this.props.children;
|
||||
}
|
||||
}
|
||||
|
||||
export default MarkdownErrorBoundary;
|
||||
@@ -6,8 +6,9 @@ import supersub from 'remark-supersub';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import rehypeHighlight from 'rehype-highlight';
|
||||
import type { PluggableList } from 'unified';
|
||||
import { code, codeNoExecution, a, p } from './Markdown';
|
||||
import { code, codeNoExecution, a, p } from './MarkdownComponents';
|
||||
import { CodeBlockProvider, ArtifactProvider } from '~/Providers';
|
||||
import MarkdownErrorBoundary from './MarkdownErrorBoundary';
|
||||
import { langSubset } from '~/utils';
|
||||
|
||||
const MarkdownLite = memo(
|
||||
@@ -25,32 +26,34 @@ const MarkdownLite = memo(
|
||||
];
|
||||
|
||||
return (
|
||||
<ArtifactProvider>
|
||||
<CodeBlockProvider>
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[
|
||||
<MarkdownErrorBoundary content={content} codeExecution={codeExecution}>
|
||||
<ArtifactProvider>
|
||||
<CodeBlockProvider>
|
||||
<ReactMarkdown
|
||||
remarkPlugins={[
|
||||
/** @ts-ignore */
|
||||
supersub,
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: false }],
|
||||
]}
|
||||
/** @ts-ignore */
|
||||
supersub,
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: false }],
|
||||
]}
|
||||
/** @ts-ignore */
|
||||
rehypePlugins={rehypePlugins}
|
||||
// linkTarget="_new"
|
||||
components={
|
||||
{
|
||||
code: codeExecution ? code : codeNoExecution,
|
||||
a,
|
||||
p,
|
||||
} as {
|
||||
[nodeType: string]: React.ElementType;
|
||||
rehypePlugins={rehypePlugins}
|
||||
// linkTarget="_new"
|
||||
components={
|
||||
{
|
||||
code: codeExecution ? code : codeNoExecution,
|
||||
a,
|
||||
p,
|
||||
} as {
|
||||
[nodeType: string]: React.ElementType;
|
||||
}
|
||||
}
|
||||
}
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
</CodeBlockProvider>
|
||||
</ArtifactProvider>
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
</CodeBlockProvider>
|
||||
</ArtifactProvider>
|
||||
</MarkdownErrorBoundary>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
@@ -13,14 +13,25 @@ export default function MemoryArtifacts({ attachments }: { attachments?: TAttach
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const prevShowInfoRef = useRef<boolean>(showInfo);
|
||||
|
||||
const memoryArtifacts = useMemo(() => {
|
||||
const { hasErrors, memoryArtifacts } = useMemo(() => {
|
||||
let hasErrors = false;
|
||||
const result: MemoryArtifact[] = [];
|
||||
for (const attachment of attachments ?? []) {
|
||||
|
||||
if (!attachments || attachments.length === 0) {
|
||||
return { hasErrors, memoryArtifacts: result };
|
||||
}
|
||||
|
||||
for (const attachment of attachments) {
|
||||
if (attachment?.[Tools.memory] != null) {
|
||||
result.push(attachment[Tools.memory]);
|
||||
|
||||
if (!hasErrors && attachment[Tools.memory].type === 'error') {
|
||||
hasErrors = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
return { hasErrors, memoryArtifacts: result };
|
||||
}, [attachments]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
@@ -75,7 +86,12 @@ export default function MemoryArtifacts({ attachments }: { attachments?: TAttach
|
||||
<div className="flex items-center">
|
||||
<div className="inline-block">
|
||||
<button
|
||||
className="outline-hidden my-1 flex items-center gap-1 text-sm font-semibold text-text-secondary-alt transition-colors hover:text-text-primary"
|
||||
className={cn(
|
||||
'outline-hidden my-1 flex items-center gap-1 text-sm font-semibold transition-colors',
|
||||
hasErrors
|
||||
? 'text-red-500 hover:text-red-600 dark:text-red-400 dark:hover:text-red-500'
|
||||
: 'text-text-secondary-alt hover:text-text-primary',
|
||||
)}
|
||||
type="button"
|
||||
onClick={() => setShowInfo((prev) => !prev)}
|
||||
aria-expanded={showInfo}
|
||||
@@ -102,7 +118,7 @@ export default function MemoryArtifacts({ attachments }: { attachments?: TAttach
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
{localize('com_ui_memory_updated')}
|
||||
{hasErrors ? localize('com_ui_memory_error') : localize('com_ui_memory_updated')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,17 +1,47 @@
|
||||
import type { MemoryArtifact } from 'librechat-data-provider';
|
||||
import { useMemo } from 'react';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: MemoryArtifact[] }) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const { updatedMemories, deletedMemories, errorMessages } = useMemo(() => {
|
||||
const updated = memoryArtifacts.filter((art) => art.type === 'update');
|
||||
const deleted = memoryArtifacts.filter((art) => art.type === 'delete');
|
||||
const errors = memoryArtifacts.filter((art) => art.type === 'error');
|
||||
|
||||
const messages = errors.map((artifact) => {
|
||||
try {
|
||||
const errorData = JSON.parse(artifact.value as string);
|
||||
|
||||
if (errorData.errorType === 'already_exceeded') {
|
||||
return localize('com_ui_memory_already_exceeded', {
|
||||
tokens: errorData.tokenCount,
|
||||
});
|
||||
} else if (errorData.errorType === 'would_exceed') {
|
||||
return localize('com_ui_memory_would_exceed', {
|
||||
tokens: errorData.tokenCount,
|
||||
});
|
||||
} else {
|
||||
return localize('com_ui_memory_error');
|
||||
}
|
||||
} catch {
|
||||
return localize('com_ui_memory_error');
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
updatedMemories: updated,
|
||||
deletedMemories: deleted,
|
||||
errorMessages: messages,
|
||||
};
|
||||
}, [memoryArtifacts, localize]);
|
||||
|
||||
if (memoryArtifacts.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Group artifacts by type
|
||||
const updatedMemories = memoryArtifacts.filter((artifact) => artifact.type === 'update');
|
||||
const deletedMemories = memoryArtifacts.filter((artifact) => artifact.type === 'delete');
|
||||
|
||||
if (updatedMemories.length === 0 && deletedMemories.length === 0) {
|
||||
if (updatedMemories.length === 0 && deletedMemories.length === 0 && errorMessages.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -23,8 +53,8 @@ export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: Memor
|
||||
{localize('com_ui_memory_updated_items')}
|
||||
</h4>
|
||||
<div className="space-y-2">
|
||||
{updatedMemories.map((artifact, index) => (
|
||||
<div key={`update-${index}`} className="rounded-lg p-3">
|
||||
{updatedMemories.map((artifact) => (
|
||||
<div key={`update-${artifact.key}`} className="rounded-lg p-3">
|
||||
<div className="mb-1 text-xs font-medium uppercase tracking-wide text-text-secondary">
|
||||
{artifact.key}
|
||||
</div>
|
||||
@@ -43,8 +73,8 @@ export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: Memor
|
||||
{localize('com_ui_memory_deleted_items')}
|
||||
</h4>
|
||||
<div className="space-y-2">
|
||||
{deletedMemories.map((artifact, index) => (
|
||||
<div key={`delete-${index}`} className="rounded-lg p-3 opacity-60">
|
||||
{deletedMemories.map((artifact) => (
|
||||
<div key={`delete-${artifact.key}`} className="rounded-lg p-3 opacity-60">
|
||||
<div className="mb-1 text-xs font-medium uppercase tracking-wide text-text-secondary">
|
||||
{artifact.key}
|
||||
</div>
|
||||
@@ -56,6 +86,24 @@ export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: Memor
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errorMessages.length > 0 && (
|
||||
<div>
|
||||
<h4 className="mb-2 text-sm font-semibold text-red-500">
|
||||
{localize('com_ui_memory_storage_full')}
|
||||
</h4>
|
||||
<div className="space-y-2">
|
||||
{errorMessages.map((errorMessage) => (
|
||||
<div
|
||||
key={errorMessage}
|
||||
className="rounded-md bg-red-50 p-3 text-sm text-red-800 dark:bg-red-900/20 dark:text-red-400"
|
||||
>
|
||||
{errorMessage}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
455
client/src/components/Chat/Messages/Content/MermaidDiagram.tsx
Normal file
455
client/src/components/Chat/Messages/Content/MermaidDiagram.tsx
Normal file
@@ -0,0 +1,455 @@
|
||||
import React, {
|
||||
useLayoutEffect,
|
||||
useState,
|
||||
memo,
|
||||
useContext,
|
||||
useMemo,
|
||||
useCallback,
|
||||
useRef,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import DOMPurify from 'dompurify';
|
||||
import { TransformWrapper, TransformComponent, ReactZoomPanPinchRef } from 'react-zoom-pan-pinch';
|
||||
import { cn } from '~/utils';
|
||||
import { ThemeContext, isDark } from '~/hooks/ThemeContext';
|
||||
import { ClipboardIcon, CheckIcon, ZoomIn, ZoomOut, RotateCcw } from 'lucide-react';
|
||||
|
||||
interface InlineMermaidProps {
|
||||
content: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
const InlineMermaidDiagram = memo(({ content, className }: InlineMermaidProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [svgContent, setSvgContent] = useState<string>('');
|
||||
const [isRendered, setIsRendered] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isCopied, setIsCopied] = useState(false);
|
||||
const [wasAutoCorrected, setWasAutoCorrected] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const { theme } = useContext(ThemeContext);
|
||||
const isDarkMode = isDark(theme);
|
||||
const transformRef = useRef<ReactZoomPanPinchRef>(null);
|
||||
|
||||
const diagramKey = useMemo(
|
||||
() => `${content.trim()}-${isDarkMode ? 'dark' : 'light'}`,
|
||||
[content, isDarkMode],
|
||||
);
|
||||
|
||||
const handleCopy = useCallback(async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(content);
|
||||
setIsCopied(true);
|
||||
setTimeout(() => setIsCopied(false), 2000);
|
||||
} catch (err) {
|
||||
console.error('Failed to copy diagram content:', err);
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const handleZoomIn = useCallback(() => {
|
||||
if (transformRef.current) {
|
||||
transformRef.current.zoomIn(0.2);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleZoomOut = useCallback(() => {
|
||||
if (transformRef.current) {
|
||||
transformRef.current.zoomOut(0.2);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleResetZoom = useCallback(() => {
|
||||
if (transformRef.current) {
|
||||
transformRef.current.resetTransform();
|
||||
transformRef.current.centerView(1, 0);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Memoized to prevent re-renders when content/theme changes
|
||||
const fixCommonSyntaxIssues = useMemo(() => {
|
||||
return (text: string) => {
|
||||
let fixed = text;
|
||||
|
||||
fixed = fixed.replace(/--\s+>/g, '-->');
|
||||
fixed = fixed.replace(/--\s+\|/g, '--|');
|
||||
fixed = fixed.replace(/\|\s+-->/g, '|-->');
|
||||
fixed = fixed.replace(/\[([^[\]]*)"([^[\]]*)"([^[\]]*)\]/g, '[$1$2$3]');
|
||||
fixed = fixed.replace(/subgraph([A-Za-z])/g, 'subgraph $1');
|
||||
|
||||
return fixed;
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleTryFix = useCallback(() => {
|
||||
const fixedContent = fixCommonSyntaxIssues(content);
|
||||
if (fixedContent !== content) {
|
||||
// Currently just copies the fixed version to clipboard
|
||||
navigator.clipboard.writeText(fixedContent).then(() => {
|
||||
setError(t('com_mermaid_fix_copied'));
|
||||
});
|
||||
}
|
||||
}, [content, fixCommonSyntaxIssues, t]);
|
||||
|
||||
// Use ref to track timeout to prevent stale closures
|
||||
const timeoutRef = React.useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
let isCancelled = false;
|
||||
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
timeoutRef.current = null;
|
||||
}
|
||||
|
||||
// Clear previous SVG content
|
||||
setSvgContent('');
|
||||
|
||||
const cleanContent = content.trim();
|
||||
|
||||
setError(null);
|
||||
setWasAutoCorrected(false);
|
||||
setIsRendered(false);
|
||||
setIsLoading(false);
|
||||
|
||||
if (!cleanContent) {
|
||||
setError(t('com_mermaid_error_no_content'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Debounce rendering to avoid flickering during rapid content changes
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
if (!isCancelled) {
|
||||
renderDiagram();
|
||||
}
|
||||
}, 300);
|
||||
|
||||
async function renderDiagram() {
|
||||
if (isCancelled) return;
|
||||
|
||||
try {
|
||||
if (
|
||||
!cleanContent.match(
|
||||
/^(graph|flowchart|sequenceDiagram|classDiagram|stateDiagram|erDiagram|journey|gantt|pie|gitgraph|mindmap|timeline|quadrant|block-beta|sankey|xychart|gitgraph)/i,
|
||||
)
|
||||
) {
|
||||
if (!isCancelled) {
|
||||
setError(t('com_mermaid_error_invalid_type'));
|
||||
setWasAutoCorrected(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Dynamic import to reduce bundle size
|
||||
setIsLoading(true);
|
||||
const mermaid = await import('mermaid').then((m) => m.default);
|
||||
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize with error suppression to avoid console spam
|
||||
mermaid.initialize({
|
||||
startOnLoad: false,
|
||||
theme: isDarkMode ? 'dark' : 'default',
|
||||
securityLevel: 'loose',
|
||||
logLevel: 'fatal',
|
||||
flowchart: {
|
||||
useMaxWidth: true,
|
||||
htmlLabels: true,
|
||||
},
|
||||
suppressErrorRendering: true,
|
||||
});
|
||||
|
||||
let result;
|
||||
let contentToRender = cleanContent;
|
||||
|
||||
try {
|
||||
const id = `mermaid-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
result = await mermaid.render(id, contentToRender);
|
||||
} catch (_renderError) {
|
||||
const fixedContent = fixCommonSyntaxIssues(cleanContent);
|
||||
if (fixedContent !== cleanContent) {
|
||||
try {
|
||||
const fixedId = `mermaid-fixed-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
result = await mermaid.render(fixedId, fixedContent);
|
||||
contentToRender = fixedContent;
|
||||
setWasAutoCorrected(true);
|
||||
} catch (_fixedRenderError) {
|
||||
if (!isCancelled) {
|
||||
setError(t('com_mermaid_error_invalid_syntax_auto_correct'));
|
||||
setWasAutoCorrected(false);
|
||||
setIsLoading(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (!isCancelled) {
|
||||
setError(t('com_mermaid_error_invalid_syntax'));
|
||||
setWasAutoCorrected(false);
|
||||
setIsLoading(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if component was unmounted during async render
|
||||
if (isCancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (result && result.svg) {
|
||||
let processedSvg = result.svg;
|
||||
|
||||
// Enhance SVG for better zoom/pan interaction
|
||||
processedSvg = processedSvg.replace(
|
||||
'<svg',
|
||||
'<svg style="width: 100%; height: auto;" preserveAspectRatio="xMidYMid meet"',
|
||||
);
|
||||
|
||||
// Sanitize SVG content to prevent XSS attacks
|
||||
const sanitizedSvg = DOMPurify.sanitize(processedSvg, {
|
||||
USE_PROFILES: { svg: true, svgFilters: true },
|
||||
ADD_TAGS: ['foreignObject'],
|
||||
ADD_ATTR: ['preserveAspectRatio'],
|
||||
FORBID_TAGS: ['script', 'object', 'embed', 'iframe'],
|
||||
FORBID_ATTR: ['onerror', 'onload', 'onclick'],
|
||||
});
|
||||
|
||||
if (!isCancelled) {
|
||||
setSvgContent(sanitizedSvg);
|
||||
setIsRendered(true);
|
||||
setIsLoading(false);
|
||||
}
|
||||
} else {
|
||||
if (!isCancelled) {
|
||||
setError(t('com_mermaid_error_no_svg'));
|
||||
setWasAutoCorrected(false);
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Mermaid rendering error:', err);
|
||||
if (!isCancelled) {
|
||||
const errorMessage =
|
||||
err instanceof Error
|
||||
? err.message
|
||||
: t('com_mermaid_error_rendering_failed', 'Failed to render diagram');
|
||||
setError(t('com_mermaid_error_rendering_failed', { '0': errorMessage }));
|
||||
setWasAutoCorrected(false);
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return () => {
|
||||
isCancelled = true;
|
||||
};
|
||||
}, [diagramKey, content, isDarkMode, fixCommonSyntaxIssues, t]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
return () => {
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
timeoutRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (error) {
|
||||
const fixedContent = fixCommonSyntaxIssues(content);
|
||||
const canTryFix = fixedContent !== content;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'my-4 overflow-auto rounded-lg border border-red-300 bg-red-50',
|
||||
'dark:border-red-700 dark:bg-red-900/20',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="p-4 text-red-600 dark:text-red-400">
|
||||
<div className="flex items-start justify-between">
|
||||
<div className="flex-1">
|
||||
<strong>{t('com_mermaid_error')}</strong> {error}
|
||||
{canTryFix && (
|
||||
<div className={cn('mt-2 text-sm text-red-500 dark:text-red-300')}>
|
||||
💡 {t('com_mermaid_error_fixes_detected')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="ml-4 flex gap-2">
|
||||
{canTryFix && (
|
||||
<button
|
||||
onClick={handleTryFix}
|
||||
className={cn(
|
||||
'rounded border px-3 py-1 text-xs transition-colors',
|
||||
'border-blue-300 bg-blue-100 text-blue-700 hover:bg-blue-200',
|
||||
'dark:border-blue-700 dark:bg-blue-900 dark:text-blue-300 dark:hover:bg-blue-800',
|
||||
)}
|
||||
title={t('com_mermaid_copy_potential_fix')}
|
||||
>
|
||||
{t('com_mermaid_try_fix')}
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={handleCopy}
|
||||
className={cn(
|
||||
'rounded border px-3 py-1 text-xs transition-colors',
|
||||
'border-gray-300 bg-gray-100 text-gray-700 hover:bg-gray-200',
|
||||
'dark:border-gray-600 dark:bg-gray-800 dark:text-gray-300 dark:hover:bg-gray-700',
|
||||
)}
|
||||
title={t('com_mermaid_copy_code')}
|
||||
>
|
||||
{isCopied ? `✓ ${t('com_mermaid_copied')}` : t('com_mermaid_copy')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="p-4 pt-0">
|
||||
<pre className="overflow-x-auto rounded bg-gray-100 p-2 text-sm dark:bg-gray-800">
|
||||
<code className="language-mermaid">{content}</code>
|
||||
</pre>
|
||||
{canTryFix && (
|
||||
<div className="mt-3 rounded border border-blue-200 bg-blue-50 p-3 dark:border-blue-800 dark:bg-blue-950">
|
||||
<div className={cn('mb-2 text-sm font-medium text-blue-800 dark:text-blue-200')}>
|
||||
{t('com_mermaid_suggested_fix')}
|
||||
</div>
|
||||
<pre className="overflow-x-auto rounded border bg-white p-2 text-sm dark:bg-gray-800">
|
||||
<code className="language-mermaid">{fixedContent}</code>
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
key={diagramKey}
|
||||
className={cn(
|
||||
'relative my-4 overflow-auto rounded-lg border border-border-light bg-surface-primary',
|
||||
'dark:border-border-heavy dark:bg-surface-primary-alt',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{isRendered && wasAutoCorrected && (
|
||||
<div
|
||||
className={cn(
|
||||
'absolute left-2 top-2 z-10 rounded-md px-2 py-1 text-xs',
|
||||
'bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200',
|
||||
'border border-yellow-300 dark:border-yellow-700',
|
||||
'shadow-sm',
|
||||
)}
|
||||
>
|
||||
✨ {t('com_mermaid_auto_fixed')}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isRendered && svgContent && (
|
||||
<div className="absolute right-2 top-2 z-10 flex gap-1">
|
||||
<button
|
||||
onClick={handleZoomIn}
|
||||
className={cn(
|
||||
'rounded-md p-2 transition-all duration-200',
|
||||
'hover:bg-surface-hover active:bg-surface-active',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
'border border-border-light dark:border-border-heavy',
|
||||
'bg-surface-primary dark:bg-surface-primary-alt',
|
||||
'shadow-sm hover:shadow-md',
|
||||
)}
|
||||
title={t('com_mermaid_zoom_in')}
|
||||
>
|
||||
<ZoomIn className="h-4 w-4" />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleZoomOut}
|
||||
className={cn(
|
||||
'rounded-md p-2 transition-all duration-200',
|
||||
'hover:bg-surface-hover active:bg-surface-active',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
'border border-border-light dark:border-border-heavy',
|
||||
'bg-surface-primary dark:bg-surface-primary-alt',
|
||||
'shadow-sm hover:shadow-md',
|
||||
)}
|
||||
title={t('com_mermaid_zoom_out')}
|
||||
>
|
||||
<ZoomOut className="h-4 w-4" />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleResetZoom}
|
||||
className={cn(
|
||||
'rounded-md p-2 transition-all duration-200',
|
||||
'hover:bg-surface-hover active:bg-surface-active',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
'border border-border-light dark:border-border-heavy',
|
||||
'bg-surface-primary dark:bg-surface-primary-alt',
|
||||
'shadow-sm hover:shadow-md',
|
||||
)}
|
||||
title={t('com_mermaid_reset_zoom')}
|
||||
>
|
||||
<RotateCcw className="h-4 w-4" />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleCopy}
|
||||
className={cn(
|
||||
'rounded-md p-2 transition-all duration-200',
|
||||
'hover:bg-surface-hover active:bg-surface-active',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
'border border-border-light dark:border-border-heavy',
|
||||
'bg-surface-primary dark:bg-surface-primary-alt',
|
||||
'shadow-sm hover:shadow-md',
|
||||
)}
|
||||
title={t('com_mermaid_copy_code')}
|
||||
>
|
||||
{isCopied ? (
|
||||
<CheckIcon className="h-4 w-4 text-green-500" />
|
||||
) : (
|
||||
<ClipboardIcon className="h-4 w-4" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="p-4">
|
||||
{(isLoading || !isRendered) && (
|
||||
<div className="animate-pulse text-center text-text-secondary">
|
||||
{t('com_mermaid_rendering')}
|
||||
</div>
|
||||
)}
|
||||
{isRendered && svgContent && (
|
||||
<TransformWrapper
|
||||
ref={transformRef}
|
||||
initialScale={1}
|
||||
minScale={0.1}
|
||||
maxScale={4}
|
||||
limitToBounds={false}
|
||||
centerOnInit={true}
|
||||
wheel={{ step: 0.1 }}
|
||||
panning={{ velocityDisabled: true }}
|
||||
alignmentAnimation={{ disabled: true }}
|
||||
>
|
||||
<TransformComponent
|
||||
wrapperStyle={{
|
||||
width: '100%',
|
||||
height: 'auto',
|
||||
minHeight: '200px',
|
||||
maxHeight: '600px',
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="mermaid-container flex min-h-[200px] items-center justify-center"
|
||||
dangerouslySetInnerHTML={{ __html: svgContent }}
|
||||
/>
|
||||
</TransformComponent>
|
||||
</TransformWrapper>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
InlineMermaidDiagram.displayName = 'InlineMermaidDiagram';
|
||||
|
||||
export default InlineMermaidDiagram;
|
||||
@@ -85,7 +85,7 @@ const Part = memo(
|
||||
|
||||
const isToolCall =
|
||||
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
|
||||
if (isToolCall && toolCall.name === Tools.execute_code) {
|
||||
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
|
||||
return (
|
||||
<ExecuteCode
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { ContentTypes } from 'librechat-data-provider';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
|
||||
import type { Agents } from 'librechat-data-provider';
|
||||
import type { TEditProps } from '~/common';
|
||||
import Container from '~/components/Chat/Messages/Content/Container';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
@@ -12,18 +13,19 @@ import { useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
|
||||
const EditTextPart = ({
|
||||
text,
|
||||
part,
|
||||
index,
|
||||
messageId,
|
||||
isSubmitting,
|
||||
enterEdit,
|
||||
}: Omit<TEditProps, 'message' | 'ask'> & {
|
||||
}: Omit<TEditProps, 'message' | 'ask' | 'text'> & {
|
||||
index: number;
|
||||
messageId: string;
|
||||
part: Agents.MessageContentText | Agents.ReasoningDeltaUpdate;
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const { getMessages, setMessages, conversation } = useChatContext();
|
||||
const { ask, getMessages, setMessages, conversation } = useChatContext();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
@@ -34,15 +36,16 @@ const EditTextPart = ({
|
||||
[getMessages, messageId],
|
||||
);
|
||||
|
||||
const chatDirection = useRecoilValue(store.chatDirection);
|
||||
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const updateMessageContentMutation = useUpdateMessageContentMutation(conversationId ?? '');
|
||||
|
||||
const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
|
||||
const isRTL = chatDirection === 'rtl';
|
||||
const isRTL = chatDirection?.toLowerCase() === 'rtl';
|
||||
|
||||
const { register, handleSubmit, setValue } = useForm({
|
||||
defaultValues: {
|
||||
text: text ?? '',
|
||||
text: (ContentTypes.THINK in part ? part.think : part.text) || '',
|
||||
},
|
||||
});
|
||||
|
||||
@@ -55,15 +58,7 @@ const EditTextPart = ({
|
||||
}
|
||||
}, []);
|
||||
|
||||
/*
|
||||
const resubmitMessage = () => {
|
||||
showToast({
|
||||
status: 'warning',
|
||||
message: localize('com_warning_resubmit_unsupported'),
|
||||
});
|
||||
|
||||
// const resubmitMessage = (data: { text: string }) => {
|
||||
// Not supported by AWS Bedrock
|
||||
const resubmitMessage = (data: { text: string }) => {
|
||||
const messages = getMessages();
|
||||
const parentMessage = messages?.find((msg) => msg.messageId === message?.parentMessageId);
|
||||
|
||||
@@ -73,17 +68,19 @@ const EditTextPart = ({
|
||||
ask(
|
||||
{ ...parentMessage },
|
||||
{
|
||||
editedText: data.text,
|
||||
editedContent: {
|
||||
index,
|
||||
text: data.text,
|
||||
type: part.type,
|
||||
},
|
||||
editedMessageId: messageId,
|
||||
isRegenerate: true,
|
||||
isEdited: true,
|
||||
},
|
||||
);
|
||||
|
||||
setSiblingIdx((siblingIdx ?? 0) - 1);
|
||||
enterEdit(true);
|
||||
};
|
||||
*/
|
||||
|
||||
const updateMessage = (data: { text: string }) => {
|
||||
const messages = getMessages();
|
||||
@@ -167,13 +164,13 @@ const EditTextPart = ({
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-2 flex w-full justify-center text-center">
|
||||
{/* <button
|
||||
<button
|
||||
className="btn btn-primary relative mr-2"
|
||||
disabled={isSubmitting}
|
||||
onClick={handleSubmit(resubmitMessage)}
|
||||
>
|
||||
{localize('com_ui_save_submit')}
|
||||
</button> */}
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-secondary relative mr-2"
|
||||
disabled={isSubmitting}
|
||||
|
||||
@@ -10,23 +10,23 @@ import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
interface ParsedArgs {
|
||||
lang: string;
|
||||
code: string;
|
||||
lang?: string;
|
||||
code?: string;
|
||||
}
|
||||
|
||||
export function useParseArgs(args: string): ParsedArgs {
|
||||
export function useParseArgs(args?: string): ParsedArgs | null {
|
||||
return useMemo(() => {
|
||||
let parsedArgs: ParsedArgs | string = args;
|
||||
let parsedArgs: ParsedArgs | string | undefined | null = args;
|
||||
try {
|
||||
parsedArgs = JSON.parse(args);
|
||||
parsedArgs = JSON.parse(args || '');
|
||||
} catch {
|
||||
// console.error('Failed to parse args:', e);
|
||||
}
|
||||
if (typeof parsedArgs === 'object') {
|
||||
return parsedArgs;
|
||||
}
|
||||
const langMatch = args.match(/"lang"\s*:\s*"(\w+)"/);
|
||||
const codeMatch = args.match(/"code"\s*:\s*"(.+?)(?="\s*,\s*"(session_id|args)"|"\s*})/s);
|
||||
const langMatch = args?.match(/"lang"\s*:\s*"(\w+)"/);
|
||||
const codeMatch = args?.match(/"code"\s*:\s*"(.+?)(?="\s*,\s*"(session_id|args)"|"\s*})/s);
|
||||
|
||||
let code = '';
|
||||
if (codeMatch) {
|
||||
@@ -51,7 +51,7 @@ export default function ExecuteCode({
|
||||
attachments,
|
||||
}: {
|
||||
initialProgress: number;
|
||||
args: string;
|
||||
args?: string;
|
||||
output?: string;
|
||||
attachments?: TAttachment[];
|
||||
}) {
|
||||
@@ -65,7 +65,7 @@ export default function ExecuteCode({
|
||||
const outputRef = useRef<string>(output);
|
||||
const prevShowCodeRef = useRef<boolean>(showCode);
|
||||
|
||||
const { lang, code } = useParseArgs(args);
|
||||
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
|
||||
const progress = useProgress(initialProgress);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -144,7 +144,7 @@ export default function ExecuteCode({
|
||||
onClick={() => setShowCode((prev) => !prev)}
|
||||
inProgressText={localize('com_ui_analyzing')}
|
||||
finishedText={localize('com_ui_analyzing_finished')}
|
||||
hasInput={!!code.length}
|
||||
hasInput={!!code?.length}
|
||||
isExpanded={showCode}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
import React, { memo, useMemo, useEffect } from 'react';
|
||||
import { SandpackPreview, SandpackProvider } from '@codesandbox/sandpack-react/unstyled';
|
||||
import dedent from 'dedent';
|
||||
import { cn } from '~/utils';
|
||||
import { sharedOptions } from '~/utils/artifacts';
|
||||
|
||||
interface SandpackMermaidDiagramProps {
|
||||
content: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
// Minimal dependencies for Mermaid only
|
||||
const mermaidDependencies = {
|
||||
mermaid: '^11.8.1',
|
||||
'react-zoom-pan-pinch': '^3.7.0',
|
||||
};
|
||||
|
||||
// Lean mermaid template with inline SVG icons
|
||||
const leanMermaidTemplate = dedent`
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
TransformWrapper,
|
||||
TransformComponent,
|
||||
ReactZoomPanPinchRef,
|
||||
} from "react-zoom-pan-pinch";
|
||||
import mermaid from "mermaid";
|
||||
|
||||
// Inline SVG icons
|
||||
const ZoomInIcon = () => (
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2">
|
||||
<circle cx="11" cy="11" r="8"/>
|
||||
<path d="m21 21-4.35-4.35"/>
|
||||
<line x1="11" y1="8" x2="11" y2="14"/>
|
||||
<line x1="8" y1="11" x2="14" y2="11"/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
const ZoomOutIcon = () => (
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2">
|
||||
<circle cx="11" cy="11" r="8"/>
|
||||
<path d="m21 21-4.35-4.35"/>
|
||||
<line x1="8" y1="11" x2="14" y2="11"/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
const ResetIcon = () => (
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2">
|
||||
<polyline points="1 4 1 10 7 10"/>
|
||||
<polyline points="23 20 23 14 17 14"/>
|
||||
<path d="M20.49 9A9 9 0 0 0 5.64 5.64L1 10m22 4l-4.64 4.36A9 9 0 0 1 3.51 15"/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
interface MermaidDiagramProps {
|
||||
content: string;
|
||||
}
|
||||
|
||||
const MermaidDiagram: React.FC<MermaidDiagramProps> = ({ content }) => {
|
||||
const mermaidRef = useRef<HTMLDivElement>(null);
|
||||
const transformRef = useRef<ReactZoomPanPinchRef>(null);
|
||||
const [isRendered, setIsRendered] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
mermaid.initialize({
|
||||
startOnLoad: false,
|
||||
theme: "default",
|
||||
securityLevel: "loose",
|
||||
flowchart: {
|
||||
useMaxWidth: true,
|
||||
htmlLabels: true,
|
||||
curve: "basis",
|
||||
},
|
||||
});
|
||||
|
||||
const renderDiagram = async () => {
|
||||
if (mermaidRef.current) {
|
||||
try {
|
||||
const id = "mermaid-" + Date.now();
|
||||
const { svg } = await mermaid.render(id, content);
|
||||
mermaidRef.current.innerHTML = svg;
|
||||
|
||||
const svgElement = mermaidRef.current.querySelector("svg");
|
||||
if (svgElement) {
|
||||
svgElement.style.width = "100%";
|
||||
svgElement.style.height = "100%";
|
||||
}
|
||||
setIsRendered(true);
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
console.error("Mermaid rendering error:", err);
|
||||
setError(err.message || "Failed to render diagram");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
renderDiagram();
|
||||
}, [content]);
|
||||
|
||||
const handleZoomIn = () => {
|
||||
if (transformRef.current) {
|
||||
transformRef.current.zoomIn(0.2);
|
||||
}
|
||||
};
|
||||
|
||||
const handleZoomOut = () => {
|
||||
if (transformRef.current) {
|
||||
transformRef.current.zoomOut(0.2);
|
||||
}
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
if (transformRef.current) {
|
||||
transformRef.current.resetTransform();
|
||||
transformRef.current.centerView(1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div style={{ padding: '16px', color: '#ef4444', backgroundColor: '#fee2e2', borderRadius: '8px', border: '1px solid #fecaca' }}>
|
||||
<strong>Error:</strong> {error}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ position: 'relative', height: '100%', width: '100%', backgroundColor: '#f9fafb' }}>
|
||||
<TransformWrapper
|
||||
ref={transformRef}
|
||||
initialScale={1}
|
||||
minScale={0.1}
|
||||
maxScale={4}
|
||||
wheel={{ step: 0.1 }}
|
||||
centerOnInit={true}
|
||||
>
|
||||
<TransformComponent
|
||||
wrapperStyle={{
|
||||
width: "100%",
|
||||
height: "100%",
|
||||
}}
|
||||
>
|
||||
<div
|
||||
ref={mermaidRef}
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
minHeight: '300px',
|
||||
padding: '20px',
|
||||
}}
|
||||
/>
|
||||
</TransformComponent>
|
||||
</TransformWrapper>
|
||||
|
||||
{isRendered && (
|
||||
<div style={{ position: 'absolute', bottom: '8px', right: '8px', display: 'flex', gap: '8px' }}>
|
||||
<button
|
||||
onClick={handleZoomIn}
|
||||
style={{
|
||||
padding: '8px',
|
||||
backgroundColor: 'white',
|
||||
border: '1px solid #e5e7eb',
|
||||
borderRadius: '6px',
|
||||
cursor: 'pointer',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
title="Zoom in"
|
||||
>
|
||||
<ZoomInIcon />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleZoomOut}
|
||||
style={{
|
||||
padding: '8px',
|
||||
backgroundColor: 'white',
|
||||
border: '1px solid #e5e7eb',
|
||||
borderRadius: '6px',
|
||||
cursor: 'pointer',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
title="Zoom out"
|
||||
>
|
||||
<ZoomOutIcon />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleReset}
|
||||
style={{
|
||||
padding: '8px',
|
||||
backgroundColor: 'white',
|
||||
border: '1px solid #e5e7eb',
|
||||
borderRadius: '6px',
|
||||
cursor: 'pointer',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
title="Reset zoom"
|
||||
>
|
||||
<ResetIcon />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default MermaidDiagram;
|
||||
`;
|
||||
|
||||
const wrapLeanMermaidDiagram = (content: string) => {
|
||||
return dedent`
|
||||
import React from 'react';
|
||||
import MermaidDiagram from './MermaidDiagram';
|
||||
|
||||
export default function App() {
|
||||
const content = \`${content.replace(/`/g, '\\`')}\`;
|
||||
return <MermaidDiagram content={content} />;
|
||||
}
|
||||
`;
|
||||
};
|
||||
|
||||
const getLeanMermaidFiles = (content: string) => {
|
||||
return {
|
||||
'/App.tsx': wrapLeanMermaidDiagram(content),
|
||||
'/MermaidDiagram.tsx': leanMermaidTemplate,
|
||||
};
|
||||
};
|
||||
|
||||
const SandpackMermaidDiagram = memo(({ content, className }: SandpackMermaidDiagramProps) => {
|
||||
const files = useMemo(() => getLeanMermaidFiles(content), [content]);
|
||||
const sandpackProps = useMemo(
|
||||
() => ({
|
||||
customSetup: {
|
||||
dependencies: mermaidDependencies,
|
||||
},
|
||||
}),
|
||||
[],
|
||||
);
|
||||
|
||||
// Force iframe to respect container height
|
||||
useEffect(() => {
|
||||
const fixIframeHeight = () => {
|
||||
const container = document.querySelector('.sandpack-mermaid-diagram');
|
||||
if (container) {
|
||||
const iframe = container.querySelector('iframe');
|
||||
if (iframe && iframe.style.height && iframe.style.height !== '100%') {
|
||||
iframe.style.height = '100%';
|
||||
iframe.style.minHeight = '100%';
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Initial fix
|
||||
fixIframeHeight();
|
||||
|
||||
// Fix on any DOM changes
|
||||
const observer = new MutationObserver(fixIframeHeight);
|
||||
const container = document.querySelector('.sandpack-mermaid-diagram');
|
||||
if (container) {
|
||||
observer.observe(container, {
|
||||
attributes: true,
|
||||
childList: true,
|
||||
subtree: true,
|
||||
attributeFilter: ['style'],
|
||||
});
|
||||
}
|
||||
|
||||
return () => observer.disconnect();
|
||||
}, [content]);
|
||||
|
||||
return (
|
||||
<SandpackProvider files={files} options={sharedOptions} template="react-ts" {...sandpackProps}>
|
||||
<SandpackPreview
|
||||
showOpenInCodeSandbox={false}
|
||||
showRefreshButton={false}
|
||||
showSandpackErrorOverlay={true}
|
||||
/>
|
||||
</SandpackProvider>
|
||||
);
|
||||
});
|
||||
|
||||
SandpackMermaidDiagram.displayName = 'SandpackMermaidDiagram';
|
||||
|
||||
export default SandpackMermaidDiagram;
|
||||
@@ -0,0 +1,197 @@
|
||||
import React from 'react';
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom/extend-expect';
|
||||
import MemoryArtifacts from '../MemoryArtifacts';
|
||||
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
|
||||
// Mock the localize hook
|
||||
jest.mock('~/hooks', () => ({
|
||||
useLocalize: () => (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
com_ui_memory_updated: 'Updated saved memory',
|
||||
com_ui_memory_error: 'Memory Error',
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock the MemoryInfo component
|
||||
jest.mock('../MemoryInfo', () => ({
|
||||
__esModule: true,
|
||||
default: ({ memoryArtifacts }: { memoryArtifacts: MemoryArtifact[] }) => (
|
||||
<div data-testid="memory-info">
|
||||
{memoryArtifacts.map((artifact, index) => (
|
||||
<div key={index} data-testid={`memory-artifact-${artifact.type}`}>
|
||||
{artifact.type}: {artifact.key}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
describe('MemoryArtifacts', () => {
|
||||
const createMemoryAttachment = (type: 'update' | 'delete' | 'error', key: string): TAttachment =>
|
||||
({
|
||||
type: Tools.memory,
|
||||
[Tools.memory]: {
|
||||
type,
|
||||
key,
|
||||
value:
|
||||
type === 'error'
|
||||
? JSON.stringify({ errorType: 'exceeded', tokenCount: 100 })
|
||||
: 'test value',
|
||||
} as MemoryArtifact,
|
||||
}) as TAttachment;
|
||||
|
||||
describe('Error State Handling', () => {
|
||||
test('displays error styling when memory artifacts contain errors', () => {
|
||||
const attachments = [
|
||||
createMemoryAttachment('error', 'system'),
|
||||
createMemoryAttachment('update', 'memory1'),
|
||||
];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
const button = screen.getByRole('button');
|
||||
expect(button).toHaveClass('text-red-500');
|
||||
expect(button).toHaveClass('hover:text-red-600');
|
||||
expect(button).toHaveClass('dark:text-red-400');
|
||||
expect(button).toHaveClass('dark:hover:text-red-500');
|
||||
});
|
||||
|
||||
test('displays normal styling when no errors present', () => {
|
||||
const attachments = [
|
||||
createMemoryAttachment('update', 'memory1'),
|
||||
createMemoryAttachment('delete', 'memory2'),
|
||||
];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
const button = screen.getByRole('button');
|
||||
expect(button).toHaveClass('text-text-secondary-alt');
|
||||
expect(button).toHaveClass('hover:text-text-primary');
|
||||
expect(button).not.toHaveClass('text-red-500');
|
||||
});
|
||||
|
||||
test('displays error message when errors are present', () => {
|
||||
const attachments = [createMemoryAttachment('error', 'system')];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
expect(screen.getByText('Memory Error')).toBeInTheDocument();
|
||||
expect(screen.queryByText('Updated saved memory')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('displays normal message when no errors are present', () => {
|
||||
const attachments = [createMemoryAttachment('update', 'memory1')];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
expect(screen.getByText('Updated saved memory')).toBeInTheDocument();
|
||||
expect(screen.queryByText('Memory Error')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memory Artifacts Filtering', () => {
|
||||
test('filters and passes only memory-type attachments to MemoryInfo', () => {
|
||||
const attachments = [
|
||||
createMemoryAttachment('update', 'memory1'),
|
||||
{ type: 'file' } as TAttachment, // Non-memory attachment
|
||||
createMemoryAttachment('error', 'system'),
|
||||
];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
// Click to expand
|
||||
fireEvent.click(screen.getByRole('button'));
|
||||
|
||||
// Check that only memory artifacts are passed to MemoryInfo
|
||||
expect(screen.getByTestId('memory-artifact-update')).toBeInTheDocument();
|
||||
expect(screen.getByTestId('memory-artifact-error')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('correctly identifies multiple error artifacts', () => {
|
||||
const attachments = [
|
||||
createMemoryAttachment('error', 'system1'),
|
||||
createMemoryAttachment('error', 'system2'),
|
||||
createMemoryAttachment('update', 'memory1'),
|
||||
];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
const button = screen.getByRole('button');
|
||||
expect(button).toHaveClass('text-red-500');
|
||||
expect(screen.getByText('Memory Error')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Collapse/Expand Functionality', () => {
|
||||
test('toggles memory info visibility on button click', () => {
|
||||
const attachments = [createMemoryAttachment('update', 'memory1')];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
// Initially collapsed
|
||||
expect(screen.queryByTestId('memory-info')).not.toBeInTheDocument();
|
||||
|
||||
// Click to expand
|
||||
fireEvent.click(screen.getByRole('button'));
|
||||
expect(screen.getByTestId('memory-info')).toBeInTheDocument();
|
||||
|
||||
// Click to collapse
|
||||
fireEvent.click(screen.getByRole('button'));
|
||||
expect(screen.queryByTestId('memory-info')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('updates aria-expanded attribute correctly', () => {
|
||||
const attachments = [createMemoryAttachment('update', 'memory1')];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
const button = screen.getByRole('button');
|
||||
expect(button).toHaveAttribute('aria-expanded', 'false');
|
||||
|
||||
fireEvent.click(button);
|
||||
expect(button).toHaveAttribute('aria-expanded', 'true');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('handles empty attachments array', () => {
|
||||
render(<MemoryArtifacts attachments={[]} />);
|
||||
expect(screen.queryByRole('button')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('handles undefined attachments', () => {
|
||||
render(<MemoryArtifacts />);
|
||||
expect(screen.queryByRole('button')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('handles attachments with no memory artifacts', () => {
|
||||
const attachments = [{ type: 'file' } as TAttachment, { type: 'image' } as TAttachment];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
expect(screen.queryByRole('button')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('handles malformed memory artifacts gracefully', () => {
|
||||
const attachments = [
|
||||
{
|
||||
type: Tools.memory,
|
||||
[Tools.memory]: {
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
// Missing value
|
||||
},
|
||||
} as TAttachment,
|
||||
];
|
||||
|
||||
render(<MemoryArtifacts attachments={attachments} />);
|
||||
|
||||
const button = screen.getByRole('button');
|
||||
expect(button).toHaveClass('text-red-500');
|
||||
expect(screen.getByText('Memory Error')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,267 @@
|
||||
import React from 'react';
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import '@testing-library/jest-dom/extend-expect';
|
||||
import MemoryInfo from '../MemoryInfo';
|
||||
import type { MemoryArtifact } from 'librechat-data-provider';
|
||||
|
||||
// Mock the localize hook
|
||||
jest.mock('~/hooks', () => ({
|
||||
useLocalize: () => (key: string, params?: Record<string, any>) => {
|
||||
const translations: Record<string, string> = {
|
||||
com_ui_memory_updated_items: 'Updated Memories',
|
||||
com_ui_memory_deleted_items: 'Deleted Memories',
|
||||
com_ui_memory_already_exceeded: `Memory storage already full - exceeded by ${params?.tokens || 0} tokens. Delete existing memories before adding new ones.`,
|
||||
com_ui_memory_would_exceed: `Cannot save - would exceed limit by ${params?.tokens || 0} tokens. Delete existing memories to make space.`,
|
||||
com_ui_memory_deleted: 'This memory has been deleted',
|
||||
com_ui_memory_storage_full: 'Memory Storage Full',
|
||||
com_ui_memory_error: 'Memory Error',
|
||||
com_ui_updated_successfully: 'Updated successfully',
|
||||
com_ui_none_selected: 'None selected',
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
}));
|
||||
|
||||
describe('MemoryInfo', () => {
|
||||
const createMemoryArtifact = (
|
||||
type: 'update' | 'delete' | 'error',
|
||||
key: string,
|
||||
value?: string,
|
||||
): MemoryArtifact => ({
|
||||
type,
|
||||
key,
|
||||
value: value || `test value for ${key}`,
|
||||
});
|
||||
|
||||
describe('Error Memory Display', () => {
|
||||
test('displays error section when memory is already exceeded', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: JSON.stringify({ errorType: 'already_exceeded', tokenCount: 150 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Memory storage already full - exceeded by 150 tokens. Delete existing memories before adding new ones.',
|
||||
),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('displays error when memory would exceed limit', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: JSON.stringify({ errorType: 'would_exceed', tokenCount: 50 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Cannot save - would exceed limit by 50 tokens. Delete existing memories to make space.',
|
||||
),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('displays multiple error messages', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system1',
|
||||
value: JSON.stringify({ errorType: 'already_exceeded', tokenCount: 100 }),
|
||||
},
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system2',
|
||||
value: JSON.stringify({ errorType: 'would_exceed', tokenCount: 25 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Memory storage already full - exceeded by 100 tokens. Delete existing memories before adding new ones.',
|
||||
),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Cannot save - would exceed limit by 25 tokens. Delete existing memories to make space.',
|
||||
),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('applies correct styling to error messages', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: JSON.stringify({ errorType: 'would_exceed', tokenCount: 50 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
const errorMessage = screen.getByText(
|
||||
'Cannot save - would exceed limit by 50 tokens. Delete existing memories to make space.',
|
||||
);
|
||||
const errorContainer = errorMessage.closest('div');
|
||||
|
||||
expect(errorContainer).toHaveClass('rounded-md');
|
||||
expect(errorContainer).toHaveClass('bg-red-50');
|
||||
expect(errorContainer).toHaveClass('p-3');
|
||||
expect(errorContainer).toHaveClass('text-sm');
|
||||
expect(errorContainer).toHaveClass('text-red-800');
|
||||
expect(errorContainer).toHaveClass('dark:bg-red-900/20');
|
||||
expect(errorContainer).toHaveClass('dark:text-red-400');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mixed Memory Types', () => {
|
||||
test('displays all sections when different memory types are present', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
createMemoryArtifact('update', 'memory1', 'Updated content'),
|
||||
createMemoryArtifact('delete', 'memory2'),
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: JSON.stringify({ errorType: 'would_exceed', tokenCount: 200 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
// Check all sections are present
|
||||
expect(screen.getByText('Updated Memories')).toBeInTheDocument();
|
||||
expect(screen.getByText('Deleted Memories')).toBeInTheDocument();
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
|
||||
// Check content
|
||||
expect(screen.getByText('memory1')).toBeInTheDocument();
|
||||
expect(screen.getByText('Updated content')).toBeInTheDocument();
|
||||
expect(screen.getByText('memory2')).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Cannot save - would exceed limit by 200 tokens. Delete existing memories to make space.',
|
||||
),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('only displays sections with content', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: JSON.stringify({ errorType: 'already_exceeded', tokenCount: 10 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
// Only error section should be present
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
expect(screen.queryByText('Updated Memories')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('Deleted Memories')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('handles empty memory artifacts array', () => {
|
||||
const { container } = render(<MemoryInfo memoryArtifacts={[]} />);
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
test('handles malformed error data gracefully', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: 'invalid json',
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
// Should render generic error message
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
expect(screen.getByText('Memory Error')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('handles missing value in error artifact', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
// value is undefined
|
||||
} as MemoryArtifact,
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
expect(screen.getByText('Memory Error')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('handles unknown errorType gracefully', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
{
|
||||
type: 'error',
|
||||
key: 'system',
|
||||
value: JSON.stringify({ errorType: 'unknown_type', tokenCount: 30 }),
|
||||
},
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
// Should show generic error message for unknown types
|
||||
expect(screen.getByText('Memory Storage Full')).toBeInTheDocument();
|
||||
expect(screen.getByText('Memory Error')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('returns null when no memories of any type exist', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [{ type: 'unknown' as any, key: 'test' }];
|
||||
|
||||
const { container } = render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Update and Delete Memory Display', () => {
|
||||
test('displays updated memories correctly', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
createMemoryArtifact('update', 'preferences', 'User prefers dark mode'),
|
||||
createMemoryArtifact('update', 'location', 'Lives in San Francisco'),
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
expect(screen.getByText('Updated Memories')).toBeInTheDocument();
|
||||
expect(screen.getByText('preferences')).toBeInTheDocument();
|
||||
expect(screen.getByText('User prefers dark mode')).toBeInTheDocument();
|
||||
expect(screen.getByText('location')).toBeInTheDocument();
|
||||
expect(screen.getByText('Lives in San Francisco')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('displays deleted memories correctly', () => {
|
||||
const memoryArtifacts: MemoryArtifact[] = [
|
||||
createMemoryArtifact('delete', 'old_preference'),
|
||||
createMemoryArtifact('delete', 'outdated_info'),
|
||||
];
|
||||
|
||||
render(<MemoryInfo memoryArtifacts={memoryArtifacts} />);
|
||||
|
||||
expect(screen.getByText('Deleted Memories')).toBeInTheDocument();
|
||||
expect(screen.getByText('old_preference')).toBeInTheDocument();
|
||||
expect(screen.getByText('outdated_info')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -233,9 +233,17 @@ export default function Fork({
|
||||
status: 'info',
|
||||
});
|
||||
},
|
||||
onError: () => {
|
||||
onError: (error) => {
|
||||
/** Rate limit error (429 status code) */
|
||||
const isRateLimitError =
|
||||
(error as any)?.response?.status === 429 ||
|
||||
(error as any)?.status === 429 ||
|
||||
(error as any)?.statusCode === 429;
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_fork_error'),
|
||||
message: isRateLimitError
|
||||
? localize('com_ui_fork_error_rate_limit')
|
||||
: localize('com_ui_fork_error'),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
|
||||
@@ -62,6 +62,7 @@ const errorMessages = {
|
||||
const { info } = json;
|
||||
return info;
|
||||
},
|
||||
[ErrorTypes.GOOGLE_TOOL_CONFLICT]: 'com_error_google_tool_conflict',
|
||||
[ViolationTypes.BAN]:
|
||||
'Your account has been temporarily banned due to violations of our service.',
|
||||
invalid_api_key:
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
General,
|
||||
Chat,
|
||||
Speech,
|
||||
Beta,
|
||||
Commands,
|
||||
Data,
|
||||
Account,
|
||||
@@ -233,9 +232,6 @@ export default function Settings({ open, onOpenChange }: TDialogProps) {
|
||||
<Tabs.Content value={SettingsTabValues.CHAT}>
|
||||
<Chat />
|
||||
</Tabs.Content>
|
||||
<Tabs.Content value={SettingsTabValues.BETA}>
|
||||
<Beta />
|
||||
</Tabs.Content>
|
||||
<Tabs.Content value={SettingsTabValues.COMMANDS}>
|
||||
<Commands />
|
||||
</Tabs.Content>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user