Compare commits

..

1 Commits

Author SHA1 Message Date
Dustin Healy
fcfb0f47f9 WIP: Responsive Segmented Controls 2025-07-07 10:22:35 -07:00
215 changed files with 3566 additions and 13353 deletions

View File

@@ -349,11 +349,6 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20
TTS_VIOLATION_SCORE=0
STT_VIOLATION_SCORE=0
FORK_VIOLATION_SCORE=0
IMPORT_VIOLATION_SCORE=0
FILE_UPLOAD_VIOLATION_SCORE=0
LOGIN_MAX=7
LOGIN_WINDOW=5
@@ -580,10 +575,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
# If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
#===================================================#
# UI #
#===================================================#
@@ -601,31 +592,11 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS Options #
#===============#
# Enable Redis for caching and session storage
# REDIS_URI=10.10.10.10:6379
# USE_REDIS=true
# Single Redis instance
# REDIS_URI=redis://127.0.0.1:6379
# Redis cluster (multiple nodes)
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
# Redis with TLS/SSL encryption and CA certificate
# REDIS_URI=rediss://127.0.0.1:6380
# REDIS_CA=/path/to/ca-cert.pem
# Redis authentication (if required)
# REDIS_USERNAME=your_redis_username
# REDIS_PASSWORD=your_redis_password
# Redis key prefix configuration
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
# REDIS_KEY_PREFIX_VAR=K_REVISION
# Or use static prefix directly
# REDIS_KEY_PREFIX=librechat
# Redis connection limits
# REDIS_MAX_LISTENERS=40
# USE_REDIS_CLUSTER=true
# REDIS_CA=/path/to/ca.crt
#==================================================#
# Others #

View File

@@ -7,7 +7,7 @@ on:
- release/*
paths:
- 'api/**'
- 'packages/**'
- 'packages/api/**'
jobs:
tests_Backend:
name: Run Backend unit tests

View File

@@ -1,32 +0,0 @@
name: Publish `@librechat/client` to NPM
on:
workflow_dispatch:
inputs:
reason:
description: 'Reason for manual trigger'
required: false
default: 'Manual publish requested'
jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Node.js
uses: actions/setup-node@v4
with:
node-version: '18.x'
- name: Check if client package exists
run: |
if [ -d "packages/client" ]; then
echo "Client package directory found"
else
echo "Client package directory not found - workflow ready for future use"
exit 0
fi
- name: Placeholder for future publishing
run: echo "Client package publishing workflow is ready"

View File

@@ -8,7 +8,7 @@ on:
- release/*
paths:
- 'client/**'
- 'packages/data-provider/**'
- 'packages/**'
jobs:
tests_frontend_ubuntu:

9
.gitignore vendored
View File

@@ -125,12 +125,3 @@ helm/**/.values.yaml
# SAML Idp cert
*.cert
# AI Assistants
/.claude/
/.cursor/
/.copilot/
/.aider/
/.openai/
/.tabnine/
/.codeium

View File

@@ -108,15 +108,12 @@ class BaseClient {
/**
* Abstract method to record token usage. Subclasses must implement this method.
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* Should only be used if `recordCollectedUsage` was not used instead.
* @param {string} [model]
* @param {number} promptTokens
* @param {number} completionTokens
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens }) {
async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
model,
promptTokens,
completionTokens,
});
@@ -200,10 +197,6 @@ class BaseClient {
this.currentMessages[this.currentMessages.length - 1].messageId = head;
}
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
responseMessageId = crypto.randomUUID();
}
this.responseMessageId = responseMessageId;
return {
@@ -744,13 +737,9 @@ class BaseClient {
} else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens,
model: responseMessage.model,
});
}
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
}
if (userMessagePromise) {

View File

@@ -237,9 +237,41 @@ const formatAgentMessages = (payload) => {
return messages;
};
/**
* Formats an array of messages for LangChain, making sure all content fields are strings
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
*/
const formatContentStrings = (payload) => {
const messages = [];
for (const message of payload) {
if (typeof message.content === 'string') {
continue;
}
if (!Array.isArray(message.content)) {
continue;
}
// Reduce text types to a single string, ignore all other types
const content = message.content.reduce((acc, curr) => {
if (curr.type === ContentTypes.TEXT) {
return `${acc}${curr[ContentTypes.TEXT]}\n`;
}
return acc;
}, '');
message.content = content.trim();
}
return messages;
};
module.exports = {
formatMessage,
formatFromLangChain,
formatAgentMessages,
formatContentStrings,
formatLangChainMessages,
};

View File

@@ -422,46 +422,6 @@ describe('BaseClient', () => {
expect(response).toEqual(expectedResult);
});
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
const mockCrypto = require('crypto');
const newUUID = 'new-uuid-1234';
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe(newUUID);
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
mockCrypto.randomUUID.mockRestore();
});
test('should not replace responseMessageId when isRegenerate is false', async () => {
const opts = {
isRegenerate: false,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id_');
});
test('should not replace responseMessageId when it does not end with underscore', async () => {
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id');
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {

View File

@@ -11,25 +11,17 @@ const { getFiles } = require('~/models/File');
* @param {Object} options
* @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @returns {Promise<{
* files: Array<{ file_id: string; filename: string }>,
* toolContext: string
* }>}
*/
const primeFiles = async (options) => {
const { tool_resources, req, agentId } = options;
const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
const dbFiles = (
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;

View File

@@ -1,9 +1,14 @@
const { mcpToolPattern } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
const {
Tools,
EToolResources,
loadWebSearchAuth,
replaceSpecialVars,
} = require('librechat-data-provider');
const {
availableTools,
manifestToolMap,
@@ -230,7 +235,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */
const toolContextMap = {};
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
const appTools = (await getCachedTools({ includeGlobal: true })) ?? {};
for (const tool of tools) {
if (tool === Tools.execute_code) {
@@ -240,13 +245,7 @@ const loadTools = async ({
authFields: [EnvVar.CODE_API_KEY],
});
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
const { files, toolContext } = await primeCodeFiles(
{
...options,
agentId: agent?.id,
},
codeApiKey,
);
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -261,10 +260,7 @@ const loadTools = async ({
continue;
} else if (tool === Tools.file_search) {
requestedTools[tool] = async () => {
const { files, toolContext } = await primeSearchFiles({
...options,
agentId: agent?.id,
});
const { files, toolContext } = await primeSearchFiles(options);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -298,7 +294,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
});
};
continue;
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
} else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () =>
createMCPTool({
req: options.req,

View File

@@ -1,33 +0,0 @@
const fs = require('fs');
const { math, isEnabled } = require('@librechat/api');
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
}
const USE_REDIS = isEnabled(process.env.USE_REDIS);
if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
const cacheConfig = {
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
CI: isEnabled(process.env.CI),
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
};
module.exports = { cacheConfig };

View File

@@ -1,108 +0,0 @@
const fs = require('fs');
describe('cacheConfig', () => {
let originalEnv;
let originalReadFileSync;
beforeEach(() => {
originalEnv = { ...process.env };
originalReadFileSync = fs.readFileSync;
// Clear all related env vars first
delete process.env.REDIS_URI;
delete process.env.REDIS_CA;
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
// Clear require cache
jest.resetModules();
});
afterEach(() => {
process.env = originalEnv;
fs.readFileSync = originalReadFileSync;
jest.resetModules();
});
describe('REDIS_KEY_PREFIX validation and resolution', () => {
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
expect(() => {
require('./cacheConfig');
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
});
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.DEPLOYMENT_ID = 'test-deployment-123';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
});
test('should use direct REDIS_KEY_PREFIX value', () => {
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
});
test('should default to empty string when no prefix is configured', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle empty variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
process.env.EMPTY_VAR = '';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle undefined variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
});
describe('USE_REDIS and REDIS_URI validation', () => {
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
process.env.USE_REDIS = 'true';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
expect(() => {
require('./cacheConfig');
}).not.toThrow();
});
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = '';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_CA).toBeNull();
});
});
});

View File

@@ -1,66 +0,0 @@
const KeyvRedis = require('@keyv/redis').default;
const { Keyv } = require('keyv');
const { cacheConfig } = require('./cacheConfig');
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
const { Time } = require('librechat-data-provider');
const { RedisStore: ConnectRedis } = require('connect-redis');
const MemoryStore = require('memorystore')(require('express-session'));
const { violationFile } = require('./keyvFiles');
const { RedisStore } = require('rate-limit-redis');
/**
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
* @param {string} namespace - The cache namespace.
* @param {number} [ttl] - Time to live for cache entries.
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
* @returns {Keyv} Cache instance.
*/
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
if (cacheConfig.USE_REDIS) {
const keyvRedis = new KeyvRedis(keyvRedisClient);
const cache = new Keyv(keyvRedis, { namespace, ttl });
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
return cache;
}
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
return new Keyv({ namespace, ttl });
};
/**
* Creates a cache instance for storing violation data.
* Uses a file-based fallback store if Redis is not enabled.
* @param {string} namespace - The cache namespace for violations.
* @param {number} [ttl] - Time to live for cache entries.
* @returns {Keyv} Cache instance for violations.
*/
const violationCache = (namespace, ttl = undefined) => {
return standardCache(`violations:${namespace}`, ttl, violationFile);
};
/**
* Creates a session cache instance using Redis or in-memory store.
* @param {string} namespace - The session namespace.
* @param {number} [ttl] - Time to live for session entries.
* @returns {MemoryStore | ConnectRedis} Session store instance.
*/
const sessionCache = (namespace, ttl = undefined) => {
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
return new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
};
/**
* Creates a rate limiter cache using Redis.
* @param {string} prefix - The key prefix for rate limiting.
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
*/
const limiterCache = (prefix) => {
if (!prefix) throw new Error('prefix is required');
if (!cacheConfig.USE_REDIS) return undefined;
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
return new RedisStore({ sendCommand, prefix });
};
const sendCommand = (...args) => ioredisClient?.call(...args);
module.exports = { standardCache, sessionCache, violationCache, limiterCache };

View File

@@ -1,270 +0,0 @@
const { Time } = require('librechat-data-provider');
// Mock dependencies first
const mockKeyvRedis = {
namespace: '',
keyPrefixSeparator: '',
};
const mockKeyv = jest.fn().mockReturnValue({ mock: 'keyv' });
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
const mockIoredisClient = {
call: jest.fn(),
};
const mockKeyvRedisClient = {};
const mockViolationFile = {};
// Mock modules before requiring the main module
jest.mock('@keyv/redis', () => ({
default: jest.fn().mockImplementation(() => mockKeyvRedis),
}));
jest.mock('keyv', () => ({
Keyv: mockKeyv,
}));
jest.mock('./cacheConfig', () => ({
cacheConfig: {
USE_REDIS: false,
REDIS_KEY_PREFIX: 'test',
},
}));
jest.mock('./redisClients', () => ({
keyvRedisClient: mockKeyvRedisClient,
ioredisClient: mockIoredisClient,
GLOBAL_PREFIX_SEPARATOR: '::',
}));
jest.mock('./keyvFiles', () => ({
violationFile: mockViolationFile,
}));
jest.mock('connect-redis', () => ({ RedisStore: mockConnectRedis }));
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
jest.mock('rate-limit-redis', () => ({
RedisStore: mockRedisStore,
}));
// Import after mocking
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
const { cacheConfig } = require('./cacheConfig');
describe('cacheFactory', () => {
beforeEach(() => {
jest.clearAllMocks();
// Reset cache config mock
cacheConfig.USE_REDIS = false;
cacheConfig.REDIS_KEY_PREFIX = 'test';
});
describe('redisCache', () => {
it('should create Redis cache when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
});
it('should create Redis cache with undefined ttl when not provided', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
standardCache(namespace);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
});
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
const fallbackStore = { some: 'store' };
standardCache(namespace, ttl, fallbackStore);
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
});
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
});
it('should handle namespace and ttl as undefined', () => {
cacheConfig.USE_REDIS = false;
standardCache();
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
});
});
describe('violationCache', () => {
it('should create violation cache with prefixed namespace', () => {
const namespace = 'test-violations';
const ttl = 7200;
// We can't easily mock the internal redisCache call since it's in the same module
// But we can test that the function executes without throwing
expect(() => violationCache(namespace, ttl)).not.toThrow();
});
it('should create violation cache with undefined ttl', () => {
const namespace = 'test-violations';
violationCache(namespace);
// The function should call redisCache with violations: prefixed namespace
// Since we can't easily mock the internal redisCache call, we test the behavior
expect(() => violationCache(namespace)).not.toThrow();
});
it('should handle undefined namespace', () => {
expect(() => violationCache(undefined)).not.toThrow();
});
});
describe('sessionCache', () => {
it('should return MemoryStore when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
expect(result).toBe(mockMemoryStore());
});
it('should return ConnectRedis when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl,
prefix: `${namespace}:`,
});
expect(result).toBe(mockConnectRedis());
});
it('should add colon to namespace if not present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should not add colon to namespace if already present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions:';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should handle undefined ttl', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockMemoryStore).toHaveBeenCalledWith({
ttl: undefined,
checkPeriod: Time.ONE_DAY,
});
});
});
describe('limiterCache', () => {
it('should return undefined when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const result = limiterCache('prefix');
expect(result).toBeUndefined();
});
it('should return RedisStore when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const result = limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: `rate-limit:`,
});
expect(result).toBe(mockRedisStore());
});
it('should add colon to prefix if not present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should not add colon to prefix if already present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit:');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should pass sendCommand function that calls ioredisClient.call', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
const sendCommandCall = mockRedisStore.mock.calls[0][0];
const sendCommand = sendCommandCall.sendCommand;
// Test that sendCommand properly delegates to ioredisClient.call
const args = ['GET', 'test-key'];
sendCommand(...args);
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
});
it('should handle undefined prefix', () => {
cacheConfig.USE_REDIS = true;
expect(() => limiterCache()).toThrow('prefix is required');
});
});
});

View File

@@ -1,52 +1,113 @@
const { cacheConfig } = require('./cacheConfig');
const { Keyv } = require('keyv');
const { isEnabled, math } = require('@librechat/api');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile } = require('./keyvFiles');
const { logFile, violationFile } = require('./keyvFiles');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo');
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const isRedisEnabled = isEnabled(USE_REDIS);
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
const createViolationInstance = (namespace) => {
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};
// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
const config = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const roles = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const mcpTools = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MCP_TOOLS });
const audioRuns = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const messages = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
const flows = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
const tokenConfig = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const s3ExpiryInterval = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
const abortKeys = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
const openIdExchangedTokensCache = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.OPENID_EXCHANGED_TOKENS, ttl: Time.TEN_MINUTES });
const namespaces = {
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }),
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS),
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT),
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER),
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT),
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS),
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
[ViolationTypes.BAN]: new Keyv({
[CacheKeys.ROLES]: roles,
[CacheKeys.MCP_TOOLS]: mcpTools,
[CacheKeys.CONFIG_STORE]: config,
[CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({
store: keyvMongo,
namespace: CacheKeys.BANS,
ttl: cacheConfig.BAN_DURATION,
namespace: CacheKeys.ENCODED_DOMAINS,
ttl: 0,
}),
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION),
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES),
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES),
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES),
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
CacheKeys.OPENID_EXCHANGED_TOKENS,
Time.TEN_MINUTES,
general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
registrations: createViolationInstance('registrations'),
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
ViolationTypes.RESET_PASSWORD_LIMIT,
),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
[CacheKeys.FLOWS]: flows,
[CacheKeys.OPENID_EXCHANGED_TOKENS]: openIdExchangedTokensCache,
};
/**
@@ -55,10 +116,7 @@ const namespaces = {
*/
function getTTLStores() {
return Object.values(namespaces).filter(
(store) =>
store instanceof Keyv &&
parseInt(store.opts?.ttl ?? '0') > 0 &&
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
(store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
);
}
@@ -94,18 +152,18 @@ async function clearExpiredFromCache(cache) {
if (data?.expires && data.expires <= expiryTime) {
const deleted = await cache.opts.store.delete(key);
if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue;
}
cleared++;
}
} catch (error) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
const deleted = await cache.opts.store.delete(key);
if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue;
}
@@ -114,7 +172,7 @@ async function clearExpiredFromCache(cache) {
}
if (cleared > 0) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.log(
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
);
@@ -155,7 +213,7 @@ async function clearAllExpiredFromCache() {
}
}
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
if (!isRedisEnabled && !isEnabled(CI)) {
/** @type {Set<NodeJS.Timeout>} */
const cleanupIntervals = new Set();
@@ -166,7 +224,7 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
cleanupIntervals.add(cleanup);
if (cacheConfig.DEBUG_MEMORY_CACHE) {
if (debugMemoryCache) {
const monitor = setInterval(() => {
const ttlStores = getTTLStores();
const memory = process.memoryUsage();
@@ -187,13 +245,13 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
}
const dispose = () => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...');
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
cleanupIntervals.forEach((interval) => clearInterval(interval));
cleanupIntervals.clear();
// One final cleanup before exit
clearAllExpiredFromCache().then(() => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed');
debugMemoryCache && console.log('[Cache] Final cleanup completed');
process.exit(0);
});
};

92
api/cache/ioredisClient.js vendored Normal file
View File

@@ -0,0 +1,92 @@
const fs = require('fs');
const Redis = require('ioredis');
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
let ioredisClient;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
} else {
ioredisClient = new Redis(REDIS_URI, redisOptions);
}
ioredisClient.on('ready', () => {
logger.info('IoRedis connection ready');
});
ioredisClient.on('reconnecting', () => {
logger.info('IoRedis connection reconnecting');
});
ioredisClient.on('end', () => {
logger.info('IoRedis connection ended');
});
ioredisClient.on('close', () => {
logger.info('IoRedis connection closed');
});
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
ioredisClient.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
);
} else {
logger.info('[Optional] IoRedis not initialized for rate limiters.');
}
module.exports = ioredisClient;

109
api/cache/keyvRedis.js vendored Normal file
View File

@@ -0,0 +1,109 @@
const fs = require('fs');
const ioredis = require('ioredis');
const KeyvRedis = require('@keyv/redis').default;
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
process.env;
let keyvRedis;
const redis_prefix = REDIS_KEY_PREFIX || '';
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
/** @type {import('@keyv/redis').KeyvRedisOptions} */
let keyvOpts = {
useRedisSets: false,
keyPrefix: redis_prefix,
};
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
const cluster = new ioredis.Cluster(hosts, { redisOptions });
keyvRedis = new KeyvRedis(cluster, keyvOpts);
} else {
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
}
const pingInterval = setInterval(
() => {
logger.debug('KeyvRedis ping');
keyvRedis.client.ping().catch((err) => logger.error('Redis keep-alive ping failed:', err));
},
5 * 60 * 1000,
);
keyvRedis.on('ready', () => {
logger.info('KeyvRedis connection ready');
});
keyvRedis.on('reconnecting', () => {
logger.info('KeyvRedis connection reconnecting');
});
keyvRedis.on('end', () => {
logger.info('KeyvRedis connection ended');
});
keyvRedis.on('close', () => {
clearInterval(pingInterval);
logger.info('KeyvRedis connection closed');
});
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
);
} else {
logger.info('[Optional] Redis not initialized.');
}
module.exports = keyvRedis;

View File

@@ -1,5 +1,4 @@
const { isEnabled } = require('~/server/utils');
const { ViolationTypes } = require('librechat-data-provider');
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
@@ -10,14 +9,14 @@ const banViolation = require('./banViolation');
* @param {Object} res - Express response object.
* @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log.
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
* @param {number} [score=1] - The severity of the violation. Defaults to 1
*/
const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id;
if (!userId) {
return;
}
const logs = getLogStores(ViolationTypes.GENERAL);
const logs = getLogStores('general');
const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;

View File

@@ -1,57 +0,0 @@
const IoRedis = require('ioredis');
const { cacheConfig } = require('./cacheConfig');
const { createClient, createCluster } = require('@keyv/redis');
const GLOBAL_PREFIX_SEPARATOR = '::';
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
const ca = cacheConfig.REDIS_CA;
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
let ioredisClient = null;
if (cacheConfig.USE_REDIS) {
const redisOptions = {
username: username,
password: password,
tls: ca ? { ca } : undefined,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
};
ioredisClient =
urls.length === 1
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
: new IoRedis.Cluster(cacheConfig.REDIS_URI, { redisOptions });
// Pinging the Redis server every 5 minutes to keep the connection alive
const pingInterval = setInterval(() => ioredisClient.ping(), 5 * 60 * 1000);
ioredisClient.on('close', () => clearInterval(pingInterval));
ioredisClient.on('end', () => clearInterval(pingInterval));
}
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
let keyvRedisClient = null;
if (cacheConfig.USE_REDIS) {
// ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
// The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
const redisOptions = { username, password, socket: { tls: ca != null, ca } };
keyvRedisClient =
urls.length === 1
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
: createCluster({
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
defaults: redisOptions,
});
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
// Pinging the Redis server every 5 minutes to keep the connection alive
const keyvPingInterval = setInterval(() => keyvRedisClient.ping(), 5 * 60 * 1000);
keyvRedisClient.on('disconnect', () => clearInterval(keyvPingInterval));
keyvRedisClient.on('end', () => clearInterval(keyvPingInterval));
}
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };

View File

@@ -11,13 +11,12 @@ let flowManager = null;
/**
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
* @param {boolean} [skipIdleCheck] - Skip idle connection checking to avoid unnecessary pings.
* @returns {MCPManager}
*/
function getMCPManager(userId, skipIdleCheck = false) {
function getMCPManager(userId) {
if (!mcpManager) {
mcpManager = MCPManager.getInstance();
} else if (!skipIdleCheck) {
} else {
mcpManager.checkIdleConnections(userId);
}
return mcpManager;

View File

@@ -61,7 +61,7 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
const { model, ...model_parameters } = _m;
/** @type {Record<string, FunctionTool>} */
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
const availableTools = await getCachedTools({ includeGlobal: true });
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);

View File

@@ -1,6 +1,6 @@
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models');

View File

@@ -1,7 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
const { getProjectByName } = require('./Project');
const { getAgent } = require('./Agent');
const { EToolResources, FileContext } = require('librechat-data-provider');
const { File } = require('~/db/models');
/**
@@ -14,124 +12,17 @@ const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean();
};
/**
* Checks if a user has access to multiple files through a shared agent (batch operation)
* @param {string} userId - The user ID to check access for
* @param {string[]} fileIds - Array of file IDs to check
* @param {string} agentId - The agent ID that might grant access
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false));
try {
const agent = await getAgent({ id: agentId });
if (!agent) {
return accessMap;
}
// Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId) {
fileIds.forEach((fileId) => accessMap.set(fileId, true));
return accessMap;
}
// Check if agent is shared with the user via projects
if (!agent.projectIds || agent.projectIds.length === 0) {
return accessMap;
}
// Check if agent is in global project
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
) {
return accessMap;
}
// Agent is globally shared - check if it's collaborative
if (checkCollaborative && !agent.isCollaborative) {
return accessMap;
}
// Check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap;
} catch (error) {
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
return accessMap;
}
};
/**
* Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field.
* @param {Object} [options] - Additional options
* @param {string} [options.userId] - User ID for access control
* @param {string} [options.agentId] - Agent ID that might grant access to files
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
*/
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions };
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
// If userId and agentId are provided, filter files based on access
if (options.userId && options.agentId) {
// Collect file IDs that need access check
const filesToCheck = [];
const ownedFiles = [];
for (const file of files) {
if (file.user && file.user.toString() === options.userId) {
ownedFiles.push(file);
} else {
filesToCheck.push(file);
}
}
if (filesToCheck.length === 0) {
return ownedFiles;
}
// Batch check access for all non-owned files
const fileIds = filesToCheck.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
options.userId,
fileIds,
options.agentId,
false,
);
// Filter files based on access
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
return [...ownedFiles, ...accessibleFiles];
}
return files;
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
};
/**
@@ -285,5 +176,4 @@ module.exports = {
deleteFiles,
deleteFileByFilter,
batchUpdateFiles,
hasAccessToFilesViaAgent,
};

View File

@@ -1,264 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { fileSchema } = require('@librechat/data-schemas');
const { agentSchema } = require('@librechat/data-schemas');
const { projectSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { getFiles, createFile } = require('./File');
const { getProjectByName } = require('./Project');
const { createAgent } = require('./Agent');
let File;
let Agent;
let Project;
describe('File Access Control', () => {
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
File = mongoose.models.File || mongoose.model('File', fileSchema);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await File.deleteMany({});
await Agent.deleteMany({});
await Project.deleteMany({});
});
describe('hasAccessToFilesViaAgent', () => {
it('should efficiently check access for multiple files at once', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
// Create files
for (const fileId of fileIds) {
await createFile({
user: authorId,
file_id: fileId,
filename: `file-${fileId}.txt`,
filepath: `/uploads/${fileId}`,
});
}
// Create agent with only first two files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileIds[0], fileIds[1]],
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for all files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have access only to the first two files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false);
expect(accessMap.get(fileIds[3])).toBe(false);
});
it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create agent
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]], // Only one file attached
},
},
});
// Check access as the author
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should handle non-existent agent gracefully', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const fileIds = [uuidv4(), uuidv4()];
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
// Should have no access to any files
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when agent is not collaborative', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create agent with files but isCollaborative: false
await createAgent({
id: agentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have no access to any files when isCollaborative is false
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
});
describe('getFiles with agent access control', () => {
test('should return files owned by user and files accessible through agent', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
const ownedFileId = `file_${uuidv4()}`;
const sharedFileId = `file_${uuidv4()}`;
const inaccessibleFileId = `file_${uuidv4()}`;
// Create/get global project using getProjectByName which will upsert
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
// Create agent with shared file
await createAgent({
id: agentId,
name: 'Shared Agent',
provider: 'test',
model: 'test-model',
author: authorId,
projectIds: [globalProject._id],
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [sharedFileId],
},
},
});
// Create files
await createFile({
file_id: ownedFileId,
user: userId,
filename: 'owned.txt',
filepath: '/uploads/owned.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: sharedFileId,
user: authorId,
filename: 'shared.txt',
filepath: '/uploads/shared.txt',
type: 'text/plain',
bytes: 200,
embedded: true,
});
await createFile({
file_id: inaccessibleFileId,
user: authorId,
filename: 'inaccessible.txt',
filepath: '/uploads/inaccessible.txt',
type: 'text/plain',
bytes: 300,
});
// Get files with access control
const files = await getFiles(
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
null,
{ text: 0 },
{ userId: userId.toString(), agentId },
);
expect(files).toHaveLength(2);
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
});
test('should return all files when no userId/agentId provided', async () => {
const userId = new mongoose.Types.ObjectId();
const fileId1 = `file_${uuidv4()}`;
const fileId2 = `file_${uuidv4()}`;
await createFile({
file_id: fileId1,
user: userId,
filename: 'file1.txt',
filepath: '/uploads/file1.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: fileId2,
user: new mongoose.Types.ObjectId(),
filename: 'file2.txt',
filepath: '/uploads/file2.txt',
type: 'text/plain',
bytes: 200,
});
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
expect(files).toHaveLength(2);
});
});
});

View File

@@ -1,7 +1,7 @@
const { z } = require('zod');
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
const { Message } = require('~/db/models');
const idSchema = z.string().uuid();

View File

@@ -135,11 +135,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-4': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 },

View File

@@ -636,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
);
});
test('should return correct prompt and completion rates for Grok 4 model', () => {
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
tokenValues['grok-3'].prompt,
@@ -671,15 +662,6 @@ describe('Grok Model Tests - Pricing', () => {
tokenValues['grok-3-mini-fast'].completion,
);
});
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
});
});

View File

@@ -44,20 +44,19 @@
"@googleapis/youtube": "^20.0.0",
"@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.47",
"@langchain/core": "^0.3.62",
"@langchain/core": "^0.3.60",
"@langchain/google-genai": "^0.2.13",
"@langchain/google-vertexai": "^0.2.13",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.67",
"@librechat/agents": "^2.4.56",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"compression": "^1.8.1",
"connect-redis": "^8.1.0",
"compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.7.2",
"cookie-parser": "^1.4.7",
"cors": "^2.8.5",
@@ -67,11 +66,10 @@
"express": "^4.21.2",
"express-mongo-sanitize": "^2.2.0",
"express-rate-limit": "^7.4.1",
"express-session": "^1.18.2",
"express-session": "^1.18.1",
"express-static-gzip": "^2.2.0",
"file-type": "^18.7.0",
"firebase": "^11.0.2",
"form-data": "^4.0.4",
"googleapis": "^126.0.1",
"handlebars": "^4.7.7",
"https-proxy-agent": "^7.0.6",
@@ -89,12 +87,12 @@
"mime": "^3.0.0",
"module-alias": "^2.2.3",
"mongoose": "^8.12.1",
"multer": "^2.0.2",
"multer": "^2.0.1",
"nanoid": "^3.3.7",
"node-fetch": "^2.7.0",
"nodemailer": "^6.9.15",
"ollama": "^0.5.0",
"openai": "^5.10.1",
"openai": "^4.96.2",
"openai-chat-tokens": "^0.2.8",
"openid-client": "^6.5.0",
"passport": "^0.6.0",

View File

@@ -1,10 +1,11 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
const { CacheKeys, AuthType } = require('librechat-data-provider');
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
const { getToolkitKey } = require('~/server/services/ToolService');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { availableTools } = require('~/app/clients/tools');
const { getLogStores } = require('~/cache');
const { Constants } = require('librechat-data-provider');
/**
* Filters out duplicate plugins from the list of plugins.
@@ -97,6 +98,7 @@ function createServerToolsCallback() {
return;
}
await mcpToolsCache.set(serverName, serverTools);
logger.debug(`MCP tools for ${serverName} added to cache.`);
} catch (error) {
logger.error('Error retrieving MCP tools from cache:', error);
}
@@ -137,21 +139,15 @@ function createGetServerTools() {
*/
const getAvailableTools = async (req, res) => {
try {
const userId = req.user?.id;
const customConfig = await getCustomConfig();
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const cachedUserTools = await getCachedTools({ userId });
const userPlugins = await convertMCPToolsToPlugins(cachedUserTools, customConfig, userId);
if (cachedToolsArray && userPlugins) {
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
res.status(200).json(dedupedTools);
const cachedTools = await cache.get(CacheKeys.TOOLS);
if (cachedTools) {
res.status(200).json(cachedTools);
return;
}
// If not in cache, build from manifest
let pluginManifest = availableTools;
const customConfig = await getCustomConfig();
if (customConfig?.mcpServers != null) {
const mcpManager = getMCPManager();
const flowsCache = getLogStores(CacheKeys.FLOWS);
@@ -177,7 +173,7 @@ const getAvailableTools = async (req, res) => {
}
});
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
const toolDefinitions = await getCachedTools({ includeGlobal: true });
const toolsOutput = [];
for (const plugin of authenticatedPlugins) {
@@ -201,150 +197,37 @@ const getAvailableTools = async (req, res) => {
const serverName = parts[parts.length - 1];
const serverConfig = customConfig?.mcpServers?.[serverName];
if (!serverConfig) {
if (!serverConfig?.customUserVars) {
toolsOutput.push(toolToAdd);
continue;
}
// Handle MCP servers with customUserVars (user-level auth required)
if (serverConfig.customUserVars) {
// Build authConfig for MCP tools
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length === 0) {
toolToAdd.authConfig = [];
toolToAdd.authenticated = true;
} else {
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
// Check actual connection status for MCP tools with auth requirements
if (userId) {
try {
const mcpManager = getMCPManager(userId);
const connectionStatus = await mcpManager.getUserConnectionStatus(userId, serverName);
toolToAdd.authenticated = connectionStatus.connected;
} catch (error) {
logger.error(
`[getAvailableTools] Error checking connection status for ${serverName}:`,
error,
);
toolToAdd.authenticated = false;
}
} else {
// For non-authenticated requests, default to false
toolToAdd.authenticated = false;
}
} else {
// Handle app-level MCP servers (no auth required)
toolToAdd.authConfig = [];
// Check if the app-level connection is active
try {
const mcpManager = getMCPManager();
const appConnection = mcpManager.getConnection(serverName);
if (appConnection) {
const connectionState = appConnection.getConnectionState();
// For app-level connections, consider them authenticated if they're in 'connected' state
// This is more reliable than isConnected() which does network calls
toolToAdd.authenticated = connectionState === 'connected';
} else {
logger.warn(`[getAvailableTools] No app-level connection found for ${serverName}`);
toolToAdd.authenticated = false;
}
} catch (error) {
logger.error(
`[getAvailableTools] Error checking app-level connection status for ${serverName}:`,
error,
);
toolToAdd.authenticated = false;
}
toolToAdd.authenticated = false;
}
toolsOutput.push(toolToAdd);
}
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
const dedupedTools = filterUniquePlugins([...userPlugins, ...finalTools]);
res.status(200).json(dedupedTools);
res.status(200).json(finalTools);
} catch (error) {
logger.error('[getAvailableTools]', error);
res.status(500).json({ message: error.message });
}
};
/**
* Converts MCP function format tools to plugin format
* @param {Object} functionTools - Object with function format tools
* @param {Object} customConfig - Custom configuration for MCP servers
* @returns {Array} Array of plugin objects
*/
async function convertMCPToolsToPlugins(functionTools, customConfig, userId = null) {
const plugins = [];
for (const [toolKey, toolData] of Object.entries(functionTools)) {
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
continue;
}
const functionData = toolData.function;
const parts = toolKey.split(Constants.mcp_delimiter);
const serverName = parts[parts.length - 1];
const plugin = {
name: parts[0], // Use the tool name without server suffix
pluginKey: toolKey,
description: functionData.description || '',
authenticated: false, // Default to false, will be updated based on connection status
icon: undefined,
};
// Build authConfig for MCP tools
const serverConfig = customConfig?.mcpServers?.[serverName];
if (!serverConfig?.customUserVars) {
plugin.authConfig = [];
plugin.authenticated = true; // No auth required
plugins.push(plugin);
continue;
}
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length === 0) {
plugin.authConfig = [];
plugin.authenticated = true; // No auth required
} else {
plugin.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
// Check actual connection status for MCP tools with auth requirements
if (userId) {
try {
const mcpManager = getMCPManager(userId);
const connectionStatus = await mcpManager.getUserConnectionStatus(userId, serverName);
plugin.authenticated = connectionStatus.connected;
} catch (error) {
logger.error(
`[convertMCPToolsToPlugins] Error checking connection status for ${serverName}:`,
error,
);
plugin.authenticated = false;
}
} else {
plugin.authenticated = false;
}
}
plugins.push(plugin);
}
return plugins;
}
module.exports = {
getAvailableTools,
getAvailablePluginsController,

View File

@@ -1,5 +1,11 @@
const {
Tools,
Constants,
FileSources,
webSearchKeys,
extractWebSearchEnvVars,
} = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const { webSearchKeys, extractWebSearchEnvVars } = require('@librechat/api');
const {
getFiles,
updateUser,
@@ -14,7 +20,6 @@ const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/service
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { Tools, Constants, FileSources } = require('librechat-data-provider');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User } = require('~/db/models');
const { deleteToolCalls } = require('~/models/ToolCall');
@@ -175,18 +180,14 @@ const updateUserPluginsController = async (req, res) => {
try {
const mcpManager = getMCPManager(user.id);
if (mcpManager) {
// Extract server name from pluginKey (e.g., "mcp_myserver" -> "myserver")
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
logger.info(
`[updateUserPluginsController] Disconnecting MCP connection for user ${user.id} and server ${serverName} after plugin auth update for ${pluginKey}.`,
`[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`,
);
// Don't kill the server connection on revoke anymore, user can just reinitialize the server if thats what they want
await mcpManager.disconnectUserConnection(user.id, serverName);
await mcpManager.disconnectUserConnections(user.id);
}
} catch (disconnectError) {
logger.error(
`[updateUserPluginsController] Error disconnecting MCP connection for user ${user.id} after plugin auth update:`,
`[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`,
disconnectError,
);
// Do not fail the request for this, but log it.

View File

@@ -1,23 +1,20 @@
require('events').EventEmitter.defaultMaxListeners = 100;
const { logger } = require('@librechat/data-schemas');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const {
sendEvent,
createRun,
Tokenizer,
checkAccess,
memoryInstructions,
formatContentStrings,
createMemoryProcessor,
} = require('@librechat/api');
const {
Callback,
Providers,
GraphEvents,
TitleMethod,
formatMessage,
formatAgentMessages,
formatContentStrings,
getTokenCountForMessage,
createMetadataAggregator,
} = require('@librechat/agents');
@@ -27,22 +24,20 @@ const {
VisionModes,
ContentTypes,
EModelEndpoint,
KnownEndpoints,
PermissionTypes,
isAgentsEndpoint,
AgentCapabilities,
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const {
findPluginAuthsByKeys,
getFormattedMemories,
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints');
const BaseClient = require('~/app/clients/BaseClient');
@@ -59,7 +54,6 @@ const omitTitleOptions = new Set([
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
'additionalModelRequestFields',
]);
/**
@@ -76,6 +70,8 @@ const payloadParser = ({ req, agent, endpoint }) => {
return req.body.endpointOption.model_parameters;
};
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
function createTokenCounter(encoding) {
@@ -456,12 +452,6 @@ class AgentClient extends BaseClient {
res: this.options.res,
agent: prelimAgent,
allowedProviders,
endpointOption: {
endpoint:
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
? EModelEndpoint.agents
: memoryConfig.agent?.provider,
},
});
if (!agent) {
@@ -710,12 +700,17 @@ class AgentClient extends BaseClient {
version: 'v2',
};
const getUserMCPAuthMap = await createGetMCPAuthMap();
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
payload,
this.indexTokenCountMap,
toolSet,
);
if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) {
initialMessages = formatContentStrings(initialMessages);
}
/**
*
@@ -779,9 +774,6 @@ class AgentClient extends BaseClient {
}
let messages = _messages;
if (agent.useLegacyContent === true) {
messages = formatContentStrings(messages);
}
if (
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
'prompt-caching',
@@ -830,11 +822,10 @@ class AgentClient extends BaseClient {
}
try {
if (await hasCustomUserVars()) {
config.configurable.userMCPAuthMap = await getMCPAuthMap({
if (getUserMCPAuthMap) {
config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
tools: agent.tools,
userId: this.options.req.user.id,
findPluginAuthsByKeys,
});
}
} catch (err) {
@@ -1010,7 +1001,7 @@ class AgentClient extends BaseClient {
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const { req, res, agent } = this.options;
let endpoint = agent.endpoint;
const endpoint = agent.endpoint;
/** @type {import('@librechat/agents').ClientOptions} */
let clientOptions = {
@@ -1018,32 +1009,17 @@ class AgentClient extends BaseClient {
model: agent.model_parameters.model,
};
let titleProviderConfig = await getProviderConfig(endpoint);
const { getOptions, overrideProvider, customEndpointConfig } =
await getProviderConfig(endpoint);
/** @type {TEndpoint | undefined} */
const endpointConfig =
req.app.locals.all ?? req.app.locals[endpoint] ?? titleProviderConfig.customEndpointConfig;
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
if (!endpointConfig) {
logger.warn(
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
);
}
if (endpointConfig?.titleEndpoint && endpointConfig.titleEndpoint !== endpoint) {
try {
titleProviderConfig = await getProviderConfig(endpointConfig.titleEndpoint);
endpoint = endpointConfig.titleEndpoint;
} catch (error) {
logger.warn(
`[api/server/controllers/agents/client.js #titleConvo] Error getting title endpoint config for ${endpointConfig.titleEndpoint}, falling back to default`,
error,
);
// Fall back to original provider config
endpoint = agent.endpoint;
titleProviderConfig = await getProviderConfig(endpoint);
}
}
if (
endpointConfig &&
endpointConfig.titleModel &&
@@ -1052,7 +1028,7 @@ class AgentClient extends BaseClient {
clientOptions.model = endpointConfig.titleModel;
}
const options = await titleProviderConfig.getOptions({
const options = await getOptions({
req,
res,
optionsOnly: true,
@@ -1061,18 +1037,12 @@ class AgentClient extends BaseClient {
endpointOption: { model_parameters: clientOptions },
});
let provider = options.provider ?? titleProviderConfig.overrideProvider ?? agent.provider;
let provider = options.provider ?? overrideProvider ?? agent.provider;
if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName == null
) {
provider = Providers.OPENAI;
} else if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName != null &&
provider !== Providers.AZURE
) {
provider = Providers.AZURE;
}
/** @type {import('@librechat/agents').ClientOptions} */
@@ -1094,23 +1064,16 @@ class AgentClient extends BaseClient {
),
);
if (
provider === Providers.GOOGLE &&
(endpointConfig?.titleMethod === TitleMethod.FUNCTIONS ||
endpointConfig?.titleMethod === TitleMethod.STRUCTURED)
) {
if (provider === Providers.GOOGLE) {
clientOptions.json = true;
}
try {
const titleResult = await this.run.generateTitle({
provider,
clientOptions,
inputText: text,
contentParts: this.contentParts,
titleMethod: endpointConfig?.titleMethod,
titlePrompt: endpointConfig?.titlePrompt,
titlePromptTemplate: endpointConfig?.titlePromptTemplate,
clientOptions,
chainOptions: {
signal: abortController.signal,
callbacks: [
@@ -1158,52 +1121,8 @@ class AgentClient extends BaseClient {
}
}
/**
* @param {object} params
* @param {number} params.promptTokens
* @param {number} params.completionTokens
* @param {OpenAIUsageMetadata} [params.usage]
* @param {string} [params.model]
* @param {string} [params.context='message']
* @returns {Promise<void>}
*/
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
try {
await spendTokens(
{
model,
context,
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ promptTokens, completionTokens },
);
if (
usage &&
typeof usage === 'object' &&
'reasoning_tokens' in usage &&
typeof usage.reasoning_tokens === 'number'
) {
await spendTokens(
{
model,
context: 'reasoning',
conversationId: this.conversationId,
user: this.user ?? this.options.req.user?.id,
endpointTokenConfig: this.options.endpointTokenConfig,
},
{ completionTokens: usage.reasoning_tokens },
);
}
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #recordTokenUsage] Error recording token usage',
error,
);
}
}
/** Silent method, as `recordCollectedUsage` is used instead */
async recordTokenUsage() {}
getEncoding() {
return 'o200k_base';

View File

@@ -1,730 +0,0 @@
const { Providers } = require('@librechat/agents');
const { Constants, EModelEndpoint } = require('librechat-data-provider');
const AgentClient = require('./client');
jest.mock('@librechat/agents', () => ({
...jest.requireActual('@librechat/agents'),
createMetadataAggregator: () => ({
handleLLMEnd: jest.fn(),
collected: [],
}),
}));
describe('AgentClient - titleConvo', () => {
let client;
let mockRun;
let mockReq;
let mockRes;
let mockAgent;
let mockOptions;
beforeEach(() => {
// Reset all mocks
jest.clearAllMocks();
// Mock run object
mockRun = {
generateTitle: jest.fn().mockResolvedValue({
title: 'Generated Title',
}),
};
// Mock agent - with both endpoint and provider
mockAgent = {
id: 'agent-123',
endpoint: EModelEndpoint.openAI, // Use a valid provider as endpoint for getProviderConfig
provider: EModelEndpoint.openAI, // Add provider property
model_parameters: {
model: 'gpt-4',
},
};
// Mock request and response
mockReq = {
app: {
locals: {
[EModelEndpoint.openAI]: {
// Match the agent endpoint
titleModel: 'gpt-3.5-turbo',
titlePrompt: 'Custom title prompt',
titleMethod: 'structured',
titlePromptTemplate: 'Template: {{content}}',
},
},
},
user: {
id: 'user-123',
},
body: {
model: 'gpt-4',
endpoint: EModelEndpoint.openAI,
key: null,
},
};
mockRes = {};
// Mock options
mockOptions = {
req: mockReq,
res: mockRes,
agent: mockAgent,
endpointTokenConfig: {},
};
// Create client instance
client = new AgentClient(mockOptions);
client.run = mockRun;
client.responseMessageId = 'response-123';
client.conversationId = 'convo-123';
client.contentParts = [{ type: 'text', text: 'Test content' }];
client.recordCollectedUsage = jest.fn().mockResolvedValue(); // Mock as async function that resolves
});
describe('titleConvo method', () => {
it('should throw error if run is not initialized', async () => {
client.run = null;
await expect(
client.titleConvo({ text: 'Test', abortController: new AbortController() }),
).rejects.toThrow('Run not initialized');
});
it('should use titlePrompt from endpoint config', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titlePrompt: 'Custom title prompt',
}),
);
});
it('should use titlePromptTemplate from endpoint config', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titlePromptTemplate: 'Template: {{content}}',
}),
);
});
it('should use titleMethod from endpoint config', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
provider: Providers.OPENAI,
titleMethod: 'structured',
}),
);
});
it('should use titleModel from endpoint config when provided', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Check that generateTitle was called with correct clientOptions
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('gpt-3.5-turbo');
});
it('should handle missing endpoint config gracefully', async () => {
// Remove endpoint config
mockReq.app.locals[EModelEndpoint.openAI] = undefined;
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titlePrompt: undefined,
titlePromptTemplate: undefined,
titleMethod: undefined,
}),
);
});
it('should use agent model when titleModel is not provided', async () => {
// Remove titleModel from config
delete mockReq.app.locals[EModelEndpoint.openAI].titleModel;
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('gpt-4'); // Should use agent's model
});
it('should not use titleModel when it equals CURRENT_MODEL constant', async () => {
mockReq.app.locals[EModelEndpoint.openAI].titleModel = Constants.CURRENT_MODEL;
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('gpt-4'); // Should use agent's model
});
it('should pass all required parameters to generateTitle', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
expect(mockRun.generateTitle).toHaveBeenCalledWith({
provider: expect.any(String),
inputText: text,
contentParts: client.contentParts,
clientOptions: expect.objectContaining({
model: 'gpt-3.5-turbo',
}),
titlePrompt: 'Custom title prompt',
titlePromptTemplate: 'Template: {{content}}',
titleMethod: 'structured',
chainOptions: expect.objectContaining({
signal: abortController.signal,
}),
});
});
it('should record collected usage after title generation', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
expect(client.recordCollectedUsage).toHaveBeenCalledWith({
model: 'gpt-3.5-turbo',
context: 'title',
collectedUsage: expect.any(Array),
});
});
it('should return the generated title', async () => {
const text = 'Test conversation text';
const abortController = new AbortController();
const result = await client.titleConvo({ text, abortController });
expect(result).toBe('Generated Title');
});
it('should handle errors gracefully and return undefined', async () => {
mockRun.generateTitle.mockRejectedValue(new Error('Title generation failed'));
const text = 'Test conversation text';
const abortController = new AbortController();
const result = await client.titleConvo({ text, abortController });
expect(result).toBeUndefined();
});
it('should pass titleEndpoint configuration to generateTitle', async () => {
// Mock the API key just for this test
const originalApiKey = process.env.ANTHROPIC_API_KEY;
process.env.ANTHROPIC_API_KEY = 'test-api-key';
// Add titleEndpoint to the config
mockReq.app.locals[EModelEndpoint.openAI].titleEndpoint = EModelEndpoint.anthropic;
mockReq.app.locals[EModelEndpoint.openAI].titleMethod = 'structured';
mockReq.app.locals[EModelEndpoint.openAI].titlePrompt = 'Custom title prompt';
mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate = 'Custom template';
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify generateTitle was called with the custom configuration
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titleMethod: 'structured',
provider: Providers.ANTHROPIC,
titlePrompt: 'Custom title prompt',
titlePromptTemplate: 'Custom template',
}),
);
// Restore the original API key
if (originalApiKey) {
process.env.ANTHROPIC_API_KEY = originalApiKey;
} else {
delete process.env.ANTHROPIC_API_KEY;
}
});
it('should use all config when endpoint config is missing', async () => {
// Remove endpoint-specific config
delete mockReq.app.locals[EModelEndpoint.openAI].titleModel;
delete mockReq.app.locals[EModelEndpoint.openAI].titlePrompt;
delete mockReq.app.locals[EModelEndpoint.openAI].titleMethod;
delete mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate;
// Set 'all' config
mockReq.app.locals.all = {
titleModel: 'gpt-4o-mini',
titlePrompt: 'All config title prompt',
titleMethod: 'completion',
titlePromptTemplate: 'All config template: {{content}}',
};
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify generateTitle was called with 'all' config values
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titleMethod: 'completion',
titlePrompt: 'All config title prompt',
titlePromptTemplate: 'All config template: {{content}}',
}),
);
// Check that the model was set from 'all' config
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('gpt-4o-mini');
});
it('should prioritize all config over endpoint config for title settings', async () => {
// Set both endpoint and 'all' config
mockReq.app.locals[EModelEndpoint.openAI].titleModel = 'gpt-3.5-turbo';
mockReq.app.locals[EModelEndpoint.openAI].titlePrompt = 'Endpoint title prompt';
mockReq.app.locals[EModelEndpoint.openAI].titleMethod = 'structured';
// Remove titlePromptTemplate from endpoint config to test fallback
delete mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate;
mockReq.app.locals.all = {
titleModel: 'gpt-4o-mini',
titlePrompt: 'All config title prompt',
titleMethod: 'completion',
titlePromptTemplate: 'All config template',
};
const text = 'Test conversation text';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify 'all' config takes precedence over endpoint config
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titleMethod: 'completion',
titlePrompt: 'All config title prompt',
titlePromptTemplate: 'All config template',
}),
);
// Check that the model was set from 'all' config
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('gpt-4o-mini');
});
it('should use all config with titleEndpoint and verify provider switch', async () => {
// Mock the API key for the titleEndpoint provider
const originalApiKey = process.env.ANTHROPIC_API_KEY;
process.env.ANTHROPIC_API_KEY = 'test-anthropic-key';
// Remove endpoint-specific config to test 'all' config
delete mockReq.app.locals[EModelEndpoint.openAI];
// Set comprehensive 'all' config with all new title options
mockReq.app.locals.all = {
titleConvo: true,
titleModel: 'claude-3-haiku-20240307',
titleMethod: 'completion', // Testing the new default method
titlePrompt: 'Generate a concise, descriptive title for this conversation',
titlePromptTemplate: 'Conversation summary: {{content}}',
titleEndpoint: EModelEndpoint.anthropic, // Should switch provider to Anthropic
};
const text = 'Test conversation about AI and machine learning';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify all config values were used
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
provider: Providers.ANTHROPIC, // Critical: Verify provider switched to Anthropic
titleMethod: 'completion',
titlePrompt: 'Generate a concise, descriptive title for this conversation',
titlePromptTemplate: 'Conversation summary: {{content}}',
inputText: text,
contentParts: client.contentParts,
}),
);
// Verify the model was set from 'all' config
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('claude-3-haiku-20240307');
// Verify other client options are set correctly
expect(generateTitleCall.clientOptions).toMatchObject({
model: 'claude-3-haiku-20240307',
// Note: Anthropic's getOptions may set its own maxTokens value
});
// Restore the original API key
if (originalApiKey) {
process.env.ANTHROPIC_API_KEY = originalApiKey;
} else {
delete process.env.ANTHROPIC_API_KEY;
}
});
it('should test all titleMethod options from all config', async () => {
// Test each titleMethod: 'completion', 'functions', 'structured'
const titleMethods = ['completion', 'functions', 'structured'];
for (const method of titleMethods) {
// Clear previous calls
mockRun.generateTitle.mockClear();
// Remove endpoint config
delete mockReq.app.locals[EModelEndpoint.openAI];
// Set 'all' config with specific titleMethod
mockReq.app.locals.all = {
titleModel: 'gpt-4o-mini',
titleMethod: method,
titlePrompt: `Testing ${method} method`,
titlePromptTemplate: `Template for ${method}: {{content}}`,
};
const text = `Test conversation for ${method} method`;
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify the correct titleMethod was used
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
titleMethod: method,
titlePrompt: `Testing ${method} method`,
titlePromptTemplate: `Template for ${method}: {{content}}`,
}),
);
}
});
describe('Azure-specific title generation', () => {
let originalEnv;
beforeEach(() => {
// Reset mocks
jest.clearAllMocks();
// Save original environment variables
originalEnv = { ...process.env };
// Mock Azure API keys
process.env.AZURE_OPENAI_API_KEY = 'test-azure-key';
process.env.AZURE_API_KEY = 'test-azure-key';
process.env.EASTUS_API_KEY = 'test-eastus-key';
process.env.EASTUS2_API_KEY = 'test-eastus2-key';
});
afterEach(() => {
// Restore environment variables
process.env = originalEnv;
});
it('should use OPENAI provider for Azure serverless endpoints', async () => {
// Set up Azure endpoint with serverless config
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
mockAgent.provider = EModelEndpoint.azureOpenAI;
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
titleConvo: true,
titleModel: 'grok-3',
titleMethod: 'completion',
titlePrompt: 'Azure serverless title prompt',
streamRate: 35,
modelGroupMap: {
'grok-3': {
group: 'Azure AI Foundry',
deploymentName: 'grok-3',
},
},
groupMap: {
'Azure AI Foundry': {
apiKey: '${AZURE_API_KEY}',
baseURL: 'https://test.services.ai.azure.com/models',
version: '2024-05-01-preview',
serverless: true,
models: {
'grok-3': {
deploymentName: 'grok-3',
},
},
},
},
};
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
mockReq.body.model = 'grok-3';
const text = 'Test Azure serverless conversation';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify provider was switched to OPENAI for serverless
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
provider: Providers.OPENAI, // Should be OPENAI for serverless
titleMethod: 'completion',
titlePrompt: 'Azure serverless title prompt',
}),
);
});
it('should use AZURE provider for Azure endpoints with instanceName', async () => {
// Set up Azure endpoint
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
mockAgent.provider = EModelEndpoint.azureOpenAI;
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
titleConvo: true,
titleModel: 'gpt-4o',
titleMethod: 'structured',
titlePrompt: 'Azure instance title prompt',
streamRate: 35,
modelGroupMap: {
'gpt-4o': {
group: 'eastus',
deploymentName: 'gpt-4o',
},
},
groupMap: {
eastus: {
apiKey: '${EASTUS_API_KEY}',
instanceName: 'region-instance',
version: '2024-02-15-preview',
models: {
'gpt-4o': {
deploymentName: 'gpt-4o',
},
},
},
},
};
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
mockReq.body.model = 'gpt-4o';
const text = 'Test Azure instance conversation';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify provider remains AZURE with instanceName
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
provider: Providers.AZURE,
titleMethod: 'structured',
titlePrompt: 'Azure instance title prompt',
}),
);
});
it('should handle Azure titleModel with CURRENT_MODEL constant', async () => {
// Set up Azure endpoint
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
mockAgent.provider = EModelEndpoint.azureOpenAI;
mockAgent.model_parameters.model = 'gpt-4o-latest';
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
titleConvo: true,
titleModel: Constants.CURRENT_MODEL,
titleMethod: 'functions',
streamRate: 35,
modelGroupMap: {
'gpt-4o-latest': {
group: 'region-eastus',
deploymentName: 'gpt-4o-mini',
version: '2024-02-15-preview',
},
},
groupMap: {
'region-eastus': {
apiKey: '${EASTUS2_API_KEY}',
instanceName: 'test-instance',
version: '2024-12-01-preview',
models: {
'gpt-4o-latest': {
deploymentName: 'gpt-4o-mini',
version: '2024-02-15-preview',
},
},
},
},
};
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
mockReq.body.model = 'gpt-4o-latest';
const text = 'Test Azure current model';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify it uses the correct model when titleModel is CURRENT_MODEL
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
// When CURRENT_MODEL is used with Azure, the model gets mapped to the deployment name
// In this case, 'gpt-4o-latest' is mapped to 'gpt-4o-mini' deployment
expect(generateTitleCall.clientOptions.model).toBe('gpt-4o-mini');
// Also verify that CURRENT_MODEL constant was not passed as the model
expect(generateTitleCall.clientOptions.model).not.toBe(Constants.CURRENT_MODEL);
});
it('should handle Azure with multiple model groups', async () => {
// Set up Azure endpoint
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
mockAgent.provider = EModelEndpoint.azureOpenAI;
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
titleConvo: true,
titleModel: 'o1-mini',
titleMethod: 'completion',
streamRate: 35,
modelGroupMap: {
'gpt-4o': {
group: 'eastus',
deploymentName: 'gpt-4o',
},
'o1-mini': {
group: 'region-eastus',
deploymentName: 'o1-mini',
},
'codex-mini': {
group: 'codex-mini',
deploymentName: 'codex-mini',
},
},
groupMap: {
eastus: {
apiKey: '${EASTUS_API_KEY}',
instanceName: 'region-eastus',
version: '2024-02-15-preview',
models: {
'gpt-4o': {
deploymentName: 'gpt-4o',
},
},
},
'region-eastus': {
apiKey: '${EASTUS2_API_KEY}',
instanceName: 'region-eastus2',
version: '2024-12-01-preview',
models: {
'o1-mini': {
deploymentName: 'o1-mini',
},
},
},
'codex-mini': {
apiKey: '${AZURE_API_KEY}',
baseURL: 'https://example.cognitiveservices.azure.com/openai/',
version: '2025-04-01-preview',
serverless: true,
models: {
'codex-mini': {
deploymentName: 'codex-mini',
},
},
},
},
};
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
mockReq.body.model = 'o1-mini';
const text = 'Test Azure multi-group conversation';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify correct model and provider are used
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
provider: Providers.AZURE,
titleMethod: 'completion',
}),
);
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
expect(generateTitleCall.clientOptions.model).toBe('o1-mini');
expect(generateTitleCall.clientOptions.maxTokens).toBeUndefined(); // o1 models shouldn't have maxTokens
});
it('should use all config as fallback for Azure endpoints', async () => {
// Set up Azure endpoint with minimal config
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
mockAgent.provider = EModelEndpoint.azureOpenAI;
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
mockReq.body.model = 'gpt-4';
// Remove Azure-specific config
delete mockReq.app.locals[EModelEndpoint.azureOpenAI];
// Set 'all' config as fallback with a serverless Azure config
mockReq.app.locals.all = {
titleConvo: true,
titleModel: 'gpt-4',
titleMethod: 'structured',
titlePrompt: 'Fallback title prompt from all config',
titlePromptTemplate: 'Template: {{content}}',
modelGroupMap: {
'gpt-4': {
group: 'default-group',
deploymentName: 'gpt-4',
},
},
groupMap: {
'default-group': {
apiKey: '${AZURE_API_KEY}',
baseURL: 'https://default.openai.azure.com/',
version: '2024-02-15-preview',
serverless: true,
models: {
'gpt-4': {
deploymentName: 'gpt-4',
},
},
},
},
};
const text = 'Test Azure with all config fallback';
const abortController = new AbortController();
await client.titleConvo({ text, abortController });
// Verify all config is used
expect(mockRun.generateTitle).toHaveBeenCalledWith(
expect.objectContaining({
provider: Providers.OPENAI, // Should be OPENAI when no instanceName
titleMethod: 'structured',
titlePrompt: 'Fallback title prompt from all config',
titlePromptTemplate: 'Template: {{content}}',
}),
);
});
});
});
});

View File

@@ -12,7 +12,6 @@ const { saveMessage } = require('~/models');
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
isRegenerate,
endpointOption,
conversationId,
isContinued = false,
@@ -168,7 +167,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
onStart,
getReqData,
isContinued,
isRegenerate,
editedContent,
conversationId,
parentMessageId,

View File

@@ -391,22 +391,6 @@ const uploadAgentAvatarHandler = async (req, res) => {
return res.status(400).json({ message: 'Agent ID is required' });
}
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id: agent_id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
return res.status(403).json({
error: 'You do not have permission to modify this non-collaborative agent',
});
}
const buffer = await fs.readFile(req.file.path);
const fileStrategy = req.app.locals.fileStrategy;
@@ -429,7 +413,14 @@ const uploadAgentAvatarHandler = async (req, res) => {
source: fileStrategy,
};
let _avatar = existingAgent.avatar;
let _avatar;
try {
const agent = await getAgent({ id: agent_id });
_avatar = agent.avatar;
} catch (error) {
logger.error('[/:agent_id/avatar] Error fetching agent', error);
_avatar = {};
}
if (_avatar && _avatar.source) {
const { deleteFile } = getStrategyFunctions(_avatar.source);
@@ -451,7 +442,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
};
promises.push(
await updateAgent({ id: agent_id }, data, {
await updateAgent({ id: agent_id, author: req.user.id }, data, {
updatingUserId: req.user.id,
}),
);

View File

@@ -1,13 +1,14 @@
const { nanoid } = require('nanoid');
const { EnvVar } = require('@librechat/agents');
const { checkAccess } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { checkAccess, loadWebSearchAuth } = require('@librechat/api');
const {
Tools,
AuthType,
Permissions,
ToolCallTypes,
PermissionTypes,
loadWebSearchAuth,
} = require('librechat-data-provider');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { processCodeOutput } = require('~/server/services/Files/Code/process');

View File

@@ -16,7 +16,7 @@ const { connectDb, indexSync } = require('~/db');
const validateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const errorController = require('./controllers/ErrorController');
const initializeMCPs = require('./services/initializeMCPs');
const initializeMCP = require('./services/initializeMCP');
const configureSocialLogins = require('./socialLogins');
const AppService = require('./services/AppService');
const staticCache = require('./utils/staticCache');
@@ -146,7 +146,7 @@ const startServer = async () => {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
initializeMCPs(app);
initializeMCP(app);
});
};

View File

@@ -1,4 +1,3 @@
const { handleError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
EndpointURLs,
@@ -15,6 +14,7 @@ const openAI = require('~/server/services/Endpoints/openAI');
const agents = require('~/server/services/Endpoints/agents');
const custom = require('~/server/services/Endpoints/custom');
const google = require('~/server/services/Endpoints/google');
const { handleError } = require('~/server/utils');
const buildFunction = {
[EModelEndpoint.openAI]: openAI.buildOptions,

View File

@@ -1,4 +1,4 @@
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { Time, CacheKeys } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils');
@@ -37,7 +37,7 @@ const concurrentLimiter = async (req, res, next) => {
const userId = req.user?.id ?? req.user?._id ?? '';
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = ViolationTypes.CONCURRENT;
const type = 'concurrent';
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
const pendingRequests = +((await cache.get(key)) ?? 0);

View File

@@ -1,6 +1,9 @@
const rateLimit = require('express-rate-limit');
const { isEnabled } = require('@librechat/api');
const { RedisStore } = require('rate-limit-redis');
const { logger } = require('@librechat/data-schemas');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
@@ -8,7 +11,6 @@ const getEnvironmentVariables = () => {
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
const forkIpMax = FORK_IP_MAX;
@@ -25,18 +27,12 @@ const getEnvironmentVariables = () => {
forkUserWindowMs,
forkUserMax,
forkUserWindowInMinutes,
forkViolationScore: FORK_VIOLATION_SCORE,
};
};
const createForkHandler = (ip = true) => {
const {
forkIpMax,
forkUserMax,
forkViolationScore,
forkIpWindowInMinutes,
forkUserWindowInMinutes,
} = getEnvironmentVariables();
const { forkIpMax, forkIpWindowInMinutes, forkUserMax, forkUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
@@ -47,7 +43,7 @@ const createForkHandler = (ip = true) => {
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, forkViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
};
};
@@ -59,7 +55,6 @@ const createForkLimiters = () => {
windowMs: forkIpWindowMs,
max: forkIpMax,
handler: createForkHandler(),
store: limiterCache('fork_ip_limiter'),
};
const userLimiterOptions = {
windowMs: forkUserWindowMs,
@@ -68,9 +63,23 @@ const createForkLimiters = () => {
keyGenerator: function (req) {
return req.user?.id;
},
store: limiterCache('fork_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for fork rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'fork_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'fork_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const forkIpLimiter = rateLimit(ipLimiterOptions);
const forkUserLimiter = rateLimit(userLimiterOptions);
return { forkIpLimiter, forkUserLimiter };

View File

@@ -1,6 +1,9 @@
const rateLimit = require('express-rate-limit');
const { isEnabled } = require('@librechat/api');
const { RedisStore } = require('rate-limit-redis');
const { logger } = require('@librechat/data-schemas');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
@@ -8,7 +11,6 @@ const getEnvironmentVariables = () => {
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
const importIpMax = IMPORT_IP_MAX;
@@ -25,18 +27,12 @@ const getEnvironmentVariables = () => {
importUserWindowMs,
importUserMax,
importUserWindowInMinutes,
importViolationScore: IMPORT_VIOLATION_SCORE,
};
};
const createImportHandler = (ip = true) => {
const {
importIpMax,
importUserMax,
importViolationScore,
importIpWindowInMinutes,
importUserWindowInMinutes,
} = getEnvironmentVariables();
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
@@ -47,7 +43,7 @@ const createImportHandler = (ip = true) => {
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, importViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
};
};
@@ -60,7 +56,6 @@ const createImportLimiters = () => {
windowMs: importIpWindowMs,
max: importIpMax,
handler: createImportHandler(),
store: limiterCache('import_ip_limiter'),
};
const userLimiterOptions = {
windowMs: importUserWindowMs,
@@ -69,9 +64,23 @@ const createImportLimiters = () => {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('import_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for import rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'import_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'import_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const importIpLimiter = rateLimit(ipLimiterOptions);
const importUserLimiter = rateLimit(userLimiterOptions);
return { importIpLimiter, importUserLimiter };

View File

@@ -1,8 +1,9 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
const windowMs = LOGIN_WINDOW * 60 * 1000;
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`;
const handler = async (req, res) => {
const type = ViolationTypes.LOGINS;
const type = 'logins';
const errorMessage = {
type,
max,
@@ -27,9 +28,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('login_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for login rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'login_limiter:',
});
limiterOptions.store = store;
}
const loginLimiter = rateLimit(limiterOptions);
module.exports = loginLimiter;

View File

@@ -1,15 +1,16 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { RedisStore } = require('rate-limit-redis');
const denyRequest = require('~/server/middleware/denyRequest');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const { isEnabled } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
MESSAGE_IP_MAX = 40,
MESSAGE_IP_WINDOW = 1,
MESSAGE_USER_MAX = 40,
MESSAGE_USER_WINDOW = 1,
MESSAGE_VIOLATION_SCORE: score,
} = process.env;
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
@@ -30,7 +31,7 @@ const userWindowInMinutes = userWindowMs / 60000;
*/
const createHandler = (ip = true) => {
return async (req, res) => {
const type = ViolationTypes.MESSAGE_LIMIT;
const type = 'message_limit';
const errorMessage = {
type,
max: ip ? ipMax : userMax,
@@ -38,7 +39,7 @@ const createHandler = (ip = true) => {
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
await logViolation(req, res, type, errorMessage);
return await denyRequest(req, res, errorMessage);
};
};
@@ -50,7 +51,6 @@ const ipLimiterOptions = {
windowMs: ipWindowMs,
max: ipMax,
handler: createHandler(),
store: limiterCache('message_ip_limiter'),
};
const userLimiterOptions = {
@@ -60,9 +60,23 @@ const userLimiterOptions = {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('message_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for message rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'message_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'message_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
/**
* Message request rate limiter by IP
*/

View File

@@ -1,8 +1,9 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
const windowMs = REGISTER_WINDOW * 60 * 1000;
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`;
const handler = async (req, res) => {
const type = ViolationTypes.REGISTRATIONS;
const type = 'registrations';
const errorMessage = {
type,
max,
@@ -27,9 +28,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('register_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for register rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'register_limiter:',
});
limiterOptions.store = store;
}
const registerLimiter = rateLimit(limiterOptions);
module.exports = registerLimiter;

View File

@@ -1,8 +1,10 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
RESET_PASSWORD_WINDOW = 2,
@@ -31,9 +33,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('reset_password_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for reset password rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'reset_password_limiter:',
});
limiterOptions.store = store;
}
const resetPasswordLimiter = rateLimit(limiterOptions);
module.exports = resetPasswordLimiter;

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
const sttIpMax = STT_IP_MAX;
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
sttUserWindowMs,
sttUserMax,
sttUserWindowInMinutes,
sttViolationScore: STT_VIOLATION_SCORE,
};
};
const createSTTHandler = (ip = true) => {
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
@@ -42,7 +43,7 @@ const createSTTHandler = (ip = true) => {
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, sttViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many STT requests. Try again later' });
};
};
@@ -54,7 +55,6 @@ const createSTTLimiters = () => {
windowMs: sttIpWindowMs,
max: sttIpMax,
handler: createSTTHandler(),
store: limiterCache('stt_ip_limiter'),
};
const userLimiterOptions = {
@@ -64,9 +64,23 @@ const createSTTLimiters = () => {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('stt_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for STT rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'stt_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'stt_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const sttIpLimiter = rateLimit(ipLimiterOptions);
const sttUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,9 +1,10 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const handler = async (req, res) => {
const type = ViolationTypes.TOOL_CALL_LIMIT;
@@ -14,7 +15,7 @@ const handler = async (req, res) => {
windowInMinutes: 1,
};
await logViolation(req, res, type, errorMessage, score);
await logViolation(req, res, type, errorMessage, 0);
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
};
@@ -25,9 +26,17 @@ const limiterOptions = {
keyGenerator: function (req) {
return req.user?.id;
},
store: limiterCache('tool_call_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for tool call rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'tool_call_limiter:',
});
limiterOptions.store = store;
}
const toolCallLimiter = rateLimit(limiterOptions);
module.exports = toolCallLimiter;

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { limiterCache } = require('~/cache/cacheFactory');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
const ttsIpMax = TTS_IP_MAX;
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
ttsUserWindowMs,
ttsUserMax,
ttsUserWindowInMinutes,
ttsViolationScore: TTS_VIOLATION_SCORE,
};
};
const createTTSHandler = (ip = true) => {
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
@@ -42,7 +43,7 @@ const createTTSHandler = (ip = true) => {
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, ttsViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
};
};
@@ -54,19 +55,32 @@ const createTTSLimiters = () => {
windowMs: ttsIpWindowMs,
max: ttsIpMax,
handler: createTTSHandler(),
store: limiterCache('tts_ip_limiter'),
};
const userLimiterOptions = {
windowMs: ttsUserWindowMs,
max: ttsUserMax,
handler: createTTSHandler(false),
store: limiterCache('tts_user_limiter'),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for TTS rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'tts_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'tts_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const ttsIpLimiter = rateLimit(ipLimiterOptions);
const ttsUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,14 +1,16 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { limiterCache } = require('~/cache/cacheFactory');
const ioredisClient = require('~/cache/ioredisClient');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
@@ -25,7 +27,6 @@ const getEnvironmentVariables = () => {
fileUploadUserWindowMs,
fileUploadUserMax,
fileUploadUserWindowInMinutes,
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
};
};
@@ -35,7 +36,6 @@ const createFileUploadHandler = (ip = true) => {
fileUploadIpWindowInMinutes,
fileUploadUserMax,
fileUploadUserWindowInMinutes,
fileUploadViolationScore,
} = getEnvironmentVariables();
return async (req, res) => {
@@ -47,7 +47,7 @@ const createFileUploadHandler = (ip = true) => {
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
};
};
@@ -60,7 +60,6 @@ const createFileLimiters = () => {
windowMs: fileUploadIpWindowMs,
max: fileUploadIpMax,
handler: createFileUploadHandler(),
store: limiterCache('file_upload_ip_limiter'),
};
const userLimiterOptions = {
@@ -70,9 +69,23 @@ const createFileLimiters = () => {
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
store: limiterCache('file_upload_user_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for file upload rate limiters.');
const sendCommand = (...args) => ioredisClient.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'file_upload_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'file_upload_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
const fileUploadUserLimiter = rateLimit(userLimiterOptions);

View File

@@ -1,8 +1,10 @@
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { limiterCache } = require('~/cache/cacheFactory');
const { removePorts, isEnabled } = require('~/server/utils');
const ioredisClient = require('~/cache/ioredisClient');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
VERIFY_EMAIL_WINDOW = 2,
@@ -31,9 +33,17 @@ const limiterOptions = {
max,
handler,
keyGenerator: removePorts,
store: limiterCache('verify_email_limiter'),
};
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
logger.debug('Using Redis for verify email rate limiter.');
const store = new RedisStore({
sendCommand: (...args) => ioredisClient.call(...args),
prefix: 'verify_email_limiter:',
});
limiterOptions.store = store;
}
const verifyEmailLimiter = rateLimit(limiterOptions);
module.exports = verifyEmailLimiter;

View File

@@ -1,6 +1,5 @@
const uap = require('ua-parser-js');
const { ViolationTypes } = require('librechat-data-provider');
const { handleError } = require('@librechat/api');
const { handleError } = require('../utils');
const { logViolation } = require('../../cache');
/**
@@ -22,7 +21,7 @@ async function uaParser(req, res, next) {
const ua = uap(req.headers['user-agent']);
if (!ua.browser.name) {
const type = ViolationTypes.NON_BROWSER;
const type = 'non_browser';
await logViolation(req, res, type, { type }, score);
return handleError(res, { message: 'Illegal request' });
}

View File

@@ -1,4 +1,4 @@
const { handleError } = require('@librechat/api');
const { handleError } = require('../utils');
function validateEndpoint(req, res, next) {
const { endpoint: _endpoint, endpointType } = req.body;

View File

@@ -1,6 +1,6 @@
const { handleError } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { getModelsConfig } = require('~/server/controllers/ModelController');
const { handleError } = require('~/server/utils');
const { logViolation } = require('~/cache');
/**
* Validates the model of the request.

View File

@@ -1,162 +0,0 @@
const fs = require('fs');
const path = require('path');
const express = require('express');
const request = require('supertest');
const zlib = require('zlib');
// Create test setup
const mockTestDir = path.join(__dirname, 'test-static-route');
// Mock the paths module to point to our test directory
jest.mock('~/config/paths', () => ({
imageOutput: mockTestDir,
}));
describe('Static Route Integration', () => {
let app;
let staticRoute;
let testDir;
let testImagePath;
beforeAll(() => {
// Create a test directory and files
testDir = mockTestDir;
testImagePath = path.join(testDir, 'test-image.jpg');
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir, { recursive: true });
}
// Create a test image file
fs.writeFileSync(testImagePath, 'fake-image-data');
// Create a gzipped version of the test image (for gzip scanning tests)
fs.writeFileSync(testImagePath + '.gz', zlib.gzipSync('fake-image-data'));
});
afterAll(() => {
// Clean up test files
if (fs.existsSync(testDir)) {
fs.rmSync(testDir, { recursive: true, force: true });
}
});
// Helper function to set up static route with specific config
const setupStaticRoute = (skipGzipScan = false) => {
if (skipGzipScan) {
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
} else {
process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN = 'true';
}
staticRoute = require('../static');
app.use('/images', staticRoute);
};
beforeEach(() => {
// Clear the module cache to get fresh imports
jest.resetModules();
app = express();
// Clear environment variables
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
delete process.env.NODE_ENV;
});
describe('route functionality', () => {
it('should serve static image files', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should return 404 for non-existent files', async () => {
setupStaticRoute();
const response = await request(app).get('/images/nonexistent.jpg');
expect(response.status).toBe(404);
});
});
describe('cache behavior', () => {
it('should set cache headers for images in production', async () => {
process.env.NODE_ENV = 'production';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
});
it('should not set cache headers in development', async () => {
process.env.NODE_ENV = 'development';
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
// Our middleware should not set the production cache-control header in development
expect(response.headers['cache-control']).not.toBe('public, max-age=172800, s-maxage=86400');
});
});
describe('gzip compression behavior', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve gzipped files when gzip scanning is enabled', async () => {
setupStaticRoute(false); // Enable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.body.toString()).toBe('fake-image-data');
});
it('should not serve gzipped files when gzip scanning is disabled', async () => {
setupStaticRoute(true); // Disable gzip scanning
const response = await request(app)
.get('/images/test-image.jpg')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.body.toString()).toBe('fake-image-data');
});
});
describe('path configuration', () => {
it('should use the configured imageOutput path', async () => {
setupStaticRoute();
const response = await request(app).get('/images/test-image.jpg').expect(200);
expect(response.body.toString()).toBe('fake-image-data');
});
it('should serve from subdirectories', async () => {
// Create a subdirectory with a file
const subDir = path.join(testDir, 'thumbs');
fs.mkdirSync(subDir, { recursive: true });
const thumbPath = path.join(subDir, 'thumb.jpg');
fs.writeFileSync(thumbPath, 'thumbnail-data');
setupStaticRoute();
const response = await request(app).get('/images/thumbs/thumb.jpg').expect(200);
expect(response.body.toString()).toBe('thumbnail-data');
// Clean up
fs.rmSync(subDir, { recursive: true, force: true });
});
});
});

View File

@@ -106,7 +106,6 @@ router.get('/', async function (req, res) {
const serverConfig = config.mcpServers[serverName];
payload.mcpServers[serverName] = {
customUserVars: serverConfig?.customUserVars || {},
requiresOAuth: req.app.locals.mcpOAuthRequirements?.[serverName] || false,
};
}
}

View File

@@ -1,282 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
// Mock dependencies
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: jest.fn().mockResolvedValue({}),
filterFile: jest.fn(),
processFileUpload: jest.fn(),
processAgentFileUpload: jest.fn(),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({})),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3FileUrls: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({
get: jest.fn(),
set: jest.fn(),
})),
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
},
}));
const { createFile } = require('~/models/File');
const { createAgent } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
// Import the router after mocks
const router = require('./files');
describe('File Routes - Agent Files Endpoint', () => {
let app;
let mongoServer;
let authorId;
let otherUserId;
let agentId;
let fileId1;
let fileId2;
let fileId3;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
// Initialize models
require('~/db/models');
app = express();
app.use(express.json());
// Mock authentication middleware
app.use((req, res, next) => {
req.user = { id: otherUserId || 'default-user' };
req.app = { locals: {} };
next();
});
app.use('/files', router);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
jest.clearAllMocks();
// Clear database
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
authorId = new mongoose.Types.ObjectId().toString();
otherUserId = new mongoose.Types.ObjectId().toString();
agentId = uuidv4();
fileId1 = uuidv4();
fileId2 = uuidv4();
fileId3 = uuidv4();
// Create files
await createFile({
user: authorId,
file_id: fileId1,
filename: 'agent-file1.txt',
filepath: `/uploads/${authorId}/${fileId1}`,
bytes: 1024,
type: 'text/plain',
});
await createFile({
user: authorId,
file_id: fileId2,
filename: 'agent-file2.txt',
filepath: `/uploads/${authorId}/${fileId2}`,
bytes: 2048,
type: 'text/plain',
});
await createFile({
user: otherUserId,
file_id: fileId3,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${fileId3}`,
bytes: 512,
type: 'text/plain',
});
// Create an agent with files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileId1, fileId2],
},
},
});
// Share the agent globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
}
});
describe('GET /files/agent/:agent_id', () => {
it('should return files accessible through the agent for non-author', async () => {
const response = await request(app).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(2); // Only agent files, not user-owned files
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).not.toContain(fileId3); // User's own file not included
});
it('should return 400 when agent_id is not provided', async () => {
const response = await request(app).get('/files/agent/');
expect(response.status).toBe(404); // Express returns 404 for missing route parameter
});
it('should return empty array for non-existent agent', async () => {
const response = await request(app).get('/files/agent/non-existent-agent');
expect(response.status).toBe(200);
expect(response.body).toEqual([]); // Empty array for non-existent agent
});
it('should return empty array when agent is not collaborative', async () => {
// Create a non-collaborative agent
const nonCollabAgentId = uuidv4();
await createAgent({
id: nonCollabAgentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: [fileId1],
},
},
});
// Share it globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: nonCollabAgentId }, { projectIds: [globalProject._id] });
}
const response = await request(app).get(`/files/agent/${nonCollabAgentId}`);
expect(response.status).toBe(200);
expect(response.body).toEqual([]); // Empty array when not collaborative
});
it('should return agent files for agent author', async () => {
// Create a new app instance with author authentication
const authorApp = express();
authorApp.use(express.json());
authorApp.use((req, res, next) => {
req.user = { id: authorId };
req.app = { locals: {} };
next();
});
authorApp.use('/files', router);
const response = await request(authorApp).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(2); // Agent files for author
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).not.toContain(fileId3); // User's own file not included
});
it('should return files uploaded by other users to shared agent for author', async () => {
// Create a file uploaded by another user
const otherUserFileId = uuidv4();
const anotherUserId = new mongoose.Types.ObjectId().toString();
await createFile({
user: anotherUserId,
file_id: otherUserFileId,
filename: 'other-user-file.txt',
filepath: `/uploads/${anotherUserId}/${otherUserFileId}`,
bytes: 4096,
type: 'text/plain',
});
// Update agent to include the file uploaded by another user
const { updateAgent } = require('~/models/Agent');
await updateAgent(
{ id: agentId },
{
tool_resources: {
file_search: {
file_ids: [fileId1, fileId2, otherUserFileId],
},
},
},
);
// Create app instance with author authentication
const authorApp = express();
authorApp.use(express.json());
authorApp.use((req, res, next) => {
req.user = { id: authorId };
req.app = { locals: {} };
next();
});
authorApp.use('/files', router);
const response = await request(authorApp).get(`/files/agent/${agentId}`);
expect(response.status).toBe(200);
expect(response.body).toHaveLength(3); // Including file from another user
const fileIds = response.body.map((f) => f.file_id);
expect(fileIds).toContain(fileId1);
expect(fileIds).toContain(fileId2);
expect(fileIds).toContain(otherUserFileId); // File uploaded by another user
});
});
});

View File

@@ -5,7 +5,6 @@ const {
Time,
isUUID,
CacheKeys,
Constants,
FileSources,
EModelEndpoint,
isAgentsEndpoint,
@@ -17,12 +16,11 @@ const {
processDeleteRequest,
processAgentFileUpload,
} = require('~/server/services/Files/process');
const { getFiles, batchUpdateFiles, hasAccessToFilesViaAgent } = require('~/models/File');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
const { getProjectByName } = require('~/models/Project');
const { getFiles, batchUpdateFiles } = require('~/models/File');
const { getAssistant } = require('~/models/Assistant');
const { getAgent } = require('~/models/Agent');
const { getLogStores } = require('~/cache');
@@ -52,68 +50,6 @@ router.get('/', async (req, res) => {
}
});
/**
* Get files specific to an agent
* @route GET /files/agent/:agent_id
* @param {string} agent_id - The agent ID to get files for
* @returns {Promise<TFile[]>} Array of files attached to the agent
*/
router.get('/agent/:agent_id', async (req, res) => {
try {
const { agent_id } = req.params;
const userId = req.user.id;
if (!agent_id) {
return res.status(400).json({ error: 'Agent ID is required' });
}
// Get the agent to check ownership and attached files
const agent = await getAgent({ id: agent_id });
if (!agent) {
// No agent found, return empty array
return res.status(200).json([]);
}
// Check if user has access to the agent
if (agent.author.toString() !== userId) {
// Non-authors need the agent to be globally shared and collaborative
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString()) ||
!agent.isCollaborative
) {
return res.status(200).json([]);
}
}
// Collect all file IDs from agent's tool resources
const agentFileIds = [];
if (agent.tool_resources) {
for (const [, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
agentFileIds.push(...resource.file_ids);
}
}
}
// If no files attached to agent, return empty array
if (agentFileIds.length === 0) {
return res.status(200).json([]);
}
// Get only the files attached to this agent
const files = await getFiles({ file_id: { $in: agentFileIds } }, null, { text: 0 });
res.status(200).json(files);
} catch (error) {
logger.error('[/files/agent/:agent_id] Error fetching agent files:', error);
res.status(500).json({ error: 'Failed to fetch agent files' });
}
});
router.get('/config', async (req, res) => {
try {
res.status(200).json(req.app.locals.fileConfig);
@@ -150,62 +86,11 @@ router.delete('/', async (req, res) => {
const fileIds = files.map((file) => file.file_id);
const dbFiles = await getFiles({ file_id: { $in: fileIds } });
const ownedFiles = [];
const nonOwnedFiles = [];
const fileMap = new Map();
for (const file of dbFiles) {
fileMap.set(file.file_id, file);
if (file.user.toString() === req.user.id) {
ownedFiles.push(file);
} else {
nonOwnedFiles.push(file);
}
}
// If all files are owned by the user, no need for further checks
if (nonOwnedFiles.length === 0) {
await processDeleteRequest({ req, files: ownedFiles });
logger.debug(
`[/files] Files deleted successfully: ${ownedFiles
.filter((f) => f.file_id)
.map((f) => f.file_id)
.join(', ')}`,
);
res.status(200).json({ message: 'Files deleted successfully' });
return;
}
// Check access for non-owned files
let authorizedFiles = [...ownedFiles];
let unauthorizedFiles = [];
if (req.body.agent_id && nonOwnedFiles.length > 0) {
// Batch check access for all non-owned files
const nonOwnedFileIds = nonOwnedFiles.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
req.user.id,
nonOwnedFileIds,
req.body.agent_id,
);
// Separate authorized and unauthorized files
for (const file of nonOwnedFiles) {
if (accessMap.get(file.file_id)) {
authorizedFiles.push(file);
} else {
unauthorizedFiles.push(file);
}
}
} else {
// No agent context, all non-owned files are unauthorized
unauthorizedFiles = nonOwnedFiles;
}
const unauthorizedFiles = dbFiles.filter((file) => file.user.toString() !== req.user.id);
if (unauthorizedFiles.length > 0) {
return res.status(403).json({
message: 'You can only delete files you have access to',
message: 'You can only delete your own files',
unauthorizedFiles: unauthorizedFiles.map((f) => f.file_id),
});
}
@@ -246,10 +131,10 @@ router.delete('/', async (req, res) => {
.json({ message: 'File associations removed successfully from Azure Assistant' });
}
await processDeleteRequest({ req, files: authorizedFiles });
await processDeleteRequest({ req, files: dbFiles });
logger.debug(
`[/files] Files deleted successfully: ${authorizedFiles
`[/files] Files deleted successfully: ${files
.filter((f) => f.file_id)
.map((f) => f.file_id)
.join(', ')}`,

View File

@@ -1,302 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
// Mock dependencies
jest.mock('~/server/services/Files/process', () => ({
processDeleteRequest: jest.fn().mockResolvedValue({}),
filterFile: jest.fn(),
processFileUpload: jest.fn(),
processAgentFileUpload: jest.fn(),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(() => ({})),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3FileUrls: jest.fn(),
}));
jest.mock('~/cache', () => ({
getLogStores: jest.fn(() => ({
get: jest.fn(),
set: jest.fn(),
})),
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
},
}));
const { createFile } = require('~/models/File');
const { createAgent } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { processDeleteRequest } = require('~/server/services/Files/process');
// Import the router after mocks
const router = require('./files');
describe('File Routes - Delete with Agent Access', () => {
let app;
let mongoServer;
let authorId;
let otherUserId;
let agentId;
let fileId;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
// Initialize models
require('~/db/models');
app = express();
app.use(express.json());
// Mock authentication middleware
app.use((req, res, next) => {
req.user = { id: otherUserId || 'default-user' };
req.app = { locals: {} };
next();
});
app.use('/files', router);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
jest.clearAllMocks();
// Clear database
const collections = mongoose.connection.collections;
for (const key in collections) {
await collections[key].deleteMany({});
}
authorId = new mongoose.Types.ObjectId().toString();
otherUserId = new mongoose.Types.ObjectId().toString();
fileId = uuidv4();
// Create a file owned by the author
await createFile({
user: authorId,
file_id: fileId,
filename: 'test.txt',
filepath: `/uploads/${authorId}/${fileId}`,
bytes: 1024,
type: 'text/plain',
});
// Create an agent with the file attached
const agent = await createAgent({
id: uuidv4(),
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileId],
},
},
});
agentId = agent.id;
// Share the agent globally
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
if (globalProject) {
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
}
});
describe('DELETE /files', () => {
it('should allow deleting files owned by the user', async () => {
// Create a file owned by the current user
const userFileId = uuidv4();
await createFile({
user: otherUserId,
file_id: userFileId,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${userFileId}`,
bytes: 1024,
type: 'text/plain',
});
const response = await request(app)
.delete('/files')
.send({
files: [
{
file_id: userFileId,
filepath: `/uploads/${otherUserId}/${userFileId}`,
},
],
});
expect(response.status).toBe(200);
expect(response.body.message).toBe('Files deleted successfully');
expect(processDeleteRequest).toHaveBeenCalled();
});
it('should prevent deleting files not owned by user without agent context', async () => {
const response = await request(app)
.delete('/files')
.send({
files: [
{
file_id: fileId,
filepath: `/uploads/${authorId}/${fileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(fileId);
expect(processDeleteRequest).not.toHaveBeenCalled();
});
it('should allow deleting files accessible through shared agent', async () => {
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: fileId,
filepath: `/uploads/${authorId}/${fileId}`,
},
],
});
expect(response.status).toBe(200);
expect(response.body.message).toBe('Files deleted successfully');
expect(processDeleteRequest).toHaveBeenCalled();
});
it('should prevent deleting files not attached to the specified agent', async () => {
// Create another file not attached to the agent
const unattachedFileId = uuidv4();
await createFile({
user: authorId,
file_id: unattachedFileId,
filename: 'unattached.txt',
filepath: `/uploads/${authorId}/${unattachedFileId}`,
bytes: 1024,
type: 'text/plain',
});
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: unattachedFileId,
filepath: `/uploads/${authorId}/${unattachedFileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(unattachedFileId);
});
it('should handle mixed authorized and unauthorized files', async () => {
// Create a file owned by the current user
const userFileId = uuidv4();
await createFile({
user: otherUserId,
file_id: userFileId,
filename: 'user-file.txt',
filepath: `/uploads/${otherUserId}/${userFileId}`,
bytes: 1024,
type: 'text/plain',
});
// Create an unauthorized file
const unauthorizedFileId = uuidv4();
await createFile({
user: authorId,
file_id: unauthorizedFileId,
filename: 'unauthorized.txt',
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
bytes: 1024,
type: 'text/plain',
});
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: fileId, // Authorized through agent
filepath: `/uploads/${authorId}/${fileId}`,
},
{
file_id: userFileId, // Owned by user
filepath: `/uploads/${otherUserId}/${userFileId}`,
},
{
file_id: unauthorizedFileId, // Not authorized
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(unauthorizedFileId);
expect(response.body.unauthorizedFiles).not.toContain(fileId);
expect(response.body.unauthorizedFiles).not.toContain(userFileId);
});
it('should prevent deleting files when agent is not collaborative', async () => {
// Update the agent to be non-collaborative
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { isCollaborative: false });
const response = await request(app)
.delete('/files')
.send({
agent_id: agentId,
files: [
{
file_id: fileId,
filepath: `/uploads/${authorId}/${fileId}`,
},
],
});
expect(response.status).toBe(403);
expect(response.body.message).toBe('You can only delete files you have access to');
expect(response.body.unauthorizedFiles).toContain(fileId);
expect(processDeleteRequest).not.toHaveBeenCalled();
});
});
});

View File

@@ -1,18 +1,13 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { MCPOAuthHandler } = require('@librechat/api');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
const { getUserPluginAuthValueByPlugin } = require('~/server/services/PluginService');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { requireJwtAuth } = require('~/server/middleware');
const { getFlowStateManager } = require('~/config');
const { getLogStores } = require('~/cache');
const router = Router();
const suppressLogging = true;
/**
* Initiate OAuth flow
* This endpoint is called when the user clicks the auth link in the UI
@@ -207,472 +202,4 @@ router.get('/oauth/status/:flowId', async (req, res) => {
}
});
/**
* Get connection status for all MCP servers
* This endpoint returns the actual connection status from MCPManager
*/
router.get('/connection/status', requireJwtAuth, async (req, res) => {
try {
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const mcpManager = getMCPManager(null, true); // Skip idle checks to avoid ping spam
const connectionStatus = {};
// Get all MCP server names from custom config
const config = await loadCustomConfig(suppressLogging);
const mcpConfig = config?.mcpServers;
if (mcpConfig) {
for (const [serverName, config] of Object.entries(mcpConfig)) {
try {
// Check if this is an app-level connection (exists in mcpManager.connections)
const appConnection = mcpManager.getConnection(serverName);
const hasAppConnection = !!appConnection;
// Check if this is a user-level connection (exists in mcpManager.userConnections)
const userConnection = mcpManager.getUserConnectionIfExists(user.id, serverName);
const hasUserConnection = !!userConnection;
// Use lightweight connection state check instead of ping-based isConnected()
let connected = false;
if (hasAppConnection) {
// Check connection state without ping to avoid rate limits
connected = appConnection.connectionState === 'connected';
} else if (hasUserConnection) {
// Check connection state without ping to avoid rate limits
connected = userConnection.connectionState === 'connected';
}
// Determine if this server requires user authentication
const hasAuthConfig =
config.customUserVars && Object.keys(config.customUserVars).length > 0;
const requiresOAuth = req.app.locals.mcpOAuthRequirements?.[serverName] || false;
connectionStatus[serverName] = {
connected,
hasAuthConfig,
hasConnection: hasAppConnection || hasUserConnection,
isAppLevel: hasAppConnection,
isUserLevel: hasUserConnection,
requiresOAuth,
};
} catch (error) {
logger.error(
`[MCP Connection Status] Error checking connection for ${serverName}:`,
error,
);
connectionStatus[serverName] = {
connected: false,
hasAuthConfig: config.customUserVars && Object.keys(config.customUserVars).length > 0,
hasConnection: false,
isAppLevel: false,
isUserLevel: false,
requiresOAuth: req.app.locals.mcpOAuthRequirements?.[serverName] || false,
error: error.message,
};
}
}
}
res.json({
success: true,
connectionStatus,
});
} catch (error) {
logger.error('[MCP Connection Status] Failed to get connection status', error);
res.status(500).json({ error: 'Failed to get connection status' });
}
});
/**
* Check which authentication values exist for a specific MCP server
* This endpoint returns only boolean flags indicating if values are set, not the actual values
*/
router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const config = await loadCustomConfig(suppressLogging);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
});
}
const serverConfig = config.mcpServers[serverName];
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
const authValueFlags = {};
// Check existence of saved values for each custom user variable (don't fetch actual values)
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
for (const varName of Object.keys(serverConfig.customUserVars)) {
try {
const value = await getUserPluginAuthValueByPlugin(user.id, varName, pluginKey, false);
// Only store boolean flag indicating if value exists
authValueFlags[varName] = !!(value && value.length > 0);
} catch (err) {
logger.error(
`[MCP Auth Value Flags] Error checking ${varName} for user ${user.id}:`,
err,
);
// Default to false if we can't check
authValueFlags[varName] = false;
}
}
}
res.json({
success: true,
serverName,
authValueFlags,
});
} catch (error) {
logger.error(
`[MCP Auth Value Flags] Failed to check auth value flags for ${req.params.serverName}`,
error,
);
res.status(500).json({ error: 'Failed to check auth value flags' });
}
});
/**
* Check if a specific MCP server requires OAuth
* This endpoint checks if a specific MCP server requires OAuth authentication
*/
router.get('/:serverName/oauth/required', requireJwtAuth, async (req, res) => {
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
return res.status(401).json({ error: 'User not authenticated' });
}
const mcpManager = getMCPManager();
const requiresOAuth = await mcpManager.isOAuthRequired(serverName);
res.json({
success: true,
serverName,
requiresOAuth,
});
} catch (error) {
logger.error(
`[MCP OAuth Required] Failed to check OAuth requirement for ${req.params.serverName}`,
error,
);
res.status(500).json({ error: 'Failed to check OAuth requirement' });
}
});
/**
* Complete MCP server reinitialization after OAuth
* This endpoint completes the reinitialization process after OAuth authentication
*/
router.post('/:serverName/reinitialize/complete', requireJwtAuth, async (req, res) => {
let responseSent = false;
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
responseSent = true;
return res.status(401).json({ error: 'User not authenticated' });
}
logger.info(`[MCP Complete Reinitialize] Starting completion for ${serverName}`);
const mcpManager = getMCPManager();
// Wait for connection to be established via event-driven approach
const userConnection = await new Promise((resolve, reject) => {
// Set a reasonable timeout (10 seconds)
const timeout = setTimeout(() => {
mcpManager.removeListener('connectionEstablished', connectionHandler);
reject(new Error('Timeout waiting for connection establishment'));
}, 10000);
const connectionHandler = ({
userId: eventUserId,
serverName: eventServerName,
connection,
}) => {
if (eventUserId === user.id && eventServerName === serverName) {
clearTimeout(timeout);
mcpManager.removeListener('connectionEstablished', connectionHandler);
resolve(connection);
}
};
// Check if connection already exists
const existingConnection = mcpManager.getUserConnectionIfExists(user.id, serverName);
if (existingConnection) {
clearTimeout(timeout);
resolve(existingConnection);
return;
}
// Listen for the connection establishment event
mcpManager.on('connectionEstablished', connectionHandler);
});
if (!userConnection) {
responseSent = true;
return res.status(404).json({ error: 'User connection not found' });
}
const userTools = (await getCachedTools({ userId: user.id })) || {};
// Remove any old tools from this server in the user's cache
const mcpDelimiter = Constants.mcp_delimiter;
for (const key of Object.keys(userTools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete userTools[key];
}
}
// Add the new tools from this server
const tools = await userConnection.fetchTools();
for (const tool of tools) {
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
userTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema,
},
};
}
// Save the updated user tool cache
await setCachedTools(userTools, { userId: user.id });
responseSent = true;
res.json({
success: true,
message: `MCP server '${serverName}' reinitialized successfully`,
serverName,
});
} catch (error) {
logger.error(
`[MCP Complete Reinitialize] Error completing reinitialization for ${req.params.serverName}:`,
error,
);
if (!responseSent) {
res.status(500).json({
success: false,
message: 'Failed to complete MCP server reinitialization',
serverName: req.params.serverName,
});
}
}
});
/**
* Reinitialize MCP server
* This endpoint allows reinitializing a specific MCP server
*/
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
let responseSent = false;
try {
const { serverName } = req.params;
const user = req.user;
if (!user?.id) {
responseSent = true;
return res.status(401).json({ error: 'User not authenticated' });
}
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
const config = await loadCustomConfig(suppressLogging);
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
responseSent = true;
return res.status(404).json({
error: `MCP server '${serverName}' not found in configuration`,
});
}
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = getFlowStateManager(flowsCache);
const mcpManager = getMCPManager();
// Clean up any stale OAuth flows for this server
try {
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
const existingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.status === 'PENDING') {
logger.info(`[MCP Reinitialize] Cleaning up stale OAuth flow for ${serverName}`);
await flowManager.failFlow(flowId, 'mcp_oauth', new Error('OAuth flow interrupted'));
}
} catch (error) {
logger.warn(
`[MCP Reinitialize] Error cleaning up stale OAuth flow for ${serverName}:`,
error,
);
}
await mcpManager.disconnectServer(serverName);
logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`);
const serverConfig = config.mcpServers[serverName];
mcpManager.mcpConfigs[serverName] = serverConfig;
let customUserVars = {};
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
for (const varName of Object.keys(serverConfig.customUserVars)) {
try {
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
const value = await getUserPluginAuthValueByPlugin(user.id, varName, pluginKey, false);
if (value) {
customUserVars[varName] = value;
}
} catch (err) {
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
}
}
}
let userConnection = null;
try {
userConnection = await mcpManager.getUserConnection({
user,
serverName,
flowManager,
customUserVars,
tokenMethods: {
findToken,
updateToken,
createToken,
deleteTokens,
},
oauthStart: (authURL) => {
// This will be called if OAuth is required
responseSent = true;
logger.info(`[MCP Reinitialize] OAuth required for ${serverName}, auth URL: ${authURL}`);
// Get the flow ID for polling
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
// Return the OAuth response immediately - client will poll for completion
res.json({
success: false,
oauthRequired: true,
authURL,
flowId,
message: `OAuth authentication required for MCP server '${serverName}'`,
serverName,
});
},
oauthEnd: () => {
// This will be called when OAuth flow completes
logger.info(`[MCP Reinitialize] OAuth flow completed for ${serverName}`);
},
});
// If response was already sent for OAuth, don't continue
if (responseSent) {
return;
}
} catch (err) {
logger.error(`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`, err);
// Check if this is an OAuth error
if (err.message && err.message.includes('OAuth required')) {
// Try to get the OAuth URL from the flow manager
try {
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
const existingFlow = await flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.metadata) {
const { serverUrl, oauth: oauthConfig } = existingFlow.metadata;
if (serverUrl && oauthConfig) {
const { authorizationUrl: authUrl } = await MCPOAuthHandler.initiateOAuthFlow(
serverName,
serverUrl,
user.id,
oauthConfig,
);
return res.json({
success: false,
oauthRequired: true,
authURL: authUrl,
flowId,
message: `OAuth authentication required for MCP server '${serverName}'`,
serverName,
});
}
}
} catch (oauthErr) {
logger.error(`[MCP Reinitialize] Error getting OAuth URL for ${serverName}:`, oauthErr);
}
responseSent = true;
return res.status(401).json({
success: false,
oauthRequired: true,
message: `OAuth authentication required for MCP server '${serverName}'`,
serverName,
});
}
responseSent = true;
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
}
const userTools = (await getCachedTools({ userId: user.id })) || {};
// Remove any old tools from this server in the user's cache
const mcpDelimiter = Constants.mcp_delimiter;
for (const key of Object.keys(userTools)) {
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
delete userTools[key];
}
}
// Add the new tools from this server
const tools = await userConnection.fetchTools();
for (const tool of tools) {
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
userTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema,
},
};
}
// Save the updated user tool cache
await setCachedTools(userTools, { userId: user.id });
responseSent = true;
res.json({
success: true,
message: `MCP server '${serverName}' reinitialized successfully`,
serverName,
});
} catch (error) {
logger.error('[MCP Reinitialize] Unexpected error', error);
if (!responseSent) {
res.status(500).json({ error: 'Internal server error' });
}
}
});
module.exports = router;

View File

@@ -172,68 +172,40 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
/**
* PATCH /memories/:key
* Updates the value of an existing memory entry for the authenticated user.
* Body: { key?: string, value: string }
* Body: { value: string }
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
*/
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
const { key: urlKey } = req.params;
const { key: bodyKey, value } = req.body || {};
const { key } = req.params;
const { value } = req.body || {};
if (typeof value !== 'string' || value.trim() === '') {
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
}
// Use the key from the body if provided, otherwise use the key from the URL
const newKey = bodyKey || urlKey;
try {
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
const memories = await getAllUserMemories(req.user.id);
const existingMemory = memories.find((m) => m.key === urlKey);
const existingMemory = memories.find((m) => m.key === key);
if (!existingMemory) {
return res.status(404).json({ error: 'Memory not found.' });
}
// If the key is changing, we need to handle it specially
if (newKey !== urlKey) {
const keyExists = memories.find((m) => m.key === newKey);
if (keyExists) {
return res.status(409).json({ error: 'Memory with this key already exists.' });
}
const result = await setMemory({
userId: req.user.id,
key,
value,
tokenCount,
});
const createResult = await createMemory({
userId: req.user.id,
key: newKey,
value,
tokenCount,
});
if (!createResult.ok) {
return res.status(500).json({ error: 'Failed to create new memory.' });
}
const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey });
if (!deleteResult.ok) {
return res.status(500).json({ error: 'Failed to delete old memory.' });
}
} else {
// Key is not changing, just update the value
const result = await setMemory({
userId: req.user.id,
key: newKey,
value,
tokenCount,
});
if (!result.ok) {
return res.status(500).json({ error: 'Failed to update memory.' });
}
if (!result.ok) {
return res.status(500).json({ error: 'Failed to update memory.' });
}
const updatedMemories = await getAllUserMemories(req.user.id);
const updatedMemory = updatedMemories.find((m) => m.key === newKey);
const updatedMemory = updatedMemories.find((m) => m.key === key);
res.json({ updated: true, memory: updatedMemory });
} catch (error) {

View File

@@ -1,11 +1,8 @@
const express = require('express');
const staticCache = require('../utils/staticCache');
const paths = require('~/config/paths');
const { isEnabled } = require('~/server/utils');
const skipGzipScan = !isEnabled(process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN);
const router = express.Router();
router.use(staticCache(paths.imageOutput, { skipGzipScan }));
router.use(staticCache(paths.imageOutput));
module.exports = router;

View File

@@ -1,11 +1,12 @@
const { agentsConfigSetup, loadWebSearchConfig } = require('@librechat/api');
const {
FileSources,
loadOCRConfig,
EModelEndpoint,
loadMemoryConfig,
getConfigDefaults,
loadWebSearchConfig,
} = require('librechat-data-provider');
const { agentsConfigSetup } = require('@librechat/api');
const {
checkHealth,
checkConfig,
@@ -157,10 +158,6 @@ const AppService = async (app) => {
}
});
if (endpoints?.all) {
endpointLocals.all = endpoints.all;
}
app.locals = {
...defaultLocals,
fileConfig: config?.fileConfig,

View File

@@ -543,206 +543,6 @@ describe('AppService', () => {
expect(process.env.IMPORT_USER_MAX).toEqual('initialUserMax');
expect(process.env.IMPORT_USER_WINDOW).toEqual('initialUserWindow');
});
it('should correctly configure endpoint with titlePrompt, titleMethod, and titlePromptTemplate', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
endpoints: {
[EModelEndpoint.openAI]: {
titleConvo: true,
titleModel: 'gpt-3.5-turbo',
titleMethod: 'structured',
titlePrompt: 'Custom title prompt for conversation',
titlePromptTemplate: 'Summarize this conversation: {{conversation}}',
},
[EModelEndpoint.assistants]: {
titleMethod: 'functions',
titlePrompt: 'Generate a title for this assistant conversation',
titlePromptTemplate: 'Assistant conversation template: {{messages}}',
},
[EModelEndpoint.azureOpenAI]: {
groups: azureGroups,
titleConvo: true,
titleMethod: 'completion',
titleModel: 'gpt-4',
titlePrompt: 'Azure title prompt',
titlePromptTemplate: 'Azure conversation: {{context}}',
},
},
}),
);
await AppService(app);
// Check OpenAI endpoint configuration
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
expect(app.locals[EModelEndpoint.openAI]).toEqual(
expect.objectContaining({
titleConvo: true,
titleModel: 'gpt-3.5-turbo',
titleMethod: 'structured',
titlePrompt: 'Custom title prompt for conversation',
titlePromptTemplate: 'Summarize this conversation: {{conversation}}',
}),
);
// Check Assistants endpoint configuration
expect(app.locals).toHaveProperty(EModelEndpoint.assistants);
expect(app.locals[EModelEndpoint.assistants]).toMatchObject({
titleMethod: 'functions',
titlePrompt: 'Generate a title for this assistant conversation',
titlePromptTemplate: 'Assistant conversation template: {{messages}}',
});
// Check Azure OpenAI endpoint configuration
expect(app.locals).toHaveProperty(EModelEndpoint.azureOpenAI);
expect(app.locals[EModelEndpoint.azureOpenAI]).toEqual(
expect.objectContaining({
titleConvo: true,
titleMethod: 'completion',
titleModel: 'gpt-4',
titlePrompt: 'Azure title prompt',
titlePromptTemplate: 'Azure conversation: {{context}}',
}),
);
});
it('should configure Agent endpoint with title generation settings', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
endpoints: {
[EModelEndpoint.agents]: {
disableBuilder: false,
titleConvo: true,
titleModel: 'gpt-4',
titleMethod: 'structured',
titlePrompt: 'Generate a descriptive title for this agent conversation',
titlePromptTemplate: 'Agent conversation summary: {{content}}',
recursionLimit: 15,
capabilities: [AgentCapabilities.tools, AgentCapabilities.actions],
},
},
}),
);
await AppService(app);
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
expect(app.locals[EModelEndpoint.agents]).toMatchObject({
disableBuilder: false,
titleConvo: true,
titleModel: 'gpt-4',
titleMethod: 'structured',
titlePrompt: 'Generate a descriptive title for this agent conversation',
titlePromptTemplate: 'Agent conversation summary: {{content}}',
recursionLimit: 15,
capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]),
});
});
it('should handle missing title configuration options with defaults', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
endpoints: {
[EModelEndpoint.openAI]: {
titleConvo: true,
// titlePrompt and titlePromptTemplate are not provided
},
},
}),
);
await AppService(app);
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
expect(app.locals[EModelEndpoint.openAI]).toMatchObject({
titleConvo: true,
});
// Check that the optional fields are undefined when not provided
expect(app.locals[EModelEndpoint.openAI].titlePrompt).toBeUndefined();
expect(app.locals[EModelEndpoint.openAI].titlePromptTemplate).toBeUndefined();
expect(app.locals[EModelEndpoint.openAI].titleMethod).toBeUndefined();
});
it('should correctly configure titleEndpoint when specified', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
endpoints: {
[EModelEndpoint.openAI]: {
titleConvo: true,
titleModel: 'gpt-3.5-turbo',
titleEndpoint: EModelEndpoint.anthropic,
titlePrompt: 'Generate a concise title',
},
[EModelEndpoint.agents]: {
titleEndpoint: 'custom-provider',
titleMethod: 'structured',
},
},
}),
);
await AppService(app);
// Check OpenAI endpoint has titleEndpoint
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
expect(app.locals[EModelEndpoint.openAI]).toMatchObject({
titleConvo: true,
titleModel: 'gpt-3.5-turbo',
titleEndpoint: EModelEndpoint.anthropic,
titlePrompt: 'Generate a concise title',
});
// Check Agents endpoint has titleEndpoint
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
expect(app.locals[EModelEndpoint.agents]).toMatchObject({
titleEndpoint: 'custom-provider',
titleMethod: 'structured',
});
});
it('should correctly configure all endpoint when specified', async () => {
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve({
endpoints: {
all: {
titleConvo: true,
titleModel: 'gpt-4o-mini',
titleMethod: 'structured',
titlePrompt: 'Default title prompt for all endpoints',
titlePromptTemplate: 'Default template: {{conversation}}',
titleEndpoint: EModelEndpoint.anthropic,
streamRate: 50,
},
[EModelEndpoint.openAI]: {
titleConvo: true,
titleModel: 'gpt-3.5-turbo',
},
},
}),
);
await AppService(app);
// Check that 'all' endpoint config is loaded
expect(app.locals).toHaveProperty('all');
expect(app.locals.all).toMatchObject({
titleConvo: true,
titleModel: 'gpt-4o-mini',
titleMethod: 'structured',
titlePrompt: 'Default title prompt for all endpoints',
titlePromptTemplate: 'Default template: {{conversation}}',
titleEndpoint: EModelEndpoint.anthropic,
streamRate: 50,
});
// Check that OpenAI endpoint has its own config
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
expect(app.locals[EModelEndpoint.openAI]).toMatchObject({
titleConvo: true,
titleModel: 'gpt-3.5-turbo',
});
});
});
describe('AppService updating app.locals and issuing warnings', () => {

View File

@@ -1,9 +1,10 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
const { getUserMCPAuthMap } = require('@librechat/api');
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName } = require('~/server/utils');
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig');
const { getCachedTools } = require('./getCachedTools');
const { findPluginAuthsByKeys } = require('~/models');
const getLogStores = require('~/cache/getLogStores');
/**
@@ -54,48 +55,46 @@ const getCustomEndpointConfig = async (endpoint) => {
);
};
/**
* @param {Object} params
* @param {string} params.userId
* @param {GenericTool[]} [params.tools]
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
try {
if (!tools || tools.length === 0) {
return;
}
const appTools = await getCachedTools({
userId,
});
return await getUserMCPAuthMap({
tools,
userId,
appTools,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
}
/**
* @returns {Promise<boolean>}
*/
async function hasCustomUserVars() {
async function createGetMCPAuthMap() {
const customConfig = await getCustomConfig();
const mcpServers = customConfig?.mcpServers;
return Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
if (!hasCustomUserVars) {
return;
}
/**
* @param {Object} params
* @param {GenericTool[]} [params.tools]
* @param {string} params.userId
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
*/
return async function ({ tools, userId }) {
try {
if (!tools || tools.length === 0) {
return;
}
const appTools = await getCachedTools({
userId,
});
return await getUserMCPAuthMap({
tools,
userId,
appTools,
findPluginAuthsByKeys,
});
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
err,
);
}
};
}
module.exports = {
getMCPAuthMap,
getCustomConfig,
getBalanceConfig,
hasCustomUserVars,
createGetMCPAuthMap,
getCustomEndpointConfig,
};

View File

@@ -22,7 +22,8 @@ async function loadAsyncEndpoints(req) {
} else {
/** Only attempt to load service key if GOOGLE_KEY is not provided */
const serviceKeyPath =
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json');
process.env.GOOGLE_SERVICE_KEY_FILE_PATH ||
path.join(__dirname, '../../..', 'data', 'auth.json');
try {
serviceKey = await loadServiceKey(serviceKeyPath);

View File

@@ -23,10 +23,9 @@ let i = 0;
* Load custom configuration files and caches the object if the `cache` field at root is true.
* Validation via parsing the config file with the config schema.
* @function loadCustomConfig
* @param {boolean} [suppressLogging=false] - If true, suppresses the verbose config logging of the entire config when called.
* @returns {Promise<TCustomConfig | null>} A promise that resolves to null or the custom config object.
* */
async function loadCustomConfig(suppressLogging = false) {
async function loadCustomConfig() {
// Use CONFIG_PATH if set, otherwise fallback to defaultConfigPath
const configPath = process.env.CONFIG_PATH || defaultConfigPath;
@@ -109,11 +108,9 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
return null;
} else {
if (!suppressLogging) {
logger.info('Custom config file loaded:');
logger.info(JSON.stringify(customConfig, null, 2));
logger.debug('Custom config:', customConfig);
}
logger.info('Custom config file loaded:');
logger.info(JSON.stringify(customConfig, null, 2));
logger.debug('Custom config:', customConfig);
}
(customConfig.endpoints?.custom ?? [])

View File

@@ -8,12 +8,11 @@ const {
ErrorTypes,
EModelEndpoint,
EToolResources,
isAgentsEndpoint,
replaceSpecialVars,
providerEndpointMap,
} = require('librechat-data-provider');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getProviderConfig } = require('~/server/services/Endpoints');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { processFiles } = require('~/server/services/Files/process');
const { getFiles, getToolFilesByIds } = require('~/models/File');
const { getConvoFiles } = require('~/models/Conversation');
@@ -43,11 +42,7 @@ const initializeAgent = async ({
allowedProviders,
isInitialAgent = false,
}) => {
if (
isAgentsEndpoint(endpointOption?.endpoint) &&
allowedProviders.size > 0 &&
!allowedProviders.has(agent.provider)
) {
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
throw new Error(
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
);
@@ -87,7 +82,6 @@ const initializeAgent = async ({
attachments: currentFiles,
tool_resources: agent.tool_resources,
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
agentId: agent.id,
});
const provider = agent.provider;
@@ -104,7 +98,7 @@ const initializeAgent = async ({
agent.endpoint = provider;
const { getOptions, overrideProvider } = await getProviderConfig(provider);
if (overrideProvider !== agent.provider) {
if (overrideProvider) {
agent.provider = overrideProvider;
}
@@ -186,11 +180,10 @@ const initializeAgent = async ({
return {
...agent,
tools,
attachments,
resendFiles,
toolContextMap,
useLegacyContent: !!options.useLegacyContent,
tools,
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
};
};

View File

@@ -1,5 +1,5 @@
const OpenAI = require('openai');
const { ProxyAgent } = require('undici');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { ErrorTypes, EModelEndpoint } = require('librechat-data-provider');
const {
getUserKeyValues,
@@ -59,10 +59,7 @@ const initializeClient = async ({ req, res, endpointOption, version, initAppClie
}
if (PROXY) {
const proxyAgent = new ProxyAgent(PROXY);
opts.fetchOptions = {
dispatcher: proxyAgent,
};
opts.httpAgent = new HttpsProxyAgent(PROXY);
}
if (OPENAI_ORGANIZATION) {

View File

@@ -1,5 +1,5 @@
// const OpenAI = require('openai');
const { ProxyAgent } = require('undici');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { ErrorTypes } = require('librechat-data-provider');
const { getUserKey, getUserKeyExpiry, getUserKeyValues } = require('~/server/services/UserService');
const initializeClient = require('./initalize');
@@ -107,7 +107,6 @@ describe('initializeClient', () => {
const res = {};
const { openai } = await initializeClient({ req, res });
expect(openai.fetchOptions).toBeDefined();
expect(openai.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent);
});
});

View File

@@ -1,5 +1,5 @@
const OpenAI = require('openai');
const { ProxyAgent } = require('undici');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api');
const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider');
const {
@@ -158,10 +158,7 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
}
if (PROXY) {
const proxyAgent = new ProxyAgent(PROXY);
opts.fetchOptions = {
dispatcher: proxyAgent,
};
opts.httpAgent = new HttpsProxyAgent(PROXY);
}
if (OPENAI_ORGANIZATION) {

View File

@@ -1,5 +1,5 @@
// const OpenAI = require('openai');
const { ProxyAgent } = require('undici');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { ErrorTypes } = require('librechat-data-provider');
const { getUserKey, getUserKeyExpiry, getUserKeyValues } = require('~/server/services/UserService');
const initializeClient = require('./initialize');
@@ -107,7 +107,6 @@ describe('initializeClient', () => {
const res = {};
const { openai } = await initializeClient({ req, res });
expect(openai.fetchOptions).toBeDefined();
expect(openai.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent);
});
});

View File

@@ -139,9 +139,6 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
);
clientOptions.modelOptions.user = req.user.id;
const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
if (options != null) {
options.useLegacyContent = true;
}
if (!customOptions.streamRate) {
return options;
}
@@ -159,7 +156,6 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
}
return {
useLegacyContent: true,
llmConfig: modelOptions,
};
}

View File

@@ -25,7 +25,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
/** Only attempt to load service key if GOOGLE_KEY is not provided */
try {
const serviceKeyPath =
process.env.GOOGLE_SERVICE_KEY_FILE ||
process.env.GOOGLE_SERVICE_KEY_FILE_PATH ||
path.join(__dirname, '../../../..', 'data', 'auth.json');
serviceKey = await loadServiceKey(serviceKeyPath);
if (!serviceKey) {

View File

@@ -34,13 +34,13 @@ const providerConfigMap = {
* @param {string} provider - The provider string
* @returns {Promise<{
* getOptions: Function,
* overrideProvider: string,
* overrideProvider?: string,
* customEndpointConfig?: TEndpoint
* }>}
*/
async function getProviderConfig(provider) {
let getOptions = providerConfigMap[provider];
let overrideProvider = provider;
let overrideProvider;
/** @type {TEndpoint | undefined} */
let customEndpointConfig;
@@ -56,7 +56,7 @@ async function getProviderConfig(provider) {
overrideProvider = Providers.OPENAI;
}
if (isKnownCustomProvider(overrideProvider) && !customEndpointConfig) {
if (isKnownCustomProvider(overrideProvider || provider) && !customEndpointConfig) {
customEndpointConfig = await getCustomEndpointConfig(provider);
if (!customEndpointConfig) {
throw new Error(`Provider ${provider} not supported`);

View File

@@ -65,20 +65,19 @@ const initializeClient = async ({
const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI;
/** @type {false | TAzureConfig} */
const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI];
let serverless = false;
if (isAzureOpenAI && azureConfig) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless: _serverless,
serverless,
} = mapModelToAzureConfig({
modelName,
modelGroupMap,
groupMap,
});
serverless = _serverless;
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
clientOptions.headers = resolveHeaders(
@@ -144,9 +143,6 @@ const initializeClient = async ({
clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions.modelOptions.user = req.user.id;
const options = getOpenAIConfig(apiKey, clientOptions);
if (options != null && serverless === true) {
options.useLegacyContent = true;
}
const streamRate = clientOptions.streamRate;
if (!streamRate) {
return options;

View File

@@ -152,7 +152,6 @@ async function getSessionInfo(fileIdentifier, apiKey) {
* @param {Object} options
* @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @param {string} apiKey
* @returns {Promise<{
* files: Array<{ id: string; session_id: string; name: string }>,
@@ -160,18 +159,11 @@ async function getSessionInfo(fileIdentifier, apiKey) {
* }>}
*/
const primeFiles = async (options, apiKey) => {
const { tool_resources, req, agentId } = options;
const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.execute_code]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.execute_code]?.files ?? [];
const dbFiles = (
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
const files = [];
const sessions = new Map();

View File

@@ -2,16 +2,16 @@ const { z } = require('zod');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Time, CacheKeys, StepTypes } = require('librechat-data-provider');
const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api');
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
const { Constants, ContentTypes, isAssistantsEndpoint } = require('librechat-data-provider');
const {
sendEvent,
MCPOAuthHandler,
normalizeServerName,
convertWithResolvedRefs,
} = require('@librechat/api');
const { findToken, createToken, updateToken } = require('~/models');
Constants,
ContentTypes,
isAssistantsEndpoint,
convertJsonSchemaToZod,
} = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getCachedTools } = require('./Config');
const { getLogStores } = require('~/cache');
@@ -104,7 +104,7 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createMCPTool({ req, res, toolKey, provider: _provider }) {
const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true });
const availableTools = await getCachedTools({ includeGlobal: true });
const toolDefinition = availableTools?.[toolKey]?.function;
if (!toolDefinition) {
logger.error(`Tool ${toolKey} not found in available tools`);
@@ -113,7 +113,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
/** @type {LCTool} */
const { description, parameters } = toolDefinition;
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
let schema = convertWithResolvedRefs(parameters, {
let schema = convertJsonSchemaToZod(parameters, {
allowEmptyObject: !isGoogle,
transformOneOfAnyOf: true,
});

View File

@@ -41,38 +41,6 @@ const getUserPluginAuthValue = async (userId, authField, throwError = true) => {
}
};
/**
* Asynchronously retrieves and decrypts the authentication value for a user's specific plugin, based on a specified authentication field and plugin key.
*
* @param {string} userId - The unique identifier of the user for whom the plugin authentication value is to be retrieved.
* @param {string} authField - The specific authentication field (e.g., 'API_KEY', 'URL') whose value is to be retrieved and decrypted.
* @param {string} pluginKey - The plugin key to filter by (e.g., 'mcp_github-mcp').
* @param {boolean} throwError - Whether to throw an error if the authentication value does not exist. Defaults to `true`.
* @returns {Promise<string|null>} A promise that resolves to the decrypted authentication value if found, or `null` if no such authentication value exists for the given user, field, and plugin.
*
* @throws {Error} Throws an error if there's an issue during the retrieval or decryption process, or if the authentication value does not exist.
* @async
*/
const getUserPluginAuthValueByPlugin = async (userId, authField, pluginKey, throwError = true) => {
try {
const pluginAuth = await findOnePluginAuth({ userId, authField, pluginKey });
if (!pluginAuth) {
throw new Error(
`No plugin auth ${authField} found for user ${userId} and plugin ${pluginKey}`,
);
}
const decryptedValue = await decrypt(pluginAuth.value);
return decryptedValue;
} catch (err) {
if (!throwError) {
return null;
}
logger.error('[getUserPluginAuthValueByPlugin]', err);
throw err;
}
};
// const updateUserPluginAuth = async (userId, authField, pluginKey, value) => {
// try {
// const encryptedValue = encrypt(value);
@@ -151,7 +119,6 @@ const deleteUserPluginAuth = async (userId, authField, all = false, pluginKey) =
module.exports = {
getUserPluginAuthValue,
getUserPluginAuthValueByPlugin,
updateUserPluginAuth,
deleteUserPluginAuth,
};

View File

@@ -226,7 +226,7 @@ async function processRequiredActions(client, requiredActions) {
`[required actions] user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`,
requiredActions,
);
const toolDefinitions = await getCachedTools({ userId: client.req.user.id, includeGlobal: true });
const toolDefinitions = await getCachedTools({ includeGlobal: true });
const seenToolkits = new Set();
const tools = requiredActions
.map((action) => {

View File

@@ -9,44 +9,20 @@ const { getLogStores } = require('~/cache');
* Initialize MCP servers
* @param {import('express').Application} app - Express app instance
*/
async function initializeMCPs(app) {
// TEMPORARY: Reset all OAuth tokens for fresh testing
try {
logger.info('[MCP] Resetting all OAuth tokens for fresh testing...');
await deleteTokens({});
logger.info('[MCP] All OAuth tokens reset successfully');
} catch (error) {
logger.error('[MCP] Error resetting OAuth tokens:', error);
}
async function initializeMCP(app) {
const mcpServers = app.locals.mcpConfig;
if (!mcpServers) {
return;
}
// Filter out servers with startup: false
const filteredServers = {};
for (const [name, config] of Object.entries(mcpServers)) {
if (config.startup === false) {
logger.info(`Skipping MCP server '${name}' due to startup: false`);
continue;
}
filteredServers[name] = config;
}
if (Object.keys(filteredServers).length === 0) {
logger.info('[MCP] No MCP servers to initialize (all skipped or none configured)');
return;
}
logger.info('Initializing MCP servers...');
const mcpManager = getMCPManager();
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
try {
const oauthRequirements = await mcpManager.initializeMCPs({
mcpServers: filteredServers,
await mcpManager.initializeMCP({
mcpServers,
flowManager,
tokenMethods: {
findToken,
@@ -68,17 +44,10 @@ async function initializeMCPs(app) {
await mcpManager.mapAvailableTools(toolsCopy, flowManager);
await setCachedTools(toolsCopy, { isGlobal: true });
const cache = getLogStores(CacheKeys.CONFIG_STORE);
await cache.delete(CacheKeys.TOOLS);
logger.debug('Cleared tools array cache after MCP initialization');
logger.info('MCP servers initialized successfully');
// Store OAuth requirement information in app locals for client access
app.locals.mcpOAuthRequirements = oauthRequirements;
} catch (error) {
logger.error('Failed to initialize MCP servers:', error);
}
}
module.exports = initializeMCPs;
module.exports = initializeMCP;

View File

@@ -52,11 +52,6 @@ function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) {
privateAssistants: parsedConfig.privateAssistants,
timeoutMs: parsedConfig.timeoutMs,
streamRate: parsedConfig.streamRate,
titlePrompt: parsedConfig.titlePrompt,
titleMethod: parsedConfig.titleMethod,
titleModel: parsedConfig.titleModel,
titleEndpoint: parsedConfig.titleEndpoint,
titlePromptTemplate: parsedConfig.titlePromptTemplate,
};
}

View File

@@ -1,6 +1,6 @@
const { webSearchKeys } = require('@librechat/api');
const {
Constants,
webSearchKeys,
deprecatedAzureVariables,
conflictingAzureVariables,
extractVariableName,

View File

@@ -1,5 +1,8 @@
const { Keyv } = require('keyv');
const passport = require('passport');
const session = require('express-session');
const MemoryStore = require('memorystore')(session);
const RedisStore = require('connect-redis').default;
const {
setupOpenId,
googleLogin,
@@ -11,9 +14,8 @@ const {
openIdJwtLogin,
} = require('~/strategies');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const { getLogStores } = require('~/cache');
const { CacheKeys } = require('librechat-data-provider');
/**
*
@@ -49,8 +51,17 @@ const configureSocialLogins = async (app) => {
secret: process.env.OPENID_SESSION_SECRET,
resave: false,
saveUninitialized: false,
store: getLogStores(CacheKeys.OPENID_SESSION),
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for session storage in OpenID...');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.client;
sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' });
} else {
sessionOptions.store = new MemoryStore({
checkPeriod: 86400000, // prune expired entries every 24h
});
}
app.use(session(sessionOptions));
app.use(passport.session());
const config = await setupOpenId();
@@ -71,8 +82,17 @@ const configureSocialLogins = async (app) => {
secret: process.env.SAML_SESSION_SECRET,
resave: false,
saveUninitialized: false,
store: getLogStores(CacheKeys.SAML_SESSION),
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for session storage in SAML...');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.client;
sessionOptions.store = new RedisStore({ client, prefix: 'saml_session' });
} else {
sessionOptions.store = new MemoryStore({
checkPeriod: 86400000, // prune expired entries every 24h
});
}
app.use(session(sessionOptions));
app.use(passport.session());
setupSaml();

View File

@@ -1,407 +0,0 @@
const fs = require('fs');
const path = require('path');
const express = require('express');
const request = require('supertest');
const zlib = require('zlib');
const staticCache = require('../staticCache');
describe('staticCache', () => {
let app;
let testDir;
let testFile;
let indexFile;
let manifestFile;
let swFile;
beforeAll(() => {
// Create a test directory and files
testDir = path.join(__dirname, 'test-static');
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir, { recursive: true });
}
// Create test files
testFile = path.join(testDir, 'test.js');
indexFile = path.join(testDir, 'index.html');
manifestFile = path.join(testDir, 'manifest.json');
swFile = path.join(testDir, 'sw.js');
const jsContent = 'console.log("test");';
const htmlContent = '<html><body>Test</body></html>';
const jsonContent = '{"name": "test"}';
const swContent = 'self.addEventListener("install", () => {});';
fs.writeFileSync(testFile, jsContent);
fs.writeFileSync(indexFile, htmlContent);
fs.writeFileSync(manifestFile, jsonContent);
fs.writeFileSync(swFile, swContent);
// Create gzipped versions of some files
fs.writeFileSync(testFile + '.gz', zlib.gzipSync(jsContent));
fs.writeFileSync(path.join(testDir, 'test.css'), 'body { color: red; }');
fs.writeFileSync(path.join(testDir, 'test.css.gz'), zlib.gzipSync('body { color: red; }'));
// Create a file that only exists in gzipped form
fs.writeFileSync(
path.join(testDir, 'only-gzipped.js.gz'),
zlib.gzipSync('console.log("only gzipped");'),
);
// Create a subdirectory for dist/images testing
const distImagesDir = path.join(testDir, 'dist', 'images');
fs.mkdirSync(distImagesDir, { recursive: true });
fs.writeFileSync(path.join(distImagesDir, 'logo.png'), 'fake-png-data');
});
afterAll(() => {
// Clean up test files
if (fs.existsSync(testDir)) {
fs.rmSync(testDir, { recursive: true, force: true });
}
});
beforeEach(() => {
app = express();
// Clear environment variables
delete process.env.NODE_ENV;
delete process.env.STATIC_CACHE_S_MAX_AGE;
delete process.env.STATIC_CACHE_MAX_AGE;
});
describe('cache headers in production', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should set standard cache headers for regular files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
});
it('should set no-cache headers for index.html', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/index.html').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
it('should set no-cache headers for manifest.json', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/manifest.json').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
it('should set no-cache headers for sw.js', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/sw.js').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
it('should not set cache headers for /dist/images/ files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/dist/images/logo.png').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=0');
});
it('should set no-cache headers when noCache option is true', async () => {
app.use(staticCache(testDir, { noCache: true }));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
});
});
describe('cache headers in non-production', () => {
beforeEach(() => {
process.env.NODE_ENV = 'development';
});
it('should not set cache headers in development', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
// Our middleware should not set cache-control in non-production
// Express static might set its own default headers
const cacheControl = response.headers['cache-control'];
expect(cacheControl).toBe('public, max-age=0');
});
});
describe('environment variable configuration', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should use custom s-maxage from environment', async () => {
process.env.STATIC_CACHE_S_MAX_AGE = '3600';
// Need to re-require to pick up new env vars
jest.resetModules();
const freshStaticCache = require('../staticCache');
app.use(freshStaticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=3600');
});
it('should use custom max-age from environment', async () => {
process.env.STATIC_CACHE_MAX_AGE = '7200';
// Need to re-require to pick up new env vars
jest.resetModules();
const freshStaticCache = require('../staticCache');
app.use(freshStaticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=7200, s-maxage=86400');
});
it('should use both custom values from environment', async () => {
process.env.STATIC_CACHE_S_MAX_AGE = '1800';
process.env.STATIC_CACHE_MAX_AGE = '3600';
// Need to re-require to pick up new env vars
jest.resetModules();
const freshStaticCache = require('../staticCache');
app.use(freshStaticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=3600, s-maxage=1800');
});
});
describe('express-static-gzip behavior', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve gzipped files when client accepts gzip encoding', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip, deflate')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['content-type']).toMatch(/javascript/);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
// Content should be decompressed by supertest
expect(response.text).toBe('console.log("test");');
});
it('should fall back to uncompressed files when client does not accept gzip', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'identity')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.headers['content-type']).toMatch(/javascript/);
expect(response.text).toBe('console.log("test");');
});
it('should serve gzipped CSS files with correct content-type', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.css')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['content-type']).toMatch(/css/);
expect(response.text).toBe('body { color: red; }');
});
it('should serve uncompressed files when no gzipped version exists', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/manifest.json')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.headers['content-type']).toMatch(/json/);
expect(response.text).toBe('{"name": "test"}');
});
it('should handle files that only exist in gzipped form', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/only-gzipped.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['content-type']).toMatch(/javascript/);
expect(response.text).toBe('console.log("only gzipped");');
});
it('should return 404 for gzip-only files when client does not accept gzip', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/only-gzipped.js')
.set('Accept-Encoding', 'identity');
expect(response.status).toBe(404);
});
it('should handle cache headers correctly for gzipped content', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.headers['content-type']).toMatch(/javascript/);
});
it('should preserve original MIME types for gzipped files', async () => {
app.use(staticCache(testDir, { skipGzipScan: false }));
const jsResponse = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
const cssResponse = await request(app)
.get('/test.css')
.set('Accept-Encoding', 'gzip')
.expect(200);
expect(jsResponse.headers['content-type']).toMatch(/javascript/);
expect(cssResponse.headers['content-type']).toMatch(/css/);
expect(jsResponse.headers['content-encoding']).toBe('gzip');
expect(cssResponse.headers['content-encoding']).toBe('gzip');
});
});
describe('skipGzipScan option comparison', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should use express.static (no gzip) when skipGzipScan is true', async () => {
app.use(staticCache(testDir, { skipGzipScan: true }));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
// Should NOT serve gzipped version even though client accepts it
expect(response.headers['content-encoding']).toBeUndefined();
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.text).toBe('console.log("test");');
});
it('should use expressStaticGzip when skipGzipScan is false', async () => {
app.use(staticCache(testDir));
const response = await request(app)
.get('/test.js')
.set('Accept-Encoding', 'gzip')
.expect(200);
// Should serve gzipped version when client accepts it
expect(response.headers['content-encoding']).toBe('gzip');
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.text).toBe('console.log("test");');
});
});
describe('file serving', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should serve files correctly', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/test.js').expect(200);
expect(response.text).toBe('console.log("test");');
expect(response.headers['content-type']).toMatch(/javascript|text/);
});
it('should return 404 for non-existent files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/nonexistent.js');
expect(response.status).toBe(404);
});
it('should serve HTML files', async () => {
app.use(staticCache(testDir));
const response = await request(app).get('/index.html').expect(200);
expect(response.text).toBe('<html><body>Test</body></html>');
expect(response.headers['content-type']).toMatch(/html/);
});
});
describe('edge cases', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should handle webmanifest files', async () => {
// Create a webmanifest file
const webmanifestFile = path.join(testDir, 'site.webmanifest');
fs.writeFileSync(webmanifestFile, '{"name": "test app"}');
app.use(staticCache(testDir));
const response = await request(app).get('/site.webmanifest').expect(200);
expect(response.headers['cache-control']).toBe('no-store, no-cache, must-revalidate');
// Clean up
fs.unlinkSync(webmanifestFile);
});
it('should handle files in subdirectories', async () => {
const subDir = path.join(testDir, 'subdir');
fs.mkdirSync(subDir, { recursive: true });
const subFile = path.join(subDir, 'nested.js');
fs.writeFileSync(subFile, 'console.log("nested");');
app.use(staticCache(testDir));
const response = await request(app).get('/subdir/nested.js').expect(200);
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
expect(response.text).toBe('console.log("nested");');
// Clean up
fs.rmSync(subDir, { recursive: true, force: true });
});
});
});

View File

@@ -1,5 +1,4 @@
const path = require('path');
const express = require('express');
const expressStaticGzip = require('express-static-gzip');
const oneDayInSeconds = 24 * 60 * 60;
@@ -8,55 +7,44 @@ const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds;
const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2;
/**
* Creates an Express static middleware with optional gzip compression and configurable caching
* Creates an Express static middleware with gzip compression and configurable caching
*
* @param {string} staticPath - The file system path to serve static files from
* @param {Object} [options={}] - Configuration options
* @param {boolean} [options.noCache=false] - If true, disables caching entirely for all files
* @param {boolean} [options.skipGzipScan=false] - If true, skips expressStaticGzip middleware
* @returns {ReturnType<expressStaticGzip>|ReturnType<express.static>} Express middleware function for serving static files
* @returns {ReturnType<expressStaticGzip>} Express middleware function for serving static files
*/
function staticCache(staticPath, options = {}) {
const { noCache = false, skipGzipScan = false } = options;
const { noCache = false } = options;
return expressStaticGzip(staticPath, {
enableBrotli: false,
orderPreference: ['gz'],
setHeaders: (res, filePath) => {
if (process.env.NODE_ENV?.toLowerCase() !== 'production') {
return;
}
if (noCache) {
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
return;
}
if (filePath.includes('/dist/images/')) {
return;
}
const fileName = path.basename(filePath);
const setHeaders = (res, filePath) => {
if (process.env.NODE_ENV?.toLowerCase() !== 'production') {
return;
}
if (noCache) {
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
return;
}
if (filePath && filePath.includes('/dist/images/')) {
return;
}
const fileName = filePath ? path.basename(filePath) : '';
if (
fileName === 'index.html' ||
fileName.endsWith('.webmanifest') ||
fileName === 'manifest.json' ||
fileName === 'sw.js'
) {
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
} else {
res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`);
}
};
if (skipGzipScan) {
return express.static(staticPath, {
setHeaders,
index: false,
});
} else {
return expressStaticGzip(staticPath, {
enableBrotli: false,
orderPreference: ['gz'],
setHeaders,
index: false,
});
}
if (
fileName === 'index.html' ||
fileName.endsWith('.webmanifest') ||
fileName === 'manifest.json' ||
fileName === 'sw.js'
) {
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate');
} else {
res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`);
}
},
index: false,
});
}
module.exports = staticCache;

View File

@@ -1074,7 +1074,7 @@
/**
* @exports JsonSchemaType
* @typedef {import('@librechat/api').JsonSchemaType} JsonSchemaType
* @typedef {import('librechat-data-provider').JsonSchemaType} JsonSchemaType
* @memberof typedefs
*/

View File

@@ -223,7 +223,6 @@ const xAIModels = {
'grok-3-fast': 131072,
'grok-3-mini': 131072,
'grok-3-mini-fast': 131072,
'grok-4': 256000, // 256K context
};
const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels };

View File

@@ -386,7 +386,7 @@ describe('matchModelName', () => {
});
it('should return the closest matching key for gpt-4-1106 partial matches', () => {
expect(matchModelName('gpt-4-1106/something')).toBe('gpt-4-1106');
expect(matchModelName('something/gpt-4-1106')).toBe('gpt-4-1106');
expect(matchModelName('gpt-4-1106-preview')).toBe('gpt-4-1106');
expect(matchModelName('gpt-4-1106-vision-preview')).toBe('gpt-4-1106');
});
@@ -589,10 +589,6 @@ describe('Grok Model Tests - Tokens', () => {
expect(getModelMaxTokens('grok-3-mini-fast')).toBe(131072);
});
test('should return correct tokens for Grok 4 model', () => {
expect(getModelMaxTokens('grok-4-0709')).toBe(256000);
});
test('should handle partial matches for Grok models with prefixes', () => {
// Vision models should match before general models
expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(32768);
@@ -610,8 +606,6 @@ describe('Grok Model Tests - Tokens', () => {
expect(getModelMaxTokens('xai/grok-3-fast')).toBe(131072);
expect(getModelMaxTokens('xai/grok-3-mini')).toBe(131072);
expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe(131072);
// Grok 4 model
expect(getModelMaxTokens('xai/grok-4-0709')).toBe(256000);
});
});
@@ -633,8 +627,6 @@ describe('Grok Model Tests - Tokens', () => {
expect(matchModelName('grok-3-fast')).toBe('grok-3-fast');
expect(matchModelName('grok-3-mini')).toBe('grok-3-mini');
expect(matchModelName('grok-3-mini-fast')).toBe('grok-3-mini-fast');
// Grok 4 model
expect(matchModelName('grok-4-0709')).toBe('grok-4');
});
test('should match Grok model variations with prefixes', () => {
@@ -654,8 +646,6 @@ describe('Grok Model Tests - Tokens', () => {
expect(matchModelName('xai/grok-3-fast')).toBe('grok-3-fast');
expect(matchModelName('xai/grok-3-mini')).toBe('grok-3-mini');
expect(matchModelName('xai/grok-3-mini-fast')).toBe('grok-3-mini-fast');
// Grok 4 model
expect(matchModelName('xai/grok-4-0709')).toBe('grok-4');
});
});
});

View File

@@ -1,31 +0,0 @@
import React, { createContext, useContext, useMemo } from 'react';
import type { EModelEndpoint } from 'librechat-data-provider';
import { useChatContext } from './ChatContext';
interface SidePanelContextValue {
endpoint?: EModelEndpoint | null;
}
const SidePanelContext = createContext<SidePanelContextValue | undefined>(undefined);
export function SidePanelProvider({ children }: { children: React.ReactNode }) {
const { conversation } = useChatContext();
/** Context value only created when endpoint changes */
const contextValue = useMemo<SidePanelContextValue>(
() => ({
endpoint: conversation?.endpoint,
}),
[conversation?.endpoint],
);
return <SidePanelContext.Provider value={contextValue}>{children}</SidePanelContext.Provider>;
}
export function useSidePanelContext() {
const context = useContext(SidePanelContext);
if (!context) {
throw new Error('useSidePanelContext must be used within SidePanelProvider');
}
return context;
}

View File

@@ -24,5 +24,4 @@ export * from './ToolCallsMapContext';
export * from './SetConvoContext';
export * from './SearchContext';
export * from './BadgeRowContext';
export * from './SidePanelContext';
export { default as BadgeRowProvider } from './BadgeRowContext';

View File

@@ -345,7 +345,9 @@ export type TOptions = {
isContinued?: boolean;
isEdited?: boolean;
overrideMessages?: t.TMessage[];
/** Currently only utilized when resubmitting user-created message, uses that message's currently attached files */
/** This value is only true when the user submits a message with "Save & Submit" for a user-created message */
isResubmission?: boolean;
/** Currently only utilized when `isResubmission === true`, uses that message's currently attached files */
overrideFiles?: t.TMessage['files'];
};

View File

@@ -1,8 +1,10 @@
import React, { useMemo } from 'react';
import { EToolResources, defaultAgentCapabilities } from 'librechat-data-provider';
import { EModelEndpoint, EToolResources } from 'librechat-data-provider';
import { FileSearch, ImageUpIcon, FileType2Icon, TerminalSquareIcon } from 'lucide-react';
import { useLocalize, useGetAgentsConfig, useAgentCapabilities } from '~/hooks';
import { OGDialog, OGDialogTemplate } from '~/components/ui';
import OGDialogTemplate from '~/components/ui/OGDialogTemplate';
import { useGetEndpointsQuery } from '~/data-provider';
import useLocalize from '~/hooks/useLocalize';
import { OGDialog } from '~/components/ui';
interface DragDropModalProps {
onOptionSelect: (option: EToolResources | undefined) => void;
@@ -20,12 +22,12 @@ interface FileOption {
const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragDropModalProps) => {
const localize = useLocalize();
const { agentsConfig } = useGetAgentsConfig();
/** TODO: Ephemeral Agent Capabilities
* Allow defining agent capabilities on a per-endpoint basis
* Use definition for agents endpoint for ephemeral agents
* */
const capabilities = useAgentCapabilities(agentsConfig?.capabilities ?? defaultAgentCapabilities);
const { data: endpointsConfig } = useGetEndpointsQuery();
const capabilities = useMemo(
() => endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [],
[endpointsConfig],
);
const options = useMemo(() => {
const _options: FileOption[] = [
{
@@ -35,26 +37,26 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
condition: files.every((file) => file.type?.startsWith('image/')),
},
];
if (capabilities.fileSearchEnabled) {
_options.push({
label: localize('com_ui_upload_file_search'),
value: EToolResources.file_search,
icon: <FileSearch className="icon-md" />,
});
}
if (capabilities.codeEnabled) {
_options.push({
label: localize('com_ui_upload_code_files'),
value: EToolResources.execute_code,
icon: <TerminalSquareIcon className="icon-md" />,
});
}
if (capabilities.ocrEnabled) {
_options.push({
label: localize('com_ui_upload_ocr_text'),
value: EToolResources.ocr,
icon: <FileType2Icon className="icon-md" />,
});
for (const capability of capabilities) {
if (capability === EToolResources.file_search) {
_options.push({
label: localize('com_ui_upload_file_search'),
value: EToolResources.file_search,
icon: <FileSearch className="icon-md" />,
});
} else if (capability === EToolResources.execute_code) {
_options.push({
label: localize('com_ui_upload_code_files'),
value: EToolResources.execute_code,
icon: <TerminalSquareIcon className="icon-md" />,
});
} else if (capability === EToolResources.ocr) {
_options.push({
label: localize('com_ui_upload_ocr_text'),
value: EToolResources.ocr,
icon: <FileType2Icon className="icon-md" />,
});
}
}
return _options;

View File

@@ -2,8 +2,6 @@ import { useEffect } from 'react';
import { EToolResources } from 'librechat-data-provider';
import type { ExtendedFile } from '~/common';
import { useDeleteFilesMutation } from '~/data-provider';
import { useToastContext } from '~/Providers';
import { useLocalize } from '~/hooks';
import { useFileDeletion } from '~/hooks/Files';
import FileContainer from './FileContainer';
import { logger } from '~/utils';
@@ -32,8 +30,6 @@ export default function FileRow({
isRTL?: boolean;
Wrapper?: React.FC<{ children: React.ReactNode }>;
}) {
const localize = useLocalize();
const { showToast } = useToastContext();
const files = Array.from(_files?.values() ?? []).filter((file) =>
fileFilter ? fileFilter(file) : true,
);
@@ -109,10 +105,6 @@ export default function FileRow({
)
.uniqueFiles.map((file: ExtendedFile, index: number) => {
const handleDelete = () => {
showToast({
message: localize('com_ui_deleting_file'),
status: 'info',
});
if (abortUpload && file.progress < 1) {
abortUpload();
}

View File

@@ -1,12 +1,9 @@
import React, { memo, useCallback, useState, useMemo } from 'react';
import { SettingsIcon, PlugZap } from 'lucide-react';
import React, { memo, useCallback, useState } from 'react';
import { SettingsIcon } from 'lucide-react';
import { Constants } from 'librechat-data-provider';
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
import { useMCPConnectionStatusQuery } from '~/data-provider';
import { useQueryClient } from '@tanstack/react-query';
import { QueryKeys } from 'librechat-data-provider';
import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider';
import { MCPConfigDialog, type ConfigFieldDetail } from '~/components/ui/MCP';
import MCPConfigDialog, { type ConfigFieldDetail } from '~/components/ui/MCPConfigDialog';
import { useToastContext, useBadgeRowContext } from '~/Providers';
import MultiSelect from '~/components/ui/MultiSelect';
import { MCPIcon } from '~/components/svg';
@@ -21,44 +18,15 @@ function MCPSelect() {
const localize = useLocalize();
const { showToast } = useToastContext();
const { mcpSelect, startupConfig } = useBadgeRowContext();
const { mcpValues, setMCPValues, isPinned } = mcpSelect;
// Get real connection status from MCPManager
const { data: statusQuery } = useMCPConnectionStatusQuery();
const mcpServerStatuses = useMemo(
() => statusQuery?.connectionStatus || {},
[statusQuery?.connectionStatus],
);
console.log('mcpServerStatuses', mcpServerStatuses);
console.log('statusQuery', statusQuery);
const { mcpValues, setMCPValues, mcpServerNames, mcpToolDetails, isPinned } = mcpSelect;
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false);
const [selectedToolForConfig, setSelectedToolForConfig] = useState<TPlugin | null>(null);
const queryClient = useQueryClient();
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
onSuccess: async () => {
onSuccess: () => {
setIsConfigModalOpen(false);
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });
// // For 'uninstall' actions (revoke), remove the server from selected values
// if (variables.action === 'uninstall') {
// const serverName = variables.pluginKey.replace(Constants.mcp_prefix, '');
// const currentValues = mcpValues ?? [];
// const filteredValues = currentValues.filter((name) => name !== serverName);
// setMCPValues(filteredValues);
// }
// Wait for all refetches to complete before ending loading state
await Promise.all([
queryClient.invalidateQueries([QueryKeys.tools]),
queryClient.refetchQueries([QueryKeys.tools]),
queryClient.invalidateQueries([QueryKeys.mcpAuthValues]),
queryClient.refetchQueries([QueryKeys.mcpAuthValues]),
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]),
queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]),
]);
},
onError: (error: unknown) => {
console.error('Error updating MCP auth:', error);
@@ -85,12 +53,10 @@ function MCPSelect() {
const handleConfigSave = useCallback(
(targetName: string, authData: Record<string, string>) => {
if (selectedToolForConfig && selectedToolForConfig.name === targetName) {
// Use the pluginKey directly since it's already in the correct format
console.log(
`[MCP Select] Saving config for ${targetName}, pluginKey: ${`${Constants.mcp_prefix}${targetName}`}`,
);
const basePluginKey = getBaseMCPPluginKey(selectedToolForConfig.pluginKey);
const payload: TUpdateUserPlugins = {
pluginKey: `${Constants.mcp_prefix}${targetName}`,
pluginKey: basePluginKey,
action: 'install',
auth: authData,
};
@@ -103,12 +69,10 @@ function MCPSelect() {
const handleConfigRevoke = useCallback(
(targetName: string) => {
if (selectedToolForConfig && selectedToolForConfig.name === targetName) {
// Use the pluginKey directly since it's already in the correct format
console.log(
`[MCP Select] Revoking config for ${targetName}, pluginKey: ${`${Constants.mcp_prefix}${targetName}`}`,
);
const basePluginKey = getBaseMCPPluginKey(selectedToolForConfig.pluginKey);
const payload: TUpdateUserPlugins = {
pluginKey: `${Constants.mcp_prefix}${targetName}`,
pluginKey: basePluginKey,
action: 'uninstall',
auth: {},
};
@@ -118,149 +82,60 @@ function MCPSelect() {
[selectedToolForConfig, updateUserPluginsMutation],
);
// Create stable callback references to prevent stale closures
const handleSave = useCallback(
(authData: Record<string, string>) => {
if (selectedToolForConfig) {
handleConfigSave(selectedToolForConfig.name, authData);
}
},
[selectedToolForConfig, handleConfigSave],
);
const handleRevoke = useCallback(() => {
if (selectedToolForConfig) {
handleConfigRevoke(selectedToolForConfig.name);
}
}, [selectedToolForConfig, handleConfigRevoke]);
// Only allow connected servers to be selected
const handleSetSelectedValues = useCallback(
(values: string[]) => {
// Filter to only include connected servers
const connectedValues = values.filter((serverName) => {
const serverStatus = mcpServerStatuses?.[serverName];
return serverStatus?.connected || false;
});
setMCPValues(connectedValues);
},
[setMCPValues, mcpServerStatuses],
);
const renderItemContent = useCallback(
(serverName: string, defaultContent: React.ReactNode) => {
const serverStatus = mcpServerStatuses?.[serverName];
const connected = serverStatus?.connected || false;
const hasAuthConfig = serverStatus?.hasAuthConfig || false;
const tool = mcpToolDetails?.find((t) => t.name === serverName);
const hasAuthConfig = tool?.authConfig && tool.authConfig.length > 0;
// Icon logic:
// - connected with auth config = gear (green)
// - connected without auth config = no icon (just text)
// - not connected = zap (orange)
let icon: React.ReactNode = null;
let tooltip = 'Configure server';
// Common wrapper for the main content (check mark + text)
// Ensures Check & Text are adjacent and the group takes available space.
const mainContentWrapper = (
<div className="flex flex-grow items-center">{defaultContent}</div>
);
if (connected) {
if (hasAuthConfig) {
icon = <SettingsIcon className="h-4 w-4 text-green-500" />;
tooltip = 'Configure connected server';
} else {
// No icon for connected servers without auth config
tooltip = 'Connected server (no configuration needed)';
}
} else {
icon = <PlugZap className="h-4 w-4 text-orange-400" />;
tooltip = 'Configure server';
}
const onClick = () => {
const serverConfig = startupConfig?.mcpServers?.[serverName];
if (serverConfig) {
const serverTool = {
name: serverName,
pluginKey: `${Constants.mcp_prefix}${serverName}`,
authConfig: Object.entries(serverConfig.customUserVars || {}).map(([key, config]) => ({
authField: key,
label: config.title,
description: config.description,
requiresOAuth: serverConfig.requiresOAuth || false,
})),
authenticated: connected,
};
setSelectedToolForConfig(serverTool);
setIsConfigModalOpen(true);
}
};
return (
<div className="flex w-full items-center justify-between">
<div className={`flex flex-grow items-center ${!connected ? 'opacity-50' : ''}`}>
{defaultContent}
</div>
{icon && (
if (tool && hasAuthConfig) {
return (
<div className="flex w-full items-center justify-between">
{mainContentWrapper}
<button
type="button"
onClick={(e) => {
e.stopPropagation();
e.preventDefault();
onClick();
setSelectedToolForConfig(tool);
setIsConfigModalOpen(true);
}}
className="ml-2 flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-surface-secondary"
aria-label={tooltip}
title={tooltip}
aria-label={`Configure ${serverName}`}
>
{icon}
<SettingsIcon className={`h-4 w-4 ${tool.authenticated ? 'text-green-500' : ''}`} />
</button>
)}
</div>
);
</div>
);
}
// For items without a settings icon, return the consistently wrapped main content.
return mainContentWrapper;
},
[mcpServerStatuses, setSelectedToolForConfig, setIsConfigModalOpen, startupConfig],
[mcpToolDetails, setSelectedToolForConfig, setIsConfigModalOpen],
);
// Memoize schema and initial values to prevent unnecessary re-renders
const fieldsSchema = useMemo(() => {
const schema: Record<string, ConfigFieldDetail> = {};
if (selectedToolForConfig?.authConfig) {
selectedToolForConfig.authConfig.forEach((field) => {
schema[field.authField] = {
title: field.label,
description: field.description,
};
});
}
return schema;
}, [selectedToolForConfig?.authConfig]);
const initialValues = useMemo(() => {
const initial: Record<string, string> = {};
// Always start with empty values for security - never prefill sensitive data
if (selectedToolForConfig?.authConfig) {
selectedToolForConfig.authConfig.forEach((field) => {
initial[field.authField] = '';
});
}
return initial;
}, [selectedToolForConfig?.authConfig]);
// Don't render if no MCP servers are available at all
if (!mcpServerStatuses || Object.keys(mcpServerStatuses).length === 0) {
return null;
}
// Don't render if no servers are selected and not pinned
if ((!mcpValues || mcpValues.length === 0) && !isPinned) {
return null;
}
if (!mcpToolDetails || mcpToolDetails.length === 0) {
return null;
}
const placeholderText =
startupConfig?.interface?.mcpServers?.placeholder || localize('com_ui_mcp_servers');
return (
<>
<MultiSelect
items={Object.keys(mcpServerStatuses) || []}
items={mcpServerNames}
selectedValues={mcpValues ?? []}
setSelectedValues={handleSetSelectedValues}
setSelectedValues={setMCPValues}
defaultSelectedValues={mcpValues ?? []}
renderSelectedValues={renderSelectedValues}
renderItemContent={renderItemContent}
@@ -276,13 +151,39 @@ function MCPSelect() {
isOpen={isConfigModalOpen}
onOpenChange={setIsConfigModalOpen}
serverName={selectedToolForConfig.name}
fieldsSchema={fieldsSchema}
initialValues={initialValues}
onSave={handleSave}
onRevoke={handleRevoke}
fieldsSchema={(() => {
const schema: Record<string, ConfigFieldDetail> = {};
if (selectedToolForConfig?.authConfig) {
selectedToolForConfig.authConfig.forEach((field) => {
schema[field.authField] = {
title: field.label,
description: field.description,
};
});
}
return schema;
})()}
initialValues={(() => {
const initial: Record<string, string> = {};
// Note: Actual initial values might need to be fetched if they are stored user-specifically
if (selectedToolForConfig?.authConfig) {
selectedToolForConfig.authConfig.forEach((field) => {
initial[field.authField] = ''; // Or fetched value
});
}
return initial;
})()}
onSave={(authData) => {
if (selectedToolForConfig) {
handleConfigSave(selectedToolForConfig.name, authData);
}
}}
onRevoke={() => {
if (selectedToolForConfig) {
handleConfigRevoke(selectedToolForConfig.name);
}
}}
isSubmitting={updateUserPluginsMutation.isLoading}
isConnected={mcpServerStatuses?.[selectedToolForConfig.name]?.connected || false}
authConfig={selectedToolForConfig.authConfig}
/>
)}
</>

View File

@@ -60,6 +60,7 @@ const EditMessage = ({
conversationId,
},
{
isResubmission: true,
overrideFiles: message.files,
},
);

View File

@@ -1,4 +1,4 @@
import React, { memo, useMemo } from 'react';
import React, { memo, useMemo, useRef, useEffect } from 'react';
import remarkGfm from 'remark-gfm';
import remarkMath from 'remark-math';
import supersub from 'remark-supersub';
@@ -7,16 +7,167 @@ import { useRecoilValue } from 'recoil';
import ReactMarkdown from 'react-markdown';
import rehypeHighlight from 'rehype-highlight';
import remarkDirective from 'remark-directive';
import { PermissionTypes, Permissions } from 'librechat-data-provider';
import type { Pluggable } from 'unified';
import {
useToastContext,
ArtifactProvider,
CodeBlockProvider,
useCodeBlockContext,
} from '~/Providers';
import { Citation, CompositeCitation, HighlightedText } from '~/components/Web/Citation';
import { Artifact, artifactPlugin } from '~/components/Artifacts/Artifact';
import { ArtifactProvider, CodeBlockProvider } from '~/Providers';
import MarkdownErrorBoundary from './MarkdownErrorBoundary';
import { langSubset, preprocessLaTeX } from '~/utils';
import { langSubset, preprocessLaTeX, handleDoubleClick } from '~/utils';
import CodeBlock from '~/components/Messages/Content/CodeBlock';
import useHasAccess from '~/hooks/Roles/useHasAccess';
import { unicodeCitation } from '~/components/Web';
import { code, a, p } from './MarkdownComponents';
import { useFileDownload } from '~/data-provider';
import useLocalize from '~/hooks/useLocalize';
import store from '~/store';
type TCodeProps = {
inline?: boolean;
className?: string;
children: React.ReactNode;
};
export const code: React.ElementType = memo(({ className, children }: TCodeProps) => {
const canRunCode = useHasAccess({
permissionType: PermissionTypes.RUN_CODE,
permission: Permissions.USE,
});
const match = /language-(\w+)/.exec(className ?? '');
const lang = match && match[1];
const isMath = lang === 'math';
const isSingleLine = typeof children === 'string' && children.split('\n').length === 1;
const { getNextIndex, resetCounter } = useCodeBlockContext();
const blockIndex = useRef(getNextIndex(isMath || isSingleLine)).current;
useEffect(() => {
resetCounter();
}, [children, resetCounter]);
if (isMath) {
return <>{children}</>;
} else if (isSingleLine) {
return (
<code onDoubleClick={handleDoubleClick} className={className}>
{children}
</code>
);
} else {
return (
<CodeBlock
lang={lang ?? 'text'}
codeChildren={children}
blockIndex={blockIndex}
allowExecution={canRunCode}
/>
);
}
});
export const codeNoExecution: React.ElementType = memo(({ className, children }: TCodeProps) => {
const match = /language-(\w+)/.exec(className ?? '');
const lang = match && match[1];
if (lang === 'math') {
return children;
} else if (typeof children === 'string' && children.split('\n').length === 1) {
return (
<code onDoubleClick={handleDoubleClick} className={className}>
{children}
</code>
);
} else {
return <CodeBlock lang={lang ?? 'text'} codeChildren={children} allowExecution={false} />;
}
});
type TAnchorProps = {
href: string;
children: React.ReactNode;
};
export const a: React.ElementType = memo(({ href, children }: TAnchorProps) => {
const user = useRecoilValue(store.user);
const { showToast } = useToastContext();
const localize = useLocalize();
const {
file_id = '',
filename = '',
filepath,
} = useMemo(() => {
const pattern = new RegExp(`(?:files|outputs)/${user?.id}/([^\\s]+)`);
const match = href.match(pattern);
if (match && match[0]) {
const path = match[0];
const parts = path.split('/');
const name = parts.pop();
const file_id = parts.pop();
return { file_id, filename: name, filepath: path };
}
return { file_id: '', filename: '', filepath: '' };
}, [user?.id, href]);
const { refetch: downloadFile } = useFileDownload(user?.id ?? '', file_id);
const props: { target?: string; onClick?: React.MouseEventHandler } = { target: '_new' };
if (!file_id || !filename) {
return (
<a href={href} {...props}>
{children}
</a>
);
}
const handleDownload = async (event: React.MouseEvent<HTMLAnchorElement>) => {
event.preventDefault();
try {
const stream = await downloadFile();
if (stream.data == null || stream.data === '') {
console.error('Error downloading file: No data found');
showToast({
status: 'error',
message: localize('com_ui_download_error'),
});
return;
}
const link = document.createElement('a');
link.href = stream.data;
link.setAttribute('download', filename);
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
window.URL.revokeObjectURL(stream.data);
} catch (error) {
console.error('Error downloading file:', error);
}
};
props.onClick = handleDownload;
props.target = '_blank';
return (
<a
href={filepath?.startsWith('files/') ? `/api/${filepath}` : `/api/files/${filepath}`}
{...props}
>
{children}
</a>
);
});
type TParagraphProps = {
children: React.ReactNode;
};
export const p: React.ElementType = memo(({ children }: TParagraphProps) => {
return <p className="mb-2 whitespace-pre-wrap">{children}</p>;
});
type TContentProps = {
content: string;
isLatestMessage: boolean;
@@ -68,33 +219,31 @@ const Markdown = memo(({ content = '', isLatestMessage }: TContentProps) => {
}
return (
<MarkdownErrorBoundary content={content} codeExecution={true}>
<ArtifactProvider>
<CodeBlockProvider>
<ReactMarkdown
/** @ts-ignore */
remarkPlugins={remarkPlugins}
/* @ts-ignore */
rehypePlugins={rehypePlugins}
components={
{
code,
a,
p,
artifact: Artifact,
citation: Citation,
'highlighted-text': HighlightedText,
'composite-citation': CompositeCitation,
} as {
[nodeType: string]: React.ElementType;
}
<ArtifactProvider>
<CodeBlockProvider>
<ReactMarkdown
/** @ts-ignore */
remarkPlugins={remarkPlugins}
/* @ts-ignore */
rehypePlugins={rehypePlugins}
components={
{
code,
a,
p,
artifact: Artifact,
citation: Citation,
'highlighted-text': HighlightedText,
'composite-citation': CompositeCitation,
} as {
[nodeType: string]: React.ElementType;
}
>
{currentContent}
</ReactMarkdown>
</CodeBlockProvider>
</ArtifactProvider>
</MarkdownErrorBoundary>
}
>
{currentContent}
</ReactMarkdown>
</CodeBlockProvider>
</ArtifactProvider>
);
});

View File

@@ -1,153 +0,0 @@
import React, { memo, useMemo, useRef, useEffect } from 'react';
import { useRecoilValue } from 'recoil';
import { PermissionTypes, Permissions } from 'librechat-data-provider';
import { useToastContext, useCodeBlockContext } from '~/Providers';
import CodeBlock from '~/components/Messages/Content/CodeBlock';
import useHasAccess from '~/hooks/Roles/useHasAccess';
import { useFileDownload } from '~/data-provider';
import useLocalize from '~/hooks/useLocalize';
import { handleDoubleClick } from '~/utils';
import store from '~/store';
type TCodeProps = {
inline?: boolean;
className?: string;
children: React.ReactNode;
};
export const code: React.ElementType = memo(({ className, children }: TCodeProps) => {
const canRunCode = useHasAccess({
permissionType: PermissionTypes.RUN_CODE,
permission: Permissions.USE,
});
const match = /language-(\w+)/.exec(className ?? '');
const lang = match && match[1];
const isMath = lang === 'math';
const isSingleLine = typeof children === 'string' && children.split('\n').length === 1;
const { getNextIndex, resetCounter } = useCodeBlockContext();
const blockIndex = useRef(getNextIndex(isMath || isSingleLine)).current;
useEffect(() => {
resetCounter();
}, [children, resetCounter]);
if (isMath) {
return <>{children}</>;
} else if (isSingleLine) {
return (
<code onDoubleClick={handleDoubleClick} className={className}>
{children}
</code>
);
} else {
return (
<CodeBlock
lang={lang ?? 'text'}
codeChildren={children}
blockIndex={blockIndex}
allowExecution={canRunCode}
/>
);
}
});
export const codeNoExecution: React.ElementType = memo(({ className, children }: TCodeProps) => {
const match = /language-(\w+)/.exec(className ?? '');
const lang = match && match[1];
if (lang === 'math') {
return children;
} else if (typeof children === 'string' && children.split('\n').length === 1) {
return (
<code onDoubleClick={handleDoubleClick} className={className}>
{children}
</code>
);
} else {
return <CodeBlock lang={lang ?? 'text'} codeChildren={children} allowExecution={false} />;
}
});
type TAnchorProps = {
href: string;
children: React.ReactNode;
};
export const a: React.ElementType = memo(({ href, children }: TAnchorProps) => {
const user = useRecoilValue(store.user);
const { showToast } = useToastContext();
const localize = useLocalize();
const {
file_id = '',
filename = '',
filepath,
} = useMemo(() => {
const pattern = new RegExp(`(?:files|outputs)/${user?.id}/([^\\s]+)`);
const match = href.match(pattern);
if (match && match[0]) {
const path = match[0];
const parts = path.split('/');
const name = parts.pop();
const file_id = parts.pop();
return { file_id, filename: name, filepath: path };
}
return { file_id: '', filename: '', filepath: '' };
}, [user?.id, href]);
const { refetch: downloadFile } = useFileDownload(user?.id ?? '', file_id);
const props: { target?: string; onClick?: React.MouseEventHandler } = { target: '_new' };
if (!file_id || !filename) {
return (
<a href={href} {...props}>
{children}
</a>
);
}
const handleDownload = async (event: React.MouseEvent<HTMLAnchorElement>) => {
event.preventDefault();
try {
const stream = await downloadFile();
if (stream.data == null || stream.data === '') {
console.error('Error downloading file: No data found');
showToast({
status: 'error',
message: localize('com_ui_download_error'),
});
return;
}
const link = document.createElement('a');
link.href = stream.data;
link.setAttribute('download', filename);
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
window.URL.revokeObjectURL(stream.data);
} catch (error) {
console.error('Error downloading file:', error);
}
};
props.onClick = handleDownload;
props.target = '_blank';
return (
<a
href={filepath?.startsWith('files/') ? `/api/${filepath}` : `/api/files/${filepath}`}
{...props}
>
{children}
</a>
);
});
type TParagraphProps = {
children: React.ReactNode;
};
export const p: React.ElementType = memo(({ children }: TParagraphProps) => {
return <p className="mb-2 whitespace-pre-wrap">{children}</p>;
});

View File

@@ -1,90 +0,0 @@
import React from 'react';
import remarkGfm from 'remark-gfm';
import supersub from 'remark-supersub';
import ReactMarkdown from 'react-markdown';
import rehypeHighlight from 'rehype-highlight';
import type { PluggableList } from 'unified';
import { code, codeNoExecution, a, p } from './MarkdownComponents';
import { CodeBlockProvider } from '~/Providers';
import { langSubset } from '~/utils';
interface ErrorBoundaryState {
hasError: boolean;
error?: Error;
}
interface MarkdownErrorBoundaryProps {
children: React.ReactNode;
content: string;
codeExecution?: boolean;
}
class MarkdownErrorBoundary extends React.Component<
MarkdownErrorBoundaryProps,
ErrorBoundaryState
> {
constructor(props: MarkdownErrorBoundaryProps) {
super(props);
this.state = { hasError: false };
}
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
return { hasError: true, error };
}
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
console.error('Markdown rendering error:', error, errorInfo);
}
componentDidUpdate(prevProps: MarkdownErrorBoundaryProps) {
if (prevProps.content !== this.props.content && this.state.hasError) {
this.setState({ hasError: false, error: undefined });
}
}
render() {
if (this.state.hasError) {
const { content, codeExecution = true } = this.props;
const rehypePlugins: PluggableList = [
[
rehypeHighlight,
{
detect: true,
ignoreMissing: true,
subset: langSubset,
},
],
];
return (
<CodeBlockProvider>
<ReactMarkdown
remarkPlugins={[
/** @ts-ignore */
supersub,
remarkGfm,
]}
/** @ts-ignore */
rehypePlugins={rehypePlugins}
components={
{
code: codeExecution ? code : codeNoExecution,
a,
p,
} as {
[nodeType: string]: React.ElementType;
}
}
>
{content}
</ReactMarkdown>
</CodeBlockProvider>
);
}
return this.props.children;
}
}
export default MarkdownErrorBoundary;

View File

@@ -6,9 +6,8 @@ import supersub from 'remark-supersub';
import ReactMarkdown from 'react-markdown';
import rehypeHighlight from 'rehype-highlight';
import type { PluggableList } from 'unified';
import { code, codeNoExecution, a, p } from './MarkdownComponents';
import { code, codeNoExecution, a, p } from './Markdown';
import { CodeBlockProvider, ArtifactProvider } from '~/Providers';
import MarkdownErrorBoundary from './MarkdownErrorBoundary';
import { langSubset } from '~/utils';
const MarkdownLite = memo(
@@ -26,34 +25,32 @@ const MarkdownLite = memo(
];
return (
<MarkdownErrorBoundary content={content} codeExecution={codeExecution}>
<ArtifactProvider>
<CodeBlockProvider>
<ReactMarkdown
remarkPlugins={[
/** @ts-ignore */
supersub,
remarkGfm,
[remarkMath, { singleDollarTextMath: false }],
]}
<ArtifactProvider>
<CodeBlockProvider>
<ReactMarkdown
remarkPlugins={[
/** @ts-ignore */
rehypePlugins={rehypePlugins}
// linkTarget="_new"
components={
{
code: codeExecution ? code : codeNoExecution,
a,
p,
} as {
[nodeType: string]: React.ElementType;
}
supersub,
remarkGfm,
[remarkMath, { singleDollarTextMath: false }],
]}
/** @ts-ignore */
rehypePlugins={rehypePlugins}
// linkTarget="_new"
components={
{
code: codeExecution ? code : codeNoExecution,
a,
p,
} as {
[nodeType: string]: React.ElementType;
}
>
{content}
</ReactMarkdown>
</CodeBlockProvider>
</ArtifactProvider>
</MarkdownErrorBoundary>
}
>
{content}
</ReactMarkdown>
</CodeBlockProvider>
</ArtifactProvider>
);
},
);

View File

@@ -13,25 +13,14 @@ export default function MemoryArtifacts({ attachments }: { attachments?: TAttach
const [isAnimating, setIsAnimating] = useState(false);
const prevShowInfoRef = useRef<boolean>(showInfo);
const { hasErrors, memoryArtifacts } = useMemo(() => {
let hasErrors = false;
const memoryArtifacts = useMemo(() => {
const result: MemoryArtifact[] = [];
if (!attachments || attachments.length === 0) {
return { hasErrors, memoryArtifacts: result };
}
for (const attachment of attachments) {
for (const attachment of attachments ?? []) {
if (attachment?.[Tools.memory] != null) {
result.push(attachment[Tools.memory]);
if (!hasErrors && attachment[Tools.memory].type === 'error') {
hasErrors = true;
}
}
}
return { hasErrors, memoryArtifacts: result };
return result;
}, [attachments]);
useLayoutEffect(() => {
@@ -86,12 +75,7 @@ export default function MemoryArtifacts({ attachments }: { attachments?: TAttach
<div className="flex items-center">
<div className="inline-block">
<button
className={cn(
'outline-hidden my-1 flex items-center gap-1 text-sm font-semibold transition-colors',
hasErrors
? 'text-red-500 hover:text-red-600 dark:text-red-400 dark:hover:text-red-500'
: 'text-text-secondary-alt hover:text-text-primary',
)}
className="outline-hidden my-1 flex items-center gap-1 text-sm font-semibold text-text-secondary-alt transition-colors hover:text-text-primary"
type="button"
onClick={() => setShowInfo((prev) => !prev)}
aria-expanded={showInfo}
@@ -118,7 +102,7 @@ export default function MemoryArtifacts({ attachments }: { attachments?: TAttach
fill="currentColor"
/>
</svg>
{hasErrors ? localize('com_ui_memory_error') : localize('com_ui_memory_updated')}
{localize('com_ui_memory_updated')}
</button>
</div>
</div>

View File

@@ -1,47 +1,17 @@
import type { MemoryArtifact } from 'librechat-data-provider';
import { useMemo } from 'react';
import { useLocalize } from '~/hooks';
export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: MemoryArtifact[] }) {
const localize = useLocalize();
const { updatedMemories, deletedMemories, errorMessages } = useMemo(() => {
const updated = memoryArtifacts.filter((art) => art.type === 'update');
const deleted = memoryArtifacts.filter((art) => art.type === 'delete');
const errors = memoryArtifacts.filter((art) => art.type === 'error');
const messages = errors.map((artifact) => {
try {
const errorData = JSON.parse(artifact.value as string);
if (errorData.errorType === 'already_exceeded') {
return localize('com_ui_memory_already_exceeded', {
tokens: errorData.tokenCount,
});
} else if (errorData.errorType === 'would_exceed') {
return localize('com_ui_memory_would_exceed', {
tokens: errorData.tokenCount,
});
} else {
return localize('com_ui_memory_error');
}
} catch {
return localize('com_ui_memory_error');
}
});
return {
updatedMemories: updated,
deletedMemories: deleted,
errorMessages: messages,
};
}, [memoryArtifacts, localize]);
if (memoryArtifacts.length === 0) {
return null;
}
if (updatedMemories.length === 0 && deletedMemories.length === 0 && errorMessages.length === 0) {
// Group artifacts by type
const updatedMemories = memoryArtifacts.filter((artifact) => artifact.type === 'update');
const deletedMemories = memoryArtifacts.filter((artifact) => artifact.type === 'delete');
if (updatedMemories.length === 0 && deletedMemories.length === 0) {
return null;
}
@@ -53,8 +23,8 @@ export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: Memor
{localize('com_ui_memory_updated_items')}
</h4>
<div className="space-y-2">
{updatedMemories.map((artifact) => (
<div key={`update-${artifact.key}`} className="rounded-lg p-3">
{updatedMemories.map((artifact, index) => (
<div key={`update-${index}`} className="rounded-lg p-3">
<div className="mb-1 text-xs font-medium uppercase tracking-wide text-text-secondary">
{artifact.key}
</div>
@@ -73,8 +43,8 @@ export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: Memor
{localize('com_ui_memory_deleted_items')}
</h4>
<div className="space-y-2">
{deletedMemories.map((artifact) => (
<div key={`delete-${artifact.key}`} className="rounded-lg p-3 opacity-60">
{deletedMemories.map((artifact, index) => (
<div key={`delete-${index}`} className="rounded-lg p-3 opacity-60">
<div className="mb-1 text-xs font-medium uppercase tracking-wide text-text-secondary">
{artifact.key}
</div>
@@ -86,24 +56,6 @@ export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: Memor
</div>
</div>
)}
{errorMessages.length > 0 && (
<div>
<h4 className="mb-2 text-sm font-semibold text-red-500">
{localize('com_ui_memory_storage_full')}
</h4>
<div className="space-y-2">
{errorMessages.map((errorMessage) => (
<div
key={errorMessage}
className="rounded-md bg-red-50 p-3 text-sm text-red-800 dark:bg-red-900/20 dark:text-red-400"
>
{errorMessage}
</div>
))}
</div>
</div>
)}
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More