Compare commits
57 Commits
feat/add-b
...
feat/clien
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
014060a11b | ||
|
|
80d406b629 | ||
|
|
22035bbf95 | ||
|
|
17a164b420 | ||
|
|
5b663d0e35 | ||
|
|
c273dfc1f4 | ||
|
|
de69bcdd64 | ||
|
|
59412c2b36 | ||
|
|
0af8fba7ca | ||
|
|
d5cf83313b | ||
|
|
355da0fc2e | ||
|
|
8ba5aa6055 | ||
|
|
f1204531a8 | ||
|
|
a65e33758d | ||
|
|
0855631c54 | ||
|
|
a1e052871f | ||
|
|
ca18ada9e2 | ||
|
|
103af99879 | ||
|
|
412948e025 | ||
|
|
4da25826d9 | ||
|
|
6e7fdeb3a3 | ||
|
|
1c3f5b972d | ||
|
|
e7aa83e073 | ||
|
|
0b5155d277 | ||
|
|
86deb4d19a | ||
|
|
f741a59ec4 | ||
|
|
83477bba34 | ||
|
|
5476029bca | ||
|
|
05a0a1f7cd | ||
|
|
55f67212d5 | ||
|
|
63d0c301a0 | ||
|
|
bac6e499b7 | ||
|
|
59de92afa9 | ||
|
|
e297386cee | ||
|
|
e2b1cc607f | ||
|
|
55bda03d19 | ||
|
|
b11ba35790 | ||
|
|
8e1b00da2a | ||
|
|
39f5dd47dc | ||
|
|
16f83c6e8e | ||
|
|
b71a82d0e9 | ||
|
|
63a5902404 | ||
|
|
21c3a831c3 | ||
|
|
ae43b4eed0 | ||
|
|
2af4ca5b5c | ||
|
|
b6c7b0bc71 | ||
|
|
f8738b207c | ||
|
|
6ea1d5eab2 | ||
|
|
992911514c | ||
|
|
9289aeb2ba | ||
|
|
898d273aaf | ||
|
|
5859350bcb | ||
|
|
882cca247a | ||
|
|
0b7dd55797 | ||
|
|
9f270127d3 | ||
|
|
1380db85cb | ||
|
|
dcaa5af598 |
@@ -442,8 +442,6 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
|
||||
OPENID_USERNAME_CLAIM=
|
||||
# Set to determine which user info property returned from OpenID Provider to store as the User's name
|
||||
OPENID_NAME_CLAIM=
|
||||
# Optional audience parameter for OpenID authorization requests
|
||||
OPENID_AUDIENCE=
|
||||
|
||||
OPENID_BUTTON_LABEL=
|
||||
OPENID_IMAGE_URL=
|
||||
|
||||
2
.github/workflows/data-schemas.yml
vendored
2
.github/workflows/data-schemas.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20.x'
|
||||
node-version: '18.x'
|
||||
|
||||
- name: Install dependencies
|
||||
run: cd packages/data-schemas && npm ci
|
||||
|
||||
2
.github/workflows/locize-i18n-sync.yml
vendored
2
.github/workflows/locize-i18n-sync.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
|
||||
# 2. Download translation files from locize.
|
||||
- name: Download Translations from locize
|
||||
uses: locize/download@v2
|
||||
uses: locize/download@v1
|
||||
with:
|
||||
project-id: ${{ secrets.LOCIZE_PROJECT_ID }}
|
||||
path: "client/src/locales"
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -13,9 +13,6 @@ pids
|
||||
*.seed
|
||||
.git
|
||||
|
||||
# CI/CD data
|
||||
test-image*
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.8.0-rc1
|
||||
# v0.7.9
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.8.0-rc1
|
||||
# v0.7.9
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
@@ -16,7 +16,6 @@ COPY package*.json ./
|
||||
COPY packages/data-provider/package*.json ./packages/data-provider/
|
||||
COPY packages/api/package*.json ./packages/api/
|
||||
COPY packages/data-schemas/package*.json ./packages/data-schemas/
|
||||
COPY packages/client/package*.json ./packages/client/
|
||||
COPY client/package*.json ./client/
|
||||
COPY api/package*.json ./api/
|
||||
|
||||
@@ -46,19 +45,11 @@ COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/d
|
||||
COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist
|
||||
RUN npm run build
|
||||
|
||||
# Build `client` package
|
||||
FROM base AS client-package-build
|
||||
WORKDIR /app/packages/client
|
||||
COPY packages/client ./
|
||||
RUN npm run build
|
||||
|
||||
# Client build
|
||||
FROM base AS client-build
|
||||
WORKDIR /app/client
|
||||
COPY client ./
|
||||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
COPY --from=client-package-build /app/packages/client/dist /app/packages/client/dist
|
||||
COPY --from=client-package-build /app/packages/client/src /app/packages/client/src
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
RUN npm run build
|
||||
|
||||
|
||||
@@ -653,10 +653,8 @@ class OpenAIClient extends BaseClient {
|
||||
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
|
||||
configOptions.baseOptions = {
|
||||
headers: resolveHeaders({
|
||||
headers: {
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
},
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
}),
|
||||
};
|
||||
}
|
||||
@@ -751,7 +749,7 @@ class OpenAIClient extends BaseClient {
|
||||
groupMap,
|
||||
});
|
||||
|
||||
this.options.headers = resolveHeaders({ headers });
|
||||
this.options.headers = resolveHeaders(headers);
|
||||
this.options.reverseProxyUrl = baseURL ?? null;
|
||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
@@ -1183,7 +1181,7 @@ ${convo}
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
opts.defaultHeaders = resolveHeaders({ headers });
|
||||
opts.defaultHeaders = resolveHeaders(headers);
|
||||
this.langchainProxy = extractBaseURL(baseURL);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
@@ -1224,9 +1222,7 @@ ${convo}
|
||||
}
|
||||
|
||||
if (this.isOmni === true && modelOptions.max_tokens != null) {
|
||||
const paramName =
|
||||
modelOptions.useResponsesApi === true ? 'max_output_tokens' : 'max_completion_tokens';
|
||||
modelOptions[paramName] = modelOptions.max_tokens;
|
||||
modelOptions.max_completion_tokens = modelOptions.max_tokens;
|
||||
delete modelOptions.max_tokens;
|
||||
}
|
||||
if (this.isOmni === true && modelOptions.temperature != null) {
|
||||
|
||||
@@ -3,8 +3,8 @@ const path = require('path');
|
||||
const OpenAI = require('openai');
|
||||
const fetch = require('node-fetch');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { FileContext, ContentTypes } = require('librechat-data-provider');
|
||||
const { getImageBasename } = require('~/server/services/Files/images');
|
||||
const extractBaseURL = require('~/utils/extractBaseURL');
|
||||
@@ -46,10 +46,7 @@ class DALLE3 extends Tool {
|
||||
}
|
||||
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
config.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
config.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
/** @type {OpenAI} */
|
||||
@@ -166,8 +163,7 @@ Error Message: ${error.message}`);
|
||||
if (this.isAgent) {
|
||||
let fetchOptions = {};
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
fetchOptions.dispatcher = proxyAgent;
|
||||
fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
const imageResponse = await fetch(theImageUrl, fetchOptions);
|
||||
const arrayBuffer = await imageResponse.arrayBuffer();
|
||||
|
||||
@@ -3,10 +3,10 @@ const axios = require('axios');
|
||||
const { v4 } = require('uuid');
|
||||
const OpenAI = require('openai');
|
||||
const FormData = require('form-data');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
@@ -189,10 +189,7 @@ function createOpenAIImageTools(fields = {}) {
|
||||
}
|
||||
const clientConfig = { ...closureConfig };
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
clientConfig.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
/** @type {OpenAI} */
|
||||
@@ -338,10 +335,7 @@ Error Message: ${error.message}`);
|
||||
|
||||
const clientConfig = { ...closureConfig };
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
clientConfig.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
@@ -453,19 +447,6 @@ Error Message: ${error.message}`);
|
||||
baseURL,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
try {
|
||||
const url = new URL(process.env.PROXY);
|
||||
axiosConfig.proxy = {
|
||||
host: url.hostname.replace(/^\[|\]$/g, ''),
|
||||
port: url.port ? parseInt(url.port, 10) : undefined,
|
||||
protocol: url.protocol.replace(':', ''),
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error parsing proxy URL:', error);
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.IMAGE_GEN_OAI_AZURE_API_VERSION && process.env.IMAGE_GEN_OAI_BASEURL) {
|
||||
axiosConfig.params = {
|
||||
'api-version': process.env.IMAGE_GEN_OAI_AZURE_API_VERSION,
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
const DALLE3 = require('../DALLE3');
|
||||
const { ProxyAgent } = require('undici');
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
getImageBasename: jest.fn().mockImplementation((url) => {
|
||||
const parts = url.split('/');
|
||||
const lastPart = parts.pop();
|
||||
const imageExtensionRegex = /\.(jpg|jpeg|png|gif|bmp|tiff|svg)$/i;
|
||||
if (imageExtensionRegex.test(lastPart)) {
|
||||
return lastPart;
|
||||
}
|
||||
return '';
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('fs', () => {
|
||||
return {
|
||||
existsSync: jest.fn(),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: {
|
||||
writeFile: jest.fn(),
|
||||
readFile: jest.fn(),
|
||||
unlink: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('path', () => {
|
||||
return {
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn(),
|
||||
relative: jest.fn(),
|
||||
extname: jest.fn().mockImplementation((filename) => {
|
||||
return filename.slice(filename.lastIndexOf('.'));
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe('DALLE3 Proxy Configuration', () => {
|
||||
let originalEnv;
|
||||
|
||||
beforeAll(() => {
|
||||
originalEnv = { ...process.env };
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should configure ProxyAgent in fetchOptions.dispatcher when PROXY env is set', () => {
|
||||
// Set proxy environment variable
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
process.env.DALLE_API_KEY = 'test-api-key';
|
||||
|
||||
// Create instance
|
||||
const dalleWithProxy = new DALLE3({ processFileURL });
|
||||
|
||||
// Check that the openai client exists
|
||||
expect(dalleWithProxy.openai).toBeDefined();
|
||||
|
||||
// Check that _options exists and has fetchOptions with a dispatcher
|
||||
expect(dalleWithProxy.openai._options).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
|
||||
});
|
||||
|
||||
it('should not configure ProxyAgent when PROXY env is not set', () => {
|
||||
// Ensure PROXY is not set
|
||||
delete process.env.PROXY;
|
||||
process.env.DALLE_API_KEY = 'test-api-key';
|
||||
|
||||
// Create instance
|
||||
const dalleWithoutProxy = new DALLE3({ processFileURL });
|
||||
|
||||
// Check that the openai client exists
|
||||
expect(dalleWithoutProxy.openai).toBeDefined();
|
||||
|
||||
// Check that _options exists but fetchOptions either doesn't exist or doesn't have a dispatcher
|
||||
expect(dalleWithoutProxy.openai._options).toBeDefined();
|
||||
|
||||
// fetchOptions should either not exist or not have a dispatcher
|
||||
if (dalleWithoutProxy.openai._options.fetchOptions) {
|
||||
expect(dalleWithoutProxy.openai._options.fetchOptions.dispatcher).toBeUndefined();
|
||||
}
|
||||
});
|
||||
});
|
||||
8
api/cache/cacheConfig.js
vendored
8
api/cache/cacheConfig.js
vendored
@@ -44,14 +44,6 @@ const cacheConfig = {
|
||||
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
|
||||
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
|
||||
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),
|
||||
/** Max delay between reconnection attempts in ms */
|
||||
REDIS_RETRY_MAX_DELAY: math(process.env.REDIS_RETRY_MAX_DELAY, 3000),
|
||||
/** Max number of reconnection attempts (0 = infinite) */
|
||||
REDIS_RETRY_MAX_ATTEMPTS: math(process.env.REDIS_RETRY_MAX_ATTEMPTS, 10),
|
||||
/** Connection timeout in ms */
|
||||
REDIS_CONNECT_TIMEOUT: math(process.env.REDIS_CONNECT_TIMEOUT, 10000),
|
||||
/** Queue commands when disconnected */
|
||||
REDIS_ENABLE_OFFLINE_QUEUE: isEnabled(process.env.REDIS_ENABLE_OFFLINE_QUEUE ?? 'true'),
|
||||
|
||||
CI: isEnabled(process.env.CI),
|
||||
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
|
||||
|
||||
61
api/cache/cacheFactory.js
vendored
61
api/cache/cacheFactory.js
vendored
@@ -1,13 +1,12 @@
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
const { Keyv } = require('keyv');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { RedisStore: ConnectRedis } = require('connect-redis');
|
||||
const MemoryStore = require('memorystore')(require('express-session'));
|
||||
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
const { violationFile } = require('./keyvFiles');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
|
||||
/**
|
||||
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
|
||||
@@ -21,21 +20,11 @@ const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) =>
|
||||
cacheConfig.USE_REDIS &&
|
||||
!cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES?.includes(namespace)
|
||||
) {
|
||||
try {
|
||||
const keyvRedis = new KeyvRedis(keyvRedisClient);
|
||||
const cache = new Keyv(keyvRedis, { namespace, ttl });
|
||||
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
|
||||
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
|
||||
|
||||
cache.on('error', (err) => {
|
||||
logger.error(`Cache error in namespace ${namespace}:`, err);
|
||||
});
|
||||
|
||||
return cache;
|
||||
} catch (err) {
|
||||
logger.error(`Failed to create Redis cache for namespace ${namespace}:`, err);
|
||||
throw err;
|
||||
}
|
||||
const keyvRedis = new KeyvRedis(keyvRedisClient);
|
||||
const cache = new Keyv(keyvRedis, { namespace, ttl });
|
||||
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
|
||||
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
|
||||
return cache;
|
||||
}
|
||||
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
|
||||
return new Keyv({ namespace, ttl });
|
||||
@@ -61,13 +50,7 @@ const violationCache = (namespace, ttl = undefined) => {
|
||||
const sessionCache = (namespace, ttl = undefined) => {
|
||||
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
|
||||
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
|
||||
const store = new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
|
||||
if (ioredisClient) {
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error(`Session store Redis error for namespace ${namespace}:`, err);
|
||||
});
|
||||
}
|
||||
return store;
|
||||
return new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -79,30 +62,8 @@ const limiterCache = (prefix) => {
|
||||
if (!prefix) throw new Error('prefix is required');
|
||||
if (!cacheConfig.USE_REDIS) return undefined;
|
||||
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
|
||||
|
||||
try {
|
||||
if (!ioredisClient) {
|
||||
logger.warn(`Redis client not available for rate limiter with prefix ${prefix}`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return new RedisStore({ sendCommand, prefix });
|
||||
} catch (err) {
|
||||
logger.error(`Failed to create Redis rate limiter for prefix ${prefix}:`, err);
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const sendCommand = (...args) => {
|
||||
if (!ioredisClient) {
|
||||
logger.warn('Redis client not available for command execution');
|
||||
return Promise.reject(new Error('Redis client not available'));
|
||||
}
|
||||
|
||||
return ioredisClient.call(...args).catch((err) => {
|
||||
logger.error('Redis command execution failed:', err);
|
||||
throw err;
|
||||
});
|
||||
return new RedisStore({ sendCommand, prefix });
|
||||
};
|
||||
const sendCommand = (...args) => ioredisClient?.call(...args);
|
||||
|
||||
module.exports = { standardCache, sessionCache, violationCache, limiterCache };
|
||||
|
||||
142
api/cache/cacheFactory.spec.js
vendored
142
api/cache/cacheFactory.spec.js
vendored
@@ -6,17 +6,13 @@ const mockKeyvRedis = {
|
||||
keyPrefixSeparator: '',
|
||||
};
|
||||
|
||||
const mockKeyv = jest.fn().mockReturnValue({
|
||||
mock: 'keyv',
|
||||
on: jest.fn(),
|
||||
});
|
||||
const mockKeyv = jest.fn().mockReturnValue({ mock: 'keyv' });
|
||||
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
|
||||
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
|
||||
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
|
||||
|
||||
const mockIoredisClient = {
|
||||
call: jest.fn(),
|
||||
on: jest.fn(),
|
||||
};
|
||||
|
||||
const mockKeyvRedisClient = {};
|
||||
@@ -57,14 +53,6 @@ jest.mock('rate-limit-redis', () => ({
|
||||
RedisStore: mockRedisStore,
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocking
|
||||
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
@@ -154,28 +142,6 @@ describe('cacheFactory', () => {
|
||||
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
|
||||
});
|
||||
|
||||
it('should throw error when Redis cache creation fails', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
const testError = new Error('Redis connection failed');
|
||||
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
KeyvRedis.mockImplementationOnce(() => {
|
||||
throw testError;
|
||||
});
|
||||
|
||||
expect(() => standardCache(namespace, ttl)).toThrow('Redis connection failed');
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
`Failed to create Redis cache for namespace ${namespace}:`,
|
||||
testError,
|
||||
);
|
||||
|
||||
expect(mockKeyv).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('violationCache', () => {
|
||||
@@ -267,86 +233,6 @@ describe('cacheFactory', () => {
|
||||
checkPeriod: Time.ONE_DAY,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error when ConnectRedis constructor fails', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
// Mock ConnectRedis to throw an error during construction
|
||||
const redisError = new Error('Redis connection failed');
|
||||
mockConnectRedis.mockImplementationOnce(() => {
|
||||
throw redisError;
|
||||
});
|
||||
|
||||
// The error should propagate up, not be caught
|
||||
expect(() => sessionCache(namespace, ttl)).toThrow('Redis connection failed');
|
||||
|
||||
// Verify that MemoryStore was NOT used as fallback
|
||||
expect(mockMemoryStore).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should register error handler but let errors propagate to Express', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
// Create a mock session store with middleware methods
|
||||
const mockSessionStore = {
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
mockConnectRedis.mockReturnValue(mockSessionStore);
|
||||
|
||||
const store = sessionCache(namespace);
|
||||
|
||||
// Verify error handler was registered
|
||||
expect(mockIoredisClient.on).toHaveBeenCalledWith('error', expect.any(Function));
|
||||
|
||||
// Get the error handler
|
||||
const errorHandler = mockIoredisClient.on.mock.calls.find((call) => call[0] === 'error')[1];
|
||||
|
||||
// Simulate an error from Redis during a session operation
|
||||
const redisError = new Error('Socket closed unexpectedly');
|
||||
|
||||
// The error handler should log but not swallow the error
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
errorHandler(redisError);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
`Session store Redis error for namespace ${namespace}::`,
|
||||
redisError,
|
||||
);
|
||||
|
||||
// Now simulate what happens when session middleware tries to use the store
|
||||
const callback = jest.fn();
|
||||
mockSessionStore.get.mockImplementation((sid, cb) => {
|
||||
cb(new Error('Redis connection lost'));
|
||||
});
|
||||
|
||||
// Call the store's get method (as Express session would)
|
||||
store.get('test-session-id', callback);
|
||||
|
||||
// The error should be passed to the callback, not swallowed
|
||||
expect(callback).toHaveBeenCalledWith(new Error('Redis connection lost'));
|
||||
});
|
||||
|
||||
it('should handle null ioredisClient gracefully', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
// Temporarily set ioredisClient to null (simulating connection not established)
|
||||
const originalClient = require('./redisClients').ioredisClient;
|
||||
require('./redisClients').ioredisClient = null;
|
||||
|
||||
// ConnectRedis might accept null client but would fail on first use
|
||||
// The important thing is it doesn't throw uncaught exceptions during construction
|
||||
const store = sessionCache(namespace);
|
||||
expect(store).toBeDefined();
|
||||
|
||||
// Restore original client
|
||||
require('./redisClients').ioredisClient = originalClient;
|
||||
});
|
||||
});
|
||||
|
||||
describe('limiterCache', () => {
|
||||
@@ -388,10 +274,8 @@ describe('cacheFactory', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass sendCommand function that calls ioredisClient.call', async () => {
|
||||
it('should pass sendCommand function that calls ioredisClient.call', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
mockIoredisClient.call.mockResolvedValue('test-value');
|
||||
|
||||
limiterCache('rate-limit');
|
||||
|
||||
const sendCommandCall = mockRedisStore.mock.calls[0][0];
|
||||
@@ -399,29 +283,9 @@ describe('cacheFactory', () => {
|
||||
|
||||
// Test that sendCommand properly delegates to ioredisClient.call
|
||||
const args = ['GET', 'test-key'];
|
||||
const result = await sendCommand(...args);
|
||||
sendCommand(...args);
|
||||
|
||||
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
|
||||
expect(result).toBe('test-value');
|
||||
});
|
||||
|
||||
it('should handle sendCommand errors properly', async () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Mock the call method to reject with an error
|
||||
const testError = new Error('Redis error');
|
||||
mockIoredisClient.call.mockRejectedValue(testError);
|
||||
|
||||
limiterCache('rate-limit');
|
||||
|
||||
const sendCommandCall = mockRedisStore.mock.calls[0][0];
|
||||
const sendCommand = sendCommandCall.sendCommand;
|
||||
|
||||
// Test that sendCommand properly handles errors
|
||||
const args = ['GET', 'test-key'];
|
||||
|
||||
await expect(sendCommand(...args)).rejects.toThrow('Redis error');
|
||||
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
|
||||
});
|
||||
|
||||
it('should handle undefined prefix', () => {
|
||||
|
||||
116
api/cache/redisClients.js
vendored
116
api/cache/redisClients.js
vendored
@@ -13,82 +13,23 @@ const ca = cacheConfig.REDIS_CA;
|
||||
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
|
||||
let ioredisClient = null;
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
/** @type {import('ioredis').RedisOptions | import('ioredis').ClusterOptions} */
|
||||
const redisOptions = {
|
||||
username: username,
|
||||
password: password,
|
||||
tls: ca ? { ca } : undefined,
|
||||
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
|
||||
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
|
||||
retryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
reconnectOnError: (err) => {
|
||||
const targetError = 'READONLY';
|
||||
if (err.message.includes(targetError)) {
|
||||
logger.warn('ioredis reconnecting due to READONLY error');
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
connectTimeout: cacheConfig.REDIS_CONNECT_TIMEOUT,
|
||||
maxRetriesPerRequest: 3,
|
||||
};
|
||||
|
||||
ioredisClient =
|
||||
urls.length === 1
|
||||
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
|
||||
: new IoRedis.Cluster(cacheConfig.REDIS_URI, {
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
});
|
||||
: new IoRedis.Cluster(cacheConfig.REDIS_URI, { redisOptions });
|
||||
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error('ioredis client error:', err);
|
||||
});
|
||||
|
||||
ioredisClient.on('connect', () => {
|
||||
logger.info('ioredis client connected');
|
||||
});
|
||||
|
||||
ioredisClient.on('ready', () => {
|
||||
logger.info('ioredis client ready');
|
||||
});
|
||||
|
||||
ioredisClient.on('reconnecting', (delay) => {
|
||||
logger.info(`ioredis client reconnecting in ${delay}ms`);
|
||||
});
|
||||
|
||||
ioredisClient.on('close', () => {
|
||||
logger.warn('ioredis client connection closed');
|
||||
});
|
||||
|
||||
/** Ping Interval to keep the Redis server connection alive (if enabled) */
|
||||
let pingInterval = null;
|
||||
const clearPingInterval = () => {
|
||||
@@ -101,9 +42,7 @@ if (cacheConfig.USE_REDIS) {
|
||||
if (cacheConfig.REDIS_PING_INTERVAL > 0) {
|
||||
pingInterval = setInterval(() => {
|
||||
if (ioredisClient && ioredisClient.status === 'ready') {
|
||||
ioredisClient.ping().catch((err) => {
|
||||
logger.error('ioredis ping failed:', err);
|
||||
});
|
||||
ioredisClient.ping();
|
||||
}
|
||||
}, cacheConfig.REDIS_PING_INTERVAL * 1000);
|
||||
ioredisClient.on('close', clearPingInterval);
|
||||
@@ -117,32 +56,8 @@ if (cacheConfig.USE_REDIS) {
|
||||
/**
|
||||
* ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
|
||||
* The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
|
||||
* @type {import('@keyv/redis').RedisClientOptions | import('@keyv/redis').RedisClusterOptions}
|
||||
*/
|
||||
const redisOptions = {
|
||||
username,
|
||||
password,
|
||||
socket: {
|
||||
tls: ca != null,
|
||||
ca,
|
||||
connectTimeout: cacheConfig.REDIS_CONNECT_TIMEOUT,
|
||||
reconnectStrategy: (retries) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
retries > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`@keyv/redis client giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return new Error('Max reconnection attempts reached');
|
||||
}
|
||||
const delay = Math.min(retries * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`@keyv/redis reconnecting... attempt ${retries}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
},
|
||||
disableOfflineQueue: !cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
};
|
||||
const redisOptions = { username, password, socket: { tls: ca != null, ca } };
|
||||
|
||||
keyvRedisClient =
|
||||
urls.length === 1
|
||||
@@ -158,27 +73,6 @@ if (cacheConfig.USE_REDIS) {
|
||||
logger.error('@keyv/redis client error:', err);
|
||||
});
|
||||
|
||||
keyvRedisClient.on('connect', () => {
|
||||
logger.info('@keyv/redis client connected');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('ready', () => {
|
||||
logger.info('@keyv/redis client ready');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('reconnecting', () => {
|
||||
logger.info('@keyv/redis client reconnecting...');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('disconnect', () => {
|
||||
logger.warn('@keyv/redis client disconnected');
|
||||
});
|
||||
|
||||
keyvRedisClient.connect().catch((err) => {
|
||||
logger.error('@keyv/redis initial connection failed:', err);
|
||||
throw err;
|
||||
});
|
||||
|
||||
/** Ping Interval to keep the Redis server connection alive (if enabled) */
|
||||
let pingInterval = null;
|
||||
const clearPingInterval = () => {
|
||||
@@ -191,9 +85,7 @@ if (cacheConfig.USE_REDIS) {
|
||||
if (cacheConfig.REDIS_PING_INTERVAL > 0) {
|
||||
pingInterval = setInterval(() => {
|
||||
if (keyvRedisClient && keyvRedisClient.isReady) {
|
||||
keyvRedisClient.ping().catch((err) => {
|
||||
logger.error('@keyv/redis ping failed:', err);
|
||||
});
|
||||
keyvRedisClient.ping();
|
||||
}
|
||||
}, cacheConfig.REDIS_PING_INTERVAL * 1000);
|
||||
keyvRedisClient.on('disconnect', clearPingInterval);
|
||||
|
||||
@@ -316,10 +316,17 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
||||
if (shouldCreateVersion) {
|
||||
const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash);
|
||||
if (duplicateVersion && !forceVersion) {
|
||||
// No changes detected, return the current agent without creating a new version
|
||||
const agentObj = currentAgent.toObject();
|
||||
agentObj.version = versions.length;
|
||||
return agentObj;
|
||||
const error = new Error(
|
||||
'Duplicate version: This would create a version identical to an existing one',
|
||||
);
|
||||
error.statusCode = 409;
|
||||
error.details = {
|
||||
duplicateVersion,
|
||||
versionIndex: versions.findIndex(
|
||||
(v) => JSON.stringify(duplicateVersion) === JSON.stringify(v),
|
||||
),
|
||||
};
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -879,31 +879,45 @@ describe('models/Agent', () => {
|
||||
expect(emptyParamsAgent.model_parameters).toEqual({});
|
||||
});
|
||||
|
||||
test('should not create new version for duplicate updates', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const testCases = generateVersionTestCases();
|
||||
test('should detect duplicate versions and reject updates', async () => {
|
||||
const originalConsoleError = console.error;
|
||||
console.error = jest.fn();
|
||||
|
||||
for (const testCase of testCases) {
|
||||
const testAgentId = `agent_${uuidv4()}`;
|
||||
try {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const testCases = generateVersionTestCases();
|
||||
|
||||
await createAgent({
|
||||
id: testAgentId,
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
...testCase.initial,
|
||||
});
|
||||
for (const testCase of testCases) {
|
||||
const testAgentId = `agent_${uuidv4()}`;
|
||||
|
||||
const updatedAgent = await updateAgent({ id: testAgentId }, testCase.update);
|
||||
expect(updatedAgent.versions).toHaveLength(2); // No new version created
|
||||
await createAgent({
|
||||
id: testAgentId,
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
...testCase.initial,
|
||||
});
|
||||
|
||||
// Update with duplicate data should succeed but not create a new version
|
||||
const duplicateUpdate = await updateAgent({ id: testAgentId }, testCase.duplicate);
|
||||
await updateAgent({ id: testAgentId }, testCase.update);
|
||||
|
||||
expect(duplicateUpdate.versions).toHaveLength(2); // No new version created
|
||||
let error;
|
||||
try {
|
||||
await updateAgent({ id: testAgentId }, testCase.duplicate);
|
||||
} catch (e) {
|
||||
error = e;
|
||||
}
|
||||
|
||||
const agent = await getAgent({ id: testAgentId });
|
||||
expect(agent.versions).toHaveLength(2);
|
||||
expect(error).toBeDefined();
|
||||
expect(error.message).toContain('Duplicate version');
|
||||
expect(error.statusCode).toBe(409);
|
||||
expect(error.details).toBeDefined();
|
||||
expect(error.details.duplicateVersion).toBeDefined();
|
||||
|
||||
const agent = await getAgent({ id: testAgentId });
|
||||
expect(agent.versions).toHaveLength(2);
|
||||
}
|
||||
} finally {
|
||||
console.error = originalConsoleError;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1079,13 +1093,20 @@ describe('models/Agent', () => {
|
||||
expect(secondUpdate.versions).toHaveLength(3);
|
||||
|
||||
// Update without forceVersion and no changes should not create a version
|
||||
const duplicateUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{ tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] },
|
||||
{ updatingUserId: authorId.toString(), forceVersion: false },
|
||||
);
|
||||
let error;
|
||||
try {
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{ tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] },
|
||||
{ updatingUserId: authorId.toString(), forceVersion: false },
|
||||
);
|
||||
} catch (e) {
|
||||
error = e;
|
||||
}
|
||||
|
||||
expect(duplicateUpdate.versions).toHaveLength(3); // No new version created
|
||||
expect(error).toBeDefined();
|
||||
expect(error.message).toContain('Duplicate version');
|
||||
expect(error.statusCode).toBe(409);
|
||||
});
|
||||
|
||||
test('should handle isDuplicateVersion with arrays containing null/undefined values', async () => {
|
||||
@@ -2379,18 +2400,11 @@ describe('models/Agent', () => {
|
||||
agent_ids: ['agent1', 'agent2'],
|
||||
});
|
||||
|
||||
const updatedAgent = await updateAgent(
|
||||
{ id: agentId },
|
||||
{ agent_ids: ['agent1', 'agent2', 'agent3'] },
|
||||
);
|
||||
expect(updatedAgent.versions).toHaveLength(2);
|
||||
await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] });
|
||||
|
||||
// Update with same agent_ids should succeed but not create a new version
|
||||
const duplicateUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{ agent_ids: ['agent1', 'agent2', 'agent3'] },
|
||||
);
|
||||
expect(duplicateUpdate.versions).toHaveLength(2); // No new version created
|
||||
await expect(
|
||||
updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }),
|
||||
).rejects.toThrow('Duplicate version');
|
||||
});
|
||||
|
||||
test('should handle agent_ids field alongside other fields', async () => {
|
||||
@@ -2529,10 +2543,9 @@ describe('models/Agent', () => {
|
||||
expect(updated.versions).toHaveLength(2);
|
||||
expect(updated.agent_ids).toEqual([]);
|
||||
|
||||
// Update with same empty agent_ids should succeed but not create a new version
|
||||
const duplicateUpdate = await updateAgent({ id: agentId }, { agent_ids: [] });
|
||||
expect(duplicateUpdate.versions).toHaveLength(2); // No new version created
|
||||
expect(duplicateUpdate.agent_ids).toEqual([]);
|
||||
await expect(updateAgent({ id: agentId }, { agent_ids: [] })).rejects.toThrow(
|
||||
'Duplicate version',
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle agent without agent_ids field', async () => {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
|
||||
@@ -1,572 +0,0 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
deleteNullOrEmptyConversations,
|
||||
searchConversation,
|
||||
getConvosByCursor,
|
||||
getConvosQueried,
|
||||
getConvoFiles,
|
||||
getConvoTitle,
|
||||
deleteConvos,
|
||||
saveConvo,
|
||||
getConvo,
|
||||
} = require('./Conversation');
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
jest.mock('./Message');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
describe('Conversation Operations', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockConversationData;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear database
|
||||
await Conversation.deleteMany({});
|
||||
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Default mock implementations
|
||||
getMessages.mockResolvedValue([]);
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {},
|
||||
};
|
||||
|
||||
mockConversationData = {
|
||||
conversationId: uuidv4(),
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
};
|
||||
});
|
||||
|
||||
describe('saveConvo', () => {
|
||||
it('should save a conversation for an authenticated user', async () => {
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBe('Test Conversation');
|
||||
expect(result.endpoint).toBe(EModelEndpoint.openAI);
|
||||
|
||||
// Verify the conversation was actually saved to the database
|
||||
const savedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
expect(savedConvo).toBeTruthy();
|
||||
expect(savedConvo.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should query messages when saving a conversation', async () => {
|
||||
// Mock messages as ObjectIds
|
||||
const mongoose = require('mongoose');
|
||||
const mockMessages = [new mongoose.Types.ObjectId(), new mongoose.Types.ObjectId()];
|
||||
getMessages.mockResolvedValue(mockMessages);
|
||||
|
||||
await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
// Verify that getMessages was called with correct parameters
|
||||
expect(getMessages).toHaveBeenCalledWith(
|
||||
{ conversationId: mockConversationData.conversationId },
|
||||
'_id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle newConversationId when provided', async () => {
|
||||
const newConversationId = uuidv4();
|
||||
const result = await saveConvo(mockReq, {
|
||||
...mockConversationData,
|
||||
newConversationId,
|
||||
});
|
||||
|
||||
expect(result.conversationId).toBe(newConversationId);
|
||||
});
|
||||
|
||||
it('should handle unsetFields metadata', async () => {
|
||||
const metadata = {
|
||||
unsetFields: { someField: 1 },
|
||||
};
|
||||
|
||||
await saveConvo(mockReq, mockConversationData, metadata);
|
||||
|
||||
const savedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(savedConvo.someField).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary conversation handling', () => {
|
||||
it('should save a conversation with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is false', async () => {
|
||||
mockReq.body = { isTemporary: false };
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary in body
|
||||
mockReq.body = {};
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
// Should still save the conversation but with expiredAt as null
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should update expiredAt when saving existing temporary conversation', async () => {
|
||||
// First save a temporary conversation
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
const originalExpiredAt = firstSave.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same conversationId but different title
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockReq, updatedData);
|
||||
|
||||
// Should update title and create new expiredAt
|
||||
expect(secondSave.title).toBe('Updated Title');
|
||||
expect(secondSave.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave.expiredAt).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not set expiredAt when updating non-temporary conversation', async () => {
|
||||
// First save a non-temporary conversation
|
||||
mockReq.body = { isTemporary: false };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
expect(firstSave.expiredAt).toBeNull();
|
||||
|
||||
// Update without isTemporary flag
|
||||
mockReq.body = {};
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockReq, updatedData);
|
||||
|
||||
expect(secondSave.title).toBe('Updated Title');
|
||||
expect(secondSave.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosByCursor', async () => {
|
||||
// Create some test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Future expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours from now
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
// Mock Meili search
|
||||
Conversation.meiliSearch = jest.fn().mockResolvedValue({ hits: [] });
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// Should only return conversations with null or non-existent expiredAt
|
||||
expect(result.conversations).toHaveLength(1);
|
||||
expect(result.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosQueried', async () => {
|
||||
// Create test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
});
|
||||
|
||||
const expiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
});
|
||||
|
||||
const convoIds = [
|
||||
{ conversationId: nonExpiredConvo.conversationId },
|
||||
{ conversationId: expiredConvo.conversationId },
|
||||
];
|
||||
|
||||
const result = await getConvosQueried('user123', convoIds);
|
||||
|
||||
// Should only return the non-expired conversation
|
||||
expect(result.conversations).toHaveLength(1);
|
||||
expect(result.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
expect(result.convoMap[nonExpiredConvo.conversationId]).toBeDefined();
|
||||
expect(result.convoMap[expiredConvo.conversationId]).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchConversation', () => {
|
||||
it('should find a conversation by conversationId', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await searchConversation(mockConversationData.conversationId);
|
||||
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBeUndefined(); // Only returns conversationId and user
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await searchConversation('non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvo', () => {
|
||||
it('should retrieve a conversation for a user', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvo('user123', mockConversationData.conversationId);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await getConvo('user123', 'non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoTitle', () => {
|
||||
it('should return the conversation title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Title',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBe('Test Title');
|
||||
});
|
||||
|
||||
it('should return null if conversation has no title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: null,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return "New Chat" if conversation not found', async () => {
|
||||
const result = await getConvoTitle('user123', 'non-existent-id');
|
||||
expect(result).toBe('New Chat');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoFiles', () => {
|
||||
it('should return conversation files', async () => {
|
||||
const files = ['file1', 'file2'];
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
files,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual(files);
|
||||
});
|
||||
|
||||
it('should return empty array if no files', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return empty array if conversation not found', async () => {
|
||||
const result = await getConvoFiles('non-existent-id');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteConvos', () => {
|
||||
it('should delete conversations and associated messages', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'To Delete',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 5 });
|
||||
|
||||
const result = await deleteConvos('user123', {
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
expect(result.messages.deletedCount).toBe(5);
|
||||
expect(deleteMessages).toHaveBeenCalledWith({
|
||||
conversationId: { $in: [mockConversationData.conversationId] },
|
||||
});
|
||||
|
||||
// Verify conversation was deleted
|
||||
const deletedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(deletedConvo).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw error if no conversations found', async () => {
|
||||
await expect(deleteConvos('user123', { conversationId: 'non-existent' })).rejects.toThrow(
|
||||
'Conversation not found or already deleted.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteNullOrEmptyConversations', () => {
|
||||
it('should delete conversations with null, empty, or missing conversationIds', async () => {
|
||||
// Since conversationId is required by the schema, we can't create documents with null/missing IDs
|
||||
// This test should verify the function works when such documents exist (e.g., from data corruption)
|
||||
|
||||
// For this test, let's create a valid conversation and verify the function doesn't delete it
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
const result = await deleteNullOrEmptyConversations();
|
||||
|
||||
expect(result.conversations.deletedCount).toBe(0); // No invalid conversations to delete
|
||||
expect(result.messages.deletedCount).toBe(0);
|
||||
|
||||
// Verify valid conversation remains
|
||||
const remainingConvos = await Conversation.find({});
|
||||
expect(remainingConvos).toHaveLength(1);
|
||||
expect(remainingConvos[0].conversationId).toBe(mockConversationData.conversationId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle database errors in saveConvo', async () => {
|
||||
// Force a database error by disconnecting
|
||||
await mongoose.disconnect();
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result).toEqual({ message: 'Error saving conversation' });
|
||||
|
||||
// Reconnect for other tests
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { messageSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
updateMessage,
|
||||
deleteMessages,
|
||||
bulkSaveMessages,
|
||||
updateMessageText,
|
||||
deleteMessagesSince,
|
||||
} = require('./Message');
|
||||
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IMessage>}
|
||||
*/
|
||||
@@ -121,21 +117,21 @@ describe('Message Operations', () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Create multiple messages in the same conversation
|
||||
await saveMessage(mockReq, {
|
||||
const message1 = await saveMessage(mockReq, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'First message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockReq, {
|
||||
const message2 = await saveMessage(mockReq, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
text: 'Second message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockReq, {
|
||||
const message3 = await saveMessage(mockReq, {
|
||||
messageId: 'msg3',
|
||||
conversationId,
|
||||
text: 'Third message',
|
||||
@@ -318,265 +314,4 @@ describe('Message Operations', () => {
|
||||
expect(messages[0].text).toBe('Victim message');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary message handling', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should save a message with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is false', async () => {
|
||||
mockReq.body = { isTemporary: false };
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary in body
|
||||
mockReq.body = {};
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
// Should still save the message but with expiredAt as null
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not update expiredAt on message update', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const savedMessage = await saveMessage(mockReq, mockMessageData);
|
||||
const originalExpiredAt = savedMessage.expiredAt;
|
||||
|
||||
// Now update the message without isTemporary flag
|
||||
mockReq.body = {};
|
||||
const updatedMessage = await updateMessage(mockReq, {
|
||||
messageId: 'msg123',
|
||||
text: 'Updated text',
|
||||
});
|
||||
|
||||
// expiredAt should not be in the returned updated message object
|
||||
expect(updatedMessage.expiredAt).toBeUndefined();
|
||||
|
||||
// Verify in database that expiredAt wasn't changed
|
||||
const dbMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(dbMessage.expiredAt).toEqual(originalExpiredAt);
|
||||
});
|
||||
|
||||
it('should preserve expiredAt when saving existing temporary message', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveMessage(mockReq, mockMessageData);
|
||||
const originalExpiredAt = firstSave.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same messageId but different text
|
||||
const updatedData = { ...mockMessageData, text: 'Updated text' };
|
||||
const secondSave = await saveMessage(mockReq, updatedData);
|
||||
|
||||
// Should update text but create new expiredAt
|
||||
expect(secondSave.text).toBe('Updated text');
|
||||
expect(secondSave.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave.expiredAt).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle bulk operations with temporary messages', async () => {
|
||||
// This test verifies bulkSaveMessages doesn't interfere with expiredAt
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'bulk1',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 1',
|
||||
user: 'user123',
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
},
|
||||
{
|
||||
messageId: 'bulk2',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 2',
|
||||
user: 'user123',
|
||||
expiredAt: null,
|
||||
},
|
||||
];
|
||||
|
||||
await bulkSaveMessages(messages);
|
||||
|
||||
const savedMessages = await Message.find({
|
||||
messageId: { $in: ['bulk1', 'bulk2'] },
|
||||
}).lean();
|
||||
|
||||
expect(savedMessages).toHaveLength(2);
|
||||
|
||||
const bulk1 = savedMessages.find((m) => m.messageId === 'bulk1');
|
||||
const bulk2 = savedMessages.find((m) => m.messageId === 'bulk2');
|
||||
|
||||
expect(bulk1.expiredAt).toBeDefined();
|
||||
expect(bulk2.expiredAt).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { matchModelName } = require('../utils/tokens');
|
||||
const { matchModelName } = require('../utils');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
@@ -87,9 +87,6 @@ const tokenValues = Object.assign(
|
||||
'gpt-4.1': { prompt: 2, completion: 8 },
|
||||
'gpt-4.5': { prompt: 75, completion: 150 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-5': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
||||
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
||||
'gpt-4o': { prompt: 2.5, completion: 10 },
|
||||
'gpt-4o-2024-05-13': { prompt: 5, completion: 15 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
@@ -150,9 +147,6 @@ const tokenValues = Object.assign(
|
||||
codestral: { prompt: 0.3, completion: 0.9 },
|
||||
'ministral-8b': { prompt: 0.1, completion: 0.1 },
|
||||
'ministral-3b': { prompt: 0.04, completion: 0.04 },
|
||||
// GPT-OSS models
|
||||
'gpt-oss-20b': { prompt: 0.05, completion: 0.2 },
|
||||
'gpt-oss-120b': { prompt: 0.15, completion: 0.6 },
|
||||
},
|
||||
bedrockValues,
|
||||
);
|
||||
@@ -220,12 +214,6 @@ const getValueKey = (model, endpoint) => {
|
||||
return 'gpt-4.1';
|
||||
} else if (modelName.includes('gpt-4o-2024-05-13')) {
|
||||
return 'gpt-4o-2024-05-13';
|
||||
} else if (modelName.includes('gpt-5-nano')) {
|
||||
return 'gpt-5-nano';
|
||||
} else if (modelName.includes('gpt-5-mini')) {
|
||||
return 'gpt-5-mini';
|
||||
} else if (modelName.includes('gpt-5')) {
|
||||
return 'gpt-5';
|
||||
} else if (modelName.includes('gpt-4o-mini')) {
|
||||
return 'gpt-4o-mini';
|
||||
} else if (modelName.includes('gpt-4o')) {
|
||||
|
||||
@@ -25,14 +25,8 @@ describe('getValueKey', () => {
|
||||
expect(getValueKey('gpt-4-some-other-info')).toBe('8k');
|
||||
});
|
||||
|
||||
it('should return "gpt-5" for model name containing "gpt-5"', () => {
|
||||
expect(getValueKey('gpt-5-some-other-info')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-2025-01-30-0130')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-turbo')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-0130')).toBe('gpt-5');
|
||||
it('should return undefined for model names that do not match any known patterns', () => {
|
||||
expect(getValueKey('gpt-5-some-other-info')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
||||
@@ -90,29 +84,6 @@ describe('getValueKey', () => {
|
||||
expect(getValueKey('gpt-4.1-nano-0125')).toBe('gpt-4.1-nano');
|
||||
});
|
||||
|
||||
it('should return "gpt-5" for model type of "gpt-5"', () => {
|
||||
expect(getValueKey('gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-2025-01-30-0130')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-turbo')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-0130')).toBe('gpt-5');
|
||||
});
|
||||
|
||||
it('should return "gpt-5-mini" for model type of "gpt-5-mini"', () => {
|
||||
expect(getValueKey('gpt-5-mini-2025-01-30')).toBe('gpt-5-mini');
|
||||
expect(getValueKey('openai/gpt-5-mini')).toBe('gpt-5-mini');
|
||||
expect(getValueKey('gpt-5-mini-0130')).toBe('gpt-5-mini');
|
||||
expect(getValueKey('gpt-5-mini-2025-01-30-0130')).toBe('gpt-5-mini');
|
||||
});
|
||||
|
||||
it('should return "gpt-5-nano" for model type of "gpt-5-nano"', () => {
|
||||
expect(getValueKey('gpt-5-nano-2025-01-30')).toBe('gpt-5-nano');
|
||||
expect(getValueKey('openai/gpt-5-nano')).toBe('gpt-5-nano');
|
||||
expect(getValueKey('gpt-5-nano-0130')).toBe('gpt-5-nano');
|
||||
expect(getValueKey('gpt-5-nano-2025-01-30-0130')).toBe('gpt-5-nano');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o" for model type of "gpt-4o"', () => {
|
||||
expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o');
|
||||
@@ -236,48 +207,6 @@ describe('getMultiplier', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5', () => {
|
||||
const valueKey = getValueKey('gpt-5-2025-01-30');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-5'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5-preview', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5-mini', () => {
|
||||
const valueKey = getValueKey('gpt-5-mini-2025-01-30');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-5-mini'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-mini'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5-mini-preview', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5-mini'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5-mini', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-mini'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5-nano', () => {
|
||||
const valueKey = getValueKey('gpt-5-nano-2025-01-30');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-5-nano'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-nano'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5-nano-preview', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5-nano'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5-nano', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-nano'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o', () => {
|
||||
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
@@ -378,22 +307,10 @@ describe('getMultiplier', () => {
|
||||
});
|
||||
|
||||
it('should return defaultRate if derived valueKey does not match any known patterns', () => {
|
||||
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-10-some-other-info' })).toBe(
|
||||
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-5-some-other-info' })).toBe(
|
||||
defaultRate,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return correct multipliers for GPT-OSS models', () => {
|
||||
const models = ['gpt-oss-20b', 'gpt-oss-120b'];
|
||||
models.forEach((key) => {
|
||||
const expectedPrompt = tokenValues[key].prompt;
|
||||
const expectedCompletion = tokenValues[key].completion;
|
||||
expect(getMultiplier({ valueKey: key, tokenType: 'prompt' })).toBe(expectedPrompt);
|
||||
expect(getMultiplier({ valueKey: key, tokenType: 'completion' })).toBe(expectedCompletion);
|
||||
expect(getMultiplier({ model: key, tokenType: 'prompt' })).toBe(expectedPrompt);
|
||||
expect(getMultiplier({ model: key, tokenType: 'completion' })).toBe(expectedCompletion);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('AWS Bedrock Model Tests', () => {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.0-rc1",
|
||||
"version": "v0.7.9",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -49,10 +49,10 @@
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.75",
|
||||
"@librechat/agents": "^2.4.68",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.17.1",
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
|
||||
@@ -1,16 +1,54 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
getToolkitKey,
|
||||
checkPluginAuth,
|
||||
filterUniquePlugins,
|
||||
convertMCPToolsToPlugins,
|
||||
} = require('@librechat/api');
|
||||
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
const { getToolkitKey } = require('~/server/services/ToolService');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { availableTools } = require('~/app/clients/tools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Filters out duplicate plugins from the list of plugins.
|
||||
*
|
||||
* @param {TPlugin[]} plugins The list of plugins to filter.
|
||||
* @returns {TPlugin[]} The list of plugins with duplicates removed.
|
||||
*/
|
||||
const filterUniquePlugins = (plugins) => {
|
||||
const seen = new Set();
|
||||
return plugins.filter((plugin) => {
|
||||
const duplicate = seen.has(plugin.pluginKey);
|
||||
seen.add(plugin.pluginKey);
|
||||
return !duplicate;
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines if a plugin is authenticated by checking if all required authentication fields have non-empty values.
|
||||
* Supports alternate authentication fields, allowing validation against multiple possible environment variables.
|
||||
*
|
||||
* @param {TPlugin} plugin The plugin object containing the authentication configuration.
|
||||
* @returns {boolean} True if the plugin is authenticated for all required fields, false otherwise.
|
||||
*/
|
||||
const checkPluginAuth = (plugin) => {
|
||||
if (!plugin.authConfig || plugin.authConfig.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return plugin.authConfig.every((authFieldObj) => {
|
||||
const authFieldOptions = authFieldObj.authField.split('||');
|
||||
let isFieldAuthenticated = false;
|
||||
|
||||
for (const fieldOption of authFieldOptions) {
|
||||
const envValue = process.env[fieldOption];
|
||||
if (envValue && envValue.trim() !== '' && envValue !== AuthType.USER_PROVIDED) {
|
||||
isFieldAuthenticated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return isFieldAuthenticated;
|
||||
});
|
||||
};
|
||||
|
||||
const getAvailablePluginsController = async (req, res) => {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
@@ -105,9 +143,9 @@ const getAvailableTools = async (req, res) => {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
const userPlugins = convertMCPToolsToPlugins({ functionTools: cachedUserTools, customConfig });
|
||||
const userPlugins = convertMCPToolsToPlugins(cachedUserTools, customConfig);
|
||||
|
||||
if (cachedToolsArray != null && userPlugins != null) {
|
||||
if (cachedToolsArray && userPlugins) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
return;
|
||||
@@ -147,9 +185,7 @@ const getAvailableTools = async (req, res) => {
|
||||
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
|
||||
const isToolkit =
|
||||
plugin.toolkit === true &&
|
||||
Object.keys(toolDefinitions).some(
|
||||
(key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey,
|
||||
);
|
||||
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey);
|
||||
|
||||
if (!isToolDefined && !isToolkit) {
|
||||
continue;
|
||||
@@ -199,6 +235,58 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Converts MCP function format tools to plugin format
|
||||
* @param {Object} functionTools - Object with function format tools
|
||||
* @param {Object} customConfig - Custom configuration for MCP servers
|
||||
* @returns {Array} Array of plugin objects
|
||||
*/
|
||||
function convertMCPToolsToPlugins(functionTools, customConfig) {
|
||||
const plugins = [];
|
||||
|
||||
for (const [toolKey, toolData] of Object.entries(functionTools)) {
|
||||
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const functionData = toolData.function;
|
||||
const parts = toolKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
|
||||
const serverConfig = customConfig?.mcpServers?.[serverName];
|
||||
|
||||
const plugin = {
|
||||
name: parts[0], // Use the tool name without server suffix
|
||||
pluginKey: toolKey,
|
||||
description: functionData.description || '',
|
||||
authenticated: true,
|
||||
icon: serverConfig?.iconPath,
|
||||
};
|
||||
|
||||
// Build authConfig for MCP tools
|
||||
if (!serverConfig?.customUserVars) {
|
||||
plugin.authConfig = [];
|
||||
plugins.push(plugin);
|
||||
continue;
|
||||
}
|
||||
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
if (customVarKeys.length === 0) {
|
||||
plugin.authConfig = [];
|
||||
} else {
|
||||
plugin.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}));
|
||||
}
|
||||
|
||||
plugins.push(plugin);
|
||||
}
|
||||
|
||||
return plugins;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getAvailableTools,
|
||||
getAvailablePluginsController,
|
||||
|
||||
@@ -28,211 +28,19 @@ jest.mock('~/config', () => ({
|
||||
|
||||
jest.mock('~/app/clients/tools', () => ({
|
||||
availableTools: [],
|
||||
toolkits: [],
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
checkPluginAuth: jest.fn(),
|
||||
filterUniquePlugins: jest.fn(),
|
||||
convertMCPToolsToPlugins: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import the actual module with the function we want to test
|
||||
const { getAvailableTools, getAvailablePluginsController } = require('./PluginController');
|
||||
const {
|
||||
filterUniquePlugins,
|
||||
checkPluginAuth,
|
||||
convertMCPToolsToPlugins,
|
||||
getToolkitKey,
|
||||
} = require('@librechat/api');
|
||||
const { getAvailableTools } = require('./PluginController');
|
||||
|
||||
describe('PluginController', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = { user: { id: 'test-user-id' } };
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
});
|
||||
|
||||
describe('getAvailablePluginsController', () => {
|
||||
beforeEach(() => {
|
||||
mockReq.app = { locals: { filteredTools: [], includedTools: [] } };
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to remove duplicate plugins', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).toHaveBeenCalled();
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
// The response includes authenticated: true for each plugin when checkPluginAuth returns true
|
||||
expect(mockRes.json).toHaveBeenCalledWith([
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First', authenticated: true },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second', authenticated: true },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify plugin authentication', async () => {
|
||||
const mockPlugin = { name: 'Plugin1', pluginKey: 'key1', description: 'First' };
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValueOnce(true);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
});
|
||||
|
||||
it('should return cached plugins when available', async () => {
|
||||
const cachedPlugins = [
|
||||
{ name: 'CachedPlugin', pluginKey: 'cached', description: 'Cached plugin' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedPlugins);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).not.toHaveBeenCalled();
|
||||
expect(checkPluginAuth).not.toHaveBeenCalled();
|
||||
expect(mockRes.json).toHaveBeenCalledWith(cachedPlugins);
|
||||
});
|
||||
|
||||
it('should filter plugins based on includedTools', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
mockReq.app.locals.includedTools = ['key1'];
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(1);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAvailableTools', () => {
|
||||
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}server1`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
},
|
||||
};
|
||||
const mockConvertedPlugins = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
},
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockConvertedPlugins);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: mockUserTools,
|
||||
customConfig: null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
|
||||
const mockUserPlugins = [
|
||||
{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' },
|
||||
];
|
||||
const mockManifestPlugins = [
|
||||
{ name: 'ManifestTool', pluginKey: 'manifest-tool', description: 'Manifest tool' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(mockManifestPlugins);
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockUserPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...mockUserPlugins, ...mockManifestPlugins]);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should be called to deduplicate the combined array
|
||||
expect(filterUniquePlugins).toHaveBeenLastCalledWith([
|
||||
...mockUserPlugins,
|
||||
...mockManifestPlugins,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify authentication status', async () => {
|
||||
const mockPlugin = { name: 'Tool1', pluginKey: 'tool1', description: 'Tool 1' };
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({ tool1: true });
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
});
|
||||
|
||||
it('should use getToolkitKey for toolkit validation', async () => {
|
||||
const mockToolkit = {
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
description: 'Toolkit 1',
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue('toolkit1');
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({
|
||||
toolkit1_function: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(getToolkitKey).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('plugin.icon behavior', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
|
||||
const callGetAvailableToolsWithMCPServer = async (mcpServers) => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue({ mcpServers });
|
||||
@@ -242,22 +50,7 @@ describe('PluginController', () => {
|
||||
function: { name: 'test-tool', description: 'A test tool' },
|
||||
},
|
||||
};
|
||||
|
||||
const mockConvertedPlugin = {
|
||||
name: 'test-tool',
|
||||
pluginKey: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
icon: mcpServers['test-server']?.iconPath,
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
};
|
||||
|
||||
getCachedTools.mockResolvedValueOnce(functionTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockConvertedPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
@@ -267,6 +60,14 @@ describe('PluginController', () => {
|
||||
return responseData.find((tool) => tool.name === 'test-tool');
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = { user: { id: 'test-user-id' } };
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
});
|
||||
|
||||
it('should set plugin.icon when iconPath is defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {
|
||||
@@ -285,236 +86,4 @@ describe('PluginController', () => {
|
||||
expect(testTool.icon).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('helper function integration', () => {
|
||||
it('should properly handle MCP tools with custom user variables', async () => {
|
||||
const customConfig = {
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: { title: 'API Key', description: 'Your API key' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// We need to test the actual flow where MCP manager tools are included
|
||||
const mcpManagerTools = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
},
|
||||
];
|
||||
|
||||
// Mock the MCP manager to return tools
|
||||
const mockMCPManager = {
|
||||
loadManifestTools: jest.fn().mockResolvedValue(mcpManagerTools),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
|
||||
// First call returns user tools (empty in this case)
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock convertMCPToolsToPlugins to return empty array for user tools
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
|
||||
// Mock filterUniquePlugins to pass through
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
|
||||
// Mock checkPluginAuth
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
// Second call returns tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Find the MCP tool in the response
|
||||
const mcpTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
|
||||
// The actual implementation adds authConfig and sets authenticated to false when customUserVars exist
|
||||
expect(mcpTool).toBeDefined();
|
||||
expect(mcpTool.authConfig).toEqual([
|
||||
{ authField: 'API_KEY', label: 'API Key', description: 'Your API key' },
|
||||
]);
|
||||
expect(mcpTool.authenticated).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle error cases gracefully', async () => {
|
||||
mockCache.get.mockRejectedValue(new Error('Cache error'));
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ message: 'Cache error' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases with undefined/null values', () => {
|
||||
it('should handle undefined cache gracefully', async () => {
|
||||
getLogStores.mockReturnValue(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
});
|
||||
|
||||
it('should handle null cachedTools and cachedUserTools', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: null,
|
||||
customConfig: null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle when getCachedTools returns undefined', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(undefined);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
// Mock getCachedTools to return undefined for both calls
|
||||
getCachedTools.mockReset();
|
||||
getCachedTools.mockResolvedValueOnce(undefined).mockResolvedValueOnce(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: undefined,
|
||||
customConfig: null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
const userTools = {
|
||||
'user-tool': { function: { name: 'user-tool', description: 'User tool' } },
|
||||
};
|
||||
const userPlugins = [{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' }];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValue(userTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(userPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...userPlugins, ...cachedTools]);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([...userPlugins, ...cachedTools]);
|
||||
});
|
||||
|
||||
it('should handle empty toolDefinitions object', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// With empty tool definitions, no tools should be in the final output
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle MCP tools without customUserVars', async () => {
|
||||
const customConfig = {
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
// No customUserVars defined
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
},
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
const mockPlugin = {
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
};
|
||||
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
// The actual implementation doesn't set authConfig on tools without customUserVars
|
||||
expect(responseData[0].authConfig).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle req.app.locals with undefined filteredTools and includedTools', async () => {
|
||||
mockReq.app = { locals: {} };
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle toolkit with undefined toolDefinitions keys', async () => {
|
||||
const mockToolkit = {
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
description: 'Toolkit 1',
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return null
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle null toolDefinitions gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
|
||||
const { webSearchKeys, extractWebSearchEnvVars } = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
updateUser,
|
||||
@@ -89,8 +89,8 @@ const updateUserPluginsController = async (req, res) => {
|
||||
|
||||
if (userPluginsService instanceof Error) {
|
||||
logger.error('[userPluginsService]', userPluginsService);
|
||||
const { status, message } = normalizeHttpError(userPluginsService);
|
||||
return res.status(status).send({ message });
|
||||
const { status, message } = userPluginsService;
|
||||
res.status(status).send({ message });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
authService = await updateUserPluginAuth(user.id, keys[i], pluginKey, values[i]);
|
||||
if (authService instanceof Error) {
|
||||
logger.error('[authService]', authService);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
({ status, message } = authService);
|
||||
}
|
||||
}
|
||||
} else if (action === 'uninstall') {
|
||||
@@ -151,7 +151,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
`[authService] Error deleting all auth for MCP tool ${pluginKey}:`,
|
||||
authService,
|
||||
);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
({ status, message } = authService);
|
||||
}
|
||||
} else {
|
||||
// This handles:
|
||||
@@ -163,7 +163,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name
|
||||
if (authService instanceof Error) {
|
||||
logger.error('[authService] Error deleting specific auth key:', authService);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
({ status, message } = authService);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -193,8 +193,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
return res.status(status).send();
|
||||
}
|
||||
|
||||
const normalized = normalizeHttpError({ status, message });
|
||||
return res.status(normalized.status).send({ message: normalized.message });
|
||||
res.status(status).send({ message });
|
||||
} catch (err) {
|
||||
logger.error('[updateUserPluginsController]', err);
|
||||
return res.status(500).json({ message: 'Something went wrong.' });
|
||||
|
||||
@@ -402,34 +402,6 @@ class AgentClient extends BaseClient {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a promise that resolves with the memory promise result or undefined after a timeout
|
||||
* @param {Promise<(TAttachment | null)[] | undefined>} memoryPromise - The memory promise to await
|
||||
* @param {number} timeoutMs - Timeout in milliseconds (default: 3000)
|
||||
* @returns {Promise<(TAttachment | null)[] | undefined>}
|
||||
*/
|
||||
async awaitMemoryWithTimeout(memoryPromise, timeoutMs = 3000) {
|
||||
if (!memoryPromise) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Memory processing timeout')), timeoutMs),
|
||||
);
|
||||
|
||||
const attachments = await Promise.race([memoryPromise, timeoutPromise]);
|
||||
return attachments;
|
||||
} catch (error) {
|
||||
if (error.message === 'Memory processing timeout') {
|
||||
logger.warn('[AgentClient] Memory processing timed out after 3 seconds');
|
||||
} else {
|
||||
logger.error('[AgentClient] Error processing memory:', error);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<string | undefined>}
|
||||
*/
|
||||
@@ -540,39 +512,6 @@ class AgentClient extends BaseClient {
|
||||
return withoutKeys;
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters out image URLs from message content
|
||||
* @param {BaseMessage} message - The message to filter
|
||||
* @returns {BaseMessage} - A new message with image URLs removed
|
||||
*/
|
||||
filterImageUrls(message) {
|
||||
if (!message.content || typeof message.content === 'string') {
|
||||
return message;
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
const filteredContent = message.content.filter(
|
||||
(part) => part.type !== ContentTypes.IMAGE_URL,
|
||||
);
|
||||
|
||||
if (filteredContent.length === 1 && filteredContent[0].type === ContentTypes.TEXT) {
|
||||
const MessageClass = message.constructor;
|
||||
return new MessageClass({
|
||||
content: filteredContent[0].text,
|
||||
additional_kwargs: message.additional_kwargs,
|
||||
});
|
||||
}
|
||||
|
||||
const MessageClass = message.constructor;
|
||||
return new MessageClass({
|
||||
content: filteredContent,
|
||||
additional_kwargs: message.additional_kwargs,
|
||||
});
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {BaseMessage[]} messages
|
||||
* @returns {Promise<void | (TAttachment | null)[]>}
|
||||
@@ -601,8 +540,7 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
const filteredMessages = messagesToProcess.map((msg) => this.filterImageUrls(msg));
|
||||
const bufferString = getBufferString(filteredMessages);
|
||||
const bufferString = getBufferString(messagesToProcess);
|
||||
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
|
||||
return await this.processMemory([bufferMessage]);
|
||||
} catch (error) {
|
||||
@@ -1030,9 +968,11 @@ class AgentClient extends BaseClient {
|
||||
});
|
||||
|
||||
try {
|
||||
const attachments = await this.awaitMemoryWithTimeout(memoryPromise);
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
} catch (err) {
|
||||
@@ -1042,9 +982,11 @@ class AgentClient extends BaseClient {
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
const attachments = await this.awaitMemoryWithTimeout(memoryPromise);
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
@@ -1146,16 +1088,11 @@ class AgentClient extends BaseClient {
|
||||
clientOptions.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
const shouldRemoveMaxTokens = /\b(o\d|gpt-[5-9])\b/i.test(clientOptions.model);
|
||||
if (shouldRemoveMaxTokens && clientOptions.maxTokens != null) {
|
||||
delete clientOptions.maxTokens;
|
||||
} else if (!shouldRemoveMaxTokens && !clientOptions.maxTokens) {
|
||||
// Ensure maxTokens is set for non-o1 models
|
||||
if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) {
|
||||
clientOptions.maxTokens = 75;
|
||||
}
|
||||
if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_completion_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_completion_tokens;
|
||||
} else if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_output_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_output_tokens;
|
||||
} else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
clientOptions = Object.assign(
|
||||
|
||||
@@ -727,464 +727,4 @@ describe('AgentClient - titleConvo', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getOptions method - GPT-5+ model handling', () => {
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
let mockAgent;
|
||||
let mockOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockAgent = {
|
||||
id: 'agent-123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
provider: EModelEndpoint.openAI,
|
||||
model_parameters: {
|
||||
model: 'gpt-5',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {},
|
||||
},
|
||||
user: {
|
||||
id: 'user-123',
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {};
|
||||
|
||||
mockOptions = {
|
||||
req: mockReq,
|
||||
res: mockRes,
|
||||
agent: mockAgent,
|
||||
};
|
||||
|
||||
client = new AgentClient(mockOptions);
|
||||
});
|
||||
|
||||
it('should move maxTokens to modelKwargs.max_completion_tokens for GPT-5 models', () => {
|
||||
const clientOptions = {
|
||||
model: 'gpt-5',
|
||||
maxTokens: 2048,
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
// Simulate the getOptions logic that handles GPT-5+ models
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
clientOptions.modelKwargs.max_completion_tokens = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
expect(clientOptions.maxTokens).toBeUndefined();
|
||||
expect(clientOptions.modelKwargs).toBeDefined();
|
||||
expect(clientOptions.modelKwargs.max_completion_tokens).toBe(2048);
|
||||
expect(clientOptions.temperature).toBe(0.7); // Other options should remain
|
||||
});
|
||||
|
||||
it('should move maxTokens to modelKwargs.max_output_tokens for GPT-5 models with useResponsesApi', () => {
|
||||
const clientOptions = {
|
||||
model: 'gpt-5',
|
||||
maxTokens: 2048,
|
||||
temperature: 0.7,
|
||||
useResponsesApi: true,
|
||||
};
|
||||
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
const paramName =
|
||||
clientOptions.useResponsesApi === true ? 'max_output_tokens' : 'max_completion_tokens';
|
||||
clientOptions.modelKwargs[paramName] = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
expect(clientOptions.maxTokens).toBeUndefined();
|
||||
expect(clientOptions.modelKwargs).toBeDefined();
|
||||
expect(clientOptions.modelKwargs.max_output_tokens).toBe(2048);
|
||||
expect(clientOptions.temperature).toBe(0.7); // Other options should remain
|
||||
});
|
||||
|
||||
it('should handle GPT-5+ models with existing modelKwargs', () => {
|
||||
const clientOptions = {
|
||||
model: 'gpt-6',
|
||||
maxTokens: 1500,
|
||||
temperature: 0.8,
|
||||
modelKwargs: {
|
||||
customParam: 'value',
|
||||
},
|
||||
};
|
||||
|
||||
// Simulate the getOptions logic
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
clientOptions.modelKwargs.max_completion_tokens = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
expect(clientOptions.maxTokens).toBeUndefined();
|
||||
expect(clientOptions.modelKwargs).toEqual({
|
||||
customParam: 'value',
|
||||
max_completion_tokens: 1500,
|
||||
});
|
||||
});
|
||||
|
||||
it('should not modify maxTokens for non-GPT-5+ models', () => {
|
||||
const clientOptions = {
|
||||
model: 'gpt-4',
|
||||
maxTokens: 2048,
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
// Simulate the getOptions logic
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
clientOptions.modelKwargs.max_completion_tokens = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
// Should not be modified since it's GPT-4
|
||||
expect(clientOptions.maxTokens).toBe(2048);
|
||||
expect(clientOptions.modelKwargs).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle various GPT-5+ model formats', () => {
|
||||
const testCases = [
|
||||
{ model: 'gpt-5', shouldTransform: true },
|
||||
{ model: 'gpt-5-turbo', shouldTransform: true },
|
||||
{ model: 'gpt-6', shouldTransform: true },
|
||||
{ model: 'gpt-7-preview', shouldTransform: true },
|
||||
{ model: 'gpt-8', shouldTransform: true },
|
||||
{ model: 'gpt-9-mini', shouldTransform: true },
|
||||
{ model: 'gpt-4', shouldTransform: false },
|
||||
{ model: 'gpt-4o', shouldTransform: false },
|
||||
{ model: 'gpt-3.5-turbo', shouldTransform: false },
|
||||
{ model: 'claude-3', shouldTransform: false },
|
||||
];
|
||||
|
||||
testCases.forEach(({ model, shouldTransform }) => {
|
||||
const clientOptions = {
|
||||
model,
|
||||
maxTokens: 1000,
|
||||
};
|
||||
|
||||
// Simulate the getOptions logic
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
clientOptions.modelKwargs.max_completion_tokens = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
if (shouldTransform) {
|
||||
expect(clientOptions.maxTokens).toBeUndefined();
|
||||
expect(clientOptions.modelKwargs?.max_completion_tokens).toBe(1000);
|
||||
} else {
|
||||
expect(clientOptions.maxTokens).toBe(1000);
|
||||
expect(clientOptions.modelKwargs).toBeUndefined();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should not swap max token param for older models when using useResponsesApi', () => {
|
||||
const testCases = [
|
||||
{ model: 'gpt-5', shouldTransform: true },
|
||||
{ model: 'gpt-5-turbo', shouldTransform: true },
|
||||
{ model: 'gpt-6', shouldTransform: true },
|
||||
{ model: 'gpt-7-preview', shouldTransform: true },
|
||||
{ model: 'gpt-8', shouldTransform: true },
|
||||
{ model: 'gpt-9-mini', shouldTransform: true },
|
||||
{ model: 'gpt-4', shouldTransform: false },
|
||||
{ model: 'gpt-4o', shouldTransform: false },
|
||||
{ model: 'gpt-3.5-turbo', shouldTransform: false },
|
||||
{ model: 'claude-3', shouldTransform: false },
|
||||
];
|
||||
|
||||
testCases.forEach(({ model, shouldTransform }) => {
|
||||
const clientOptions = {
|
||||
model,
|
||||
maxTokens: 1000,
|
||||
useResponsesApi: true,
|
||||
};
|
||||
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
const paramName =
|
||||
clientOptions.useResponsesApi === true ? 'max_output_tokens' : 'max_completion_tokens';
|
||||
clientOptions.modelKwargs[paramName] = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
if (shouldTransform) {
|
||||
expect(clientOptions.maxTokens).toBeUndefined();
|
||||
expect(clientOptions.modelKwargs?.max_output_tokens).toBe(1000);
|
||||
} else {
|
||||
expect(clientOptions.maxTokens).toBe(1000);
|
||||
expect(clientOptions.modelKwargs).toBeUndefined();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should not transform if maxTokens is null or undefined', () => {
|
||||
const testCases = [
|
||||
{ model: 'gpt-5', maxTokens: null },
|
||||
{ model: 'gpt-5', maxTokens: undefined },
|
||||
{ model: 'gpt-6', maxTokens: 0 }, // Should transform even if 0
|
||||
];
|
||||
|
||||
testCases.forEach(({ model, maxTokens }, index) => {
|
||||
const clientOptions = {
|
||||
model,
|
||||
maxTokens,
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
// Simulate the getOptions logic
|
||||
if (/\bgpt-[5-9]\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
clientOptions.modelKwargs = clientOptions.modelKwargs ?? {};
|
||||
clientOptions.modelKwargs.max_completion_tokens = clientOptions.maxTokens;
|
||||
delete clientOptions.maxTokens;
|
||||
}
|
||||
|
||||
if (index < 2) {
|
||||
// null or undefined cases
|
||||
expect(clientOptions.maxTokens).toBe(maxTokens);
|
||||
expect(clientOptions.modelKwargs).toBeUndefined();
|
||||
} else {
|
||||
// 0 case - should transform
|
||||
expect(clientOptions.maxTokens).toBeUndefined();
|
||||
expect(clientOptions.modelKwargs?.max_completion_tokens).toBe(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('runMemory method', () => {
|
||||
let client;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
let mockAgent;
|
||||
let mockOptions;
|
||||
let mockProcessMemory;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockAgent = {
|
||||
id: 'agent-123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
provider: EModelEndpoint.openAI,
|
||||
model_parameters: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {
|
||||
memory: {
|
||||
messageWindowSize: 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
user: {
|
||||
id: 'user-123',
|
||||
personalization: {
|
||||
memories: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {};
|
||||
|
||||
mockOptions = {
|
||||
req: mockReq,
|
||||
res: mockRes,
|
||||
agent: mockAgent,
|
||||
};
|
||||
|
||||
mockProcessMemory = jest.fn().mockResolvedValue([]);
|
||||
|
||||
client = new AgentClient(mockOptions);
|
||||
client.processMemory = mockProcessMemory;
|
||||
client.conversationId = 'convo-123';
|
||||
client.responseMessageId = 'response-123';
|
||||
});
|
||||
|
||||
it('should filter out image URLs from message content', async () => {
|
||||
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'What is in this image?',
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==',
|
||||
detail: 'auto',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
new AIMessage('I can see a small red pixel in the image.'),
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'What about this one?',
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: 'data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/',
|
||||
detail: 'high',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
await client.runMemory(messages);
|
||||
|
||||
expect(mockProcessMemory).toHaveBeenCalledTimes(1);
|
||||
const processedMessage = mockProcessMemory.mock.calls[0][0][0];
|
||||
|
||||
// Verify the buffer message was created
|
||||
expect(processedMessage.constructor.name).toBe('HumanMessage');
|
||||
expect(processedMessage.content).toContain('# Current Chat:');
|
||||
|
||||
// Verify that image URLs are not in the buffer string
|
||||
expect(processedMessage.content).not.toContain('image_url');
|
||||
expect(processedMessage.content).not.toContain('data:image');
|
||||
expect(processedMessage.content).not.toContain('base64');
|
||||
|
||||
// Verify text content is preserved
|
||||
expect(processedMessage.content).toContain('What is in this image?');
|
||||
expect(processedMessage.content).toContain('I can see a small red pixel in the image.');
|
||||
expect(processedMessage.content).toContain('What about this one?');
|
||||
});
|
||||
|
||||
it('should handle messages with only text content', async () => {
|
||||
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
|
||||
const messages = [
|
||||
new HumanMessage('Hello, how are you?'),
|
||||
new AIMessage('I am doing well, thank you!'),
|
||||
new HumanMessage('That is great to hear.'),
|
||||
];
|
||||
|
||||
await client.runMemory(messages);
|
||||
|
||||
expect(mockProcessMemory).toHaveBeenCalledTimes(1);
|
||||
const processedMessage = mockProcessMemory.mock.calls[0][0][0];
|
||||
|
||||
expect(processedMessage.content).toContain('Hello, how are you?');
|
||||
expect(processedMessage.content).toContain('I am doing well, thank you!');
|
||||
expect(processedMessage.content).toContain('That is great to hear.');
|
||||
});
|
||||
|
||||
it('should handle mixed content types correctly', async () => {
|
||||
const { HumanMessage } = require('@langchain/core/messages');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Here is some text',
|
||||
},
|
||||
{
|
||||
type: ContentTypes.IMAGE_URL,
|
||||
image_url: {
|
||||
url: 'https://example.com/image.png',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: ' and more text',
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
await client.runMemory(messages);
|
||||
|
||||
expect(mockProcessMemory).toHaveBeenCalledTimes(1);
|
||||
const processedMessage = mockProcessMemory.mock.calls[0][0][0];
|
||||
|
||||
// Should contain text parts but not image URLs
|
||||
expect(processedMessage.content).toContain('Here is some text');
|
||||
expect(processedMessage.content).toContain('and more text');
|
||||
expect(processedMessage.content).not.toContain('example.com/image.png');
|
||||
expect(processedMessage.content).not.toContain('IMAGE_URL');
|
||||
});
|
||||
|
||||
it('should preserve original messages without mutation', async () => {
|
||||
const { HumanMessage } = require('@langchain/core/messages');
|
||||
const originalContent = [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Original text',
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: 'data:image/png;base64,ABC123',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [...originalContent],
|
||||
}),
|
||||
];
|
||||
|
||||
await client.runMemory(messages);
|
||||
|
||||
// Verify original message wasn't mutated
|
||||
expect(messages[0].content).toHaveLength(2);
|
||||
expect(messages[0].content[1].type).toBe('image_url');
|
||||
expect(messages[0].content[1].image_url.url).toBe('data:image/png;base64,ABC123');
|
||||
});
|
||||
|
||||
it('should handle message window size correctly', async () => {
|
||||
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
|
||||
const messages = [
|
||||
new HumanMessage('Message 1'),
|
||||
new AIMessage('Response 1'),
|
||||
new HumanMessage('Message 2'),
|
||||
new AIMessage('Response 2'),
|
||||
new HumanMessage('Message 3'),
|
||||
new AIMessage('Response 3'),
|
||||
];
|
||||
|
||||
// Window size is set to 3 in mockReq
|
||||
await client.runMemory(messages);
|
||||
|
||||
expect(mockProcessMemory).toHaveBeenCalledTimes(1);
|
||||
const processedMessage = mockProcessMemory.mock.calls[0][0][0];
|
||||
|
||||
// Should only include last 3 messages due to window size
|
||||
expect(processedMessage.content).toContain('Message 3');
|
||||
expect(processedMessage.content).toContain('Response 3');
|
||||
expect(processedMessage.content).not.toContain('Message 1');
|
||||
expect(processedMessage.content).not.toContain('Response 1');
|
||||
});
|
||||
|
||||
it('should return early if processMemory is not set', async () => {
|
||||
const { HumanMessage } = require('@langchain/core/messages');
|
||||
client.processMemory = null;
|
||||
|
||||
const result = await client.runMemory([new HumanMessage('Test')]);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(mockProcessMemory).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -105,6 +105,8 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
// const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
// logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||
}
|
||||
@@ -113,6 +115,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
|
||||
let run;
|
||||
try {
|
||||
// run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
@@ -125,9 +128,18 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
|
||||
let finalEvent;
|
||||
try {
|
||||
// const errorContentPart = {
|
||||
// text: {
|
||||
// value:
|
||||
// error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||
// },
|
||||
// type: ContentTypes.ERROR,
|
||||
// };
|
||||
|
||||
finalEvent = {
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
// runMessages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error finalizing error process`, error);
|
||||
|
||||
@@ -233,26 +233,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
);
|
||||
}
|
||||
}
|
||||
// Edge case: sendMessage completed but abort happened during sendCompletion
|
||||
// We need to ensure a final event is sent
|
||||
else if (!res.headersSent && !res.finished) {
|
||||
logger.debug(
|
||||
'[AgentController] Handling edge case: `sendMessage` completed but aborted during `sendCompletion`',
|
||||
);
|
||||
|
||||
const finalResponse = { ...response };
|
||||
finalResponse.error = true;
|
||||
|
||||
sendEvent(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: finalResponse,
|
||||
error: { message: 'Request was aborted during completion' },
|
||||
});
|
||||
res.end();
|
||||
}
|
||||
|
||||
// Save user message if needed
|
||||
if (!client.skipSaveUserMessage) {
|
||||
|
||||
@@ -194,9 +194,6 @@ const updateAgentHandler = async (req, res) => {
|
||||
});
|
||||
}
|
||||
|
||||
// Add version count to the response
|
||||
updatedAgent.version = updatedAgent.versions ? updatedAgent.versions.length : 0;
|
||||
|
||||
if (updatedAgent.author) {
|
||||
updatedAgent.author = updatedAgent.author.toString();
|
||||
}
|
||||
|
||||
@@ -498,28 +498,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
|
||||
});
|
||||
|
||||
test('should include version field in update response', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated with Version Check',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify version field is included and is a number
|
||||
expect(updatedAgent).toHaveProperty('version');
|
||||
expect(typeof updatedAgent.version).toBe('number');
|
||||
expect(updatedAgent.version).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(updatedAgent.version).toBe(agentInDb.versions.length);
|
||||
});
|
||||
|
||||
test('should handle validation errors properly', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
||||
@@ -152,7 +152,7 @@ const chatV1 = async (req, res) => {
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(run_id, { thread_id });
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error cancelling run', error);
|
||||
@@ -162,7 +162,7 @@ const chatV1 = async (req, res) => {
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(run_id, { thread_id });
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
@@ -623,7 +623,7 @@ const chatV1 = async (req, res) => {
|
||||
|
||||
if (!response.run.usage) {
|
||||
await sleep(3000);
|
||||
completedRun = await openai.beta.threads.runs.retrieve(response.run.id, { thread_id });
|
||||
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
|
||||
if (completedRun.usage) {
|
||||
await recordUsage({
|
||||
...completedRun.usage,
|
||||
|
||||
@@ -467,7 +467,7 @@ const chatV2 = async (req, res) => {
|
||||
|
||||
if (!response.run.usage) {
|
||||
await sleep(3000);
|
||||
completedRun = await openai.beta.threads.runs.retrieve(response.run.id, { thread_id });
|
||||
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
|
||||
if (completedRun.usage) {
|
||||
await recordUsage({
|
||||
...completedRun.usage,
|
||||
|
||||
@@ -108,7 +108,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(run_id, { thread_id });
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||
@@ -118,7 +118,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(run_id, { thread_id });
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
|
||||
@@ -173,16 +173,6 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Initializes the OpenAI client.
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {ServerRequest} params.req - The request object.
|
||||
* @param {ServerResponse} params.res - The response object.
|
||||
* @param {TEndpointOption} params.endpointOption - The endpoint options.
|
||||
* @param {boolean} params.initAppClient - Whether to initialize the app client.
|
||||
* @param {string} params.overrideEndpoint - The endpoint to override.
|
||||
* @returns {Promise<{ openai: OpenAIClient, openAIApiKey: string; client: import('~/app/clients/OpenAIClient') }>} - The initialized OpenAI client.
|
||||
*/
|
||||
async function getOpenAIClient({ req, res, endpointOption, initAppClient, overrideEndpoint }) {
|
||||
let endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
|
||||
@@ -197,7 +197,7 @@ const deleteAssistant = async (req, res) => {
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const deletionStatus = await openai.beta.assistants.delete(assistant_id);
|
||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||
if (deletionStatus?.deleted) {
|
||||
await deleteAssistantActions({ req, assistant_id });
|
||||
}
|
||||
@@ -365,7 +365,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
||||
try {
|
||||
await fs.unlink(req.file.path);
|
||||
logger.debug('[/:agent_id/avatar] Temp. image upload file deleted');
|
||||
} catch {
|
||||
} catch (error) {
|
||||
logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ async function abortRun(req, res) {
|
||||
|
||||
try {
|
||||
await cache.set(cacheKey, 'cancelled', three_minutes);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(run_id, { thread_id });
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug('[abortRun] Cancelled run:', cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error('[abortRun] Error cancelling run', error);
|
||||
@@ -60,7 +60,7 @@ async function abortRun(req, res) {
|
||||
}
|
||||
|
||||
try {
|
||||
const run = await openai.beta.threads.runs.retrieve(run_id, { thread_id });
|
||||
const run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,10 @@
|
||||
const express = require('express');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -103,16 +102,11 @@ router.get('/', async function (req, res) {
|
||||
payload.mcpServers = {};
|
||||
const config = await getCustomConfig();
|
||||
if (config?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
|
||||
for (const serverName in config.mcpServers) {
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = {
|
||||
customUserVars: serverConfig?.customUserVars || {},
|
||||
chatMenu: serverConfig?.chatMenu,
|
||||
isOAuth: oauthServers.has(serverName),
|
||||
startup: serverConfig?.startup,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ router.delete('/', async (req, res) => {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await assistantClients[endpoint].initializeClient({ req, res });
|
||||
try {
|
||||
const response = await openai.beta.threads.delete(thread_id);
|
||||
const response = await openai.beta.threads.del(thread_id);
|
||||
logger.debug('Deleted OpenAI thread:', response);
|
||||
} catch (error) {
|
||||
logger.error('Error deleting OpenAI thread:', error);
|
||||
|
||||
@@ -413,15 +413,13 @@ router.post('/', async (req, res) => {
|
||||
logger.error('[/files] Error deleting file:', error);
|
||||
}
|
||||
res.status(500).json({ message });
|
||||
} finally {
|
||||
if (cleanup) {
|
||||
try {
|
||||
await fs.unlink(req.file.path);
|
||||
} catch (error) {
|
||||
logger.error('[/files] Error deleting file after file processing:', error);
|
||||
}
|
||||
} else {
|
||||
logger.debug('[/files] File processing completed without cleanup');
|
||||
}
|
||||
|
||||
if (cleanup) {
|
||||
try {
|
||||
await fs.unlink(req.file.path);
|
||||
} catch (error) {
|
||||
logger.error('[/files] Error deleting file after file processing:', error);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -4,7 +4,6 @@ const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
@@ -93,6 +92,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
return res.redirect('/oauth/error?error=missing_state');
|
||||
}
|
||||
|
||||
// Extract flow ID from state
|
||||
const flowId = state;
|
||||
logger.debug('[MCP OAuth] Using flow ID from state', { flowId });
|
||||
|
||||
@@ -115,17 +115,22 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
hasCodeVerifier: !!flowState.codeVerifier,
|
||||
});
|
||||
|
||||
// Complete the OAuth flow
|
||||
logger.debug('[MCP OAuth] Completing OAuth flow');
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
// Try to establish the MCP connection with the new tokens
|
||||
try {
|
||||
const mcpManager = getMCPManager(flowState.userId);
|
||||
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
|
||||
|
||||
// For user-level OAuth, try to establish the connection
|
||||
if (flowState.userId !== 'system') {
|
||||
// We need to get the user object - in this case we'll need to reconstruct it
|
||||
const user = { id: flowState.userId };
|
||||
|
||||
// Try to establish connection with the new tokens
|
||||
const userConnection = await mcpManager.getUserConnection({
|
||||
user,
|
||||
serverName,
|
||||
@@ -142,8 +147,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
// Fetch and cache tools now that we have a successful connection
|
||||
const userTools = (await getCachedTools({ userId: flowState.userId })) || {};
|
||||
|
||||
// Remove any old tools from this server in the user's cache
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
@@ -151,6 +158,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new tools from this server
|
||||
const tools = await userConnection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
@@ -164,6 +172,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
};
|
||||
}
|
||||
|
||||
// Save the updated user tool cache
|
||||
await setCachedTools(userTools, { userId: flowState.userId });
|
||||
|
||||
logger.debug(
|
||||
@@ -173,6 +182,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Don't fail the OAuth callback if reconnection fails - the tokens are still saved
|
||||
logger.warn(
|
||||
`[MCP OAuth] Failed to reconnect ${serverName} after OAuth, but tokens are saved:`,
|
||||
error,
|
||||
@@ -208,6 +218,7 @@ router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
// Allow system flows or user-owned flows
|
||||
if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) {
|
||||
return res.status(403).json({ error: 'Access denied' });
|
||||
}
|
||||
@@ -275,7 +286,11 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
// Generate the flow ID for this user/server combination
|
||||
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
|
||||
// Check if flow exists
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (!flowState) {
|
||||
@@ -286,7 +301,8 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
||||
});
|
||||
}
|
||||
|
||||
await flowManager.failFlow(flowId, 'mcp_oauth', 'User cancelled OAuth flow');
|
||||
// Cancel the flow by marking it as failed
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', null, 'User cancelled OAuth flow');
|
||||
|
||||
logger.info(`[MCP OAuth Cancel] Successfully cancelled OAuth flow for ${serverName}`);
|
||||
|
||||
@@ -337,7 +353,9 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
||||
try {
|
||||
const value = await getUserPluginAuthValue(user.id, varName, false);
|
||||
customUserVars[varName] = value;
|
||||
if (value) {
|
||||
customUserVars[varName] = value;
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
|
||||
}
|
||||
@@ -360,7 +378,8 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
createToken,
|
||||
deleteTokens,
|
||||
},
|
||||
returnOnOAuth: true,
|
||||
returnOnOAuth: true, // Return immediately when OAuth is initiated
|
||||
// Add OAuth handlers to capture the OAuth URL when needed
|
||||
oauthStart: async (authURL) => {
|
||||
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
|
||||
oauthUrl = authURL;
|
||||
@@ -375,6 +394,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||
);
|
||||
|
||||
// Check if this is an OAuth error - if so, the flow state should be set up now
|
||||
const isOAuthError =
|
||||
err.message?.includes('OAuth') ||
|
||||
err.message?.includes('authentication') ||
|
||||
@@ -387,6 +407,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
|
||||
);
|
||||
oauthRequired = true;
|
||||
// Don't return error - continue so frontend can handle OAuth
|
||||
} else {
|
||||
logger.error(
|
||||
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
|
||||
@@ -396,9 +417,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Only fetch and cache tools if we successfully connected (no OAuth required)
|
||||
if (userConnection && !oauthRequired) {
|
||||
const userTools = (await getCachedTools({ userId: user.id })) || {};
|
||||
|
||||
// Remove any old tools from this server in the user's cache
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
@@ -406,6 +429,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new tools from this server
|
||||
const tools = await userConnection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
@@ -419,6 +443,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
};
|
||||
}
|
||||
|
||||
// Save the updated user tool cache
|
||||
await setCachedTools(userTools, { userId: user.id });
|
||||
}
|
||||
|
||||
@@ -426,19 +451,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||
);
|
||||
|
||||
const getResponseMessage = () => {
|
||||
if (oauthRequired) {
|
||||
return `MCP server '${serverName}' ready for OAuth authentication`;
|
||||
}
|
||||
if (userConnection) {
|
||||
return `MCP server '${serverName}' reinitialized successfully`;
|
||||
}
|
||||
return `Failed to reinitialize MCP server '${serverName}'`;
|
||||
};
|
||||
|
||||
res.json({
|
||||
success: (userConnection && !oauthRequired) || (oauthRequired && oauthUrl),
|
||||
message: getResponseMessage(),
|
||||
success: true,
|
||||
message: oauthRequired
|
||||
? `MCP server '${serverName}' ready for OAuth authentication`
|
||||
: `MCP server '${serverName}' reinitialized successfully`,
|
||||
serverName,
|
||||
oauthRequired,
|
||||
oauthUrl,
|
||||
@@ -451,7 +468,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
|
||||
/**
|
||||
* Get connection status for all MCP servers
|
||||
* This endpoint returns all app level and user-scoped connection statuses from MCPManager without disconnecting idle connections
|
||||
* This endpoint returns the actual connection status from MCPManager without disconnecting idle connections
|
||||
*/
|
||||
router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
@@ -461,19 +478,84 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
);
|
||||
const mcpManager = getMCPManager(user.id);
|
||||
const connectionStatus = {};
|
||||
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
const mcpConfig = config?.mcpServers;
|
||||
|
||||
const appConnections = mcpManager.getAllConnections() || new Map();
|
||||
const userConnections = mcpManager.getUserConnections(user.id) || new Map();
|
||||
const oauthServers = mcpManager.getOAuthServers() || new Set();
|
||||
|
||||
if (!mcpConfig) {
|
||||
return res.status(404).json({ error: 'MCP config not found' });
|
||||
}
|
||||
|
||||
// Get flow manager to check for active/timed-out OAuth flows
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
for (const [serverName] of Object.entries(mcpConfig)) {
|
||||
connectionStatus[serverName] = await getServerConnectionStatus(
|
||||
user.id,
|
||||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
const getConnectionState = (serverName) =>
|
||||
appConnections.get(serverName)?.connectionState ??
|
||||
userConnections.get(serverName)?.connectionState ??
|
||||
'disconnected';
|
||||
|
||||
const baseConnectionState = getConnectionState(serverName);
|
||||
|
||||
let hasActiveOAuthFlow = false;
|
||||
let hasFailedOAuthFlow = false;
|
||||
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
try {
|
||||
// Check for user-specific OAuth flows
|
||||
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (flowState) {
|
||||
// Check if flow failed or timed out
|
||||
const flowAge = Date.now() - flowState.createdAt;
|
||||
const flowTTL = flowState.ttl || 180000; // Default 3 minutes
|
||||
|
||||
if (flowState.status === 'FAILED' || flowAge > flowTTL) {
|
||||
hasFailedOAuthFlow = true;
|
||||
logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
status: flowState.status,
|
||||
flowAge,
|
||||
flowTTL,
|
||||
timedOut: flowAge > flowTTL,
|
||||
});
|
||||
} else if (flowState.status === 'PENDING') {
|
||||
hasActiveOAuthFlow = true;
|
||||
logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
flowAge,
|
||||
flowTTL,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[MCP Connection Status] Error checking OAuth flows for ${serverName}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the final connection state
|
||||
let finalConnectionState = baseConnectionState;
|
||||
if (hasFailedOAuthFlow) {
|
||||
finalConnectionState = 'error'; // Report as error if OAuth failed
|
||||
} else if (hasActiveOAuthFlow && baseConnectionState === 'disconnected') {
|
||||
finalConnectionState = 'connecting'; // Still waiting for OAuth
|
||||
}
|
||||
|
||||
connectionStatus[serverName] = {
|
||||
requiresOAuth: oauthServers.has(serverName),
|
||||
connectionState: finalConnectionState,
|
||||
};
|
||||
}
|
||||
|
||||
res.json({
|
||||
@@ -481,63 +563,11 @@ router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
||||
connectionStatus,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error('[MCP Connection Status] Failed to get connection status', error);
|
||||
res.status(500).json({ error: 'Failed to get connection status' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Get connection status for a single MCP server
|
||||
* This endpoint returns the connection status for a specific server for a given user
|
||||
*/
|
||||
router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const user = req.user;
|
||||
const { serverName } = req.params;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
);
|
||||
|
||||
if (!mcpConfig[serverName]) {
|
||||
return res
|
||||
.status(404)
|
||||
.json({ error: `MCP server '${serverName}' not found in configuration` });
|
||||
}
|
||||
|
||||
const serverStatus = await getServerConnectionStatus(
|
||||
user.id,
|
||||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
serverName,
|
||||
connectionStatus: serverStatus.connectionState,
|
||||
requiresOAuth: serverStatus.requiresOAuth,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error(
|
||||
`[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`,
|
||||
error,
|
||||
);
|
||||
res.status(500).json({ error: 'Failed to get connection status' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Check which authentication values exist for a specific MCP server
|
||||
* This endpoint returns only boolean flags indicating if values are set, not the actual values
|
||||
@@ -563,16 +593,19 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
|
||||
const authValueFlags = {};
|
||||
|
||||
// Check existence of saved values for each custom user variable (don't fetch actual values)
|
||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
||||
try {
|
||||
const value = await getUserPluginAuthValue(user.id, varName, false, pluginKey);
|
||||
// Only store boolean flag indicating if value exists
|
||||
authValueFlags[varName] = !!(value && value.length > 0);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[MCP Auth Value Flags] Error checking ${varName} for user ${user.id}:`,
|
||||
err,
|
||||
);
|
||||
// Default to false if we can't check
|
||||
authValueFlags[varName] = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,14 +60,7 @@ const replaceArtifactContent = (originalText, artifact, original, updated) => {
|
||||
|
||||
// Find boundaries between ARTIFACT_START and ARTIFACT_END
|
||||
const contentStart = artifactContent.indexOf('\n', artifactContent.indexOf(ARTIFACT_START)) + 1;
|
||||
let contentEnd = artifactContent.lastIndexOf(ARTIFACT_END);
|
||||
|
||||
// Special case: if contentEnd is 0, it means the only ::: found is at the start of :::artifact
|
||||
// This indicates an incomplete artifact (no closing :::)
|
||||
// We need to check that it's exactly at position 0 (the beginning of artifactContent)
|
||||
if (contentEnd === 0 && artifactContent.indexOf(ARTIFACT_START) === 0) {
|
||||
contentEnd = artifactContent.length;
|
||||
}
|
||||
const contentEnd = artifactContent.lastIndexOf(ARTIFACT_END);
|
||||
|
||||
if (contentStart === -1 || contentEnd === -1) {
|
||||
return null;
|
||||
@@ -79,20 +72,12 @@ const replaceArtifactContent = (originalText, artifact, original, updated) => {
|
||||
|
||||
// Determine where to look for the original content
|
||||
let searchStart, searchEnd;
|
||||
if (codeBlockStart !== -1) {
|
||||
// Code block starts
|
||||
if (codeBlockStart !== -1 && codeBlockEnd !== -1) {
|
||||
// If code blocks exist, search between them
|
||||
searchStart = codeBlockStart + 4; // after ```\n
|
||||
|
||||
if (codeBlockEnd !== -1 && codeBlockEnd > codeBlockStart) {
|
||||
// Code block has proper ending
|
||||
searchEnd = codeBlockEnd;
|
||||
} else {
|
||||
// No closing backticks found or they're before the opening (shouldn't happen)
|
||||
// This might be an incomplete artifact - search to contentEnd
|
||||
searchEnd = contentEnd;
|
||||
}
|
||||
searchEnd = codeBlockEnd;
|
||||
} else {
|
||||
// No code blocks at all
|
||||
// Otherwise search in the whole artifact content
|
||||
searchStart = contentStart;
|
||||
searchEnd = contentEnd;
|
||||
}
|
||||
|
||||
@@ -89,9 +89,9 @@ describe('replaceArtifactContent', () => {
|
||||
};
|
||||
|
||||
test('should replace content within artifact boundaries', () => {
|
||||
const original = "console.log('hello')";
|
||||
const original = 'console.log(\'hello\')';
|
||||
const artifact = createTestArtifact(original);
|
||||
const updated = "console.log('updated')";
|
||||
const updated = 'console.log(\'updated\')';
|
||||
|
||||
const result = replaceArtifactContent(artifact.text, artifact, original, updated);
|
||||
expect(result).toContain(updated);
|
||||
@@ -317,182 +317,4 @@ console.log(greeting);`;
|
||||
expect(result).not.toContain('\n\n```');
|
||||
expect(result).not.toContain('```\n\n');
|
||||
});
|
||||
|
||||
describe('incomplete artifacts', () => {
|
||||
test('should handle incomplete artifacts (missing closing ::: and ```)', () => {
|
||||
const original = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Pomodoro</title>
|
||||
<meta name="description" content="A single-file Pomodoro timer with logs, charts, sounds, and dark mode." />
|
||||
<style>
|
||||
:root{`;
|
||||
|
||||
const prefix = `Awesome idea! I'll deliver a complete single-file HTML app called "Pomodoro" with:
|
||||
- Custom session/break durations
|
||||
|
||||
You can save this as pomodoro.html and open it directly in your browser.
|
||||
|
||||
`;
|
||||
|
||||
// This simulates the real incomplete artifact case - no closing ``` or :::
|
||||
const incompleteArtifact = `${ARTIFACT_START}{identifier="pomodoro-single-file-app" type="text/html" title="Pomodoro — Single File App"}
|
||||
\`\`\`
|
||||
${original}`;
|
||||
|
||||
const fullText = prefix + incompleteArtifact;
|
||||
const message = { text: fullText };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
expect(artifacts[0].end).toBe(fullText.length);
|
||||
|
||||
const updated = original.replace('Pomodoro</title>', 'Pomodoro</title>UPDATED');
|
||||
const result = replaceArtifactContent(fullText, artifacts[0], original, updated);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toContain('UPDATED');
|
||||
expect(result).toContain(prefix);
|
||||
// Should not have added closing markers
|
||||
expect(result).not.toMatch(/:::\s*$/);
|
||||
});
|
||||
|
||||
test('should handle incomplete artifacts with only opening code block', () => {
|
||||
const original = 'function hello() { console.log("world"); }';
|
||||
const incompleteArtifact = `${ARTIFACT_START}{id="test"}\n\`\`\`\n${original}`;
|
||||
|
||||
const message = { text: incompleteArtifact };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
|
||||
const updated = 'function hello() { console.log("UPDATED"); }';
|
||||
const result = replaceArtifactContent(incompleteArtifact, artifacts[0], original, updated);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toContain('UPDATED');
|
||||
});
|
||||
|
||||
test('should handle incomplete artifacts without code blocks', () => {
|
||||
const original = 'Some plain text content';
|
||||
const incompleteArtifact = `${ARTIFACT_START}{id="test"}\n${original}`;
|
||||
|
||||
const message = { text: incompleteArtifact };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
|
||||
const updated = 'Some UPDATED text content';
|
||||
const result = replaceArtifactContent(incompleteArtifact, artifacts[0], original, updated);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toContain('UPDATED');
|
||||
});
|
||||
});
|
||||
|
||||
describe('regression tests for edge cases', () => {
|
||||
test('should still handle complete artifacts correctly', () => {
|
||||
// Ensure we didn't break normal artifact handling
|
||||
const original = 'console.log("test");';
|
||||
const artifact = createArtifactText({ content: original });
|
||||
|
||||
const message = { text: artifact };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
|
||||
const updated = 'console.log("updated");';
|
||||
const result = replaceArtifactContent(artifact, artifacts[0], original, updated);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toContain(updated);
|
||||
expect(result).toContain(ARTIFACT_END);
|
||||
expect(result).toMatch(/```\nconsole\.log\("updated"\);\n```/);
|
||||
});
|
||||
|
||||
test('should handle multiple complete artifacts', () => {
|
||||
// Ensure multiple artifacts still work
|
||||
const content1 = 'First artifact';
|
||||
const content2 = 'Second artifact';
|
||||
const text = `${createArtifactText({ content: content1 })}\n\n${createArtifactText({ content: content2 })}`;
|
||||
|
||||
const message = { text };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(2);
|
||||
|
||||
// Update first artifact
|
||||
const result1 = replaceArtifactContent(text, artifacts[0], content1, 'First UPDATED');
|
||||
expect(result1).not.toBeNull();
|
||||
expect(result1).toContain('First UPDATED');
|
||||
expect(result1).toContain(content2);
|
||||
|
||||
// Update second artifact
|
||||
const result2 = replaceArtifactContent(text, artifacts[1], content2, 'Second UPDATED');
|
||||
expect(result2).not.toBeNull();
|
||||
expect(result2).toContain(content1);
|
||||
expect(result2).toContain('Second UPDATED');
|
||||
});
|
||||
|
||||
test('should not mistake ::: at position 0 for artifact end in complete artifacts', () => {
|
||||
// This tests the specific fix - ensuring contentEnd=0 doesn't break complete artifacts
|
||||
const original = 'test content';
|
||||
// Create an artifact that will have ::: at position 0 when substring'd
|
||||
const artifact = `${ARTIFACT_START}\n\`\`\`\n${original}\n\`\`\`\n${ARTIFACT_END}`;
|
||||
|
||||
const message = { text: artifact };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
|
||||
const updated = 'updated content';
|
||||
const result = replaceArtifactContent(artifact, artifacts[0], original, updated);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toContain(updated);
|
||||
expect(result).toContain(ARTIFACT_END);
|
||||
});
|
||||
|
||||
test('should handle empty artifacts', () => {
|
||||
// Edge case: empty artifact
|
||||
const artifact = `${ARTIFACT_START}\n${ARTIFACT_END}`;
|
||||
|
||||
const message = { text: artifact };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
expect(artifacts).toHaveLength(1);
|
||||
|
||||
// Trying to replace non-existent content should return null
|
||||
const result = replaceArtifactContent(artifact, artifacts[0], 'something', 'updated');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
test('should preserve whitespace and formatting in complete artifacts', () => {
|
||||
const original = ` function test() {
|
||||
return {
|
||||
value: 42
|
||||
};
|
||||
}`;
|
||||
const artifact = createArtifactText({ content: original });
|
||||
|
||||
const message = { text: artifact };
|
||||
const artifacts = findAllArtifacts(message);
|
||||
|
||||
const updated = ` function test() {
|
||||
return {
|
||||
value: 100
|
||||
};
|
||||
}`;
|
||||
const result = replaceArtifactContent(artifact, artifacts[0], original, updated);
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toContain('value: 100');
|
||||
// Should preserve exact formatting
|
||||
expect(result).toMatch(
|
||||
/```\n {2}function test\(\) \{\n {4}return \{\n {6}value: 100\n {4}\};\n {2}\}\n```/,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -281,7 +281,7 @@ function createInProgressHandler(openai, thread_id, messages) {
|
||||
|
||||
openai.seenCompletedMessages.add(message_id);
|
||||
|
||||
const message = await openai.beta.threads.messages.retrieve(message_id, { thread_id });
|
||||
const message = await openai.beta.threads.messages.retrieve(thread_id, message_id);
|
||||
if (!message?.content?.length) {
|
||||
return;
|
||||
}
|
||||
@@ -435,11 +435,9 @@ async function runAssistant({
|
||||
};
|
||||
});
|
||||
|
||||
const tool_outputs = await processRequiredActions(openai, actions);
|
||||
const toolRun = await openai.beta.threads.runs.submitToolOutputs(run.id, {
|
||||
thread_id: run.thread_id,
|
||||
tool_outputs,
|
||||
});
|
||||
const outputs = await processRequiredActions(openai, actions);
|
||||
|
||||
const toolRun = await openai.beta.threads.runs.submitToolOutputs(run.thread_id, run.id, outputs);
|
||||
|
||||
// Recursive call with accumulated steps and messages
|
||||
return await runAssistant({
|
||||
|
||||
@@ -6,7 +6,7 @@ const {
|
||||
getUserKeyExpiry,
|
||||
checkUserKeyExpiry,
|
||||
} = require('~/server/services/UserService');
|
||||
const OAIClient = require('~/app/clients/OpenAIClient');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, version, initAppClient = false }) => {
|
||||
@@ -79,7 +79,7 @@ const initializeClient = async ({ req, res, endpointOption, version, initAppClie
|
||||
openai.res = res;
|
||||
|
||||
if (endpointOption && initAppClient) {
|
||||
const client = new OAIClient(apiKey, clientOptions);
|
||||
const client = new OpenAIClient(apiKey, clientOptions);
|
||||
return {
|
||||
client,
|
||||
openai,
|
||||
|
||||
@@ -3,11 +3,11 @@ const { ProxyAgent } = require('undici');
|
||||
const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api');
|
||||
const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider');
|
||||
const {
|
||||
checkUserKeyExpiry,
|
||||
getUserKeyValues,
|
||||
getUserKeyExpiry,
|
||||
checkUserKeyExpiry,
|
||||
} = require('~/server/services/UserService');
|
||||
const OAIClient = require('~/app/clients/OpenAIClient');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
|
||||
class Files {
|
||||
constructor(client) {
|
||||
@@ -109,15 +109,14 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
|
||||
|
||||
apiKey = azureOptions.azureOpenAIApiKey;
|
||||
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
|
||||
opts.defaultHeaders = resolveHeaders({
|
||||
headers: {
|
||||
opts.defaultHeaders = resolveHeaders(
|
||||
{
|
||||
...headers,
|
||||
'api-key': apiKey,
|
||||
'OpenAI-Beta': `assistants=${version}`,
|
||||
},
|
||||
user: req.user,
|
||||
body: req.body,
|
||||
});
|
||||
req.user,
|
||||
);
|
||||
opts.model = azureOptions.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (initAppClient) {
|
||||
@@ -185,7 +184,7 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
|
||||
}
|
||||
|
||||
if (endpointOption && initAppClient) {
|
||||
const client = new OAIClient(apiKey, clientOptions);
|
||||
const client = new OpenAIClient(apiKey, clientOptions);
|
||||
return {
|
||||
client,
|
||||
openai,
|
||||
|
||||
@@ -28,11 +28,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
|
||||
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
|
||||
|
||||
let resolvedHeaders = resolveHeaders({
|
||||
headers: endpointConfig.headers,
|
||||
user: req.user,
|
||||
body: req.body,
|
||||
});
|
||||
let resolvedHeaders = resolveHeaders(endpointConfig.headers, req.user);
|
||||
|
||||
if (CUSTOM_API_KEY.match(envVarRegex)) {
|
||||
throw new Error(`Missing API Key for ${endpoint}.`);
|
||||
|
||||
@@ -64,14 +64,13 @@ describe('custom/initializeClient', () => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('calls resolveHeaders with headers, user, and body for body placeholder support', async () => {
|
||||
it('calls resolveHeaders with headers and user', async () => {
|
||||
const { resolveHeaders } = require('@librechat/api');
|
||||
await initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true });
|
||||
expect(resolveHeaders).toHaveBeenCalledWith({
|
||||
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
user: { id: 'user-123', email: 'test@example.com' },
|
||||
body: { endpoint: 'test-endpoint' }, // body - supports {{LIBRECHAT_BODY_*}} placeholders
|
||||
});
|
||||
expect(resolveHeaders).toHaveBeenCalledWith(
|
||||
{ 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
{ id: 'user-123', email: 'test@example.com' },
|
||||
);
|
||||
});
|
||||
|
||||
it('throws if endpoint config is missing', async () => {
|
||||
|
||||
@@ -81,11 +81,10 @@ const initializeClient = async ({
|
||||
serverless = _serverless;
|
||||
|
||||
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = resolveHeaders({
|
||||
headers: { ...headers, ...(clientOptions.headers ?? {}) },
|
||||
user: req.user,
|
||||
body: req.body,
|
||||
});
|
||||
clientOptions.headers = resolveHeaders(
|
||||
{ ...headers, ...(clientOptions.headers ?? {}) },
|
||||
req.user,
|
||||
);
|
||||
|
||||
clientOptions.titleConvo = azureConfig.titleConvo;
|
||||
clientOptions.titleModel = azureConfig.titleModel;
|
||||
|
||||
@@ -12,7 +12,7 @@ const {
|
||||
} = require('@librechat/api');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, loadCustomConfig } = require('./Config');
|
||||
const { getCachedTools } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
@@ -189,7 +189,6 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
},
|
||||
oauthStart,
|
||||
oauthEnd,
|
||||
body: req.body,
|
||||
});
|
||||
|
||||
if (isAssistantsEndpoint(provider) && Array.isArray(result)) {
|
||||
@@ -240,135 +239,6 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
return toolInstance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get MCP setup data including config, connections, and OAuth servers
|
||||
* @param {string} userId - The user ID
|
||||
* @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers
|
||||
*/
|
||||
async function getMCPSetupData(userId) {
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
const mcpConfig = config?.mcpServers;
|
||||
|
||||
if (!mcpConfig) {
|
||||
throw new Error('MCP config not found');
|
||||
}
|
||||
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const appConnections = mcpManager.getAllConnections() || new Map();
|
||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||
const oauthServers = mcpManager.getOAuthServers() || new Set();
|
||||
|
||||
return {
|
||||
mcpConfig,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check OAuth flow status for a user and server
|
||||
* @param {string} userId - The user ID
|
||||
* @param {string} serverName - The server name
|
||||
* @returns {Object} Object containing hasActiveFlow and hasFailedFlow flags
|
||||
*/
|
||||
async function checkOAuthFlowStatus(userId, serverName) {
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
|
||||
try {
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (!flowState) {
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
}
|
||||
|
||||
const flowAge = Date.now() - flowState.createdAt;
|
||||
const flowTTL = flowState.ttl || 180000; // Default 3 minutes
|
||||
|
||||
if (flowState.status === 'FAILED' || flowAge > flowTTL) {
|
||||
const wasCancelled = flowState.error && flowState.error.includes('cancelled');
|
||||
|
||||
if (wasCancelled) {
|
||||
logger.debug(`[MCP Connection Status] Found cancelled OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
status: flowState.status,
|
||||
error: flowState.error,
|
||||
});
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
} else {
|
||||
logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
status: flowState.status,
|
||||
flowAge,
|
||||
flowTTL,
|
||||
timedOut: flowAge > flowTTL,
|
||||
error: flowState.error,
|
||||
});
|
||||
return { hasActiveFlow: false, hasFailedFlow: true };
|
||||
}
|
||||
}
|
||||
|
||||
if (flowState.status === 'PENDING') {
|
||||
logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
flowAge,
|
||||
flowTTL,
|
||||
});
|
||||
return { hasActiveFlow: true, hasFailedFlow: false };
|
||||
}
|
||||
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Connection Status] Error checking OAuth flows for ${serverName}:`, error);
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection status for a specific MCP server
|
||||
* @param {string} userId - The user ID
|
||||
* @param {string} serverName - The server name
|
||||
* @param {Map} appConnections - App-level connections
|
||||
* @param {Map} userConnections - User-level connections
|
||||
* @param {Set} oauthServers - Set of OAuth servers
|
||||
* @returns {Object} Object containing requiresOAuth and connectionState
|
||||
*/
|
||||
async function getServerConnectionStatus(
|
||||
userId,
|
||||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
) {
|
||||
const getConnectionState = () =>
|
||||
appConnections.get(serverName)?.connectionState ??
|
||||
userConnections.get(serverName)?.connectionState ??
|
||||
'disconnected';
|
||||
|
||||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
requiresOAuth: oauthServers.has(serverName),
|
||||
connectionState: finalConnectionState,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createMCPTool,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
};
|
||||
|
||||
@@ -1,510 +0,0 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getMCPSetupData, checkOAuthFlowStatus, getServerConnectionStatus } = require('./MCP');
|
||||
|
||||
// Mock all dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
MCPOAuthHandler: {
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
CacheKeys: {
|
||||
FLOWS: 'flows',
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('./Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
findToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
updateToken: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('tests for the new helper functions used by the MCP connection status endpoints', () => {
|
||||
let mockLoadCustomConfig;
|
||||
let mockGetMCPManager;
|
||||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockLoadCustomConfig = require('./Config').loadCustomConfig;
|
||||
mockGetMCPManager = require('~/config').getMCPManager;
|
||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
const mockUserId = 'user-123';
|
||||
const mockConfig = {
|
||||
mcpServers: {
|
||||
server1: { type: 'stdio' },
|
||||
server2: { type: 'http' },
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockGetMCPManager.mockReturnValue({
|
||||
getAllConnections: jest.fn(() => new Map()),
|
||||
getUserConnections: jest.fn(() => new Map()),
|
||||
getOAuthServers: jest.fn(() => new Set()),
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully return MCP setup data', async () => {
|
||||
mockLoadCustomConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const mockAppConnections = new Map([['server1', { status: 'connected' }]]);
|
||||
const mockUserConnections = new Map([['server2', { status: 'disconnected' }]]);
|
||||
const mockOAuthServers = new Set(['server2']);
|
||||
|
||||
const mockMCPManager = {
|
||||
getAllConnections: jest.fn(() => mockAppConnections),
|
||||
getUserConnections: jest.fn(() => mockUserConnections),
|
||||
getOAuthServers: jest.fn(() => mockOAuthServers),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
expect(mockLoadCustomConfig).toHaveBeenCalledWith(false);
|
||||
expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.getAllConnections).toHaveBeenCalled();
|
||||
expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId);
|
||||
expect(mockMCPManager.getOAuthServers).toHaveBeenCalled();
|
||||
|
||||
expect(result).toEqual({
|
||||
mcpConfig: mockConfig.mcpServers,
|
||||
appConnections: mockAppConnections,
|
||||
userConnections: mockUserConnections,
|
||||
oauthServers: mockOAuthServers,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error when MCP config not found', async () => {
|
||||
mockLoadCustomConfig.mockResolvedValue({});
|
||||
await expect(getMCPSetupData(mockUserId)).rejects.toThrow('MCP config not found');
|
||||
});
|
||||
|
||||
it('should handle null values from MCP manager gracefully', async () => {
|
||||
mockLoadCustomConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const mockMCPManager = {
|
||||
getAllConnections: jest.fn(() => null),
|
||||
getUserConnections: jest.fn(() => null),
|
||||
getOAuthServers: jest.fn(() => null),
|
||||
};
|
||||
mockGetMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
const result = await getMCPSetupData(mockUserId);
|
||||
|
||||
expect(result).toEqual({
|
||||
mcpConfig: mockConfig.mcpServers,
|
||||
appConnections: new Map(),
|
||||
userConnections: new Map(),
|
||||
oauthServers: new Set(),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('checkOAuthFlowStatus', () => {
|
||||
const mockUserId = 'user-123';
|
||||
const mockServerName = 'test-server';
|
||||
const mockFlowId = 'flow-123';
|
||||
|
||||
beforeEach(() => {
|
||||
const mockFlowsCache = {};
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
};
|
||||
|
||||
mockGetLogStores.mockReturnValue(mockFlowsCache);
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
MCPOAuthHandler.generateFlowId.mockReturnValue(mockFlowId);
|
||||
});
|
||||
|
||||
it('should return false flags when no flow state exists', async () => {
|
||||
const mockFlowManager = { getFlowState: jest.fn(() => null) };
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(mockGetLogStores).toHaveBeenCalledWith(CacheKeys.FLOWS);
|
||||
expect(MCPOAuthHandler.generateFlowId).toHaveBeenCalledWith(mockUserId, mockServerName);
|
||||
expect(mockFlowManager.getFlowState).toHaveBeenCalledWith(mockFlowId, 'mcp_oauth');
|
||||
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
|
||||
});
|
||||
|
||||
it('should detect failed flow when status is FAILED', async () => {
|
||||
const mockFlowState = {
|
||||
status: 'FAILED',
|
||||
createdAt: Date.now() - 60000, // 1 minute ago
|
||||
ttl: 180000,
|
||||
};
|
||||
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
|
||||
expect(logger.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Found failed OAuth flow'),
|
||||
expect.objectContaining({
|
||||
flowId: mockFlowId,
|
||||
status: 'FAILED',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should detect failed flow when flow has timed out', async () => {
|
||||
const mockFlowState = {
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now() - 200000, // 200 seconds ago (> 180s TTL)
|
||||
ttl: 180000,
|
||||
};
|
||||
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
|
||||
expect(logger.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Found failed OAuth flow'),
|
||||
expect.objectContaining({
|
||||
timedOut: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should detect failed flow when TTL not specified and flow exceeds default TTL', async () => {
|
||||
const mockFlowState = {
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now() - 200000, // 200 seconds ago (> 180s default TTL)
|
||||
// ttl not specified, should use 180000 default
|
||||
};
|
||||
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: true });
|
||||
});
|
||||
|
||||
it('should detect active flow when status is PENDING and within TTL', async () => {
|
||||
const mockFlowState = {
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now() - 60000, // 1 minute ago (< 180s TTL)
|
||||
ttl: 180000,
|
||||
};
|
||||
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(result).toEqual({ hasActiveFlow: true, hasFailedFlow: false });
|
||||
expect(logger.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Found active OAuth flow'),
|
||||
expect.objectContaining({
|
||||
flowId: mockFlowId,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false flags for other statuses', async () => {
|
||||
const mockFlowState = {
|
||||
status: 'COMPLETED',
|
||||
createdAt: Date.now() - 60000,
|
||||
ttl: 180000,
|
||||
};
|
||||
const mockFlowManager = { getFlowState: jest.fn(() => mockFlowState) };
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
const mockError = new Error('Flow state error');
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => {
|
||||
throw mockError;
|
||||
}),
|
||||
};
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
|
||||
const result = await checkOAuthFlowStatus(mockUserId, mockServerName);
|
||||
|
||||
expect(result).toEqual({ hasActiveFlow: false, hasFailedFlow: false });
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error checking OAuth flows'),
|
||||
mockError,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getServerConnectionStatus', () => {
|
||||
const mockUserId = 'user-123';
|
||||
const mockServerName = 'test-server';
|
||||
|
||||
it('should return app connection state when available', async () => {
|
||||
const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]);
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set();
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: false,
|
||||
connectionState: 'connected',
|
||||
});
|
||||
});
|
||||
|
||||
it('should fallback to user connection state when app connection not available', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map([[mockServerName, { connectionState: 'connecting' }]]);
|
||||
const oauthServers = new Set();
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: false,
|
||||
connectionState: 'connecting',
|
||||
});
|
||||
});
|
||||
|
||||
it('should default to disconnected when no connections exist', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set();
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: false,
|
||||
connectionState: 'disconnected',
|
||||
});
|
||||
});
|
||||
|
||||
it('should prioritize app connection over user connection', async () => {
|
||||
const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]);
|
||||
const userConnections = new Map([[mockServerName, { connectionState: 'disconnected' }]]);
|
||||
const oauthServers = new Set();
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: false,
|
||||
connectionState: 'connected',
|
||||
});
|
||||
});
|
||||
|
||||
it('should indicate OAuth requirement when server is in OAuth servers set', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result.requiresOAuth).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle OAuth flow status when disconnected and requires OAuth with failed flow', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock flow state to return failed flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
status: 'FAILED',
|
||||
createdAt: Date.now() - 60000,
|
||||
ttl: 180000,
|
||||
})),
|
||||
};
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
mockGetLogStores.mockReturnValue({});
|
||||
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'error',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle OAuth flow status when disconnected and requires OAuth with active flow', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock flow state to return active flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
status: 'PENDING',
|
||||
createdAt: Date.now() - 60000, // 1 minute ago
|
||||
ttl: 180000, // 3 minutes TTL
|
||||
})),
|
||||
};
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
mockGetLogStores.mockReturnValue({});
|
||||
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'connecting',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle OAuth flow status when disconnected and requires OAuth with no flow', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock flow state to return no flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => null),
|
||||
};
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
mockGetLogStores.mockReturnValue({});
|
||||
MCPOAuthHandler.generateFlowId.mockReturnValue('test-flow-id');
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'disconnected',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not check OAuth flow status when server is connected', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
};
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
mockGetLogStores.mockReturnValue({});
|
||||
|
||||
const appConnections = new Map([[mockServerName, { connectionState: 'connected' }]]);
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'connected',
|
||||
});
|
||||
|
||||
// Should not call flow manager since server is connected
|
||||
expect(mockFlowManager.getFlowState).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not check OAuth flow status when server does not require OAuth', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
};
|
||||
mockGetFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
mockGetLogStores.mockReturnValue({});
|
||||
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set(); // Server not in OAuth servers
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: false,
|
||||
connectionState: 'disconnected',
|
||||
});
|
||||
|
||||
// Should not call flow manager since server doesn't require OAuth
|
||||
expect(mockFlowManager.getFlowState).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -91,10 +91,11 @@ class RunManager {
|
||||
* @param {boolean} [params.final] - The end of the run polling loop, due to `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, or `expired` statuses.
|
||||
*/
|
||||
async fetchRunSteps({ openai, thread_id, run_id, runStatus, final = false }) {
|
||||
// const { data: steps, first_id, last_id, has_more } = await openai.beta.threads.runs.steps.list(run_id, { thread_id });
|
||||
// const { data: steps, first_id, last_id, has_more } = await openai.beta.threads.runs.steps.list(thread_id, run_id);
|
||||
const { data: _steps } = await openai.beta.threads.runs.steps.list(
|
||||
thread_id,
|
||||
run_id,
|
||||
{ thread_id },
|
||||
{},
|
||||
{
|
||||
timeout: 3000,
|
||||
maxRetries: 5,
|
||||
|
||||
@@ -573,9 +573,9 @@ class StreamRunManager {
|
||||
let toolRun;
|
||||
try {
|
||||
toolRun = this.openai.beta.threads.runs.submitToolOutputsStream(
|
||||
run.thread_id,
|
||||
run.id,
|
||||
{
|
||||
thread_id: run.thread_id,
|
||||
tool_outputs,
|
||||
stream: true,
|
||||
},
|
||||
|
||||
@@ -179,7 +179,7 @@ async function waitForRun({
|
||||
* @return {Promise<RunStep[]>} A promise that resolves to an array of RunStep objects.
|
||||
*/
|
||||
async function _retrieveRunSteps({ openai, thread_id, run_id }) {
|
||||
const runSteps = await openai.beta.threads.runs.steps.list(run_id, { thread_id });
|
||||
const runSteps = await openai.beta.threads.runs.steps.list(thread_id, run_id);
|
||||
return runSteps;
|
||||
}
|
||||
|
||||
|
||||
@@ -192,8 +192,7 @@ async function addThreadMetadata({ openai, thread_id, messageId, messages }) {
|
||||
const promises = [];
|
||||
for (const message of messages) {
|
||||
promises.push(
|
||||
openai.beta.threads.messages.update(message.id, {
|
||||
thread_id,
|
||||
openai.beta.threads.messages.update(thread_id, message.id, {
|
||||
metadata: {
|
||||
messageId,
|
||||
},
|
||||
@@ -264,8 +263,7 @@ async function syncMessages({
|
||||
}
|
||||
|
||||
modifyPromises.push(
|
||||
openai.beta.threads.messages.update(apiMessage.id, {
|
||||
thread_id,
|
||||
openai.beta.threads.messages.update(thread_id, apiMessage.id, {
|
||||
metadata: {
|
||||
messageId: dbMessage.messageId,
|
||||
},
|
||||
@@ -415,7 +413,7 @@ async function checkMessageGaps({
|
||||
}) {
|
||||
const promises = [];
|
||||
promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery));
|
||||
promises.push(openai.beta.threads.runs.steps.list(run_id, { thread_id }));
|
||||
promises.push(openai.beta.threads.runs.steps.list(thread_id, run_id));
|
||||
/** @type {[{ data: ThreadMessage[] }, { data: RunStep[] }]} */
|
||||
const [response, stepsResponse] = await Promise.all(promises);
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { getToolkitKey } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { zodToJsonSchema } = require('zod-to-json-schema');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
@@ -12,6 +11,7 @@ const {
|
||||
ErrorTypes,
|
||||
ContentTypes,
|
||||
imageGenTools,
|
||||
EToolResources,
|
||||
EModelEndpoint,
|
||||
actionDelimiter,
|
||||
ImageVisionTool,
|
||||
@@ -40,6 +40,30 @@ const { recordUsage } = require('~/server/services/Threads');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { redactMessage } = require('~/config/parsers');
|
||||
|
||||
/**
|
||||
* @param {string} toolName
|
||||
* @returns {string | undefined} toolKey
|
||||
*/
|
||||
function getToolkitKey(toolName) {
|
||||
/** @type {string|undefined} */
|
||||
let toolkitKey;
|
||||
for (const toolkit of toolkits) {
|
||||
if (toolName.startsWith(EToolResources.image_edit)) {
|
||||
const splitMatches = toolkit.pluginKey.split('_');
|
||||
const suffix = splitMatches[splitMatches.length - 1];
|
||||
if (toolName.endsWith(suffix)) {
|
||||
toolkitKey = toolkit.pluginKey;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (toolName.startsWith(toolkit.pluginKey)) {
|
||||
toolkitKey = toolkit.pluginKey;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return toolkitKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads and formats tools from the specified tool directory.
|
||||
*
|
||||
@@ -121,7 +145,7 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
|
||||
for (const toolInstance of basicToolInstances) {
|
||||
const formattedTool = formatToOpenAIAssistantTool(toolInstance);
|
||||
let toolName = formattedTool[Tools.function].name;
|
||||
toolName = getToolkitKey({ toolkits, toolName }) ?? toolName;
|
||||
toolName = getToolkitKey(toolName) ?? toolName;
|
||||
if (filter.has(toolName) && included.size === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -104,14 +104,6 @@ class CustomOpenIDStrategy extends OpenIDStrategy {
|
||||
if (options?.state && !params.has('state')) {
|
||||
params.set('state', options.state);
|
||||
}
|
||||
|
||||
if (process.env.OPENID_AUDIENCE) {
|
||||
params.set('audience', process.env.OPENID_AUDIENCE);
|
||||
logger.debug(
|
||||
`[openidStrategy] Adding audience to authorization request: ${process.env.OPENID_AUDIENCE}`,
|
||||
);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
}
|
||||
@@ -361,7 +353,7 @@ async function setupOpenId() {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
userinfo.username || userinfo.given_name || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,9 @@ jest.mock('openid-client', () => {
|
||||
}),
|
||||
fetchUserInfo: jest.fn().mockImplementation((config, accessToken, sub) => {
|
||||
// Only return additional properties, but don't override any claims
|
||||
return Promise.resolve({});
|
||||
return Promise.resolve({
|
||||
preferred_username: 'preferred_username',
|
||||
});
|
||||
}),
|
||||
customFetch: Symbol('customFetch'),
|
||||
};
|
||||
@@ -102,7 +104,6 @@ describe('setupOpenId', () => {
|
||||
given_name: 'First',
|
||||
family_name: 'Last',
|
||||
name: 'My Full',
|
||||
preferred_username: 'testusername',
|
||||
username: 'flast',
|
||||
picture: 'https://example.com/avatar.png',
|
||||
}),
|
||||
@@ -155,20 +156,20 @@ describe('setupOpenId', () => {
|
||||
verifyCallback = require('openid-client/passport').__getVerifyCallback();
|
||||
});
|
||||
|
||||
it('should create a new user with correct username when preferred_username claim exists', async () => {
|
||||
// Arrange – our userinfo already has preferred_username 'testusername'
|
||||
it('should create a new user with correct username when username claim exists', async () => {
|
||||
// Arrange – our userinfo already has username 'flast'
|
||||
const userinfo = tokenset.claims();
|
||||
|
||||
// Act
|
||||
const { user } = await validate(tokenset);
|
||||
|
||||
// Assert
|
||||
expect(user.username).toBe(userinfo.preferred_username);
|
||||
expect(user.username).toBe(userinfo.username);
|
||||
expect(createUser).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username: userinfo.preferred_username,
|
||||
username: userinfo.username,
|
||||
email: userinfo.email,
|
||||
name: `${userinfo.given_name} ${userinfo.family_name}`,
|
||||
}),
|
||||
@@ -178,12 +179,12 @@ describe('setupOpenId', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should use username as username when preferred_username claim is missing', async () => {
|
||||
// Arrange – remove preferred_username from userinfo
|
||||
it('should use given_name as username when username claim is missing', async () => {
|
||||
// Arrange – remove username from userinfo
|
||||
const userinfo = { ...tokenset.claims() };
|
||||
delete userinfo.preferred_username;
|
||||
// Expect the username to be the "username"
|
||||
const expectUsername = userinfo.username;
|
||||
delete userinfo.username;
|
||||
// Expect the username to be the given name (unchanged case)
|
||||
const expectUsername = userinfo.given_name;
|
||||
|
||||
// Act
|
||||
const { user } = await validate({ ...tokenset, claims: () => userinfo });
|
||||
@@ -198,11 +199,11 @@ describe('setupOpenId', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should use email as username when username and preferred_username are missing', async () => {
|
||||
// Arrange – remove username and preferred_username
|
||||
it('should use email as username when username and given_name are missing', async () => {
|
||||
// Arrange – remove username and given_name
|
||||
const userinfo = { ...tokenset.claims() };
|
||||
delete userinfo.username;
|
||||
delete userinfo.preferred_username;
|
||||
delete userinfo.given_name;
|
||||
const expectUsername = userinfo.email;
|
||||
|
||||
// Act
|
||||
@@ -288,7 +289,7 @@ describe('setupOpenId', () => {
|
||||
expect.objectContaining({
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username: userinfo.preferred_username,
|
||||
username: userinfo.username,
|
||||
name: `${userinfo.given_name} ${userinfo.family_name}`,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -19,9 +19,6 @@ const openAIModels = {
|
||||
'gpt-4.1': 1047576,
|
||||
'gpt-4.1-mini': 1047576,
|
||||
'gpt-4.1-nano': 1047576,
|
||||
'gpt-5': 400000,
|
||||
'gpt-5-mini': 400000,
|
||||
'gpt-5-nano': 400000,
|
||||
'gpt-4o': 127500, // -500 from max
|
||||
'gpt-4o-mini': 127500, // -500 from max
|
||||
'gpt-4o-2024-05-13': 127500, // -500 from max
|
||||
@@ -199,7 +196,6 @@ const amazonModels = {
|
||||
'amazon.nova-micro-v1:0': 127000, // -1000 from max,
|
||||
'amazon.nova-lite-v1:0': 295000, // -5000 from max,
|
||||
'amazon.nova-pro-v1:0': 295000, // -5000 from max,
|
||||
'amazon.nova-premier-v1:0': 995000, // -5000 from max,
|
||||
};
|
||||
|
||||
const bedrockModels = {
|
||||
@@ -237,9 +233,6 @@ const aggregateModels = {
|
||||
...xAIModels,
|
||||
// misc.
|
||||
kimi: 131000,
|
||||
// GPT-OSS
|
||||
'gpt-oss-20b': 131000,
|
||||
'gpt-oss-120b': 131000,
|
||||
};
|
||||
|
||||
const maxTokensMap = {
|
||||
@@ -256,11 +249,6 @@ const modelMaxOutputs = {
|
||||
o1: 32268, // -500 from max: 32,768
|
||||
'o1-mini': 65136, // -500 from max: 65,536
|
||||
'o1-preview': 32268, // -500 from max: 32,768
|
||||
'gpt-5': 128000,
|
||||
'gpt-5-mini': 128000,
|
||||
'gpt-5-nano': 128000,
|
||||
'gpt-oss-20b': 131000,
|
||||
'gpt-oss-120b': 131000,
|
||||
system_default: 1024,
|
||||
};
|
||||
|
||||
@@ -479,11 +467,10 @@ const tiktokenModels = new Set([
|
||||
]);
|
||||
|
||||
module.exports = {
|
||||
tiktokenModels,
|
||||
maxTokensMap,
|
||||
inputSchema,
|
||||
modelSchema,
|
||||
maxTokensMap,
|
||||
tiktokenModels,
|
||||
maxOutputTokensMap,
|
||||
matchModelName,
|
||||
processModelData,
|
||||
getModelMaxTokens,
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
maxOutputTokensMap,
|
||||
getModelMaxTokens,
|
||||
processModelData,
|
||||
matchModelName,
|
||||
maxTokensMap,
|
||||
} = require('./tokens');
|
||||
const { getModelMaxTokens, processModelData, matchModelName, maxTokensMap } = require('./tokens');
|
||||
|
||||
describe('getModelMaxTokens', () => {
|
||||
test('should return correct tokens for exact match', () => {
|
||||
@@ -156,35 +150,6 @@ describe('getModelMaxTokens', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct tokens for gpt-5 matches', () => {
|
||||
expect(getModelMaxTokens('gpt-5')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-5']);
|
||||
expect(getModelMaxTokens('gpt-5-preview')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-5']);
|
||||
expect(getModelMaxTokens('openai/gpt-5')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-5']);
|
||||
expect(getModelMaxTokens('gpt-5-2025-01-30')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['gpt-5'],
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct tokens for gpt-5-mini matches', () => {
|
||||
expect(getModelMaxTokens('gpt-5-mini')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-5-mini']);
|
||||
expect(getModelMaxTokens('gpt-5-mini-preview')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['gpt-5-mini'],
|
||||
);
|
||||
expect(getModelMaxTokens('openai/gpt-5-mini')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['gpt-5-mini'],
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct tokens for gpt-5-nano matches', () => {
|
||||
expect(getModelMaxTokens('gpt-5-nano')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-5-nano']);
|
||||
expect(getModelMaxTokens('gpt-5-nano-preview')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['gpt-5-nano'],
|
||||
);
|
||||
expect(getModelMaxTokens('openai/gpt-5-nano')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['gpt-5-nano'],
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct tokens for Anthropic models', () => {
|
||||
const models = [
|
||||
'claude-2.1',
|
||||
@@ -384,39 +349,6 @@ describe('getModelMaxTokens', () => {
|
||||
expect(getModelMaxTokens('o3')).toBe(o3Tokens);
|
||||
expect(getModelMaxTokens('openai/o3')).toBe(o3Tokens);
|
||||
});
|
||||
|
||||
test('should return correct tokens for GPT-OSS models', () => {
|
||||
const expected = maxTokensMap[EModelEndpoint.openAI]['gpt-oss-20b'];
|
||||
['gpt-oss-20b', 'gpt-oss-120b', 'openai/gpt-oss-20b', 'openai/gpt-oss-120b'].forEach((name) => {
|
||||
expect(getModelMaxTokens(name)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for GPT-5 models', () => {
|
||||
const { getModelMaxOutputTokens } = require('./tokens');
|
||||
['gpt-5', 'gpt-5-mini', 'gpt-5-nano'].forEach((model) => {
|
||||
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
|
||||
maxOutputTokensMap[EModelEndpoint.openAI][model],
|
||||
);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.azureOpenAI)).toBe(
|
||||
maxOutputTokensMap[EModelEndpoint.azureOpenAI][model],
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for GPT-OSS models', () => {
|
||||
const { getModelMaxOutputTokens } = require('./tokens');
|
||||
['gpt-oss-20b', 'gpt-oss-120b'].forEach((model) => {
|
||||
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
|
||||
maxOutputTokensMap[EModelEndpoint.openAI][model],
|
||||
);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.azureOpenAI)).toBe(
|
||||
maxOutputTokensMap[EModelEndpoint.azureOpenAI][model],
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('matchModelName', () => {
|
||||
@@ -488,25 +420,6 @@ describe('matchModelName', () => {
|
||||
expect(matchModelName('gpt-4.1-nano-2024-08-06')).toBe('gpt-4.1-nano');
|
||||
});
|
||||
|
||||
it('should return the closest matching key for gpt-5 matches', () => {
|
||||
expect(matchModelName('openai/gpt-5')).toBe('gpt-5');
|
||||
expect(matchModelName('gpt-5-preview')).toBe('gpt-5');
|
||||
expect(matchModelName('gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(matchModelName('gpt-5-2025-01-30-0130')).toBe('gpt-5');
|
||||
});
|
||||
|
||||
it('should return the closest matching key for gpt-5-mini matches', () => {
|
||||
expect(matchModelName('openai/gpt-5-mini')).toBe('gpt-5-mini');
|
||||
expect(matchModelName('gpt-5-mini-preview')).toBe('gpt-5-mini');
|
||||
expect(matchModelName('gpt-5-mini-2025-01-30')).toBe('gpt-5-mini');
|
||||
});
|
||||
|
||||
it('should return the closest matching key for gpt-5-nano matches', () => {
|
||||
expect(matchModelName('openai/gpt-5-nano')).toBe('gpt-5-nano');
|
||||
expect(matchModelName('gpt-5-nano-preview')).toBe('gpt-5-nano');
|
||||
expect(matchModelName('gpt-5-nano-2025-01-30')).toBe('gpt-5-nano');
|
||||
});
|
||||
|
||||
// Tests for Google models
|
||||
it('should return the exact model name if it exists in maxTokensMap - Google models', () => {
|
||||
expect(matchModelName('text-bison-32k', EModelEndpoint.google)).toBe('text-bison-32k');
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "v0.8.0-rc1",
|
||||
"version": "v0.7.9",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
@@ -57,8 +57,8 @@
|
||||
"@react-spring/web": "^9.7.5",
|
||||
"@tanstack/react-query": "^4.28.0",
|
||||
"@tanstack/react-table": "^8.11.7",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"class-variance-authority": "^0.6.0",
|
||||
"clsx": "^1.2.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"cross-env": "^7.0.3",
|
||||
"date-fns": "^3.3.1",
|
||||
@@ -76,7 +76,7 @@
|
||||
"librechat-data-provider": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.394.0",
|
||||
"match-sorter": "^8.1.0",
|
||||
"match-sorter": "^6.3.4",
|
||||
"micromark-extension-llm-math": "^3.1.0",
|
||||
"qrcode.react": "^4.2.0",
|
||||
"rc-input-number": "^7.4.2",
|
||||
|
||||
@@ -7,7 +7,6 @@ import { ReactQueryDevtools } from '@tanstack/react-query-devtools';
|
||||
import { Toast, ThemeProvider, ToastProvider } from '@librechat/client';
|
||||
import { QueryClient, QueryClientProvider, QueryCache } from '@tanstack/react-query';
|
||||
import { ScreenshotProvider, useApiErrorBoundary } from './hooks';
|
||||
import { getThemeFromEnv } from './utils/getThemeFromEnv';
|
||||
import { LiveAnnouncer } from '~/a11y';
|
||||
import { router } from './routes';
|
||||
|
||||
@@ -24,23 +23,11 @@ const App = () => {
|
||||
}),
|
||||
});
|
||||
|
||||
// Load theme from environment variables if available
|
||||
const envTheme = getThemeFromEnv();
|
||||
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<RecoilRoot>
|
||||
<LiveAnnouncer>
|
||||
<ThemeProvider
|
||||
// Only pass initialTheme and themeRGB if environment theme exists
|
||||
// This allows localStorage values to persist when no env theme is set
|
||||
{...(envTheme && { initialTheme: 'system', themeRGB: envTheme })}
|
||||
>
|
||||
{/* The ThemeProvider will automatically:
|
||||
1. Apply dark/light mode classes
|
||||
2. Apply custom theme colors if envTheme is provided
|
||||
3. Otherwise use stored theme preferences from localStorage
|
||||
4. Fall back to default theme colors if nothing is stored */}
|
||||
<ThemeProvider>
|
||||
<RadixToast.Provider>
|
||||
<ToastProvider>
|
||||
<DndProvider backend={HTML5Backend}>
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useChatContext } from './ChatContext';
|
||||
import { getLatestText } from '~/utils';
|
||||
|
||||
interface ArtifactsContextValue {
|
||||
isSubmitting: boolean;
|
||||
latestMessageId: string | null;
|
||||
latestMessageText: string;
|
||||
conversationId: string | null;
|
||||
}
|
||||
|
||||
const ArtifactsContext = createContext<ArtifactsContextValue | undefined>(undefined);
|
||||
|
||||
export function ArtifactsProvider({ children }: { children: React.ReactNode }) {
|
||||
const { isSubmitting, latestMessage, conversation } = useChatContext();
|
||||
|
||||
const latestMessageText = useMemo(() => {
|
||||
return getLatestText({
|
||||
messageId: latestMessage?.messageId ?? null,
|
||||
text: latestMessage?.text ?? null,
|
||||
content: latestMessage?.content ?? null,
|
||||
} as TMessage);
|
||||
}, [latestMessage?.messageId, latestMessage?.text, latestMessage?.content]);
|
||||
|
||||
/** Context value only created when relevant values change */
|
||||
const contextValue = useMemo<ArtifactsContextValue>(
|
||||
() => ({
|
||||
isSubmitting,
|
||||
latestMessageText,
|
||||
latestMessageId: latestMessage?.messageId ?? null,
|
||||
conversationId: conversation?.conversationId ?? null,
|
||||
}),
|
||||
[isSubmitting, latestMessage?.messageId, latestMessageText, conversation?.conversationId],
|
||||
);
|
||||
|
||||
return <ArtifactsContext.Provider value={contextValue}>{children}</ArtifactsContext.Provider>;
|
||||
}
|
||||
|
||||
export function useArtifactsContext() {
|
||||
const context = useContext(ArtifactsContext);
|
||||
if (!context) {
|
||||
throw new Error('useArtifactsContext must be used within ArtifactsProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -23,5 +23,4 @@ export * from './SetConvoContext';
|
||||
export * from './SearchContext';
|
||||
export * from './BadgeRowContext';
|
||||
export * from './SidePanelContext';
|
||||
export * from './ArtifactsContext';
|
||||
export { default as BadgeRowProvider } from './BadgeRowContext';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import debounce from 'lodash/debounce';
|
||||
import React, { useMemo, useState, useEffect, useCallback } from 'react';
|
||||
import React, { memo, useEffect, useMemo, useCallback } from 'react';
|
||||
import {
|
||||
useSandpack,
|
||||
SandpackCodeEditor,
|
||||
@@ -10,8 +10,8 @@ import type { SandpackBundlerFile } from '@codesandbox/sandpack-client';
|
||||
import type { CodeEditorRef } from '@codesandbox/sandpack-react';
|
||||
import type { ArtifactFiles, Artifact } from '~/common';
|
||||
import { useEditArtifact, useGetStartupConfig } from '~/data-provider';
|
||||
import { useEditorContext, useArtifactsContext } from '~/Providers';
|
||||
import { sharedFiles, sharedOptions } from '~/utils/artifacts';
|
||||
import { useEditorContext } from '~/Providers';
|
||||
|
||||
const createDebouncedMutation = (
|
||||
callback: (params: {
|
||||
@@ -29,21 +29,18 @@ const CodeEditor = ({
|
||||
editorRef,
|
||||
}: {
|
||||
fileKey: string;
|
||||
readOnly?: boolean;
|
||||
readOnly: boolean;
|
||||
artifact: Artifact;
|
||||
editorRef: React.MutableRefObject<CodeEditorRef>;
|
||||
}) => {
|
||||
const { sandpack } = useSandpack();
|
||||
const [currentUpdate, setCurrentUpdate] = useState<string | null>(null);
|
||||
const { isMutating, setIsMutating, setCurrentCode } = useEditorContext();
|
||||
const editArtifact = useEditArtifact({
|
||||
onMutate: (vars) => {
|
||||
onMutate: () => {
|
||||
setIsMutating(true);
|
||||
setCurrentUpdate(vars.updated);
|
||||
},
|
||||
onSuccess: () => {
|
||||
setIsMutating(false);
|
||||
setCurrentUpdate(null);
|
||||
},
|
||||
onError: () => {
|
||||
setIsMutating(false);
|
||||
@@ -74,14 +71,8 @@ const CodeEditor = ({
|
||||
}
|
||||
|
||||
const currentCode = (sandpack.files['/' + fileKey] as SandpackBundlerFile | undefined)?.code;
|
||||
const isNotOriginal =
|
||||
currentCode && artifact.content != null && currentCode.trim() !== artifact.content.trim();
|
||||
const isNotRepeated =
|
||||
currentUpdate == null
|
||||
? true
|
||||
: currentCode != null && currentCode.trim() !== currentUpdate.trim();
|
||||
|
||||
if (artifact.content && isNotOriginal && isNotRepeated) {
|
||||
if (currentCode && artifact.content != null && currentCode.trim() !== artifact.content.trim()) {
|
||||
setCurrentCode(currentCode);
|
||||
debouncedMutation({
|
||||
index: artifact.index,
|
||||
@@ -101,9 +92,8 @@ const CodeEditor = ({
|
||||
artifact.messageId,
|
||||
readOnly,
|
||||
isMutating,
|
||||
currentUpdate,
|
||||
setIsMutating,
|
||||
sandpack.files,
|
||||
setIsMutating,
|
||||
setCurrentCode,
|
||||
debouncedMutation,
|
||||
]);
|
||||
@@ -112,32 +102,33 @@ const CodeEditor = ({
|
||||
<SandpackCodeEditor
|
||||
ref={editorRef}
|
||||
showTabs={false}
|
||||
readOnly={readOnly}
|
||||
showRunButton={false}
|
||||
showLineNumbers={true}
|
||||
showInlineErrors={true}
|
||||
readOnly={readOnly === true}
|
||||
className="hljs language-javascript bg-black"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const ArtifactCodeEditor = function ({
|
||||
export const ArtifactCodeEditor = memo(function ({
|
||||
files,
|
||||
fileKey,
|
||||
template,
|
||||
artifact,
|
||||
editorRef,
|
||||
sharedProps,
|
||||
isSubmitting,
|
||||
}: {
|
||||
fileKey: string;
|
||||
artifact: Artifact;
|
||||
files: ArtifactFiles;
|
||||
isSubmitting: boolean;
|
||||
template: SandpackProviderProps['template'];
|
||||
sharedProps: Partial<SandpackProviderProps>;
|
||||
editorRef: React.MutableRefObject<CodeEditorRef>;
|
||||
}) {
|
||||
const { data: config } = useGetStartupConfig();
|
||||
const { isSubmitting } = useArtifactsContext();
|
||||
const options: typeof sharedOptions = useMemo(() => {
|
||||
if (!config) {
|
||||
return sharedOptions;
|
||||
@@ -147,10 +138,6 @@ export const ArtifactCodeEditor = function ({
|
||||
bundlerURL: template === 'static' ? config.staticBundlerURL : config.bundlerURL,
|
||||
};
|
||||
}, [config, template]);
|
||||
const [readOnly, setReadOnly] = useState(isSubmitting ?? false);
|
||||
useEffect(() => {
|
||||
setReadOnly(isSubmitting ?? false);
|
||||
}, [isSubmitting]);
|
||||
|
||||
if (Object.keys(files).length === 0) {
|
||||
return null;
|
||||
@@ -167,7 +154,12 @@ export const ArtifactCodeEditor = function ({
|
||||
{...sharedProps}
|
||||
template={template}
|
||||
>
|
||||
<CodeEditor fileKey={fileKey} artifact={artifact} editorRef={editorRef} readOnly={readOnly} />
|
||||
<CodeEditor
|
||||
editorRef={editorRef}
|
||||
fileKey={fileKey}
|
||||
readOnly={isSubmitting}
|
||||
artifact={artifact}
|
||||
/>
|
||||
</StyledProvider>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
@@ -2,12 +2,12 @@ import { useRef, useEffect } from 'react';
|
||||
import * as Tabs from '@radix-ui/react-tabs';
|
||||
import type { SandpackPreviewRef, CodeEditorRef } from '@codesandbox/sandpack-react';
|
||||
import type { Artifact } from '~/common';
|
||||
import { useEditorContext, useArtifactsContext } from '~/Providers';
|
||||
import useArtifactProps from '~/hooks/Artifacts/useArtifactProps';
|
||||
import { useAutoScroll } from '~/hooks/Artifacts/useAutoScroll';
|
||||
import { ArtifactCodeEditor } from './ArtifactCodeEditor';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import { ArtifactPreview } from './ArtifactPreview';
|
||||
import { useEditorContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function ArtifactTabs({
|
||||
@@ -15,13 +15,14 @@ export default function ArtifactTabs({
|
||||
isMermaid,
|
||||
editorRef,
|
||||
previewRef,
|
||||
isSubmitting,
|
||||
}: {
|
||||
artifact: Artifact;
|
||||
isMermaid: boolean;
|
||||
isSubmitting: boolean;
|
||||
editorRef: React.MutableRefObject<CodeEditorRef>;
|
||||
previewRef: React.MutableRefObject<SandpackPreviewRef>;
|
||||
}) {
|
||||
const { isSubmitting } = useArtifactsContext();
|
||||
const { currentCode, setCurrentCode } = useEditorContext();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const lastIdRef = useRef<string | null>(null);
|
||||
@@ -51,6 +52,7 @@ export default function ArtifactTabs({
|
||||
artifact={artifact}
|
||||
editorRef={editorRef}
|
||||
sharedProps={sharedProps}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
</Tabs.Content>
|
||||
<Tabs.Content
|
||||
|
||||
@@ -29,6 +29,7 @@ export default function Artifacts() {
|
||||
isMermaid,
|
||||
setActiveTab,
|
||||
currentIndex,
|
||||
isSubmitting,
|
||||
cycleArtifact,
|
||||
currentArtifact,
|
||||
orderedArtifactIds,
|
||||
@@ -115,6 +116,7 @@ export default function Artifacts() {
|
||||
<ArtifactTabs
|
||||
isMermaid={isMermaid}
|
||||
artifact={currentArtifact}
|
||||
isSubmitting={isSubmitting}
|
||||
editorRef={editorRef as React.MutableRefObject<CodeEditorRef>}
|
||||
previewRef={previewRef as React.MutableRefObject<SandpackPreviewRef>}
|
||||
/>
|
||||
|
||||
@@ -15,142 +15,133 @@ interface ArtifactsSubMenuProps {
|
||||
handleCustomToggle: () => void;
|
||||
}
|
||||
|
||||
const ArtifactsSubMenu = React.forwardRef<HTMLDivElement, ArtifactsSubMenuProps>(
|
||||
(
|
||||
{
|
||||
isArtifactsPinned,
|
||||
setIsArtifactsPinned,
|
||||
artifactsMode,
|
||||
handleArtifactsToggle,
|
||||
handleShadcnToggle,
|
||||
handleCustomToggle,
|
||||
...props
|
||||
},
|
||||
ref,
|
||||
) => {
|
||||
const localize = useLocalize();
|
||||
const ArtifactsSubMenu = ({
|
||||
isArtifactsPinned,
|
||||
setIsArtifactsPinned,
|
||||
artifactsMode,
|
||||
handleArtifactsToggle,
|
||||
handleShadcnToggle,
|
||||
handleCustomToggle,
|
||||
...props
|
||||
}: ArtifactsSubMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
focusLoop: true,
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
focusLoop: true,
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
|
||||
const isEnabled = artifactsMode !== '' && artifactsMode !== undefined;
|
||||
const isShadcnEnabled = artifactsMode === ArtifactModes.SHADCNUI;
|
||||
const isCustomEnabled = artifactsMode === ArtifactModes.CUSTOM;
|
||||
const isEnabled = artifactsMode !== '' && artifactsMode !== undefined;
|
||||
const isShadcnEnabled = artifactsMode === ArtifactModes.SHADCNUI;
|
||||
const isCustomEnabled = artifactsMode === ArtifactModes.CUSTOM;
|
||||
|
||||
return (
|
||||
<div ref={ref}>
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
hideOnClick={false}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
handleArtifactsToggle();
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
if (isEnabled) {
|
||||
menuStore.show();
|
||||
}
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<WandSparkles className="icon-md" />
|
||||
<span>{localize('com_ui_artifacts')}</span>
|
||||
{isEnabled && <ChevronRight className="ml-auto h-3 w-3" />}
|
||||
return (
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
hideOnClick={false}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
handleArtifactsToggle();
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
if (isEnabled) {
|
||||
menuStore.show();
|
||||
}
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<WandSparkles className="icon-md" />
|
||||
<span>{localize('com_ui_artifacts')}</span>
|
||||
{isEnabled && <ChevronRight className="ml-auto h-3 w-3" />}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsArtifactsPinned(!isArtifactsPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isArtifactsPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isArtifactsPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isArtifactsPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{isEnabled && (
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[250px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary px-1.5 py-1 shadow-lg',
|
||||
)}
|
||||
>
|
||||
<div className="px-2 py-1.5">
|
||||
<div className="mb-2 text-xs font-medium text-text-secondary">
|
||||
{localize('com_ui_artifacts_options')}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsArtifactsPinned(!isArtifactsPinned);
|
||||
|
||||
{/* Include shadcn/ui Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleShadcnToggle();
|
||||
}}
|
||||
disabled={isCustomEnabled}
|
||||
className={cn(
|
||||
'mb-1 flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
isCustomEnabled && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isShadcnEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_include_shadcnui' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{/* Custom Prompt Mode Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleCustomToggle();
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isArtifactsPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isArtifactsPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isArtifactsPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{isEnabled && (
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[250px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary px-1.5 py-1 shadow-lg',
|
||||
'flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
)}
|
||||
>
|
||||
<div className="px-2 py-1.5">
|
||||
<div className="mb-2 text-xs font-medium text-text-secondary">
|
||||
{localize('com_ui_artifacts_options')}
|
||||
</div>
|
||||
|
||||
{/* Include shadcn/ui Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleShadcnToggle();
|
||||
}}
|
||||
disabled={isCustomEnabled}
|
||||
className={cn(
|
||||
'mb-1 flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
isCustomEnabled && 'cursor-not-allowed opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isShadcnEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_include_shadcnui' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
|
||||
{/* Custom Prompt Mode Option */}
|
||||
<Ariakit.MenuItem
|
||||
hideOnClick={false}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
handleCustomToggle();
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center justify-between rounded-lg px-2 py-2',
|
||||
'cursor-pointer text-text-primary outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isCustomEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_custom_prompt_mode' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.MenuItem>
|
||||
<div className="flex items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isCustomEnabled} />
|
||||
<span className="text-sm">{localize('com_ui_custom_prompt_mode' as any)}</span>
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
)}
|
||||
</Ariakit.MenuProvider>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
ArtifactsSubMenu.displayName = 'ArtifactsSubMenu';
|
||||
</Ariakit.MenuItem>
|
||||
</div>
|
||||
</Ariakit.Menu>
|
||||
)}
|
||||
</Ariakit.MenuProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(ArtifactsSubMenu);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { memo, useCallback } from 'react';
|
||||
import { MultiSelect, MCPIcon } from '@librechat/client';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import MCPServerStatusIcon from '~/components/ui/MCP/MCPServerStatusIcon';
|
||||
import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import MCPConfigDialog from '~/components/ui/MCP/MCPConfigDialog';
|
||||
|
||||
function MCPSelect() {
|
||||
const {
|
||||
@@ -13,7 +13,6 @@ function MCPSelect() {
|
||||
batchToggleServers,
|
||||
getServerStatusIconProps,
|
||||
getConfigDialogProps,
|
||||
isInitializing,
|
||||
localize,
|
||||
} = useMCPServerManager();
|
||||
|
||||
@@ -33,20 +32,14 @@ function MCPSelect() {
|
||||
const renderItemContent = useCallback(
|
||||
(serverName: string, defaultContent: React.ReactNode) => {
|
||||
const statusIconProps = getServerStatusIconProps(serverName);
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
|
||||
/**
|
||||
Common wrapper for the main content (check mark + text).
|
||||
Ensures Check & Text are adjacent and the group takes available space.
|
||||
*/
|
||||
// Common wrapper for the main content (check mark + text)
|
||||
// Ensures Check & Text are adjacent and the group takes available space.
|
||||
const mainContentWrapper = (
|
||||
<button
|
||||
type="button"
|
||||
className={`flex flex-grow items-center rounded bg-transparent p-0 text-left transition-colors focus:outline-none ${
|
||||
isServerInitializing ? 'opacity-50' : ''
|
||||
}`}
|
||||
className="flex flex-grow items-center rounded bg-transparent p-0 text-left transition-colors focus:outline-none"
|
||||
tabIndex={0}
|
||||
disabled={isServerInitializing}
|
||||
>
|
||||
{defaultContent}
|
||||
</button>
|
||||
@@ -65,13 +58,15 @@ function MCPSelect() {
|
||||
|
||||
return mainContentWrapper;
|
||||
},
|
||||
[getServerStatusIconProps, isInitializing],
|
||||
[getServerStatusIconProps],
|
||||
);
|
||||
|
||||
// Don't render if no servers are selected and not pinned
|
||||
if ((!mcpValues || mcpValues.length === 0) && !isPinned) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Don't render if no MCP servers are configured
|
||||
if (!configuredServers || configuredServers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
@@ -84,6 +79,7 @@ function MCPSelect() {
|
||||
items={configuredServers}
|
||||
selectedValues={mcpValues ?? []}
|
||||
setSelectedValues={batchToggleServers}
|
||||
defaultSelectedValues={mcpValues ?? []}
|
||||
renderSelectedValues={renderSelectedValues}
|
||||
renderItemContent={renderItemContent}
|
||||
placeholder={placeholderText}
|
||||
|
||||
@@ -2,129 +2,124 @@ import React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronRight } from 'lucide-react';
|
||||
import { PinIcon, MCPIcon } from '@librechat/client';
|
||||
import MCPServerStatusIcon from '~/components/MCP/MCPServerStatusIcon';
|
||||
import MCPServerStatusIcon from '~/components/ui/MCP/MCPServerStatusIcon';
|
||||
import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
|
||||
import MCPConfigDialog from '~/components/MCP/MCPConfigDialog';
|
||||
import MCPConfigDialog from '~/components/ui/MCP/MCPConfigDialog';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface MCPSubMenuProps {
|
||||
placeholder?: string;
|
||||
}
|
||||
|
||||
const MCPSubMenu = React.forwardRef<HTMLDivElement, MCPSubMenuProps>(
|
||||
({ placeholder, ...props }, ref) => {
|
||||
const {
|
||||
configuredServers,
|
||||
mcpValues,
|
||||
isPinned,
|
||||
setIsPinned,
|
||||
placeholderText,
|
||||
toggleServerSelection,
|
||||
getServerStatusIconProps,
|
||||
getConfigDialogProps,
|
||||
isInitializing,
|
||||
} = useMCPServerManager();
|
||||
const MCPSubMenu = ({ placeholder, ...props }: MCPSubMenuProps) => {
|
||||
const {
|
||||
configuredServers,
|
||||
mcpValues,
|
||||
isPinned,
|
||||
setIsPinned,
|
||||
placeholderText,
|
||||
toggleServerSelection,
|
||||
getServerStatusIconProps,
|
||||
getConfigDialogProps,
|
||||
} = useMCPServerManager();
|
||||
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
focusLoop: true,
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
focusLoop: true,
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
|
||||
// Don't render if no MCP servers are configured
|
||||
if (!configuredServers || configuredServers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
// Don't render if no MCP servers are configured
|
||||
if (!configuredServers || configuredServers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const configDialogProps = getConfigDialogProps();
|
||||
const configDialogProps = getConfigDialogProps();
|
||||
|
||||
return (
|
||||
<div ref={ref}>
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
menuStore.toggle();
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<MCPIcon className="icon-md" />
|
||||
<span>{placeholder || placeholderText}</span>
|
||||
<ChevronRight className="ml-auto h-3 w-3" />
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
return (
|
||||
<>
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
onClick={(e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
setIsPinned(!isPinned);
|
||||
menuStore.toggle();
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover"
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<MCPIcon className="icon-md" />
|
||||
<span>{placeholder || placeholderText}</span>
|
||||
<ChevronRight className="ml-auto h-3 w-3" />
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsPinned(!isPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[200px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary p-1 shadow-lg',
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
{configuredServers.map((serverName) => {
|
||||
const statusIconProps = getServerStatusIconProps(serverName);
|
||||
const isSelected = mcpValues?.includes(serverName) ?? false;
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
<Ariakit.Menu
|
||||
portal={true}
|
||||
unmountOnHide={true}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[200px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary p-1 shadow-lg',
|
||||
)}
|
||||
>
|
||||
{configuredServers.map((serverName) => {
|
||||
const statusIconProps = getServerStatusIconProps(serverName);
|
||||
const isSelected = mcpValues?.includes(serverName) ?? false;
|
||||
|
||||
const statusIcon = statusIconProps && <MCPServerStatusIcon {...statusIconProps} />;
|
||||
const statusIcon = statusIconProps && <MCPServerStatusIcon {...statusIconProps} />;
|
||||
|
||||
return (
|
||||
<Ariakit.MenuItem
|
||||
key={serverName}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
toggleServerSelection(serverName);
|
||||
}}
|
||||
disabled={isServerInitializing}
|
||||
className={cn(
|
||||
'flex items-center gap-2 rounded-lg px-2 py-1.5 text-text-primary hover:cursor-pointer',
|
||||
'scroll-m-1 outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
'w-full min-w-0 justify-between text-sm',
|
||||
isServerInitializing &&
|
||||
'opacity-50 hover:bg-transparent dark:hover:bg-transparent',
|
||||
)}
|
||||
return (
|
||||
<Ariakit.MenuItem
|
||||
key={serverName}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
toggleServerSelection(serverName);
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center gap-2 rounded-lg px-2 py-1.5 text-text-primary hover:cursor-pointer',
|
||||
'scroll-m-1 outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
'w-full min-w-0 justify-between text-sm',
|
||||
)}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
className="flex flex-grow items-center gap-2 rounded bg-transparent p-0 text-left transition-colors focus:outline-none"
|
||||
tabIndex={0}
|
||||
>
|
||||
<div className="flex flex-grow items-center gap-2">
|
||||
<Ariakit.MenuItemCheck checked={isSelected} />
|
||||
<span>{serverName}</span>
|
||||
</div>
|
||||
{statusIcon && <div className="ml-2 flex items-center">{statusIcon}</div>}
|
||||
</Ariakit.MenuItem>
|
||||
);
|
||||
})}
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
{configDialogProps && <MCPConfigDialog {...configDialogProps} />}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
MCPSubMenu.displayName = 'MCPSubMenu';
|
||||
<Ariakit.MenuItemCheck checked={isSelected} />
|
||||
<span>{serverName}</span>
|
||||
</button>
|
||||
{statusIcon && <div className="ml-2 flex items-center">{statusIcon}</div>}
|
||||
</Ariakit.MenuItem>
|
||||
);
|
||||
})}
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
{configDialogProps && <MCPConfigDialog {...configDialogProps} />}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(MCPSubMenu);
|
||||
|
||||
@@ -4,7 +4,7 @@ import { FileSources, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import type { ExtendedFile } from '~/common';
|
||||
import { useDeleteFilesMutation } from '~/data-provider';
|
||||
import DragDropWrapper from '~/components/Chat/Input/Files/DragDropWrapper';
|
||||
import { EditorProvider, SidePanelProvider, ArtifactsProvider } from '~/Providers';
|
||||
import { EditorProvider, SidePanelProvider } from '~/Providers';
|
||||
import Artifacts from '~/components/Artifacts/Artifacts';
|
||||
import { SidePanelGroup } from '~/components/SidePanel';
|
||||
import { useSetFilesToDelete } from '~/hooks';
|
||||
@@ -66,11 +66,9 @@ export default function Presentation({ children }: { children: React.ReactNode }
|
||||
defaultCollapsed={defaultCollapsed}
|
||||
artifacts={
|
||||
artifactsVisibility === true && Object.keys(artifacts ?? {}).length > 0 ? (
|
||||
<ArtifactsProvider>
|
||||
<EditorProvider>
|
||||
<Artifacts />
|
||||
</EditorProvider>
|
||||
</ArtifactsProvider>
|
||||
<EditorProvider>
|
||||
<Artifacts />
|
||||
</EditorProvider>
|
||||
) : null
|
||||
}
|
||||
>
|
||||
|
||||
@@ -25,7 +25,7 @@ type EndpointIcon = {
|
||||
|
||||
function getOpenAIColor(_model: string | null | undefined) {
|
||||
const model = _model?.toLowerCase() ?? '';
|
||||
if (model && (/\b(o\d)\b/i.test(model) || /\bgpt-[5-9]\b/i.test(model))) {
|
||||
if (model && /\b(o\d)\b/i.test(model)) {
|
||||
return '#000000';
|
||||
}
|
||||
return model.includes('gpt-4') ? '#AB68FF' : '#19C37D';
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import React from 'react';
|
||||
import { RefreshCw } from 'lucide-react';
|
||||
import { Button, Spinner } from '@librechat/client';
|
||||
import { useMCPServerManager } from '~/hooks/MCP/useMCPServerManager';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface ServerInitializationSectionProps {
|
||||
sidePanel?: boolean;
|
||||
serverName: string;
|
||||
requiresOAuth: boolean;
|
||||
hasCustomUserVars?: boolean;
|
||||
}
|
||||
|
||||
export default function ServerInitializationSection({
|
||||
sidePanel = false,
|
||||
serverName,
|
||||
requiresOAuth,
|
||||
hasCustomUserVars = false,
|
||||
}: ServerInitializationSectionProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const {
|
||||
initializeServer,
|
||||
connectionStatus,
|
||||
cancelOAuthFlow,
|
||||
isInitializing,
|
||||
isCancellable,
|
||||
getOAuthUrl,
|
||||
} = useMCPServerManager();
|
||||
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
const isConnected = serverStatus?.connectionState === 'connected';
|
||||
const canCancel = isCancellable(serverName);
|
||||
const isServerInitializing = isInitializing(serverName);
|
||||
const serverOAuthUrl = getOAuthUrl(serverName);
|
||||
|
||||
const shouldShowReinit = isConnected && (requiresOAuth || hasCustomUserVars);
|
||||
const shouldShowInit = !isConnected && !serverOAuthUrl;
|
||||
|
||||
if (!shouldShowReinit && !shouldShowInit && !serverOAuthUrl) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (serverOAuthUrl) {
|
||||
return (
|
||||
<>
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
onClick={() => cancelOAuthFlow(serverName)}
|
||||
disabled={!canCancel}
|
||||
variant="outline"
|
||||
title={!canCancel ? 'disabled' : undefined}
|
||||
>
|
||||
{localize('com_ui_cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="submit"
|
||||
onClick={() => window.open(serverOAuthUrl, '_blank', 'noopener,noreferrer')}
|
||||
className="flex-1"
|
||||
>
|
||||
{localize('com_ui_continue_oauth')}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
// Unified button rendering
|
||||
const isReinit = shouldShowReinit;
|
||||
const outerClass = isReinit ? 'flex justify-start' : 'flex justify-end';
|
||||
const buttonVariant = isReinit ? undefined : 'default';
|
||||
const buttonText = isServerInitializing
|
||||
? localize('com_ui_loading')
|
||||
: isReinit
|
||||
? localize('com_ui_reinitialize')
|
||||
: requiresOAuth
|
||||
? localize('com_ui_authenticate')
|
||||
: localize('com_ui_mcp_initialize');
|
||||
const icon = isServerInitializing ? (
|
||||
<Spinner className="h-4 w-4" />
|
||||
) : (
|
||||
<RefreshCw className="h-4 w-4" />
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={outerClass}>
|
||||
<Button
|
||||
variant={buttonVariant}
|
||||
onClick={() => initializeServer(serverName, false)}
|
||||
disabled={isServerInitializing}
|
||||
size={sidePanel ? 'sm' : 'default'}
|
||||
className="w-full"
|
||||
>
|
||||
{icon}
|
||||
{buttonText}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -72,7 +72,7 @@ export default function ExportModal({
|
||||
|
||||
const { exportConversation } = useExportConversation({
|
||||
conversation,
|
||||
filename: filenamify(filename),
|
||||
filename,
|
||||
type,
|
||||
includeOptions,
|
||||
exportBranches,
|
||||
@@ -95,7 +95,7 @@ export default function ExportModal({
|
||||
<Input
|
||||
id="filename"
|
||||
value={filename}
|
||||
onChange={(e) => setFileName(e.target.value || '')}
|
||||
onChange={(e) => setFileName(filenamify(e.target.value || ''))}
|
||||
placeholder={localize('com_nav_export_filename_placeholder')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -105,8 +105,6 @@ export const LangSelector = ({
|
||||
{ value: 'nl-NL', label: localize('com_nav_lang_dutch') },
|
||||
{ value: 'id-ID', label: localize('com_nav_lang_indonesia') },
|
||||
{ value: 'fi-FI', label: localize('com_nav_lang_finnish') },
|
||||
{ value: 'bo', label: localize('com_nav_lang_tibetan') },
|
||||
{ value: 'uk-UA', label: localize('com_nav_lang_ukrainian') },
|
||||
];
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import React, { useCallback } from 'react';
|
||||
import { Trash2 } from 'lucide-react';
|
||||
import { useDeletePrompt } from '~/data-provider';
|
||||
import { Button, OGDialog, OGDialogTrigger, Label, OGDialogTemplate } from '@librechat/client';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
const DeleteConfirmDialog = ({
|
||||
const DeleteVersion = ({
|
||||
name,
|
||||
disabled,
|
||||
selectHandler,
|
||||
@@ -60,42 +58,4 @@ const DeleteConfirmDialog = ({
|
||||
);
|
||||
};
|
||||
|
||||
interface DeletePromptProps {
|
||||
promptId?: string;
|
||||
groupId: string;
|
||||
promptName: string;
|
||||
disabled: boolean;
|
||||
}
|
||||
|
||||
const DeletePrompt = React.memo(
|
||||
({ promptId, groupId, promptName, disabled }: DeletePromptProps) => {
|
||||
const deletePromptMutation = useDeletePrompt();
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (!promptId) {
|
||||
console.warn('No prompt ID provided for deletion');
|
||||
return;
|
||||
}
|
||||
deletePromptMutation.mutate({
|
||||
_id: promptId,
|
||||
groupId,
|
||||
});
|
||||
}, [promptId, groupId, deletePromptMutation]);
|
||||
|
||||
if (!promptId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<DeleteConfirmDialog
|
||||
name={promptName}
|
||||
disabled={disabled || !promptId}
|
||||
selectHandler={handleDelete}
|
||||
/>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
DeletePrompt.displayName = 'DeletePrompt';
|
||||
|
||||
export default DeletePrompt;
|
||||
export default DeleteVersion;
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import React, { useMemo, useState } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import React, { useMemo } from 'react';
|
||||
import { Dropdown } from '@librechat/client';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { DropdownPopup } from '@librechat/client';
|
||||
import { LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { useFormContext, Controller } from 'react-hook-form';
|
||||
import type { MenuItemProps } from '@librechat/client';
|
||||
import { LocalStorageKeys } from 'librechat-data-provider';
|
||||
import type { ReactNode } from 'react';
|
||||
import { useCategories } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface CategorySelectorProps {
|
||||
currentCategory?: string;
|
||||
@@ -23,11 +20,10 @@ const CategorySelector: React.FC<CategorySelectorProps> = ({
|
||||
const { t } = useTranslation();
|
||||
const formContext = useFormContext();
|
||||
const { categories, emptyCategory } = useCategories();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const control = formContext?.control;
|
||||
const watch = formContext?.watch;
|
||||
const setValue = formContext?.setValue;
|
||||
const control = formContext.control;
|
||||
const watch = formContext.watch;
|
||||
const setValue = formContext.setValue;
|
||||
|
||||
const watchedCategory = watch ? watch('category') : currentCategory;
|
||||
|
||||
@@ -50,71 +46,53 @@ const CategorySelector: React.FC<CategorySelectorProps> = ({
|
||||
return categoryOption;
|
||||
}, [categoryOption, t]);
|
||||
|
||||
const menuItems: MenuItemProps[] = useMemo(() => {
|
||||
if (!categories) return [];
|
||||
|
||||
return categories.map((category) => ({
|
||||
id: category.value,
|
||||
label: category.label,
|
||||
icon: 'icon' in category ? category.icon : undefined,
|
||||
onClick: () => {
|
||||
const value = category.value || '';
|
||||
if (formContext && setValue) {
|
||||
setValue('category', value, { shouldDirty: false });
|
||||
}
|
||||
localStorage.setItem(LocalStorageKeys.LAST_PROMPT_CATEGORY, value);
|
||||
onValueChange?.(value);
|
||||
setIsOpen(false);
|
||||
},
|
||||
}));
|
||||
}, [categories, formContext, setValue, onValueChange]);
|
||||
|
||||
const trigger = (
|
||||
<Ariakit.MenuButton
|
||||
className={cn(
|
||||
'focus:ring-offset-ring-offset relative inline-flex items-center justify-between rounded-xl border border-input bg-background px-3 py-2 text-sm text-text-primary transition-all duration-200 ease-in-out hover:bg-accent hover:text-accent-foreground focus:ring-ring-primary',
|
||||
'w-fit gap-2',
|
||||
className,
|
||||
)}
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
aria-label="Prompt's category selector"
|
||||
aria-labelledby="category-selector-label"
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{'icon' in displayCategory && displayCategory.icon != null && (
|
||||
<span>{displayCategory.icon as ReactNode}</span>
|
||||
)}
|
||||
<span>{displayCategory.value ? displayCategory.label : t('com_ui_category')}</span>
|
||||
</div>
|
||||
<Ariakit.MenuButtonArrow />
|
||||
</Ariakit.MenuButton>
|
||||
);
|
||||
|
||||
return formContext ? (
|
||||
<Controller
|
||||
name="category"
|
||||
control={control}
|
||||
render={() => (
|
||||
<DropdownPopup
|
||||
trigger={trigger}
|
||||
items={menuItems}
|
||||
isOpen={isOpen}
|
||||
setIsOpen={setIsOpen}
|
||||
menuId="category-selector-menu"
|
||||
className="mt-2"
|
||||
portal={true}
|
||||
<Dropdown
|
||||
value={displayCategory.value ?? ''}
|
||||
label={displayCategory.value ? undefined : t('com_ui_category')}
|
||||
onChange={(value: string) => {
|
||||
setValue('category', value, { shouldDirty: false });
|
||||
localStorage.setItem(LocalStorageKeys.LAST_PROMPT_CATEGORY, value);
|
||||
onValueChange?.(value);
|
||||
}}
|
||||
aria-labelledby="category-selector-label"
|
||||
ariaLabel="Prompt's category selector"
|
||||
className={className}
|
||||
options={categories || []}
|
||||
renderValue={() => (
|
||||
<div className="flex items-center space-x-2">
|
||||
{'icon' in displayCategory && displayCategory.icon != null && (
|
||||
<span>{displayCategory.icon as ReactNode}</span>
|
||||
)}
|
||||
<span>{displayCategory.label}</span>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
) : (
|
||||
<DropdownPopup
|
||||
trigger={trigger}
|
||||
items={menuItems}
|
||||
isOpen={isOpen}
|
||||
setIsOpen={setIsOpen}
|
||||
menuId="category-selector-menu"
|
||||
className="mt-2"
|
||||
portal={true}
|
||||
<Dropdown
|
||||
value={currentCategory ?? ''}
|
||||
onChange={(value: string) => {
|
||||
localStorage.setItem(LocalStorageKeys.LAST_PROMPT_CATEGORY, value);
|
||||
onValueChange?.(value);
|
||||
}}
|
||||
aria-labelledby="category-selector-label"
|
||||
ariaLabel="Prompt's category selector"
|
||||
className={className}
|
||||
options={categories || []}
|
||||
renderValue={() => (
|
||||
<div className="flex items-center space-x-2">
|
||||
{'icon' in displayCategory && displayCategory.icon != null && (
|
||||
<span>{displayCategory.icon as ReactNode}</span>
|
||||
)}
|
||||
<span>{displayCategory.label}</span>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { useEffect, useState, useMemo, useCallback, useRef } from 'react';
|
||||
import React from 'react';
|
||||
import debounce from 'lodash/debounce';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { Menu, Rocket } from 'lucide-react';
|
||||
@@ -7,13 +6,14 @@ import { useForm, FormProvider } from 'react-hook-form';
|
||||
import { useParams, useOutletContext } from 'react-router-dom';
|
||||
import { Button, Skeleton, useToastContext } from '@librechat/client';
|
||||
import { SystemRoles, PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import type { TCreatePrompt, TPrompt, TPromptGroup } from 'librechat-data-provider';
|
||||
import type { TCreatePrompt } from 'librechat-data-provider';
|
||||
import {
|
||||
useGetPrompts,
|
||||
useCreatePrompt,
|
||||
useGetPrompts,
|
||||
useGetPromptGroup,
|
||||
useUpdatePromptGroup,
|
||||
useMakePromptProduction,
|
||||
useDeletePrompt,
|
||||
} from '~/data-provider';
|
||||
import { useAuthContext, usePromptGroupsNav, useHasAccess, useLocalize } from '~/hooks';
|
||||
import CategorySelector from './Groups/CategorySelector';
|
||||
@@ -22,7 +22,7 @@ import PromptVariables from './PromptVariables';
|
||||
import { cn, findPromptGroup } from '~/utils';
|
||||
import PromptVersions from './PromptVersions';
|
||||
import { PromptsEditorMode } from '~/common';
|
||||
import DeleteVersion from './DeleteVersion';
|
||||
import DeleteConfirm from './DeleteVersion';
|
||||
import PromptDetails from './PromptDetails';
|
||||
import PromptEditor from './PromptEditor';
|
||||
import SkeletonForm from './SkeletonForm';
|
||||
@@ -32,136 +32,16 @@ import PromptName from './PromptName';
|
||||
import Command from './Command';
|
||||
import store from '~/store';
|
||||
|
||||
interface RightPanelProps {
|
||||
group: TPromptGroup;
|
||||
prompts: TPrompt[];
|
||||
selectedPrompt: any;
|
||||
selectionIndex: number;
|
||||
selectedPromptId?: string;
|
||||
isLoadingPrompts: boolean;
|
||||
setSelectionIndex: React.Dispatch<React.SetStateAction<number>>;
|
||||
}
|
||||
|
||||
const RightPanel = React.memo(
|
||||
({
|
||||
group,
|
||||
prompts,
|
||||
selectedPrompt,
|
||||
selectedPromptId,
|
||||
isLoadingPrompts,
|
||||
selectionIndex,
|
||||
setSelectionIndex,
|
||||
}: RightPanelProps) => {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const editorMode = useRecoilValue(store.promptsEditorMode);
|
||||
const hasShareAccess = useHasAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permission: Permissions.SHARED_GLOBAL,
|
||||
});
|
||||
|
||||
const updateGroupMutation = useUpdatePromptGroup({
|
||||
onError: () => {
|
||||
showToast({
|
||||
status: 'error',
|
||||
message: localize('com_ui_prompt_update_error'),
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const makeProductionMutation = useMakePromptProduction();
|
||||
|
||||
const groupId = group?._id || '';
|
||||
const groupName = group?.name || '';
|
||||
const groupCategory = group?.category || '';
|
||||
const isLoadingGroup = !group;
|
||||
|
||||
return (
|
||||
<div
|
||||
className="h-full w-full overflow-y-auto bg-surface-primary px-4"
|
||||
style={{ maxHeight: 'calc(100vh - 100px)' }}
|
||||
>
|
||||
<div className="mb-2 flex flex-col lg:flex-row lg:items-center lg:justify-center lg:gap-x-2 xl:flex-row xl:space-y-0">
|
||||
<CategorySelector
|
||||
currentCategory={groupCategory}
|
||||
onValueChange={(value) =>
|
||||
updateGroupMutation.mutate({
|
||||
id: groupId,
|
||||
payload: { name: groupName, category: value },
|
||||
})
|
||||
}
|
||||
/>
|
||||
<div className="mt-2 flex flex-row items-center justify-center gap-x-2 lg:mt-0">
|
||||
{hasShareAccess && <SharePrompt group={group} disabled={isLoadingGroup} />}
|
||||
{editorMode === PromptsEditorMode.ADVANCED && (
|
||||
<Button
|
||||
variant="submit"
|
||||
size="sm"
|
||||
aria-label="Make prompt production"
|
||||
className="h-10 w-10 border border-transparent p-0.5 transition-all"
|
||||
onClick={() => {
|
||||
if (!selectedPrompt) {
|
||||
console.warn('No prompt is selected');
|
||||
return;
|
||||
}
|
||||
const { _id: promptVersionId = '', prompt } = selectedPrompt;
|
||||
makeProductionMutation.mutate({
|
||||
id: promptVersionId,
|
||||
groupId,
|
||||
productionPrompt: { prompt },
|
||||
});
|
||||
}}
|
||||
disabled={
|
||||
isLoadingGroup ||
|
||||
!selectedPrompt ||
|
||||
selectedPrompt._id === group?.productionId ||
|
||||
makeProductionMutation.isLoading
|
||||
}
|
||||
>
|
||||
<Rocket className="size-5 cursor-pointer text-white" />
|
||||
</Button>
|
||||
)}
|
||||
<DeleteVersion
|
||||
promptId={selectedPromptId}
|
||||
groupId={groupId}
|
||||
promptName={groupName}
|
||||
disabled={isLoadingGroup}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{editorMode === PromptsEditorMode.ADVANCED &&
|
||||
(isLoadingPrompts
|
||||
? Array.from({ length: 6 }).map((_, index: number) => (
|
||||
<div key={index} className="my-2">
|
||||
<Skeleton className="h-[72px] w-full" />
|
||||
</div>
|
||||
))
|
||||
: prompts.length > 0 && (
|
||||
<PromptVersions
|
||||
group={group}
|
||||
prompts={prompts}
|
||||
selectionIndex={selectionIndex}
|
||||
setSelectionIndex={setSelectionIndex}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
RightPanel.displayName = 'RightPanel';
|
||||
|
||||
const PromptForm = () => {
|
||||
const params = useParams();
|
||||
const localize = useLocalize();
|
||||
const { user } = useAuthContext();
|
||||
const { showToast } = useToastContext();
|
||||
const alwaysMakeProd = useRecoilValue(store.alwaysMakeProd);
|
||||
const { showToast } = useToastContext();
|
||||
const promptId = params.promptId || '';
|
||||
|
||||
const editorMode = useRecoilValue(store.promptsEditorMode);
|
||||
const [selectionIndex, setSelectionIndex] = useState<number>(0);
|
||||
|
||||
const editorMode = useRecoilValue(store.promptsEditorMode);
|
||||
const prevIsEditingRef = useRef(false);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [initialLoad, setInitialLoad] = useState(true);
|
||||
@@ -192,9 +72,11 @@ const PromptForm = () => {
|
||||
[prompts, selectionIndex],
|
||||
);
|
||||
|
||||
const selectedPromptId = useMemo(() => selectedPrompt?._id, [selectedPrompt?._id]);
|
||||
|
||||
const { groupsQuery } = useOutletContext<ReturnType<typeof usePromptGroupsNav>>();
|
||||
const hasShareAccess = useHasAccess({
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
permission: Permissions.SHARED_GLOBAL,
|
||||
});
|
||||
|
||||
const updateGroupMutation = useUpdatePromptGroup({
|
||||
onError: () => {
|
||||
@@ -206,6 +88,7 @@ const PromptForm = () => {
|
||||
});
|
||||
|
||||
const makeProductionMutation = useMakePromptProduction();
|
||||
const deletePromptMutation = useDeletePrompt();
|
||||
|
||||
const createPromptMutation = useCreatePrompt({
|
||||
onMutate: (variables) => {
|
||||
@@ -294,40 +177,24 @@ const PromptForm = () => {
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, []);
|
||||
|
||||
const debouncedUpdateOneliner = useMemo(
|
||||
() =>
|
||||
debounce((groupId: string, oneliner: string, mutate: any) => {
|
||||
mutate({ id: groupId, payload: { oneliner } });
|
||||
}, 950),
|
||||
[],
|
||||
);
|
||||
|
||||
const debouncedUpdateCommand = useMemo(
|
||||
() =>
|
||||
debounce((groupId: string, command: string, mutate: any) => {
|
||||
mutate({ id: groupId, payload: { command } });
|
||||
}, 950),
|
||||
[],
|
||||
);
|
||||
|
||||
const handleUpdateOneliner = useCallback(
|
||||
(oneliner: string) => {
|
||||
const debouncedUpdateOneliner = useCallback(
|
||||
debounce((oneliner: string) => {
|
||||
if (!group || !group._id) {
|
||||
return console.warn('Group not found');
|
||||
}
|
||||
debouncedUpdateOneliner(group._id, oneliner, updateGroupMutation.mutate);
|
||||
},
|
||||
[group, updateGroupMutation.mutate, debouncedUpdateOneliner],
|
||||
updateGroupMutation.mutate({ id: group._id, payload: { oneliner } });
|
||||
}, 950),
|
||||
[updateGroupMutation, group],
|
||||
);
|
||||
|
||||
const handleUpdateCommand = useCallback(
|
||||
(command: string) => {
|
||||
const debouncedUpdateCommand = useCallback(
|
||||
debounce((command: string) => {
|
||||
if (!group || !group._id) {
|
||||
return console.warn('Group not found');
|
||||
}
|
||||
debouncedUpdateCommand(group._id, command, updateGroupMutation.mutate);
|
||||
},
|
||||
[group, updateGroupMutation.mutate, debouncedUpdateCommand],
|
||||
updateGroupMutation.mutate({ id: group._id, payload: { command } });
|
||||
}, 950),
|
||||
[updateGroupMutation, group],
|
||||
);
|
||||
|
||||
if (initialLoad) {
|
||||
@@ -350,7 +217,89 @@ const PromptForm = () => {
|
||||
return null;
|
||||
}
|
||||
|
||||
const groupId = group._id;
|
||||
|
||||
const groupName = group.name;
|
||||
const groupCategory = group.category;
|
||||
|
||||
const RightPanel = () => (
|
||||
<div
|
||||
className="h-full w-full overflow-y-auto bg-surface-primary px-4"
|
||||
style={{ maxHeight: 'calc(100vh - 100px)' }}
|
||||
>
|
||||
<div className="mb-2 flex flex-col lg:flex-row lg:items-center lg:justify-center lg:gap-x-2 xl:flex-row xl:space-y-0">
|
||||
<CategorySelector
|
||||
currentCategory={groupCategory}
|
||||
onValueChange={(value) =>
|
||||
updateGroupMutation.mutate({
|
||||
id: groupId,
|
||||
payload: { name: groupName, category: value },
|
||||
})
|
||||
}
|
||||
/>
|
||||
<div className="mt-2 flex flex-row items-center justify-center gap-x-2 lg:mt-0">
|
||||
{hasShareAccess && <SharePrompt group={group} disabled={isLoadingGroup} />}
|
||||
{editorMode === PromptsEditorMode.ADVANCED && (
|
||||
<Button
|
||||
variant="submit"
|
||||
size="sm"
|
||||
aria-label="Make prompt production"
|
||||
className="h-10 w-10 border border-transparent p-0.5 transition-all"
|
||||
onClick={() => {
|
||||
if (!selectedPrompt) {
|
||||
console.warn('No prompt is selected');
|
||||
return;
|
||||
}
|
||||
const { _id: promptVersionId = '', prompt } = selectedPrompt;
|
||||
makeProductionMutation.mutate({
|
||||
id: promptVersionId,
|
||||
groupId,
|
||||
productionPrompt: { prompt },
|
||||
});
|
||||
}}
|
||||
disabled={
|
||||
isLoadingGroup ||
|
||||
!selectedPrompt ||
|
||||
selectedPrompt._id === group.productionId ||
|
||||
makeProductionMutation.isLoading
|
||||
}
|
||||
>
|
||||
<Rocket className="size-5 cursor-pointer text-white" />
|
||||
</Button>
|
||||
)}
|
||||
<DeleteConfirm
|
||||
name={groupName}
|
||||
disabled={isLoadingGroup}
|
||||
selectHandler={() => {
|
||||
if (!selectedPrompt || !selectedPrompt._id) {
|
||||
console.warn('No prompt is selected or prompt _id is missing');
|
||||
return;
|
||||
}
|
||||
deletePromptMutation.mutate({
|
||||
_id: selectedPrompt._id,
|
||||
groupId,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{editorMode === PromptsEditorMode.ADVANCED &&
|
||||
(isLoadingPrompts
|
||||
? Array.from({ length: 6 }).map((_, index: number) => (
|
||||
<div key={index} className="my-2">
|
||||
<Skeleton className="h-[72px] w-full" />
|
||||
</div>
|
||||
))
|
||||
: prompts.length > 0 && (
|
||||
<PromptVersions
|
||||
group={group}
|
||||
prompts={prompts}
|
||||
selectionIndex={selectionIndex}
|
||||
setSelectionIndex={setSelectionIndex}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<FormProvider {...methods}>
|
||||
@@ -390,17 +339,7 @@ const PromptForm = () => {
|
||||
<Menu className="size-5" />
|
||||
</Button>
|
||||
<div className="hidden lg:block">
|
||||
{editorMode === PromptsEditorMode.SIMPLE && (
|
||||
<RightPanel
|
||||
group={group}
|
||||
prompts={prompts}
|
||||
selectedPrompt={selectedPrompt}
|
||||
selectionIndex={selectionIndex}
|
||||
selectedPromptId={selectedPromptId}
|
||||
isLoadingPrompts={isLoadingPrompts}
|
||||
setSelectionIndex={setSelectionIndex}
|
||||
/>
|
||||
)}
|
||||
{editorMode === PromptsEditorMode.SIMPLE && <RightPanel />}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
@@ -413,11 +352,11 @@ const PromptForm = () => {
|
||||
<PromptVariables promptText={promptText} />
|
||||
<Description
|
||||
initialValue={group.oneliner ?? ''}
|
||||
onValueChange={handleUpdateOneliner}
|
||||
onValueChange={debouncedUpdateOneliner}
|
||||
/>
|
||||
<Command
|
||||
initialValue={group.command ?? ''}
|
||||
onValueChange={handleUpdateCommand}
|
||||
onValueChange={debouncedUpdateCommand}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
@@ -425,15 +364,7 @@ const PromptForm = () => {
|
||||
|
||||
{editorMode === PromptsEditorMode.ADVANCED && (
|
||||
<div className="hidden w-1/4 border-l border-border-light lg:block">
|
||||
<RightPanel
|
||||
group={group}
|
||||
prompts={prompts}
|
||||
selectionIndex={selectionIndex}
|
||||
selectedPrompt={selectedPrompt}
|
||||
selectedPromptId={selectedPromptId}
|
||||
isLoadingPrompts={isLoadingPrompts}
|
||||
setSelectionIndex={setSelectionIndex}
|
||||
/>
|
||||
<RightPanel />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -464,15 +395,7 @@ const PromptForm = () => {
|
||||
>
|
||||
<div className="h-full">
|
||||
<div className="h-full overflow-auto">
|
||||
<RightPanel
|
||||
group={group}
|
||||
prompts={prompts}
|
||||
selectionIndex={selectionIndex}
|
||||
selectedPrompt={selectedPrompt}
|
||||
selectedPromptId={selectedPromptId}
|
||||
isLoadingPrompts={isLoadingPrompts}
|
||||
setSelectionIndex={setSelectionIndex}
|
||||
/>
|
||||
<RightPanel />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,380 +0,0 @@
|
||||
/**
|
||||
* @jest-environment jsdom
|
||||
*/
|
||||
import * as React from 'react';
|
||||
import { describe, it, expect, beforeEach, jest } from '@jest/globals';
|
||||
import { render, waitFor, fireEvent } from '@testing-library/react';
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
|
||||
import type { Agent } from 'librechat-data-provider';
|
||||
|
||||
// Mock toast context - define this after all mocks
|
||||
let mockShowToast: jest.Mock;
|
||||
|
||||
// Mock notification severity enum before other imports
|
||||
jest.mock('~/common/types', () => ({
|
||||
NotificationSeverity: {
|
||||
SUCCESS: 'success',
|
||||
ERROR: 'error',
|
||||
INFO: 'info',
|
||||
WARNING: 'warning',
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock store to prevent import errors
|
||||
jest.mock('~/store/toast', () => ({
|
||||
default: () => ({
|
||||
showToast: jest.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/store', () => {});
|
||||
|
||||
// Mock the data service to control network responses
|
||||
jest.mock('librechat-data-provider', () => {
|
||||
const actualModule = jest.requireActual('librechat-data-provider') as any;
|
||||
return {
|
||||
...actualModule,
|
||||
dataService: {
|
||||
updateAgent: jest.fn(),
|
||||
},
|
||||
Tools: {
|
||||
execute_code: 'execute_code',
|
||||
file_search: 'file_search',
|
||||
web_search: 'web_search',
|
||||
},
|
||||
Constants: {
|
||||
EPHEMERAL_AGENT_ID: 'ephemeral',
|
||||
},
|
||||
SystemRoles: {
|
||||
ADMIN: 'ADMIN',
|
||||
},
|
||||
EModelEndpoint: {
|
||||
agents: 'agents',
|
||||
chatGPTBrowser: 'chatGPTBrowser',
|
||||
gptPlugins: 'gptPlugins',
|
||||
},
|
||||
isAssistantsEndpoint: jest.fn(() => false),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('@librechat/client', () => ({
|
||||
Button: ({ children, onClick, ...props }: any) => (
|
||||
<button onClick={onClick} {...props}>
|
||||
{children}
|
||||
</button>
|
||||
),
|
||||
useToastContext: () => ({
|
||||
get showToast() {
|
||||
return mockShowToast || jest.fn();
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock other dependencies
|
||||
jest.mock('librechat-data-provider/react-query', () => ({
|
||||
useGetModelsQuery: () => ({ data: {} }),
|
||||
}));
|
||||
|
||||
jest.mock('~/utils', () => ({
|
||||
createProviderOption: jest.fn((provider: string) => ({ value: provider, label: provider })),
|
||||
getDefaultAgentFormValues: jest.fn(() => ({
|
||||
id: '',
|
||||
name: '',
|
||||
description: '',
|
||||
model: '',
|
||||
provider: '',
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/hooks', () => ({
|
||||
useSelectAgent: () => ({ onSelect: jest.fn() }),
|
||||
useLocalize: () => (key: string) => key,
|
||||
useAuthContext: () => ({ user: { id: 'user-123', role: 'USER' } }),
|
||||
}));
|
||||
|
||||
jest.mock('~/Providers/AgentPanelContext', () => ({
|
||||
useAgentPanelContext: () => ({
|
||||
activePanel: 'builder',
|
||||
agentsConfig: { allowedProviders: [] },
|
||||
setActivePanel: jest.fn(),
|
||||
endpointsConfig: {},
|
||||
setCurrentAgentId: jest.fn(),
|
||||
agent_id: 'agent-123',
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/common', () => ({
|
||||
Panel: {
|
||||
model: 'model',
|
||||
builder: 'builder',
|
||||
advanced: 'advanced',
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock child components to simplify testing
|
||||
jest.mock('./AgentPanelSkeleton', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div>{`Loading...`}</div>,
|
||||
}));
|
||||
|
||||
jest.mock('./Advanced/AdvancedPanel', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div>{`Advanced Panel`}</div>,
|
||||
}));
|
||||
|
||||
jest.mock('./AgentConfig', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div>{`Agent Config`}</div>,
|
||||
}));
|
||||
|
||||
jest.mock('./AgentSelect', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div>{`Agent Select`}</div>,
|
||||
}));
|
||||
|
||||
jest.mock('./ModelPanel', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div>{`Model Panel`}</div>,
|
||||
}));
|
||||
|
||||
// Mock AgentFooter to provide a save button
|
||||
jest.mock('./AgentFooter', () => ({
|
||||
__esModule: true,
|
||||
default: () => (
|
||||
<button type="submit" data-testid="save-agent-button">
|
||||
{`Save Agent`}
|
||||
</button>
|
||||
),
|
||||
}));
|
||||
|
||||
// Mock react-hook-form to capture form submission
|
||||
let mockFormSubmitHandler: (() => void) | null = null;
|
||||
|
||||
jest.mock('react-hook-form', () => {
|
||||
const actual = jest.requireActual('react-hook-form') as any;
|
||||
return {
|
||||
...actual,
|
||||
useForm: () => {
|
||||
const methods = actual.useForm({
|
||||
defaultValues: {
|
||||
id: 'agent-123',
|
||||
name: 'Test Agent',
|
||||
description: 'Test description',
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tools: [],
|
||||
execute_code: false,
|
||||
file_search: false,
|
||||
web_search: false,
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
...methods,
|
||||
handleSubmit: (onSubmit: any) => (e?: any) => {
|
||||
e?.preventDefault?.();
|
||||
mockFormSubmitHandler = () => onSubmit(methods.getValues());
|
||||
return mockFormSubmitHandler;
|
||||
},
|
||||
};
|
||||
},
|
||||
FormProvider: ({ children }: any) => children,
|
||||
useWatch: () => 'agent-123',
|
||||
};
|
||||
});
|
||||
|
||||
// Import after mocks
|
||||
import { dataService } from 'librechat-data-provider';
|
||||
import { useGetAgentByIdQuery } from '~/data-provider';
|
||||
import AgentPanel from './AgentPanel';
|
||||
|
||||
// Mock useGetAgentByIdQuery
|
||||
jest.mock('~/data-provider', () => {
|
||||
const actual = jest.requireActual('~/data-provider') as any;
|
||||
return {
|
||||
...actual,
|
||||
useGetAgentByIdQuery: jest.fn(),
|
||||
useUpdateAgentMutation: actual.useUpdateAgentMutation,
|
||||
};
|
||||
});
|
||||
|
||||
// Test wrapper with QueryClient
|
||||
const createWrapper = () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
});
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
// Test helpers
|
||||
const setupMocks = () => {
|
||||
const mockUseGetAgentByIdQuery = useGetAgentByIdQuery as jest.MockedFunction<
|
||||
typeof useGetAgentByIdQuery
|
||||
>;
|
||||
const mockUpdateAgent = dataService.updateAgent as jest.MockedFunction<
|
||||
typeof dataService.updateAgent
|
||||
>;
|
||||
|
||||
return { mockUseGetAgentByIdQuery, mockUpdateAgent };
|
||||
};
|
||||
|
||||
const mockAgentQuery = (
|
||||
mockUseGetAgentByIdQuery: jest.MockedFunction<typeof useGetAgentByIdQuery>,
|
||||
agent: Partial<Agent>,
|
||||
) => {
|
||||
mockUseGetAgentByIdQuery.mockReturnValue({
|
||||
data: {
|
||||
id: 'agent-123',
|
||||
author: 'user-123',
|
||||
isCollaborative: false,
|
||||
...agent,
|
||||
} as Agent,
|
||||
isInitialLoading: false,
|
||||
} as any);
|
||||
};
|
||||
|
||||
const createMockAgent = (overrides: Partial<Agent> = {}): Agent =>
|
||||
({
|
||||
id: 'agent-123',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
...overrides,
|
||||
}) as Agent;
|
||||
|
||||
const renderAndSubmitForm = async () => {
|
||||
const Wrapper = createWrapper();
|
||||
const { container, rerender } = render(<AgentPanel />, { wrapper: Wrapper });
|
||||
|
||||
const form = container.querySelector('form');
|
||||
expect(form).toBeTruthy();
|
||||
|
||||
fireEvent.submit(form!);
|
||||
|
||||
if (mockFormSubmitHandler) {
|
||||
mockFormSubmitHandler();
|
||||
}
|
||||
|
||||
return { container, rerender, form };
|
||||
};
|
||||
|
||||
describe('AgentPanel - Update Agent Toast Messages', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockShowToast = jest.fn();
|
||||
mockFormSubmitHandler = null;
|
||||
});
|
||||
|
||||
describe('AgentPanel', () => {
|
||||
it('should show "no changes" toast when version does not change', async () => {
|
||||
const { mockUseGetAgentByIdQuery, mockUpdateAgent } = setupMocks();
|
||||
|
||||
// Mock the agent query with version 2
|
||||
mockAgentQuery(mockUseGetAgentByIdQuery, {
|
||||
name: 'Test Agent',
|
||||
version: 2,
|
||||
});
|
||||
|
||||
// Mock network response - same version
|
||||
mockUpdateAgent.mockResolvedValue(createMockAgent({ name: 'Test Agent', version: 2 }));
|
||||
|
||||
await renderAndSubmitForm();
|
||||
|
||||
// Wait for the toast to be shown
|
||||
await waitFor(() => {
|
||||
expect(mockShowToast).toHaveBeenCalledWith({
|
||||
message: 'com_ui_no_changes',
|
||||
status: 'info',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('should show "update success" toast when version changes', async () => {
|
||||
const { mockUseGetAgentByIdQuery, mockUpdateAgent } = setupMocks();
|
||||
|
||||
// Mock the agent query with version 2
|
||||
mockAgentQuery(mockUseGetAgentByIdQuery, {
|
||||
name: 'Test Agent',
|
||||
version: 2,
|
||||
});
|
||||
|
||||
// Mock network response - different version
|
||||
mockUpdateAgent.mockResolvedValue(createMockAgent({ name: 'Test Agent', version: 3 }));
|
||||
|
||||
await renderAndSubmitForm();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockShowToast).toHaveBeenCalledWith({
|
||||
message: 'com_assistants_update_success Test Agent',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('should show "update success" with default name when agent has no name', async () => {
|
||||
const { mockUseGetAgentByIdQuery, mockUpdateAgent } = setupMocks();
|
||||
|
||||
// Mock the agent query without name
|
||||
mockAgentQuery(mockUseGetAgentByIdQuery, {
|
||||
version: 1,
|
||||
});
|
||||
|
||||
// Mock network response - no name
|
||||
mockUpdateAgent.mockResolvedValue(createMockAgent({ version: 2 }));
|
||||
|
||||
await renderAndSubmitForm();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockShowToast).toHaveBeenCalledWith({
|
||||
message: 'com_assistants_update_success com_ui_agent',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('should show "update success" when agent query has no version (undefined)', async () => {
|
||||
const { mockUseGetAgentByIdQuery, mockUpdateAgent } = setupMocks();
|
||||
|
||||
// Mock the agent query with no version data
|
||||
mockAgentQuery(mockUseGetAgentByIdQuery, {
|
||||
name: 'Test Agent',
|
||||
// No version property
|
||||
});
|
||||
|
||||
mockUpdateAgent.mockResolvedValue(createMockAgent({ name: 'Test Agent', version: 1 }));
|
||||
|
||||
await renderAndSubmitForm();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockShowToast).toHaveBeenCalledWith({
|
||||
message: 'com_assistants_update_success Test Agent',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('should show error toast on update failure', async () => {
|
||||
const { mockUseGetAgentByIdQuery, mockUpdateAgent } = setupMocks();
|
||||
|
||||
// Mock the agent query
|
||||
mockAgentQuery(mockUseGetAgentByIdQuery, {
|
||||
name: 'Test Agent',
|
||||
version: 1,
|
||||
});
|
||||
|
||||
// Mock network error
|
||||
mockUpdateAgent.mockRejectedValue(new Error('Update failed'));
|
||||
|
||||
await renderAndSubmitForm();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockShowToast).toHaveBeenCalledWith({
|
||||
message: 'com_agents_update_error com_ui_error: Update failed',
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Plus } from 'lucide-react';
|
||||
import React, { useMemo, useCallback, useRef } from 'react';
|
||||
import React, { useMemo, useCallback } from 'react';
|
||||
import { Button, useToastContext } from '@librechat/client';
|
||||
import { useWatch, useForm, FormProvider } from 'react-hook-form';
|
||||
import { useGetModelsQuery } from 'librechat-data-provider/react-query';
|
||||
@@ -54,7 +54,6 @@ export default function AgentPanel() {
|
||||
|
||||
const { control, handleSubmit, reset } = methods;
|
||||
const agent_id = useWatch({ control, name: 'id' });
|
||||
const previousVersionRef = useRef<number | undefined>();
|
||||
|
||||
const allowedProviders = useMemo(
|
||||
() => new Set(agentsConfig?.allowedProviders),
|
||||
@@ -78,29 +77,50 @@ export default function AgentPanel() {
|
||||
|
||||
/* Mutations */
|
||||
const update = useUpdateAgentMutation({
|
||||
onMutate: () => {
|
||||
// Store the current version before mutation
|
||||
previousVersionRef.current = agentQuery.data?.version;
|
||||
},
|
||||
onSuccess: (data) => {
|
||||
// Check if agent version is the same (no changes were made)
|
||||
if (previousVersionRef.current !== undefined && data.version === previousVersionRef.current) {
|
||||
showToast({
|
||||
message: localize('com_ui_no_changes'),
|
||||
status: 'info',
|
||||
});
|
||||
} else {
|
||||
showToast({
|
||||
message: `${localize('com_assistants_update_success')} ${
|
||||
data.name ?? localize('com_ui_agent')
|
||||
}`,
|
||||
});
|
||||
}
|
||||
// Clear the ref after use
|
||||
previousVersionRef.current = undefined;
|
||||
showToast({
|
||||
message: `${localize('com_assistants_update_success')} ${
|
||||
data.name ?? localize('com_ui_agent')
|
||||
}`,
|
||||
});
|
||||
},
|
||||
onError: (err) => {
|
||||
const error = err as Error;
|
||||
const error = err as Error & {
|
||||
statusCode?: number;
|
||||
details?: { duplicateVersion?: any; versionIndex?: number };
|
||||
response?: { status?: number; data?: any };
|
||||
};
|
||||
|
||||
const isDuplicateVersionError =
|
||||
(error.statusCode === 409 && error.details?.duplicateVersion) ||
|
||||
(error.response?.status === 409 && error.response?.data?.details?.duplicateVersion);
|
||||
|
||||
if (isDuplicateVersionError) {
|
||||
let versionIndex: number | undefined = undefined;
|
||||
|
||||
if (error.details?.versionIndex !== undefined) {
|
||||
versionIndex = error.details.versionIndex;
|
||||
} else if (error.response?.data?.details?.versionIndex !== undefined) {
|
||||
versionIndex = error.response.data.details.versionIndex;
|
||||
}
|
||||
|
||||
if (versionIndex === undefined || versionIndex < 0) {
|
||||
showToast({
|
||||
message: localize('com_agents_update_error'),
|
||||
status: 'error',
|
||||
duration: 5000,
|
||||
});
|
||||
} else {
|
||||
showToast({
|
||||
message: localize('com_ui_agent_version_duplicate', { versionIndex: versionIndex + 1 }),
|
||||
status: 'error',
|
||||
duration: 10000,
|
||||
});
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
showToast({
|
||||
message: `${localize('com_agents_update_error')}${
|
||||
error.message ? ` ${localize('com_ui_error')}: ${error.message}` : ''
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
/* eslint-disable react-hooks/rules-of-hooks */
|
||||
import { ArrowUpDown } from 'lucide-react';
|
||||
import { Button } from '@librechat/client';
|
||||
import type { ColumnDef } from '@tanstack/react-table';
|
||||
|
||||
@@ -5,15 +5,14 @@ import { Button, useToastContext } from '@librechat/client';
|
||||
import { Constants, QueryKeys } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TUpdateUserPlugins } from 'librechat-data-provider';
|
||||
import ServerInitializationSection from '~/components/MCP/ServerInitializationSection';
|
||||
import ServerInitializationSection from '~/components/ui/MCP/ServerInitializationSection';
|
||||
import CustomUserVarsSection from '~/components/ui/MCP/CustomUserVarsSection';
|
||||
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
|
||||
import CustomUserVarsSection from '~/components/MCP/CustomUserVarsSection';
|
||||
import BadgeRowProvider from '~/Providers/BadgeRowContext';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import MCPPanelSkeleton from './MCPPanelSkeleton';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function MCPPanelContent() {
|
||||
export default function MCPPanel() {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const queryClient = useQueryClient();
|
||||
@@ -127,45 +126,50 @@ function MCPPanelContent() {
|
||||
const serverStatus = connectionStatus[selectedServerNameForEditing];
|
||||
|
||||
return (
|
||||
<div className="h-auto max-w-full space-y-4 overflow-x-hidden py-2">
|
||||
<Button variant="outline" onClick={handleGoBackToList} size="sm">
|
||||
<div className="h-auto max-w-full overflow-x-hidden p-3">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={handleGoBackToList}
|
||||
className="mb-3 flex items-center px-3 py-2 text-sm"
|
||||
>
|
||||
<ChevronLeft className="mr-1 h-4 w-4" />
|
||||
{localize('com_ui_back')}
|
||||
</Button>
|
||||
|
||||
<h3 className="mb-3 text-lg font-medium">
|
||||
{localize('com_sidepanel_mcp_variables_for', { '0': serverBeingEdited.serverName })}
|
||||
</h3>
|
||||
|
||||
{/* Server Initialization Section */}
|
||||
<div className="mb-4">
|
||||
<CustomUserVarsSection
|
||||
<ServerInitializationSection
|
||||
serverName={selectedServerNameForEditing}
|
||||
fields={serverBeingEdited.config.customUserVars}
|
||||
onSave={(authData) => {
|
||||
if (selectedServerNameForEditing) {
|
||||
handleConfigSave(selectedServerNameForEditing, authData);
|
||||
}
|
||||
}}
|
||||
onRevoke={() => {
|
||||
if (selectedServerNameForEditing) {
|
||||
handleConfigRevoke(selectedServerNameForEditing);
|
||||
}
|
||||
}}
|
||||
isSubmitting={updateUserPluginsMutation.isLoading}
|
||||
requiresOAuth={serverStatus?.requiresOAuth || false}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<ServerInitializationSection
|
||||
sidePanel={true}
|
||||
{/* Custom User Variables Section */}
|
||||
<CustomUserVarsSection
|
||||
serverName={selectedServerNameForEditing}
|
||||
requiresOAuth={serverStatus?.requiresOAuth || false}
|
||||
hasCustomUserVars={
|
||||
serverBeingEdited.config.customUserVars &&
|
||||
Object.keys(serverBeingEdited.config.customUserVars).length > 0
|
||||
}
|
||||
fields={serverBeingEdited.config.customUserVars}
|
||||
onSave={(authData) => {
|
||||
if (selectedServerNameForEditing) {
|
||||
handleConfigSave(selectedServerNameForEditing, authData);
|
||||
}
|
||||
}}
|
||||
onRevoke={() => {
|
||||
if (selectedServerNameForEditing) {
|
||||
handleConfigRevoke(selectedServerNameForEditing);
|
||||
}
|
||||
}}
|
||||
isSubmitting={updateUserPluginsMutation.isLoading}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
// Server List View
|
||||
return (
|
||||
<div className="h-auto max-w-full overflow-x-hidden py-2">
|
||||
<div className="h-auto max-w-full overflow-x-hidden p-3">
|
||||
<div className="space-y-2">
|
||||
{mcpServerDefinitions.map((server) => {
|
||||
const serverStatus = connectionStatus[server.serverName];
|
||||
@@ -182,7 +186,7 @@ function MCPPanelContent() {
|
||||
<span>{server.serverName}</span>
|
||||
{serverStatus && (
|
||||
<span
|
||||
className={`rounded-xl px-2 py-0.5 text-xs ${
|
||||
className={`rounded px-2 py-0.5 text-xs ${
|
||||
isConnected
|
||||
? 'bg-green-100 text-green-700 dark:bg-green-900 dark:text-green-300'
|
||||
: 'bg-gray-100 text-gray-700 dark:bg-gray-800 dark:text-gray-300'
|
||||
@@ -201,11 +205,3 @@ function MCPPanelContent() {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export default function MCPPanel() {
|
||||
return (
|
||||
<BadgeRowProvider>
|
||||
<MCPPanelContent />
|
||||
</BadgeRowProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,18 +4,18 @@ import { Plus } from 'lucide-react';
|
||||
import { matchSorter } from 'match-sorter';
|
||||
import { SystemRoles, PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import {
|
||||
Spinner,
|
||||
EditIcon,
|
||||
TrashIcon,
|
||||
Table,
|
||||
Input,
|
||||
Label,
|
||||
Button,
|
||||
Switch,
|
||||
Spinner,
|
||||
TableRow,
|
||||
OGDialog,
|
||||
EditIcon,
|
||||
TableHead,
|
||||
TableBody,
|
||||
TrashIcon,
|
||||
TableCell,
|
||||
TableHeader,
|
||||
TooltipAnchor,
|
||||
@@ -25,10 +25,10 @@ import {
|
||||
} from '@librechat/client';
|
||||
import type { TUserMemory } from 'librechat-data-provider';
|
||||
import {
|
||||
useUpdateMemoryPreferencesMutation,
|
||||
useDeleteMemoryMutation,
|
||||
useMemoriesQuery,
|
||||
useGetUserQuery,
|
||||
useMemoriesQuery,
|
||||
useDeleteMemoryMutation,
|
||||
useUpdateMemoryPreferencesMutation,
|
||||
} from '~/data-provider';
|
||||
import { useLocalize, useAuthContext, useHasAccess } from '~/hooks';
|
||||
import MemoryCreateDialog from './MemoryCreateDialog';
|
||||
@@ -36,114 +36,18 @@ import MemoryEditDialog from './MemoryEditDialog';
|
||||
import AdminSettings from './AdminSettings';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
const EditMemoryButton = ({ memory }: { memory: TUserMemory }) => {
|
||||
const localize = useLocalize();
|
||||
const [open, setOpen] = useState(false);
|
||||
const triggerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
return (
|
||||
<MemoryEditDialog
|
||||
open={open}
|
||||
memory={memory}
|
||||
onOpenChange={setOpen}
|
||||
triggerRef={triggerRef as React.MutableRefObject<HTMLButtonElement | null>}
|
||||
>
|
||||
<OGDialogTrigger asChild>
|
||||
<TooltipAnchor
|
||||
description={localize('com_ui_edit_memory')}
|
||||
render={
|
||||
<Button
|
||||
variant="ghost"
|
||||
aria-label={localize('com_ui_bookmarks_edit')}
|
||||
onClick={() => setOpen(!open)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<EditIcon />
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</OGDialogTrigger>
|
||||
</MemoryEditDialog>
|
||||
);
|
||||
};
|
||||
|
||||
const DeleteMemoryButton = ({ memory }: { memory: TUserMemory }) => {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const [open, setOpen] = useState(false);
|
||||
const { mutate: deleteMemory } = useDeleteMemoryMutation();
|
||||
const [deletingKey, setDeletingKey] = useState<string | null>(null);
|
||||
|
||||
const confirmDelete = async () => {
|
||||
setDeletingKey(memory.key);
|
||||
deleteMemory(memory.key, {
|
||||
onSuccess: () => {
|
||||
showToast({
|
||||
message: localize('com_ui_deleted'),
|
||||
status: 'success',
|
||||
});
|
||||
setOpen(false);
|
||||
},
|
||||
onError: () =>
|
||||
showToast({
|
||||
message: localize('com_ui_error'),
|
||||
status: 'error',
|
||||
}),
|
||||
onSettled: () => setDeletingKey(null),
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog open={open} onOpenChange={setOpen}>
|
||||
<OGDialogTrigger asChild>
|
||||
<TooltipAnchor
|
||||
description={localize('com_ui_delete_memory')}
|
||||
render={
|
||||
<Button
|
||||
variant="ghost"
|
||||
aria-label={localize('com_ui_delete')}
|
||||
onClick={() => setOpen(!open)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
{deletingKey === memory.key ? (
|
||||
<Spinner className="size-4 animate-spin" />
|
||||
) : (
|
||||
<TrashIcon className="size-4" />
|
||||
)}
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</OGDialogTrigger>
|
||||
<OGDialogTemplate
|
||||
showCloseButton={false}
|
||||
title={localize('com_ui_delete_memory')}
|
||||
className="w-11/12 max-w-lg"
|
||||
main={
|
||||
<Label className="text-left text-sm font-medium">
|
||||
{localize('com_ui_delete_confirm')} "{memory.key}"?
|
||||
</Label>
|
||||
}
|
||||
selection={{
|
||||
selectHandler: confirmDelete,
|
||||
selectClasses:
|
||||
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 text-white',
|
||||
selectText: localize('com_ui_delete'),
|
||||
}}
|
||||
/>
|
||||
</OGDialog>
|
||||
);
|
||||
};
|
||||
|
||||
const pageSize = 10;
|
||||
export default function MemoryViewer() {
|
||||
const localize = useLocalize();
|
||||
const { user } = useAuthContext();
|
||||
const { data: userData } = useGetUserQuery();
|
||||
const { data: memData, isLoading } = useMemoriesQuery();
|
||||
const { mutate: deleteMemory } = useDeleteMemoryMutation();
|
||||
const { showToast } = useToastContext();
|
||||
const [pageIndex, setPageIndex] = useState(0);
|
||||
const [searchQuery, setSearchQuery] = useState('');
|
||||
const pageSize = 10;
|
||||
const [createDialogOpen, setCreateDialogOpen] = useState(false);
|
||||
const [deletingKey, setDeletingKey] = useState<string | null>(null);
|
||||
const [referenceSavedMemories, setReferenceSavedMemories] = useState(true);
|
||||
|
||||
const updateMemoryPreferencesMutation = useUpdateMemoryPreferencesMutation({
|
||||
@@ -215,6 +119,108 @@ export default function MemoryViewer() {
|
||||
return 'stroke-green-500';
|
||||
};
|
||||
|
||||
const EditMemoryButton = ({ memory }: { memory: TUserMemory }) => {
|
||||
const [open, setOpen] = useState(false);
|
||||
const triggerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Only show edit button if user has UPDATE permission
|
||||
if (!hasUpdateAccess) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<MemoryEditDialog
|
||||
open={open}
|
||||
memory={memory}
|
||||
onOpenChange={setOpen}
|
||||
triggerRef={triggerRef as React.MutableRefObject<HTMLButtonElement | null>}
|
||||
>
|
||||
<OGDialogTrigger asChild>
|
||||
<TooltipAnchor
|
||||
description={localize('com_ui_edit_memory')}
|
||||
render={
|
||||
<Button
|
||||
variant="ghost"
|
||||
aria-label={localize('com_ui_bookmarks_edit')}
|
||||
onClick={() => setOpen(!open)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<EditIcon />
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</OGDialogTrigger>
|
||||
</MemoryEditDialog>
|
||||
);
|
||||
};
|
||||
|
||||
const DeleteMemoryButton = ({ memory }: { memory: TUserMemory }) => {
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
if (!hasUpdateAccess) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const confirmDelete = async () => {
|
||||
setDeletingKey(memory.key);
|
||||
deleteMemory(memory.key, {
|
||||
onSuccess: () => {
|
||||
showToast({
|
||||
message: localize('com_ui_deleted'),
|
||||
status: 'success',
|
||||
});
|
||||
setOpen(false);
|
||||
},
|
||||
onError: () =>
|
||||
showToast({
|
||||
message: localize('com_ui_error'),
|
||||
status: 'error',
|
||||
}),
|
||||
onSettled: () => setDeletingKey(null),
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog open={open} onOpenChange={setOpen}>
|
||||
<OGDialogTrigger asChild>
|
||||
<TooltipAnchor
|
||||
description={localize('com_ui_delete_memory')}
|
||||
render={
|
||||
<Button
|
||||
variant="ghost"
|
||||
aria-label={localize('com_ui_delete')}
|
||||
onClick={() => setOpen(!open)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
{deletingKey === memory.key ? (
|
||||
<Spinner className="size-4 animate-spin" />
|
||||
) : (
|
||||
<TrashIcon className="size-4" />
|
||||
)}
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
</OGDialogTrigger>
|
||||
<OGDialogTemplate
|
||||
showCloseButton={false}
|
||||
title={localize('com_ui_delete_memory')}
|
||||
className="w-11/12 max-w-lg"
|
||||
main={
|
||||
<Label className="text-left text-sm font-medium">
|
||||
{localize('com_ui_delete_confirm')} "{memory.key}"?
|
||||
</Label>
|
||||
}
|
||||
selection={{
|
||||
selectHandler: confirmDelete,
|
||||
selectClasses:
|
||||
'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 text-white',
|
||||
selectText: localize('com_ui_delete'),
|
||||
}}
|
||||
/>
|
||||
</OGDialog>
|
||||
);
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex h-full w-full items-center justify-center p-4">
|
||||
|
||||
@@ -21,7 +21,7 @@ interface SourceItemProps {
|
||||
expanded?: boolean;
|
||||
}
|
||||
|
||||
function SourceItem({ source, isNews: _isNews, expanded = false }: SourceItemProps) {
|
||||
function SourceItem({ source, isNews, expanded = false }: SourceItemProps) {
|
||||
const localize = useLocalize();
|
||||
const domain = getCleanDomain(source.link);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { useForm, Controller } from 'react-hook-form';
|
||||
import { Input, Label, Button, TooltipAnchor, CircleHelpIcon } from '@librechat/client';
|
||||
import { Input, Label, Button } from '@librechat/client';
|
||||
import { useMCPAuthValuesQuery } from '~/data-provider/Tools/queries';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
@@ -31,25 +31,16 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<TooltipAnchor
|
||||
enableHTML={true}
|
||||
description={config.description || ''}
|
||||
render={
|
||||
<div className="flex items-center gap-2">
|
||||
<Label htmlFor={name} className="text-sm font-medium">
|
||||
{config.title}
|
||||
</Label>
|
||||
<CircleHelpIcon className="h-6 w-6 cursor-help text-text-secondary transition-colors hover:text-text-primary" />
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
<Label htmlFor={name} className="text-sm font-medium">
|
||||
{config.title}
|
||||
</Label>
|
||||
{hasValue ? (
|
||||
<div className="flex min-w-fit items-center gap-2 whitespace-nowrap rounded-full border border-border-light px-2 py-0.5 text-xs font-medium text-text-secondary">
|
||||
<div className="flex min-w-fit items-center gap-2 whitespace-nowrap rounded-full border border-border-medium px-2 py-0.5 text-xs font-medium text-text-secondary">
|
||||
<div className="h-1.5 w-1.5 rounded-full bg-green-500" />
|
||||
<span>{localize('com_ui_set')}</span>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex min-w-fit items-center gap-2 whitespace-nowrap rounded-full border border-border-light px-2 py-0.5 text-xs font-medium text-text-secondary">
|
||||
<div className="flex min-w-fit items-center gap-2 whitespace-nowrap rounded-full border border-border-medium px-2 py-0.5 text-xs font-medium text-text-secondary">
|
||||
<div className="h-1.5 w-1.5 rounded-full border border-border-medium" />
|
||||
<span>{localize('com_ui_unset')}</span>
|
||||
</div>
|
||||
@@ -69,10 +60,16 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
|
||||
? localize('com_ui_mcp_update_var', { 0: config.title })
|
||||
: localize('com_ui_mcp_enter_var', { 0: config.title })
|
||||
}
|
||||
className="w-full shadow-sm sm:text-sm"
|
||||
className="w-full rounded-md border-gray-300 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 dark:border-gray-600 dark:bg-gray-700 dark:text-white sm:text-sm"
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
{config.description && (
|
||||
<p
|
||||
className="text-xs text-text-secondary [&_a]:text-blue-500 [&_a]:hover:text-blue-600 dark:[&_a]:text-blue-400 dark:[&_a]:hover:text-blue-300"
|
||||
dangerouslySetInnerHTML={{ __html: config.description }}
|
||||
/>
|
||||
)}
|
||||
{errors[name] && <p className="text-xs text-red-500">{errors[name]?.message}</p>}
|
||||
</div>
|
||||
);
|
||||
@@ -113,15 +110,17 @@ export default function CustomUserVarsSection({
|
||||
|
||||
const handleRevokeClick = () => {
|
||||
onRevoke();
|
||||
// Reset form after revoke
|
||||
reset();
|
||||
};
|
||||
|
||||
// Don't render if no fields to configure
|
||||
if (!fields || Object.keys(fields).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex-1 space-y-4">
|
||||
<div className="space-y-4">
|
||||
<form onSubmit={handleSubmit(onFormSubmit)} className="space-y-4">
|
||||
{Object.entries(fields).map(([key, config]) => {
|
||||
const hasValue = authValuesData?.authValueFlags?.[key] || false;
|
||||
@@ -139,11 +138,21 @@ export default function CustomUserVarsSection({
|
||||
})}
|
||||
</form>
|
||||
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button onClick={handleRevokeClick} variant="destructive" disabled={isSubmitting}>
|
||||
<div className="flex justify-end gap-2 pt-2">
|
||||
<Button
|
||||
onClick={handleRevokeClick}
|
||||
className="bg-red-600 text-white hover:bg-red-700 dark:hover:bg-red-800"
|
||||
disabled={isSubmitting}
|
||||
size="sm"
|
||||
>
|
||||
{localize('com_ui_revoke')}
|
||||
</Button>
|
||||
<Button onClick={handleSubmit(onFormSubmit)} variant="submit" disabled={isSubmitting}>
|
||||
<Button
|
||||
onClick={handleSubmit(onFormSubmit)}
|
||||
className="bg-green-500 text-white hover:bg-green-600"
|
||||
disabled={isSubmitting}
|
||||
size="sm"
|
||||
>
|
||||
{isSubmitting ? localize('com_ui_saving') : localize('com_ui_save')}
|
||||
</Button>
|
||||
</div>
|
||||
@@ -1,13 +1,13 @@
|
||||
import React from 'react';
|
||||
import { KeyRound, PlugZap, AlertTriangle } from 'lucide-react';
|
||||
import { Loader2, KeyRound, PlugZap, AlertTriangle } from 'lucide-react';
|
||||
import { MCPServerStatus } from 'librechat-data-provider/dist/types/types/queries';
|
||||
import {
|
||||
Spinner,
|
||||
OGDialog,
|
||||
OGDialogTitle,
|
||||
OGDialogHeader,
|
||||
OGDialogContent,
|
||||
OGDialogHeader,
|
||||
OGDialogTitle,
|
||||
OGDialogDescription,
|
||||
} from '@librechat/client';
|
||||
import type { MCPServerStatus } from 'librechat-data-provider';
|
||||
import ServerInitializationSection from './ServerInitializationSection';
|
||||
import CustomUserVarsSection from './CustomUserVarsSection';
|
||||
import { useLocalize } from '~/hooks';
|
||||
@@ -45,6 +45,9 @@ export default function MCPConfigDialog({
|
||||
const dialogTitle = hasFields
|
||||
? localize('com_ui_configure_mcp_variables_for', { 0: serverName })
|
||||
: `${serverName} MCP Server`;
|
||||
const dialogDescription = hasFields
|
||||
? localize('com_ui_mcp_dialog_desc')
|
||||
: `Manage connection and settings for the ${serverName} MCP server.`;
|
||||
|
||||
// Helper function to render status badge based on connection state
|
||||
const renderStatusBadge = () => {
|
||||
@@ -57,7 +60,7 @@ export default function MCPConfigDialog({
|
||||
if (connectionState === 'connecting') {
|
||||
return (
|
||||
<div className="flex items-center gap-2 rounded-full bg-blue-50 px-2 py-0.5 text-xs font-medium text-blue-600 dark:bg-blue-950 dark:text-blue-400">
|
||||
<Spinner className="h-3 w-3" />
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
<span>{localize('com_ui_connecting')}</span>
|
||||
</div>
|
||||
);
|
||||
@@ -104,30 +107,31 @@ export default function MCPConfigDialog({
|
||||
|
||||
return (
|
||||
<OGDialog open={isOpen} onOpenChange={onOpenChange}>
|
||||
<OGDialogContent className="flex max-h-screen w-11/12 max-w-lg flex-col space-y-2">
|
||||
<OGDialogContent className="flex max-h-[90vh] w-full max-w-md flex-col">
|
||||
<OGDialogHeader>
|
||||
<div className="flex items-center gap-3">
|
||||
<OGDialogTitle className="text-xl">
|
||||
{dialogTitle.charAt(0).toUpperCase() + dialogTitle.slice(1)}
|
||||
</OGDialogTitle>
|
||||
<OGDialogTitle>{dialogTitle}</OGDialogTitle>
|
||||
{renderStatusBadge()}
|
||||
</div>
|
||||
<OGDialogDescription>{dialogDescription}</OGDialogDescription>
|
||||
</OGDialogHeader>
|
||||
|
||||
{/* Custom User Variables Section */}
|
||||
<CustomUserVarsSection
|
||||
serverName={serverName}
|
||||
fields={fieldsSchema}
|
||||
onSave={onSave}
|
||||
onRevoke={onRevoke || (() => {})}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
{/* Content */}
|
||||
<div className="flex-1 overflow-y-auto p-6">
|
||||
{/* Custom User Variables Section */}
|
||||
<CustomUserVarsSection
|
||||
serverName={serverName}
|
||||
fields={fieldsSchema}
|
||||
onSave={onSave}
|
||||
onRevoke={onRevoke || (() => {})}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Server Initialization Section */}
|
||||
<ServerInitializationSection
|
||||
serverName={serverName}
|
||||
requiresOAuth={serverStatus?.requiresOAuth || false}
|
||||
hasCustomUserVars={fieldsSchema && Object.keys(fieldsSchema).length > 0}
|
||||
/>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
@@ -1,6 +1,5 @@
|
||||
import React from 'react';
|
||||
import { Spinner } from '@librechat/client';
|
||||
import { SettingsIcon, AlertTriangle, KeyRound, PlugZap, X } from 'lucide-react';
|
||||
import { SettingsIcon, AlertTriangle, Loader2, KeyRound, PlugZap, X } from 'lucide-react';
|
||||
import type { MCPServerStatus, TPlugin } from 'librechat-data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
@@ -97,12 +96,12 @@ function InitializingStatusIcon({ serverName, onCancel, canCancel }: Initializin
|
||||
<button
|
||||
type="button"
|
||||
onClick={onCancel}
|
||||
className="group flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-red-100 dark:hover:bg-red-900/20"
|
||||
className="flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-red-100 dark:hover:bg-red-900/20"
|
||||
aria-label={localize('com_ui_cancel')}
|
||||
title={localize('com_ui_cancel')}
|
||||
>
|
||||
<div className="relative h-4 w-4">
|
||||
<Spinner className="h-4 w-4 group-hover:opacity-0" />
|
||||
<div className="group relative h-4 w-4">
|
||||
<Loader2 className="h-4 w-4 animate-spin text-blue-500 group-hover:opacity-0" />
|
||||
<X className="absolute inset-0 h-4 w-4 text-red-500 opacity-0 group-hover:opacity-100" />
|
||||
</div>
|
||||
</button>
|
||||
@@ -111,8 +110,8 @@ function InitializingStatusIcon({ serverName, onCancel, canCancel }: Initializin
|
||||
|
||||
return (
|
||||
<div className="flex h-6 w-6 items-center justify-center rounded p-1">
|
||||
<Spinner
|
||||
className="h-4 w-4"
|
||||
<Loader2
|
||||
className="h-4 w-4 animate-spin text-blue-500"
|
||||
aria-label={localize('com_nav_mcp_status_connecting', { 0: serverName })}
|
||||
/>
|
||||
</div>
|
||||
@@ -122,8 +121,8 @@ function InitializingStatusIcon({ serverName, onCancel, canCancel }: Initializin
|
||||
function ConnectingStatusIcon({ serverName }: StatusIconProps) {
|
||||
return (
|
||||
<div className="flex h-6 w-6 items-center justify-center rounded p-1">
|
||||
<Spinner
|
||||
className="h-4 w-4"
|
||||
<Loader2
|
||||
className="h-4 w-4 animate-spin text-blue-500"
|
||||
aria-label={localize('com_nav_mcp_status_connecting', { 0: serverName })}
|
||||
/>
|
||||
</div>
|
||||
131
client/src/components/ui/MCP/ServerInitializationSection.tsx
Normal file
131
client/src/components/ui/MCP/ServerInitializationSection.tsx
Normal file
@@ -0,0 +1,131 @@
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import { Button } from '@librechat/client';
|
||||
import { RefreshCw, Link } from 'lucide-react';
|
||||
import { useMCPServerInitialization } from '~/hooks/MCP/useMCPServerInitialization';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface ServerInitializationSectionProps {
|
||||
serverName: string;
|
||||
requiresOAuth: boolean;
|
||||
}
|
||||
|
||||
export default function ServerInitializationSection({
|
||||
serverName,
|
||||
requiresOAuth,
|
||||
}: ServerInitializationSectionProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const [oauthUrl, setOauthUrl] = useState<string | null>(null);
|
||||
|
||||
// Use the shared initialization hook
|
||||
const { initializeServer, isLoading, connectionStatus, cancelOAuthFlow, isCancellable } =
|
||||
useMCPServerInitialization({
|
||||
onOAuthStarted: (name, url) => {
|
||||
// Store the OAuth URL locally for display
|
||||
setOauthUrl(url);
|
||||
},
|
||||
onSuccess: () => {
|
||||
// Clear OAuth URL on success
|
||||
setOauthUrl(null);
|
||||
},
|
||||
});
|
||||
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
const isConnected = serverStatus?.connectionState === 'connected';
|
||||
const canCancel = isCancellable(serverName);
|
||||
|
||||
const handleInitializeClick = useCallback(() => {
|
||||
setOauthUrl(null);
|
||||
initializeServer(serverName);
|
||||
}, [initializeServer, serverName]);
|
||||
|
||||
const handleCancelClick = useCallback(() => {
|
||||
setOauthUrl(null);
|
||||
cancelOAuthFlow(serverName);
|
||||
}, [cancelOAuthFlow, serverName]);
|
||||
|
||||
// Show subtle reinitialize option if connected
|
||||
if (isConnected) {
|
||||
return (
|
||||
<div className="flex justify-start">
|
||||
<button
|
||||
onClick={handleInitializeClick}
|
||||
disabled={isLoading}
|
||||
className="flex items-center gap-1 text-xs text-gray-400 hover:text-gray-600 disabled:opacity-50 dark:text-gray-500 dark:hover:text-gray-400"
|
||||
>
|
||||
<RefreshCw className={`h-3 w-3 ${isLoading ? 'animate-spin' : ''}`} />
|
||||
{isLoading ? localize('com_ui_loading') : localize('com_ui_reinitialize')}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-[#991b1b] bg-[#2C1315] p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium text-red-700 dark:text-red-300">
|
||||
{requiresOAuth
|
||||
? localize('com_ui_mcp_not_authenticated', { 0: serverName })
|
||||
: localize('com_ui_mcp_not_initialized', { 0: serverName })}
|
||||
</span>
|
||||
</div>
|
||||
{/* Only show authenticate button when OAuth URL is not present */}
|
||||
{!oauthUrl && (
|
||||
<Button
|
||||
onClick={handleInitializeClick}
|
||||
disabled={isLoading}
|
||||
className="flex items-center gap-2 bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 dark:hover:bg-blue-800"
|
||||
>
|
||||
{isLoading ? (
|
||||
<>
|
||||
<RefreshCw className="h-4 w-4 animate-spin" />
|
||||
{localize('com_ui_loading')}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<RefreshCw className="h-4 w-4" />
|
||||
{requiresOAuth
|
||||
? localize('com_ui_authenticate')
|
||||
: localize('com_ui_mcp_initialize')}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* OAuth URL display */}
|
||||
{oauthUrl && (
|
||||
<div className="mt-4 rounded-lg border border-blue-200 bg-blue-50 p-3 dark:border-blue-700 dark:bg-blue-900/20">
|
||||
<div className="mb-2 flex items-center gap-2">
|
||||
<div className="flex h-4 w-4 items-center justify-center rounded-full bg-blue-500">
|
||||
<Link className="h-2.5 w-2.5 text-white" />
|
||||
</div>
|
||||
<span className="text-sm font-medium text-blue-700 dark:text-blue-300">
|
||||
{localize('com_ui_auth_url')}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
onClick={() => window.open(oauthUrl, '_blank', 'noopener,noreferrer')}
|
||||
className="flex-1 bg-blue-600 text-white hover:bg-blue-700 dark:hover:bg-blue-800"
|
||||
>
|
||||
{localize('com_ui_continue_oauth')}
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCancelClick}
|
||||
disabled={!canCancel}
|
||||
className="bg-gray-200 text-gray-700 hover:bg-gray-300 disabled:cursor-not-allowed disabled:opacity-50 dark:bg-gray-700 dark:text-gray-200 dark:hover:bg-gray-600"
|
||||
title={!canCancel ? 'disabled' : undefined}
|
||||
>
|
||||
{localize('com_ui_cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
<p className="mt-2 text-xs text-blue-600 dark:text-blue-400">
|
||||
{localize('com_ui_oauth_flow_desc')}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -43,7 +43,11 @@ export const useCreateAgentMutation = (
|
||||
*/
|
||||
export const useUpdateAgentMutation = (
|
||||
options?: t.UpdateAgentMutationOptions,
|
||||
): UseMutationResult<t.Agent, Error, { agent_id: string; data: t.AgentUpdateParams }> => {
|
||||
): UseMutationResult<
|
||||
t.Agent,
|
||||
t.DuplicateVersionError,
|
||||
{ agent_id: string; data: t.AgentUpdateParams }
|
||||
> => {
|
||||
const queryClient = useQueryClient();
|
||||
return useMutation(
|
||||
({ agent_id, data }: { agent_id: string; data: t.AgentUpdateParams }) => {
|
||||
@@ -55,7 +59,8 @@ export const useUpdateAgentMutation = (
|
||||
{
|
||||
onMutate: (variables) => options?.onMutate?.(variables),
|
||||
onError: (error, variables, context) => {
|
||||
return options?.onError?.(error, variables, context);
|
||||
const typedError = error as t.DuplicateVersionError;
|
||||
return options?.onError?.(typedError, variables, context);
|
||||
},
|
||||
onSuccess: (updatedAgent, variables, context) => {
|
||||
const listRes = queryClient.getQueryData<t.AgentListResponse>([
|
||||
|
||||
@@ -55,10 +55,9 @@ export const useUpdateConversationMutation = (
|
||||
return useMutation(
|
||||
(payload: t.TUpdateConversationRequest) => dataService.updateConversation(payload),
|
||||
{
|
||||
onSuccess: (updatedConvo, payload) => {
|
||||
const targetId = payload.conversationId || id;
|
||||
queryClient.setQueryData([QueryKeys.conversation, targetId], updatedConvo);
|
||||
updateConvoInAllQueries(queryClient, targetId, () => updatedConvo);
|
||||
onSuccess: (updatedConvo) => {
|
||||
queryClient.setQueryData([QueryKeys.conversation, id], updatedConvo);
|
||||
updateConvoInAllQueries(queryClient, id, () => updatedConvo);
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import { useMemo, useState, useEffect, useRef } from 'react';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useRecoilState, useRecoilValue, useResetRecoilState } from 'recoil';
|
||||
import { logger } from '~/utils';
|
||||
import { useArtifactsContext } from '~/Providers';
|
||||
import { getLatestText, logger } from '~/utils';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { getKey } from '~/utils/artifacts';
|
||||
import store from '~/store';
|
||||
|
||||
export default function useArtifacts() {
|
||||
const [activeTab, setActiveTab] = useState('preview');
|
||||
const { isSubmitting, latestMessageId, latestMessageText, conversationId } =
|
||||
useArtifactsContext();
|
||||
const { isSubmitting, latestMessage, conversation } = useChatContext();
|
||||
|
||||
const artifacts = useRecoilValue(store.artifactsState);
|
||||
const resetArtifacts = useResetRecoilState(store.artifactsState);
|
||||
@@ -32,23 +31,26 @@ export default function useArtifacts() {
|
||||
const resetState = () => {
|
||||
resetArtifacts();
|
||||
resetCurrentArtifactId();
|
||||
prevConversationIdRef.current = conversationId;
|
||||
prevConversationIdRef.current = conversation?.conversationId ?? null;
|
||||
lastRunMessageIdRef.current = null;
|
||||
lastContentRef.current = null;
|
||||
hasEnclosedArtifactRef.current = false;
|
||||
};
|
||||
if (conversationId !== prevConversationIdRef.current && prevConversationIdRef.current != null) {
|
||||
if (
|
||||
conversation?.conversationId !== prevConversationIdRef.current &&
|
||||
prevConversationIdRef.current != null
|
||||
) {
|
||||
resetState();
|
||||
} else if (conversationId === Constants.NEW_CONVO) {
|
||||
} else if (conversation?.conversationId === Constants.NEW_CONVO) {
|
||||
resetState();
|
||||
}
|
||||
prevConversationIdRef.current = conversationId;
|
||||
prevConversationIdRef.current = conversation?.conversationId ?? null;
|
||||
/** Resets artifacts when unmounting */
|
||||
return () => {
|
||||
logger.log('artifacts_visibility', 'Unmounting artifacts');
|
||||
resetState();
|
||||
};
|
||||
}, [conversationId, resetArtifacts, resetCurrentArtifactId]);
|
||||
}, [conversation?.conversationId, resetArtifacts, resetCurrentArtifactId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (orderedArtifactIds.length > 0) {
|
||||
@@ -64,7 +66,7 @@ export default function useArtifacts() {
|
||||
if (orderedArtifactIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
if (latestMessageId == null) {
|
||||
if (latestMessage == null) {
|
||||
return;
|
||||
}
|
||||
const latestArtifactId = orderedArtifactIds[orderedArtifactIds.length - 1];
|
||||
@@ -76,6 +78,7 @@ export default function useArtifacts() {
|
||||
setCurrentArtifactId(latestArtifactId);
|
||||
lastContentRef.current = latestArtifact?.content ?? null;
|
||||
|
||||
const latestMessageText = getLatestText(latestMessage);
|
||||
const hasEnclosedArtifact =
|
||||
/:::artifact(?:\{[^}]*\})?(?:\s|\n)*(?:```[\s\S]*?```(?:\s|\n)*)?:::/m.test(
|
||||
latestMessageText.trim(),
|
||||
@@ -92,22 +95,15 @@ export default function useArtifacts() {
|
||||
hasAutoSwitchedToCodeRef.current = true;
|
||||
}
|
||||
}
|
||||
}, [
|
||||
artifacts,
|
||||
isSubmitting,
|
||||
latestMessageId,
|
||||
latestMessageText,
|
||||
orderedArtifactIds,
|
||||
setCurrentArtifactId,
|
||||
]);
|
||||
}, [setCurrentArtifactId, isSubmitting, orderedArtifactIds, artifacts, latestMessage]);
|
||||
|
||||
useEffect(() => {
|
||||
if (latestMessageId !== lastRunMessageIdRef.current) {
|
||||
lastRunMessageIdRef.current = latestMessageId;
|
||||
if (latestMessage?.messageId !== lastRunMessageIdRef.current) {
|
||||
lastRunMessageIdRef.current = latestMessage?.messageId ?? null;
|
||||
hasEnclosedArtifactRef.current = false;
|
||||
hasAutoSwitchedToCodeRef.current = false;
|
||||
}
|
||||
}, [latestMessageId]);
|
||||
}, [latestMessage]);
|
||||
|
||||
const currentArtifact = currentArtifactId != null ? artifacts?.[currentArtifactId] : null;
|
||||
|
||||
@@ -135,6 +131,7 @@ export default function useArtifacts() {
|
||||
isMermaid,
|
||||
setActiveTab,
|
||||
currentIndex,
|
||||
isSubmitting,
|
||||
cycleArtifact,
|
||||
currentArtifact,
|
||||
orderedArtifactIds,
|
||||
|
||||
@@ -107,7 +107,7 @@ const useSpeechToTextExternal = (
|
||||
});
|
||||
setPermission(true);
|
||||
audioStream.current = streamData ?? null;
|
||||
} catch {
|
||||
} catch (err) {
|
||||
setPermission(false);
|
||||
}
|
||||
};
|
||||
@@ -268,7 +268,6 @@ const useSpeechToTextExternal = (
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleKeyDown);
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [isListening]);
|
||||
|
||||
return {
|
||||
|
||||
@@ -1 +1 @@
|
||||
export { useMCPServerManager } from './useMCPServerManager';
|
||||
export { useMCPServerInitialization } from './useMCPServerInitialization';
|
||||
|
||||
317
client/src/hooks/MCP/useMCPServerInitialization.ts
Normal file
317
client/src/hooks/MCP/useMCPServerInitialization.ts
Normal file
@@ -0,0 +1,317 @@
|
||||
import { useCallback, useState, useEffect, useMemo } from 'react';
|
||||
import { useToastContext } from '@librechat/client';
|
||||
import { QueryKeys } from 'librechat-data-provider';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import {
|
||||
useReinitializeMCPServerMutation,
|
||||
useCancelMCPOAuthMutation,
|
||||
} from 'librechat-data-provider/react-query';
|
||||
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { logger } from '~/utils';
|
||||
|
||||
interface UseMCPServerInitializationOptions {
|
||||
onSuccess?: (serverName: string) => void;
|
||||
onOAuthStarted?: (serverName: string, oauthUrl: string) => void;
|
||||
onError?: (serverName: string, error: any) => void;
|
||||
}
|
||||
|
||||
export function useMCPServerInitialization(options?: UseMCPServerInitializationOptions) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
// OAuth state management
|
||||
const [oauthPollingServers, setOauthPollingServers] = useState<Map<string, string>>(new Map());
|
||||
const [oauthStartTimes, setOauthStartTimes] = useState<Map<string, number>>(new Map());
|
||||
const [initializingServers, setInitializingServers] = useState<Set<string>>(new Set());
|
||||
const [cancellableServers, setCancellableServers] = useState<Set<string>>(new Set());
|
||||
|
||||
// Get connection status
|
||||
const { data: connectionStatusData } = useMCPConnectionStatusQuery();
|
||||
const connectionStatus = useMemo(
|
||||
() => connectionStatusData?.connectionStatus || {},
|
||||
[connectionStatusData],
|
||||
);
|
||||
|
||||
// Main initialization mutation
|
||||
const reinitializeMutation = useReinitializeMCPServerMutation();
|
||||
|
||||
// Track which server is currently being processed
|
||||
const [currentProcessingServer, setCurrentProcessingServer] = useState<string | null>(null);
|
||||
|
||||
// Cancel OAuth mutation
|
||||
const cancelOAuthMutation = useCancelMCPOAuthMutation();
|
||||
|
||||
// Helper function to clean up OAuth state
|
||||
const cleanupOAuthState = useCallback((serverName: string) => {
|
||||
setOauthPollingServers((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.delete(serverName);
|
||||
return newMap;
|
||||
});
|
||||
|
||||
setOauthStartTimes((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.delete(serverName);
|
||||
return newMap;
|
||||
});
|
||||
|
||||
setInitializingServers((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(serverName);
|
||||
return newSet;
|
||||
});
|
||||
|
||||
setCancellableServers((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(serverName);
|
||||
return newSet;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Cancel OAuth flow
|
||||
const cancelOAuthFlow = useCallback(
|
||||
(serverName: string) => {
|
||||
logger.info(`[MCP OAuth] User cancelling OAuth flow for ${serverName}`);
|
||||
|
||||
cancelOAuthMutation.mutate(serverName, {
|
||||
onSuccess: () => {
|
||||
cleanupOAuthState(serverName);
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_oauth_cancelled', { 0: serverName }),
|
||||
status: 'info',
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
logger.error(`[MCP OAuth] Failed to cancel OAuth flow for ${serverName}:`, error);
|
||||
// Clean up state anyway
|
||||
cleanupOAuthState(serverName);
|
||||
},
|
||||
});
|
||||
},
|
||||
[cancelOAuthMutation, cleanupOAuthState, showToast, localize],
|
||||
);
|
||||
|
||||
// Helper function to handle successful connection
|
||||
const handleSuccessfulConnection = useCallback(
|
||||
async (serverName: string, message: string) => {
|
||||
showToast({ message, status: 'success' });
|
||||
|
||||
// Force immediate refetch to update UI
|
||||
await Promise.all([
|
||||
queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]),
|
||||
queryClient.refetchQueries([QueryKeys.tools]),
|
||||
]);
|
||||
|
||||
// Clean up OAuth state
|
||||
cleanupOAuthState(serverName);
|
||||
|
||||
// Call optional success callback
|
||||
options?.onSuccess?.(serverName);
|
||||
},
|
||||
[showToast, queryClient, options, cleanupOAuthState],
|
||||
);
|
||||
|
||||
// Helper function to handle OAuth timeout/failure
|
||||
const handleOAuthFailure = useCallback(
|
||||
(serverName: string, isTimeout: boolean) => {
|
||||
logger.warn(
|
||||
`[MCP OAuth] OAuth ${isTimeout ? 'timed out' : 'failed'} for ${serverName}, stopping poll`,
|
||||
);
|
||||
|
||||
// Clean up OAuth state
|
||||
cleanupOAuthState(serverName);
|
||||
|
||||
// Show error toast
|
||||
showToast({
|
||||
message: isTimeout
|
||||
? localize('com_ui_mcp_oauth_timeout', { 0: serverName })
|
||||
: localize('com_ui_mcp_init_failed'),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
[showToast, localize, cleanupOAuthState],
|
||||
);
|
||||
|
||||
// Poll for OAuth completion
|
||||
useEffect(() => {
|
||||
if (oauthPollingServers.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const pollInterval = setInterval(() => {
|
||||
// Check each polling server
|
||||
oauthPollingServers.forEach((oauthUrl, serverName) => {
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
|
||||
// Check for client-side timeout (3 minutes)
|
||||
const startTime = oauthStartTimes.get(serverName);
|
||||
const hasTimedOut = startTime && Date.now() - startTime > 180000; // 3 minutes
|
||||
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
// OAuth completed successfully
|
||||
handleSuccessfulConnection(
|
||||
serverName,
|
||||
localize('com_ui_mcp_authenticated_success', { 0: serverName }),
|
||||
);
|
||||
} else if (serverStatus?.connectionState === 'error' || hasTimedOut) {
|
||||
// OAuth failed or timed out
|
||||
handleOAuthFailure(serverName, !!hasTimedOut);
|
||||
}
|
||||
|
||||
setCancellableServers((prev) => new Set(prev).add(serverName));
|
||||
});
|
||||
|
||||
queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
|
||||
}, 3500);
|
||||
|
||||
return () => {
|
||||
clearInterval(pollInterval);
|
||||
};
|
||||
}, [
|
||||
oauthPollingServers,
|
||||
oauthStartTimes,
|
||||
connectionStatus,
|
||||
queryClient,
|
||||
handleSuccessfulConnection,
|
||||
handleOAuthFailure,
|
||||
localize,
|
||||
]);
|
||||
|
||||
// Initialize server function
|
||||
const initializeServer = useCallback(
|
||||
(serverName: string) => {
|
||||
// Prevent spam - check if already initializing
|
||||
if (initializingServers.has(serverName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (connectionStatus[serverName]?.requiresOAuth) {
|
||||
setCancellableServers((prev) => new Set(prev).add(serverName));
|
||||
}
|
||||
|
||||
// Add to initializing set
|
||||
setInitializingServers((prev) => new Set(prev).add(serverName));
|
||||
|
||||
// If there's already a server being processed, that one will be cancelled
|
||||
if (currentProcessingServer && currentProcessingServer !== serverName) {
|
||||
// Clean up the cancelled server's state immediately
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_cancelled', { 0: currentProcessingServer }),
|
||||
status: 'warning',
|
||||
});
|
||||
|
||||
cleanupOAuthState(currentProcessingServer);
|
||||
}
|
||||
|
||||
// Track the current server being processed
|
||||
setCurrentProcessingServer(serverName);
|
||||
|
||||
reinitializeMutation.mutate(serverName, {
|
||||
onSuccess: (response: any) => {
|
||||
// Clear current processing server
|
||||
setCurrentProcessingServer(null);
|
||||
|
||||
if (response.success) {
|
||||
if (response.oauthRequired && response.oauthUrl) {
|
||||
// OAuth required - store URL and start polling
|
||||
setOauthPollingServers((prev) => new Map(prev).set(serverName, response.oauthUrl));
|
||||
|
||||
// Track when OAuth started for timeout detection
|
||||
setOauthStartTimes((prev) => new Map(prev).set(serverName, Date.now()));
|
||||
|
||||
// Call optional OAuth callback or open URL directly
|
||||
if (options?.onOAuthStarted) {
|
||||
options.onOAuthStarted(serverName, response.oauthUrl);
|
||||
} else {
|
||||
window.open(response.oauthUrl, '_blank', 'noopener,noreferrer');
|
||||
}
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_connecting'),
|
||||
status: 'info',
|
||||
});
|
||||
} else if (response.oauthRequired) {
|
||||
// OAuth required but no URL - shouldn't happen
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_oauth_no_url'),
|
||||
status: 'warning',
|
||||
});
|
||||
// Remove from initializing since it failed
|
||||
setInitializingServers((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(serverName);
|
||||
return newSet;
|
||||
});
|
||||
} else {
|
||||
// Successful connection without OAuth
|
||||
handleSuccessfulConnection(
|
||||
serverName,
|
||||
response.message || localize('com_ui_mcp_initialized_success', { 0: serverName }),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Remove from initializing if not successful
|
||||
setInitializingServers((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(serverName);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
},
|
||||
onError: (error: any) => {
|
||||
console.error(`Error initializing MCP server ${serverName}:`, error);
|
||||
setCurrentProcessingServer(null);
|
||||
|
||||
const isCancelled =
|
||||
error?.name === 'CanceledError' ||
|
||||
error?.code === 'ERR_CANCELED' ||
|
||||
error?.message?.includes('cancel') ||
|
||||
error?.message?.includes('abort');
|
||||
|
||||
if (isCancelled) {
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_cancelled', { 0: serverName }),
|
||||
status: 'warning',
|
||||
});
|
||||
} else {
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_failed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
|
||||
// Clean up OAuth state using helper function
|
||||
cleanupOAuthState(serverName);
|
||||
|
||||
// Call optional error callback
|
||||
options?.onError?.(serverName, error);
|
||||
},
|
||||
});
|
||||
},
|
||||
[
|
||||
initializingServers,
|
||||
connectionStatus,
|
||||
currentProcessingServer,
|
||||
reinitializeMutation,
|
||||
showToast,
|
||||
localize,
|
||||
cleanupOAuthState,
|
||||
options,
|
||||
handleSuccessfulConnection,
|
||||
],
|
||||
);
|
||||
|
||||
return {
|
||||
initializeServer,
|
||||
isInitializing: (serverName: string) => initializingServers.has(serverName),
|
||||
isCancellable: (serverName: string) => cancellableServers.has(serverName),
|
||||
initializingServers,
|
||||
oauthPollingServers,
|
||||
oauthStartTimes,
|
||||
connectionStatus,
|
||||
isLoading: reinitializeMutation.isLoading,
|
||||
cancelOAuthFlow,
|
||||
};
|
||||
}
|
||||
@@ -1,52 +1,34 @@
|
||||
import { useCallback, useState, useMemo, useRef, useEffect } from 'react';
|
||||
import { useToastContext } from '@librechat/client';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { Constants, QueryKeys } from 'librechat-data-provider';
|
||||
import {
|
||||
useCancelMCPOAuthMutation,
|
||||
useUpdateUserPluginsMutation,
|
||||
useReinitializeMCPServerMutation,
|
||||
} from 'librechat-data-provider/react-query';
|
||||
import { useCallback, useState, useMemo, useRef } from 'react';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import { useMCPServerInitialization } from '~/hooks/MCP/useMCPServerInitialization';
|
||||
import type { ConfigFieldDetail } from '~/components/ui/MCP/MCPConfigDialog';
|
||||
import type { TUpdateUserPlugins, TPlugin } from 'librechat-data-provider';
|
||||
import type { ConfigFieldDetail } from '~/components/MCP/MCPConfigDialog';
|
||||
import { useMCPConnectionStatusQuery } from '~/data-provider/Tools/queries';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface ServerState {
|
||||
isInitializing: boolean;
|
||||
oauthUrl: string | null;
|
||||
oauthStartTime: number | null;
|
||||
isCancellable: boolean;
|
||||
pollInterval: NodeJS.Timeout | null;
|
||||
}
|
||||
|
||||
export function useMCPServerManager() {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const { mcpSelect, startupConfig } = useBadgeRowContext();
|
||||
const { mcpValues, setMCPValues, mcpToolDetails, isPinned, setIsPinned } = mcpSelect;
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false);
|
||||
const [selectedToolForConfig, setSelectedToolForConfig] = useState<TPlugin | null>(null);
|
||||
const previousFocusRef = useRef<HTMLElement | null>(null);
|
||||
const mcpValuesRef = useRef(mcpValues);
|
||||
|
||||
// fixes the issue where OAuth flows would deselect all the servers except the one that is being authenticated on success
|
||||
useEffect(() => {
|
||||
mcpValuesRef.current = mcpValues;
|
||||
}, [mcpValues]);
|
||||
|
||||
const configuredServers = useMemo(() => {
|
||||
if (!startupConfig?.mcpServers) return [];
|
||||
if (!startupConfig?.mcpServers) {
|
||||
return [];
|
||||
}
|
||||
return Object.entries(startupConfig.mcpServers)
|
||||
.filter(([, config]) => config.chatMenu !== false)
|
||||
.map(([serverName]) => serverName);
|
||||
}, [startupConfig?.mcpServers]);
|
||||
|
||||
const reinitializeMutation = useReinitializeMCPServerMutation();
|
||||
const cancelOAuthMutation = useCancelMCPOAuthMutation();
|
||||
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false);
|
||||
const [selectedToolForConfig, setSelectedToolForConfig] = useState<TPlugin | null>(null);
|
||||
const previousFocusRef = useRef<HTMLElement | null>(null);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
|
||||
onSuccess: async () => {
|
||||
@@ -67,314 +49,52 @@ export function useMCPServerManager() {
|
||||
},
|
||||
});
|
||||
|
||||
const [serverStates, setServerStates] = useState<Record<string, ServerState>>(() => {
|
||||
const initialStates: Record<string, ServerState> = {};
|
||||
configuredServers.forEach((serverName) => {
|
||||
initialStates[serverName] = {
|
||||
isInitializing: false,
|
||||
oauthUrl: null,
|
||||
oauthStartTime: null,
|
||||
isCancellable: false,
|
||||
pollInterval: null,
|
||||
};
|
||||
});
|
||||
return initialStates;
|
||||
});
|
||||
|
||||
const { data: connectionStatusData } = useMCPConnectionStatusQuery({
|
||||
enabled: !!startupConfig?.mcpServers && Object.keys(startupConfig.mcpServers).length > 0,
|
||||
});
|
||||
const connectionStatus = useMemo(
|
||||
() => connectionStatusData?.connectionStatus || {},
|
||||
[connectionStatusData?.connectionStatus],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!mcpValues?.length) return;
|
||||
|
||||
const connectedSelected = mcpValues.filter(
|
||||
(serverName) => connectionStatus[serverName]?.connectionState === 'connected',
|
||||
);
|
||||
|
||||
if (connectedSelected.length !== mcpValues.length) {
|
||||
setMCPValues(connectedSelected);
|
||||
}
|
||||
}, [connectionStatus, mcpValues, setMCPValues]);
|
||||
|
||||
const updateServerState = useCallback((serverName: string, updates: Partial<ServerState>) => {
|
||||
setServerStates((prev) => {
|
||||
const newStates = { ...prev };
|
||||
const currentState = newStates[serverName] || {
|
||||
isInitializing: false,
|
||||
oauthUrl: null,
|
||||
oauthStartTime: null,
|
||||
isCancellable: false,
|
||||
pollInterval: null,
|
||||
};
|
||||
newStates[serverName] = { ...currentState, ...updates };
|
||||
return newStates;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const cleanupServerState = useCallback(
|
||||
(serverName: string) => {
|
||||
const state = serverStates[serverName];
|
||||
if (state?.pollInterval) {
|
||||
clearInterval(state.pollInterval);
|
||||
}
|
||||
updateServerState(serverName, {
|
||||
isInitializing: false,
|
||||
oauthUrl: null,
|
||||
oauthStartTime: null,
|
||||
isCancellable: false,
|
||||
pollInterval: null,
|
||||
});
|
||||
},
|
||||
[serverStates, updateServerState],
|
||||
);
|
||||
|
||||
const startServerPolling = useCallback(
|
||||
(serverName: string) => {
|
||||
const pollInterval = setInterval(async () => {
|
||||
try {
|
||||
await queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
|
||||
|
||||
const freshConnectionData = queryClient.getQueryData([
|
||||
QueryKeys.mcpConnectionStatus,
|
||||
]) as any;
|
||||
const freshConnectionStatus = freshConnectionData?.connectionStatus || {};
|
||||
|
||||
const state = serverStates[serverName];
|
||||
const serverStatus = freshConnectionStatus[serverName];
|
||||
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
clearInterval(pollInterval);
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_authenticated_success', { 0: serverName }),
|
||||
status: 'success',
|
||||
});
|
||||
|
||||
const currentValues = mcpValuesRef.current ?? [];
|
||||
if (!currentValues.includes(serverName)) {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
}
|
||||
|
||||
await queryClient.invalidateQueries([QueryKeys.tools]);
|
||||
|
||||
// This delay is to ensure UI has updated with new connection status before cleanup
|
||||
// Otherwise servers will show as disconnected for a second after OAuth flow completes
|
||||
setTimeout(() => {
|
||||
cleanupServerState(serverName);
|
||||
}, 1000);
|
||||
return;
|
||||
}
|
||||
|
||||
if (state?.oauthStartTime && Date.now() - state.oauthStartTime > 180000) {
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_oauth_timeout', { 0: serverName }),
|
||||
status: 'error',
|
||||
});
|
||||
clearInterval(pollInterval);
|
||||
cleanupServerState(serverName);
|
||||
return;
|
||||
}
|
||||
|
||||
if (serverStatus?.connectionState === 'error') {
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_failed'),
|
||||
status: 'error',
|
||||
});
|
||||
clearInterval(pollInterval);
|
||||
cleanupServerState(serverName);
|
||||
return;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[MCP Manager] Error polling server ${serverName}:`, error);
|
||||
clearInterval(pollInterval);
|
||||
cleanupServerState(serverName);
|
||||
return;
|
||||
}
|
||||
}, 3500);
|
||||
|
||||
updateServerState(serverName, { pollInterval });
|
||||
},
|
||||
[
|
||||
queryClient,
|
||||
serverStates,
|
||||
showToast,
|
||||
localize,
|
||||
setMCPValues,
|
||||
cleanupServerState,
|
||||
updateServerState,
|
||||
],
|
||||
);
|
||||
|
||||
const initializeServer = useCallback(
|
||||
async (serverName: string, autoOpenOAuth: boolean = true) => {
|
||||
updateServerState(serverName, { isInitializing: true });
|
||||
|
||||
try {
|
||||
const response = await reinitializeMutation.mutateAsync(serverName);
|
||||
|
||||
if (response.success) {
|
||||
if (response.oauthRequired && response.oauthUrl) {
|
||||
updateServerState(serverName, {
|
||||
oauthUrl: response.oauthUrl,
|
||||
oauthStartTime: Date.now(),
|
||||
isCancellable: true,
|
||||
isInitializing: true,
|
||||
});
|
||||
|
||||
if (autoOpenOAuth) {
|
||||
window.open(response.oauthUrl, '_blank', 'noopener,noreferrer');
|
||||
}
|
||||
|
||||
startServerPolling(serverName);
|
||||
} else {
|
||||
await queryClient.refetchQueries([QueryKeys.mcpConnectionStatus]);
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_initialized_success', { 0: serverName }),
|
||||
status: 'success',
|
||||
});
|
||||
|
||||
const currentValues = mcpValues ?? [];
|
||||
if (!currentValues.includes(serverName)) {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
}
|
||||
|
||||
cleanupServerState(serverName);
|
||||
}
|
||||
} else {
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_failed', { 0: serverName }),
|
||||
status: 'error',
|
||||
});
|
||||
cleanupServerState(serverName);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[MCP Manager] Failed to initialize ${serverName}:`, error);
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_failed', { 0: serverName }),
|
||||
status: 'error',
|
||||
});
|
||||
cleanupServerState(serverName);
|
||||
}
|
||||
},
|
||||
[
|
||||
updateServerState,
|
||||
reinitializeMutation,
|
||||
startServerPolling,
|
||||
queryClient,
|
||||
showToast,
|
||||
localize,
|
||||
mcpValues,
|
||||
cleanupServerState,
|
||||
setMCPValues,
|
||||
],
|
||||
);
|
||||
|
||||
const cancelOAuthFlow = useCallback(
|
||||
(serverName: string) => {
|
||||
cancelOAuthMutation.mutate(serverName, {
|
||||
onSuccess: () => {
|
||||
cleanupServerState(serverName);
|
||||
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_oauth_cancelled', { 0: serverName }),
|
||||
status: 'warning',
|
||||
});
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error(`[MCP Manager] Failed to cancel OAuth for ${serverName}:`, error);
|
||||
showToast({
|
||||
message: localize('com_ui_mcp_init_failed', { 0: serverName }),
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
});
|
||||
},
|
||||
[queryClient, cleanupServerState, showToast, localize, cancelOAuthMutation],
|
||||
);
|
||||
|
||||
const isInitializing = useCallback(
|
||||
(serverName: string) => {
|
||||
return serverStates[serverName]?.isInitializing || false;
|
||||
},
|
||||
[serverStates],
|
||||
);
|
||||
|
||||
const isCancellable = useCallback(
|
||||
(serverName: string) => {
|
||||
return serverStates[serverName]?.isCancellable || false;
|
||||
},
|
||||
[serverStates],
|
||||
);
|
||||
|
||||
const getOAuthUrl = useCallback(
|
||||
(serverName: string) => {
|
||||
return serverStates[serverName]?.oauthUrl || null;
|
||||
},
|
||||
[serverStates],
|
||||
);
|
||||
|
||||
const placeholderText = useMemo(
|
||||
() => startupConfig?.interface?.mcpServers?.placeholder || localize('com_ui_mcp_servers'),
|
||||
[startupConfig?.interface?.mcpServers?.placeholder, localize],
|
||||
);
|
||||
|
||||
const batchToggleServers = useCallback(
|
||||
(serverNames: string[]) => {
|
||||
const connectedServers: string[] = [];
|
||||
const disconnectedServers: string[] = [];
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
if (isInitializing(serverName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
connectedServers.push(serverName);
|
||||
} else {
|
||||
disconnectedServers.push(serverName);
|
||||
}
|
||||
});
|
||||
|
||||
setMCPValues(connectedServers);
|
||||
|
||||
disconnectedServers.forEach((serverName) => {
|
||||
initializeServer(serverName);
|
||||
});
|
||||
},
|
||||
[connectionStatus, setMCPValues, initializeServer, isInitializing],
|
||||
);
|
||||
|
||||
const toggleServerSelection = useCallback(
|
||||
(serverName: string) => {
|
||||
if (isInitializing(serverName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentValues = mcpValues ?? [];
|
||||
const isCurrentlySelected = currentValues.includes(serverName);
|
||||
|
||||
if (isCurrentlySelected) {
|
||||
const filteredValues = currentValues.filter((name) => name !== serverName);
|
||||
setMCPValues(filteredValues);
|
||||
} else {
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
const { initializeServer, isInitializing, connectionStatus, cancelOAuthFlow, isCancellable } =
|
||||
useMCPServerInitialization({
|
||||
onSuccess: (serverName) => {
|
||||
const currentValues = mcpValues ?? [];
|
||||
if (!currentValues.includes(serverName)) {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
} else {
|
||||
initializeServer(serverName);
|
||||
}
|
||||
}
|
||||
},
|
||||
[mcpValues, setMCPValues, connectionStatus, initializeServer, isInitializing],
|
||||
);
|
||||
},
|
||||
onError: (serverName) => {
|
||||
const tool = mcpToolDetails?.find((t) => t.name === serverName);
|
||||
const serverConfig = startupConfig?.mcpServers?.[serverName];
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
|
||||
const hasAuthConfig =
|
||||
(tool?.authConfig && tool.authConfig.length > 0) ||
|
||||
(serverConfig?.customUserVars && Object.keys(serverConfig.customUserVars).length > 0);
|
||||
|
||||
const wouldShowButton =
|
||||
!serverStatus ||
|
||||
serverStatus.connectionState === 'disconnected' ||
|
||||
serverStatus.connectionState === 'error' ||
|
||||
(serverStatus.connectionState === 'connected' && hasAuthConfig);
|
||||
|
||||
if (!wouldShowButton) {
|
||||
return;
|
||||
}
|
||||
|
||||
const configTool = tool || {
|
||||
name: serverName,
|
||||
pluginKey: `${Constants.mcp_prefix}${serverName}`,
|
||||
authConfig: serverConfig?.customUserVars
|
||||
? Object.entries(serverConfig.customUserVars).map(([key, config]) => ({
|
||||
authField: key,
|
||||
label: config.title,
|
||||
description: config.description,
|
||||
}))
|
||||
: [],
|
||||
authenticated: false,
|
||||
};
|
||||
|
||||
previousFocusRef.current = document.activeElement as HTMLElement;
|
||||
|
||||
setSelectedToolForConfig(configTool);
|
||||
setIsConfigModalOpen(true);
|
||||
},
|
||||
});
|
||||
|
||||
const handleConfigSave = useCallback(
|
||||
(targetName: string, authData: Record<string, string>) => {
|
||||
@@ -436,6 +156,48 @@ export function useMCPServerManager() {
|
||||
}
|
||||
}, []);
|
||||
|
||||
const toggleServerSelection = useCallback(
|
||||
(serverName: string) => {
|
||||
const currentValues = mcpValues ?? [];
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
|
||||
if (currentValues.includes(serverName)) {
|
||||
const filteredValues = currentValues.filter((name) => name !== serverName);
|
||||
setMCPValues(filteredValues);
|
||||
} else {
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
setMCPValues([...currentValues, serverName]);
|
||||
} else {
|
||||
initializeServer(serverName);
|
||||
}
|
||||
}
|
||||
},
|
||||
[connectionStatus, mcpValues, setMCPValues, initializeServer],
|
||||
);
|
||||
|
||||
const batchToggleServers = useCallback(
|
||||
(serverNames: string[]) => {
|
||||
const connectedServers: string[] = [];
|
||||
const disconnectedServers: string[] = [];
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
const serverStatus = connectionStatus[serverName];
|
||||
if (serverStatus?.connectionState === 'connected') {
|
||||
connectedServers.push(serverName);
|
||||
} else {
|
||||
disconnectedServers.push(serverName);
|
||||
}
|
||||
});
|
||||
|
||||
setMCPValues(connectedServers);
|
||||
|
||||
disconnectedServers.forEach((serverName) => {
|
||||
initializeServer(serverName);
|
||||
});
|
||||
},
|
||||
[connectionStatus, setMCPValues, initializeServer],
|
||||
);
|
||||
|
||||
const getServerStatusIconProps = useCallback(
|
||||
(serverName: string) => {
|
||||
const tool = mcpToolDetails?.find((t) => t.name === serverName);
|
||||
@@ -494,6 +256,11 @@ export function useMCPServerManager() {
|
||||
],
|
||||
);
|
||||
|
||||
const placeholderText = useMemo(
|
||||
() => startupConfig?.interface?.mcpServers?.placeholder || localize('com_ui_mcp_servers'),
|
||||
[startupConfig?.interface?.mcpServers?.placeholder, localize],
|
||||
);
|
||||
|
||||
const getConfigDialogProps = useCallback(() => {
|
||||
if (!selectedToolForConfig) return null;
|
||||
|
||||
@@ -536,31 +303,27 @@ export function useMCPServerManager() {
|
||||
]);
|
||||
|
||||
return {
|
||||
// Data
|
||||
configuredServers,
|
||||
connectionStatus,
|
||||
initializeServer,
|
||||
cancelOAuthFlow,
|
||||
isInitializing,
|
||||
isCancellable,
|
||||
getOAuthUrl,
|
||||
mcpValues,
|
||||
setMCPValues,
|
||||
|
||||
mcpToolDetails,
|
||||
isPinned,
|
||||
setIsPinned,
|
||||
startupConfig,
|
||||
connectionStatus,
|
||||
placeholderText,
|
||||
batchToggleServers,
|
||||
toggleServerSelection,
|
||||
localize,
|
||||
|
||||
isConfigModalOpen,
|
||||
handleDialogOpenChange,
|
||||
selectedToolForConfig,
|
||||
setSelectedToolForConfig,
|
||||
handleSave,
|
||||
handleRevoke,
|
||||
// Handlers
|
||||
toggleServerSelection,
|
||||
batchToggleServers,
|
||||
getServerStatusIconProps,
|
||||
|
||||
// Dialog state
|
||||
selectedToolForConfig,
|
||||
isConfigModalOpen,
|
||||
getConfigDialogProps,
|
||||
|
||||
// Utilities
|
||||
localize,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -155,10 +155,7 @@ export default function useSideNavLinks({
|
||||
if (
|
||||
startupConfig?.mcpServers &&
|
||||
Object.values(startupConfig.mcpServers).some(
|
||||
(server: any) =>
|
||||
(server.customUserVars && Object.keys(server.customUserVars).length > 0) ||
|
||||
server.isOAuth ||
|
||||
server.startup === false,
|
||||
(server) => server.customUserVars && Object.keys(server.customUserVars).length > 0,
|
||||
)
|
||||
) {
|
||||
links.push({
|
||||
|
||||
@@ -81,21 +81,13 @@ export function useMCPSelect({ conversationId }: UseMCPSelectOptions) {
|
||||
[setEphemeralAgent],
|
||||
);
|
||||
|
||||
const [mcpValues, setMCPValuesRaw] = useLocalStorage<string[]>(
|
||||
const [mcpValues, setMCPValues] = useLocalStorage<string[]>(
|
||||
`${LocalStorageKeys.LAST_MCP_}${key}`,
|
||||
mcpState,
|
||||
setSelectedValues,
|
||||
storageCondition,
|
||||
);
|
||||
|
||||
const setMCPValuesRawRef = useRef(setMCPValuesRaw);
|
||||
setMCPValuesRawRef.current = setMCPValuesRaw;
|
||||
|
||||
// Create a stable memoized setter to avoid re-creating it on every render and causing an infinite render loop
|
||||
const setMCPValues = useCallback((value: string[]) => {
|
||||
setMCPValuesRawRef.current(value);
|
||||
}, []);
|
||||
|
||||
const [isPinned, setIsPinned] = useLocalStorage<boolean>(
|
||||
`${LocalStorageKeys.PIN_MCP_}${key}`,
|
||||
true,
|
||||
|
||||
@@ -8,7 +8,6 @@ export * from './Nav';
|
||||
export * from './Files';
|
||||
export * from './Generic';
|
||||
export * from './Input';
|
||||
export * from './MCP';
|
||||
export * from './Messages';
|
||||
export * from './Plugins';
|
||||
export * from './Prompts';
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user