Compare commits
111 Commits
feat/direc
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278590d0bb | ||
|
|
41a4674469 | ||
|
|
e7a9cf88ac | ||
|
|
f6925f906b | ||
|
|
e90fd1df15 | ||
|
|
a1f9f3dd39 | ||
|
|
fbe0def2fa | ||
|
|
d04da60b3b | ||
|
|
0e94d97bfb | ||
|
|
45ab4d4503 | ||
|
|
0ceef12eea | ||
|
|
6738360051 | ||
|
|
52b65492d5 | ||
|
|
7a9a99d2a0 | ||
|
|
5bfb06b417 | ||
|
|
2ce8f1f686 | ||
|
|
1a47601533 | ||
|
|
5245aeea8f | ||
|
|
dd93db40bc | ||
|
|
136cf1d5a8 | ||
|
|
751522087a | ||
|
|
7fe830acfc | ||
|
|
cdfe686987 | ||
|
|
5b5723343c | ||
|
|
30c24a66f6 | ||
|
|
ecf9733bc1 | ||
|
|
133312fb40 | ||
|
|
b62ffb533c | ||
|
|
d75fb76338 | ||
|
|
51f2d43fed | ||
|
|
e3a645e8fb | ||
|
|
180046a3c5 | ||
|
|
916742ab9d | ||
|
|
d91f34dd42 | ||
|
|
5676976564 | ||
|
|
85aa3e7d9c | ||
|
|
a2ff6613c5 | ||
|
|
8d6cb5eee0 | ||
|
|
31445e391a | ||
|
|
04c3a5a861 | ||
|
|
5667cc9702 | ||
|
|
c0f95f971a | ||
|
|
f125f5bd32 | ||
|
|
f3eca8c7a7 | ||
|
|
f22e5f965e | ||
|
|
749f539dfc | ||
|
|
1247207afe | ||
|
|
5c0e9d8fbb | ||
|
|
957fa7a994 | ||
|
|
751c2e1d17 | ||
|
|
519645c0b0 | ||
|
|
0d0a318c3c | ||
|
|
588e0c4611 | ||
|
|
79144a6365 | ||
|
|
ca53c20370 | ||
|
|
d635503f49 | ||
|
|
920966f895 | ||
|
|
c46e0d3ecc | ||
|
|
c6ecf0095b | ||
|
|
7de6f6e44c | ||
|
|
035f85c3ba | ||
|
|
6f6a34d126 | ||
|
|
fff1f1cf27 | ||
|
|
1869854d70 | ||
|
|
4dd2998592 | ||
|
|
a4a174b3dc | ||
|
|
65c83317aa | ||
|
|
e95e0052da | ||
|
|
0ecafcd38e | ||
|
|
cadfe14abe | ||
|
|
75dd6fb28b | ||
|
|
eef93024d5 | ||
|
|
cd73cb0b3e | ||
|
|
e705b09280 | ||
|
|
23bd4dfbfd | ||
|
|
df17582103 | ||
|
|
d79b80a4bf | ||
|
|
45da421e7d | ||
|
|
122ff416ac | ||
|
|
b66bf93b31 | ||
|
|
6d791e3e12 | ||
|
|
f9b12517b0 | ||
|
|
195e1e9eb2 | ||
|
|
47aa90df1d | ||
|
|
460eac36f6 | ||
|
|
3a47deac07 | ||
|
|
49e8443ec5 | ||
|
|
d16f93b5f7 | ||
|
|
20b29bbfa6 | ||
|
|
e2a6937ca6 | ||
|
|
005a0cb84a | ||
|
|
beabe38311 | ||
|
|
62315be197 | ||
|
|
a26597a696 | ||
|
|
8772b04d1d | ||
|
|
7742b18c9c | ||
|
|
b75b799e34 | ||
|
|
43add11b05 | ||
|
|
1764de53a5 | ||
|
|
c0511b9a5f | ||
|
|
2483623c88 | ||
|
|
229d6f2dfe | ||
|
|
d5ec838218 | ||
|
|
15d7a3d221 | ||
|
|
c3e88b97c8 | ||
|
|
ba424666f8 | ||
|
|
ea3b671182 | ||
|
|
f209f616c9 | ||
|
|
961af515d5 | ||
|
|
a362963017 | ||
|
|
78d735f35c |
15
.env.example
15
.env.example
@@ -40,6 +40,13 @@ NO_INDEX=true
|
||||
# Defaulted to 1.
|
||||
TRUST_PROXY=1
|
||||
|
||||
# Minimum password length for user authentication
|
||||
# Default: 8
|
||||
# Note: When using LDAP authentication, you may want to set this to 1
|
||||
# to bypass local password validation, as LDAP servers handle their own
|
||||
# password policies.
|
||||
# MIN_PASSWORD_LENGTH=8
|
||||
|
||||
#===============#
|
||||
# JSON Logging #
|
||||
#===============#
|
||||
@@ -660,6 +667,10 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
# REDIS_URI=rediss://127.0.0.1:6380
|
||||
# REDIS_CA=/path/to/ca-cert.pem
|
||||
|
||||
# Elasticache may need to use an alternate dnsLookup for TLS connections. see "Special Note: Aws Elasticache Clusters with TLS" on this webpage: https://www.npmjs.com/package/ioredis
|
||||
# Enable alternative dnsLookup for redis
|
||||
# REDIS_USE_ALTERNATIVE_DNS_LOOKUP=true
|
||||
|
||||
# Redis authentication (if required)
|
||||
# REDIS_USERNAME=your_redis_username
|
||||
# REDIS_PASSWORD=your_redis_password
|
||||
@@ -679,8 +690,8 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
# REDIS_PING_INTERVAL=300
|
||||
|
||||
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
|
||||
# Comma-separated list of CacheKeys (e.g., STATIC_CONFIG,ROLES,MESSAGES)
|
||||
# FORCED_IN_MEMORY_CACHE_NAMESPACES=STATIC_CONFIG,ROLES
|
||||
# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES)
|
||||
# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES
|
||||
|
||||
#==================================================#
|
||||
# Others #
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.8.0-rc3
|
||||
# v0.8.0-rc4
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
@@ -30,7 +30,7 @@ RUN \
|
||||
# Allow mounting of these files, which have no default
|
||||
touch .env ; \
|
||||
# Create directories for the volumes to inherit the correct permissions
|
||||
mkdir -p /app/client/public/images /app/api/logs ; \
|
||||
mkdir -p /app/client/public/images /app/api/logs /app/uploads ; \
|
||||
npm config set fetch-retry-maxtimeout 600000 ; \
|
||||
npm config set fetch-retries 5 ; \
|
||||
npm config set fetch-retry-mintimeout 15000 ; \
|
||||
@@ -44,8 +44,6 @@ RUN \
|
||||
npm prune --production; \
|
||||
npm cache clean --force
|
||||
|
||||
RUN mkdir -p /app/client/public/images /app/api/logs
|
||||
|
||||
# Node API setup
|
||||
EXPOSE 3080
|
||||
ENV HOST=0.0.0.0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.8.0-rc3
|
||||
# v0.8.0-rc4
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
|
||||
@@ -75,6 +75,7 @@
|
||||
- 🔍 **Web Search**:
|
||||
- Search the internet and retrieve relevant information to enhance your AI context
|
||||
- Combines search providers, content scrapers, and result rerankers for optimal results
|
||||
- **Customizable Jina Reranking**: Configure custom Jina API URLs for reranking services
|
||||
- **[Learn More →](https://www.librechat.ai/docs/features/web_search)**
|
||||
|
||||
- 🪄 **Generative UI with Code Artifacts**:
|
||||
|
||||
@@ -10,7 +10,17 @@ const {
|
||||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
|
||||
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
|
||||
const {
|
||||
Tokenizer,
|
||||
createFetch,
|
||||
matchModelName,
|
||||
getClaudeHeaders,
|
||||
getModelMaxTokens,
|
||||
configureReasoning,
|
||||
checkPromptCacheSupport,
|
||||
getModelMaxOutputTokens,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
truncateText,
|
||||
formatMessage,
|
||||
@@ -19,12 +29,6 @@ const {
|
||||
parseParamFromPrompt,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const {
|
||||
getClaudeHeaders,
|
||||
configureReasoning,
|
||||
checkPromptCacheSupport,
|
||||
} = require('~/server/services/Endpoints/anthropic/helpers');
|
||||
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { sleep } = require('~/server/utils');
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { google } = require('googleapis');
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { Tokenizer, getSafetySettings } = require('@librechat/api');
|
||||
@@ -21,7 +22,6 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const {
|
||||
|
||||
@@ -7,7 +7,9 @@ const {
|
||||
createFetch,
|
||||
resolveHeaders,
|
||||
constructAzureURL,
|
||||
getModelMaxTokens,
|
||||
genAzureChatCompletion,
|
||||
getModelMaxOutputTokens,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
@@ -31,13 +33,13 @@ const {
|
||||
titleInstruction,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const { summaryBuffer } = require('./memory');
|
||||
const { runTitleChain } = require('./chains');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { tokenSplit } = require('./document');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { createLLM } = require('./llm');
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const BaseClient = require('../BaseClient');
|
||||
const { getModelMaxTokens } = require('../../../utils');
|
||||
|
||||
class FakeClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
|
||||
@@ -71,9 +71,10 @@ const primeFiles = async (options) => {
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Array<{ file_id: string; filename: string }>} options.files
|
||||
* @param {string} [options.entity_id]
|
||||
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
|
||||
* @returns
|
||||
*/
|
||||
const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
const createFileSearchTool = async ({ req, files, entity_id, fileCitations = false }) => {
|
||||
return tool(
|
||||
async ({ query }) => {
|
||||
if (files.length === 0) {
|
||||
@@ -142,9 +143,9 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
const formattedString = formattedResults
|
||||
.map(
|
||||
(result, index) =>
|
||||
`File: ${result.filename}\nAnchor: \\ue202turn0file${index} (${result.filename})\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${
|
||||
result.content
|
||||
}\n`,
|
||||
`File: ${result.filename}${
|
||||
fileCitations ? `\nAnchor: \\ue202turn0file${index} (${result.filename})` : ''
|
||||
}\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${result.content}\n`,
|
||||
)
|
||||
.join('\n---\n');
|
||||
|
||||
@@ -158,12 +159,14 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
pageRelevance: result.page ? { [result.page]: 1.0 - result.distance } : {},
|
||||
}));
|
||||
|
||||
return [formattedString, { [Tools.file_search]: { sources } }];
|
||||
return [formattedString, { [Tools.file_search]: { sources, fileCitations } }];
|
||||
},
|
||||
{
|
||||
name: Tools.file_search,
|
||||
responseFormat: 'content_and_artifact',
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.${
|
||||
fileCitations
|
||||
? `
|
||||
|
||||
**CITE FILE SEARCH RESULTS:**
|
||||
Use anchor markers immediately after statements derived from file content. Reference the filename in your text:
|
||||
@@ -171,7 +174,9 @@ Use anchor markers immediately after statements derived from file content. Refer
|
||||
- Page reference: "According to report.docx... \\ue202turn0file1"
|
||||
- Multi-file: "Multiple sources confirm... \\ue200\\ue202turn0file0\\ue202turn0file1\\ue201"
|
||||
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`,
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`
|
||||
: ''
|
||||
}`,
|
||||
schema: z.object({
|
||||
query: z
|
||||
.string()
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||
const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api');
|
||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||
const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
Permissions,
|
||||
EToolResources,
|
||||
PermissionTypes,
|
||||
replaceSpecialVars,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
@@ -27,6 +34,7 @@ const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
||||
@@ -281,7 +289,29 @@ const loadTools = async ({
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
|
||||
|
||||
/** @type {boolean | undefined} Check if user has FILE_CITATIONS permission */
|
||||
let fileCitations;
|
||||
if (fileCitations == null && options.req?.user != null) {
|
||||
try {
|
||||
fileCitations = await checkAccess({
|
||||
user: options.req.user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[handleTools] FILE_CITATIONS permission check failed:', error);
|
||||
fileCitations = false;
|
||||
}
|
||||
}
|
||||
|
||||
return createFileSearchTool({
|
||||
req: options.req,
|
||||
files,
|
||||
entity_id: agent?.id,
|
||||
fileCitations,
|
||||
});
|
||||
};
|
||||
continue;
|
||||
} else if (tool === Tools.web_search) {
|
||||
@@ -312,6 +342,16 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
continue;
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
|
||||
if (toolName === Constants.mcp_server) {
|
||||
/** Placeholder used for UI purposes */
|
||||
continue;
|
||||
}
|
||||
if (serverName && options.req?.config?.mcpConfig?.[serverName] == null) {
|
||||
logger.warn(
|
||||
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (toolName === Constants.mcp_all) {
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTools({
|
||||
|
||||
3
api/cache/cacheConfig.js
vendored
3
api/cache/cacheConfig.js
vendored
@@ -52,6 +52,9 @@ const cacheConfig = {
|
||||
REDIS_CONNECT_TIMEOUT: math(process.env.REDIS_CONNECT_TIMEOUT, 10000),
|
||||
/** Queue commands when disconnected */
|
||||
REDIS_ENABLE_OFFLINE_QUEUE: isEnabled(process.env.REDIS_ENABLE_OFFLINE_QUEUE ?? 'true'),
|
||||
/** flag to modify redis connection by adding dnsLookup this is required when connecting to elasticache for ioredis
|
||||
* see "Special Note: Aws Elasticache Clusters with TLS" on this webpage: https://www.npmjs.com/package/ioredis **/
|
||||
REDIS_USE_ALTERNATIVE_DNS_LOOKUP: isEnabled(process.env.REDIS_USE_ALTERNATIVE_DNS_LOOKUP),
|
||||
/** Enable redis cluster without the need of multiple URIs */
|
||||
USE_REDIS_CLUSTER: isEnabled(process.env.USE_REDIS_CLUSTER ?? 'false'),
|
||||
CI: isEnabled(process.env.CI),
|
||||
|
||||
3
api/cache/cacheConfig.spec.js
vendored
3
api/cache/cacheConfig.spec.js
vendored
@@ -157,12 +157,11 @@ describe('cacheConfig', () => {
|
||||
|
||||
describe('FORCED_IN_MEMORY_CACHE_NAMESPACES validation', () => {
|
||||
test('should parse comma-separated cache keys correctly', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, STATIC_CONFIG ,MESSAGES ';
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, MESSAGES ';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([
|
||||
'ROLES',
|
||||
'STATIC_CONFIG',
|
||||
'MESSAGES',
|
||||
]);
|
||||
});
|
||||
|
||||
2
api/cache/getLogStores.js
vendored
2
api/cache/getLogStores.js
vendored
@@ -31,8 +31,8 @@ const namespaces = {
|
||||
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
|
||||
|
||||
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
|
||||
[CacheKeys.APP_CONFIG]: standardCache(CacheKeys.APP_CONFIG),
|
||||
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
|
||||
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
|
||||
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
|
||||
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
|
||||
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
|
||||
|
||||
3
api/cache/redisClients.js
vendored
3
api/cache/redisClients.js
vendored
@@ -53,6 +53,9 @@ if (cacheConfig.USE_REDIS) {
|
||||
: new IoRedis.Cluster(
|
||||
urls.map((url) => ({ host: url.hostname, port: parseInt(url.port, 10) || 6379 })),
|
||||
{
|
||||
...(cacheConfig.REDIS_USE_ALTERNATIVE_DNS_LOOKUP
|
||||
? { dnsLookup: (address, callback) => callback(null, address) }
|
||||
: {}),
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
@@ -26,4 +26,6 @@ module.exports = {
|
||||
createMCPManager: MCPManager.createInstance,
|
||||
getMCPManager: MCPManager.getInstance,
|
||||
getFlowStateManager,
|
||||
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
|
||||
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
|
||||
};
|
||||
|
||||
@@ -211,7 +211,67 @@ describe('File Access Control', () => {
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when user only has VIEW permission', async () => {
|
||||
it('should deny access when user only has VIEW permission and needs access for deletion', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
isDelete: true,
|
||||
});
|
||||
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access when user has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
@@ -265,9 +325,8 @@ describe('File Access Control', () => {
|
||||
agentId,
|
||||
});
|
||||
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ async function getListPromptGroupsByAccess({
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
@@ -189,11 +189,15 @@ async function createAutoRefillTransaction(txData) {
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createTransaction(_txData) {
|
||||
const { balance, ...txData } = _txData;
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
calculateTokenValue(transaction);
|
||||
@@ -222,7 +226,11 @@ async function createTransaction(_txData) {
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createStructuredTransaction(_txData) {
|
||||
const { balance, ...txData } = _txData;
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction({
|
||||
...txData,
|
||||
endpointTokenConfig: txData.endpointTokenConfig,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { createTransaction } = require('./Transaction');
|
||||
const { Balance } = require('~/db/models');
|
||||
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
||||
const { Balance, Transaction } = require('~/db/models');
|
||||
|
||||
let mongoServer;
|
||||
beforeAll(async () => {
|
||||
@@ -380,3 +379,188 @@ describe('NaN Handling Tests', () => {
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Transactions Config Tests', () => {
|
||||
test('createTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createTransaction should save when transactions.enabled is true', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created
|
||||
expect(result).toBeDefined();
|
||||
expect(result.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
});
|
||||
|
||||
test('createTransaction should save when balance.enabled is true even if transactions config is missing', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
balance: { enabled: true },
|
||||
// No transactions config provided
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created (backward compatibility)
|
||||
expect(result).toBeDefined();
|
||||
expect(result.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
});
|
||||
|
||||
test('createTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].inputTokens).toBe(-10);
|
||||
expect(transactions[0].writeTokens).toBe(-100);
|
||||
expect(transactions[0].readTokens).toBe(-5);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,47 +1,9 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { buildTree } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { getMessages, bulkSaveMessages } = require('./Message');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
// Original version of buildTree function
|
||||
function buildTree({ messages, fileMap }) {
|
||||
if (messages === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const messageMap = {};
|
||||
const rootMessages = [];
|
||||
const childrenCount = {};
|
||||
|
||||
messages.forEach((message) => {
|
||||
const parentId = message.parentMessageId ?? '';
|
||||
childrenCount[parentId] = (childrenCount[parentId] || 0) + 1;
|
||||
|
||||
const extendedMessage = {
|
||||
...message,
|
||||
children: [],
|
||||
depth: 0,
|
||||
siblingIndex: childrenCount[parentId] - 1,
|
||||
};
|
||||
|
||||
if (message.files && fileMap) {
|
||||
extendedMessage.files = message.files.map((file) => fileMap[file.file_id ?? ''] ?? file);
|
||||
}
|
||||
|
||||
messageMap[message.messageId] = extendedMessage;
|
||||
|
||||
const parentMessage = messageMap[parentId];
|
||||
if (parentMessage) {
|
||||
parentMessage.children.push(extendedMessage);
|
||||
extendedMessage.depth = parentMessage.depth + 1;
|
||||
} else {
|
||||
rootMessages.push(extendedMessage);
|
||||
}
|
||||
});
|
||||
|
||||
return rootMessages;
|
||||
}
|
||||
|
||||
let mongod;
|
||||
beforeAll(async () => {
|
||||
mongod = await MongoMemoryServer.create();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { matchModelName } = require('../utils/tokens');
|
||||
const { matchModelName } = require('@librechat/api');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.0-rc3",
|
||||
"version": "v0.8.0-rc4",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -49,14 +49,14 @@
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.76",
|
||||
"@librechat/agents": "^2.4.79",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@modelcontextprotocol/sdk": "^1.17.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
"axios": "^1.12.1",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.8.1",
|
||||
"connect-redis": "^8.1.0",
|
||||
|
||||
@@ -11,8 +11,9 @@ const {
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
@@ -75,7 +76,7 @@ const refreshController = async (req, res) => {
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
const token = setOpenIDAuthTokens(tokenset, res);
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
|
||||
return res.status(200).send({ token, user });
|
||||
} catch (error) {
|
||||
logger.error('[refreshController] OpenID token refresh error', error);
|
||||
@@ -96,14 +97,25 @@ const refreshController = async (req, res) => {
|
||||
return res.status(200).send({ token, user });
|
||||
}
|
||||
|
||||
// Find the session with the hashed refresh token
|
||||
const session = await findSession({
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
});
|
||||
/** Session with the hashed refresh token */
|
||||
const session = await findSession(
|
||||
{
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
},
|
||||
{ lean: false },
|
||||
);
|
||||
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
const token = await setAuthTokens(userId, res, session);
|
||||
|
||||
// trigger OAuth MCP server reconnection asynchronously (best effort)
|
||||
void getOAuthReconnectionManager()
|
||||
.reconnectServers(userId)
|
||||
.catch((err) => {
|
||||
logger.error('Error reconnecting OAuth MCP servers:', err);
|
||||
});
|
||||
|
||||
res.status(200).send({ token, user });
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
|
||||
@@ -74,14 +74,23 @@ const getAvailableTools = async (req, res) => {
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
|
||||
const mcpManager = getMCPManager();
|
||||
const userPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
|
||||
: undefined;
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
|
||||
if (cachedToolsArray != null && userPlugins != null) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
/** @type {TPlugin[]} */
|
||||
let mcpPlugins;
|
||||
if (appConfig?.mcpConfig) {
|
||||
const mcpManager = getMCPManager();
|
||||
mcpPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
|
||||
: undefined;
|
||||
}
|
||||
|
||||
if (
|
||||
cachedToolsArray != null &&
|
||||
(appConfig?.mcpConfig != null ? mcpPlugins != null && mcpPlugins.length > 0 : true)
|
||||
) {
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
return;
|
||||
}
|
||||
@@ -93,9 +102,9 @@ const getAvailableTools = async (req, res) => {
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
let pluginManifest = availableTools;
|
||||
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
if (appConfig?.mcpConfig != null) {
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const mcpTools = await mcpManager.getAllToolFunctions(userId);
|
||||
prelimCachedTools = prelimCachedTools ?? {};
|
||||
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
|
||||
@@ -175,7 +184,7 @@ const getAvailableTools = async (req, res) => {
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...(userPlugins ?? []), ...finalTools]);
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...finalTools]);
|
||||
res.status(200).json(dedupedTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
|
||||
@@ -174,10 +174,19 @@ describe('PluginController', () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially (since getAllToolFunctions is called)
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// Mock second call to return tool definitions (includeGlobal: true)
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
@@ -505,7 +514,7 @@ describe('PluginController', () => {
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
|
||||
it('should handle `cachedToolsArray` and `mcpPlugins` both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
|
||||
const userTools = {
|
||||
@@ -522,10 +531,19 @@ describe('PluginController', () => {
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValueOnce(userTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// The controller expects a second call to getCachedTools
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
|
||||
const {
|
||||
webSearchKeys,
|
||||
extractWebSearchEnvVars,
|
||||
normalizeHttpError,
|
||||
MCPTokenStorage,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
updateUser,
|
||||
@@ -16,11 +21,17 @@ const { verifyEmail, resendVerificationEmail } = require('~/server/services/Auth
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { Tools, Constants, FileSources } = require('librechat-data-provider');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User } = require('~/db/models');
|
||||
const { Transaction, Balance, User, Token } = require('~/db/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { deleteAllSharedLinks } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { clearMCPServerTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { findToken } = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
@@ -162,6 +173,15 @@ const updateUserPluginsController = async (req, res) => {
|
||||
);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
}
|
||||
try {
|
||||
// if the MCP server uses OAuth, perform a full cleanup and token revocation
|
||||
await maybeUninstallOAuthMCP(user.id, pluginKey, appConfig);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[updateUserPluginsController] Error uninstalling OAuth MCP for ${pluginKey}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// This handles:
|
||||
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
|
||||
@@ -187,7 +207,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
// Extract server name from pluginKey (format: "mcp_<serverName>")
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
logger.info(
|
||||
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
}
|
||||
@@ -269,6 +289,97 @@ const resendVerificationController = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* OAuth MCP specific uninstall logic
|
||||
*/
|
||||
const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||
if (!pluginKey.startsWith(Constants.mcp_prefix)) {
|
||||
// this is not an MCP server, so nothing to do here
|
||||
return;
|
||||
}
|
||||
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!mcpManager.getOAuthServers().has(serverName)) {
|
||||
// this server does not use OAuth, so nothing to do here as well
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. get client info used for revocation (client id, secret)
|
||||
const clientTokenData = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
});
|
||||
if (clientTokenData == null) {
|
||||
return;
|
||||
}
|
||||
const { clientInfo, clientMetadata } = clientTokenData;
|
||||
|
||||
// 2. get decrypted tokens before deletion
|
||||
const tokens = await MCPTokenStorage.getTokens({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
});
|
||||
|
||||
// 3. revoke OAuth tokens at the provider
|
||||
const revocationEndpoint =
|
||||
serverConfig.oauth?.revocation_endpoint ?? clientMetadata.revocation_endpoint;
|
||||
const revocationEndpointAuthMethodsSupported =
|
||||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||
|
||||
if (tokens?.access_token) {
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
|
||||
serverUrl: serverConfig.url,
|
||||
clientId: clientInfo.client_id,
|
||||
clientSecret: clientInfo.client_secret ?? '',
|
||||
revocationEndpoint,
|
||||
revocationEndpointAuthMethodsSupported,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens?.refresh_token) {
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
|
||||
serverUrl: serverConfig.url,
|
||||
clientId: clientInfo.client_id,
|
||||
clientSecret: clientInfo.client_secret ?? '',
|
||||
revocationEndpoint,
|
||||
revocationEndpointAuthMethodsSupported,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. delete tokens from the DB after revocation attempts
|
||||
await MCPTokenStorage.deleteUserTokens({
|
||||
userId,
|
||||
serverName,
|
||||
deleteToken: async (filter) => {
|
||||
await Token.deleteOne(filter);
|
||||
},
|
||||
});
|
||||
|
||||
// 5. clear the flow state for the OAuth tokens
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
// 6. clear the tools cache for the server
|
||||
await clearMCPServerTools({ userId, serverName });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getUserController,
|
||||
getTermsStatusController,
|
||||
|
||||
342
api/server/controllers/agents/__tests__/callbacks.spec.js
Normal file
342
api/server/controllers/agents/__tests__/callbacks.spec.js
Normal file
@@ -0,0 +1,342 @@
|
||||
const { Tools } = require('librechat-data-provider');
|
||||
|
||||
// Mock all dependencies before requiring the module
|
||||
jest.mock('nanoid', () => ({
|
||||
nanoid: jest.fn(() => 'mock-id'),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
sendEvent: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
EnvVar: { CODE_API_KEY: 'CODE_API_KEY' },
|
||||
Providers: { GOOGLE: 'google' },
|
||||
GraphEvents: {},
|
||||
getMessageId: jest.fn(),
|
||||
ToolEndHandler: jest.fn(),
|
||||
handleToolCalls: jest.fn(),
|
||||
ChatModelStreamHandler: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/Citations', () => ({
|
||||
processFileCitations: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/Code/process', () => ({
|
||||
processCodeOutput: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
saveBase64Image: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('createToolEndCallback', () => {
|
||||
let req, res, artifactPromises, createToolEndCallback;
|
||||
let logger;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Get the mocked logger
|
||||
logger = require('@librechat/data-schemas').logger;
|
||||
|
||||
// Now require the module after all mocks are set up
|
||||
const callbacks = require('../callbacks');
|
||||
createToolEndCallback = callbacks.createToolEndCallback;
|
||||
|
||||
req = {
|
||||
user: { id: 'user123' },
|
||||
};
|
||||
res = {
|
||||
headersSent: false,
|
||||
write: jest.fn(),
|
||||
};
|
||||
artifactPromises = [];
|
||||
});
|
||||
|
||||
describe('ui_resources artifact handling', () => {
|
||||
it('should process ui_resources artifact and return attachment when headers not sent', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'button', label: 'Click me' },
|
||||
1: { type: 'input', placeholder: 'Enter text' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
|
||||
// Wait for all promises to resolve
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
// When headers are not sent, it returns attachment without writing
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
|
||||
const attachment = results[0];
|
||||
expect(attachment).toEqual({
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'run456',
|
||||
toolCallId: 'tool123',
|
||||
conversationId: 'thread789',
|
||||
[Tools.ui_resources]: {
|
||||
0: { type: 'button', label: 'Click me' },
|
||||
1: { type: 'input', placeholder: 'Enter text' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should write to response when headers are already sent', async () => {
|
||||
res.headersSent = true;
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'carousel', items: [] },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(res.write).toHaveBeenCalled();
|
||||
expect(results[0]).toEqual({
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'run456',
|
||||
toolCallId: 'tool123',
|
||||
conversationId: 'thread789',
|
||||
[Tools.ui_resources]: {
|
||||
0: { type: 'carousel', items: [] },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle errors when processing ui_resources', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
// Mock res.write to throw an error
|
||||
res.headersSent = true;
|
||||
res.write.mockImplementation(() => {
|
||||
throw new Error('Write failed');
|
||||
});
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'test' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'Error processing artifact content:',
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(results[0]).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle multiple artifacts including ui_resources', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'chart', data: [] },
|
||||
},
|
||||
},
|
||||
[Tools.web_search]: {
|
||||
results: ['result1', 'result2'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
// Both ui_resources and web_search should be processed
|
||||
expect(artifactPromises).toHaveLength(2);
|
||||
expect(results).toHaveLength(2);
|
||||
|
||||
// Check ui_resources attachment
|
||||
const uiResourceAttachment = results.find((r) => r?.type === Tools.ui_resources);
|
||||
expect(uiResourceAttachment).toBeTruthy();
|
||||
expect(uiResourceAttachment[Tools.ui_resources]).toEqual({
|
||||
0: { type: 'chart', data: [] },
|
||||
});
|
||||
|
||||
// Check web_search attachment
|
||||
const webSearchAttachment = results.find((r) => r?.type === Tools.web_search);
|
||||
expect(webSearchAttachment).toBeTruthy();
|
||||
expect(webSearchAttachment[Tools.web_search]).toEqual({
|
||||
results: ['result1', 'result2'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should not process artifacts when output has no artifacts', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
content: 'Some regular content',
|
||||
// No artifact property
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
|
||||
expect(artifactPromises).toHaveLength(0);
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty ui_resources data object', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(results[0]).toEqual({
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'run456',
|
||||
toolCallId: 'tool123',
|
||||
conversationId: 'thread789',
|
||||
[Tools.ui_resources]: {},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle ui_resources with complex nested data', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const complexData = {
|
||||
0: {
|
||||
type: 'form',
|
||||
fields: [
|
||||
{ name: 'field1', type: 'text', required: true },
|
||||
{ name: 'field2', type: 'select', options: ['a', 'b', 'c'] },
|
||||
],
|
||||
nested: {
|
||||
deep: {
|
||||
value: 123,
|
||||
array: [1, 2, 3],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: complexData,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(results[0][Tools.ui_resources]).toEqual(complexData);
|
||||
});
|
||||
|
||||
it('should handle when output is undefined', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output: undefined }, metadata);
|
||||
|
||||
expect(artifactPromises).toHaveLength(0);
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle when data parameter is undefined', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback(undefined, metadata);
|
||||
|
||||
expect(artifactPromises).toHaveLength(0);
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -265,6 +265,30 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: a lot of duplicated code in createToolEndCallback
|
||||
// we should refactor this to use a helper function in a follow-up PR
|
||||
if (output.artifact[Tools.ui_resources]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const attachment = {
|
||||
type: Tools.ui_resources,
|
||||
messageId: metadata.run_id,
|
||||
toolCallId: output.tool_call_id,
|
||||
conversationId: metadata.thread_id,
|
||||
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (output.artifact[Tools.web_search]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
|
||||
@@ -7,9 +7,12 @@ const {
|
||||
createRun,
|
||||
Tokenizer,
|
||||
checkAccess,
|
||||
logAxiosError,
|
||||
resolveHeaders,
|
||||
getBalanceConfig,
|
||||
memoryInstructions,
|
||||
formatContentStrings,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
@@ -86,11 +89,10 @@ function createTokenCounter(encoding) {
|
||||
}
|
||||
|
||||
function logToolError(graph, error, toolId) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
|
||||
logAxiosError({
|
||||
error,
|
||||
toolId,
|
||||
);
|
||||
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
|
||||
});
|
||||
}
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
@@ -622,11 +624,13 @@ class AgentClient extends BaseClient {
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.context='message']
|
||||
* @param {AppConfig['balance']} [params.balance]
|
||||
* @param {AppConfig['transactions']} [params.transactions]
|
||||
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
|
||||
*/
|
||||
async recordCollectedUsage({
|
||||
model,
|
||||
balance,
|
||||
transactions,
|
||||
context = 'message',
|
||||
collectedUsage = this.collectedUsage,
|
||||
}) {
|
||||
@@ -652,6 +656,7 @@ class AgentClient extends BaseClient {
|
||||
const txMetadata = {
|
||||
context,
|
||||
balance,
|
||||
transactions,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
@@ -867,11 +872,10 @@ class AgentClient extends BaseClient {
|
||||
if (agent.useLegacyContent === true) {
|
||||
messages = formatContentStrings(messages);
|
||||
}
|
||||
if (
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
|
||||
'prompt-caching',
|
||||
)
|
||||
) {
|
||||
const defaultHeaders =
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders ??
|
||||
agent.model_parameters?.configuration?.defaultHeaders;
|
||||
if (defaultHeaders?.['anthropic-beta']?.includes('prompt-caching')) {
|
||||
messages = addCacheControl(messages);
|
||||
}
|
||||
|
||||
@@ -879,6 +883,16 @@ class AgentClient extends BaseClient {
|
||||
memoryPromise = this.runMemory(messages);
|
||||
}
|
||||
|
||||
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
|
||||
* non-custom endpoints, needs consideration of varying provider header configs.
|
||||
*/
|
||||
if (agent.model_parameters?.configuration?.defaultHeaders != null) {
|
||||
agent.model_parameters.configuration.defaultHeaders = resolveHeaders({
|
||||
headers: agent.model_parameters.configuration.defaultHeaders,
|
||||
body: config.configurable.requestBody,
|
||||
});
|
||||
}
|
||||
|
||||
run = await createRun({
|
||||
agent,
|
||||
req: this.options.req,
|
||||
@@ -1040,7 +1054,12 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
await this.recordCollectedUsage({ context: 'message', balance: balanceConfig });
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
await this.recordCollectedUsage({
|
||||
context: 'message',
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
|
||||
@@ -1181,6 +1200,20 @@ class AgentClient extends BaseClient {
|
||||
clientOptions.json = true;
|
||||
}
|
||||
|
||||
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
|
||||
* non-custom endpoints, needs consideration of varying provider header configs.
|
||||
*/
|
||||
if (clientOptions?.configuration?.defaultHeaders != null) {
|
||||
clientOptions.configuration.defaultHeaders = resolveHeaders({
|
||||
headers: clientOptions.configuration.defaultHeaders,
|
||||
body: {
|
||||
messageId: this.responseMessageId,
|
||||
conversationId: this.conversationId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const titleResult = await this.run.generateTitle({
|
||||
provider,
|
||||
@@ -1220,11 +1253,13 @@ class AgentClient extends BaseClient {
|
||||
});
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
await this.recordCollectedUsage({
|
||||
collectedUsage,
|
||||
context: 'title',
|
||||
model: clientOptions.model,
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
}).catch((err) => {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
|
||||
|
||||
@@ -237,6 +237,9 @@ describe('AgentClient - titleConvo', () => {
|
||||
balance: {
|
||||
enabled: false,
|
||||
},
|
||||
transactions: {
|
||||
enabled: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
SystemRoles,
|
||||
FileSources,
|
||||
ResourceType,
|
||||
@@ -69,9 +70,9 @@ const createAgentHandler = async (req, res) => {
|
||||
for (const tool of tools) {
|
||||
if (availableTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
|
||||
if (systemTools[tool]) {
|
||||
} else if (systemTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
} else if (tool.includes(Constants.mcp_delimiter)) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig } = require('@librechat/api');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -34,7 +34,6 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig } = require('@librechat/api');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -31,7 +31,6 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generate2FATempToken } = require('~/server/services/twoFactorService');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const loginController = async (req, res) => {
|
||||
try {
|
||||
|
||||
50
api/server/controllers/auth/oauth.js
Normal file
50
api/server/controllers/auth/oauth.js
Normal file
@@ -0,0 +1,50 @@
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { checkBan } = require('~/server/middleware');
|
||||
|
||||
const domains = {
|
||||
client: process.env.DOMAIN_CLIENT,
|
||||
server: process.env.DOMAIN_SERVER,
|
||||
};
|
||||
|
||||
function createOAuthHandler(redirectUri = domains.client) {
|
||||
/**
|
||||
* A handler to process OAuth authentication results.
|
||||
* @type {Function}
|
||||
* @param {ServerRequest} req - Express request object.
|
||||
* @param {ServerResponse} res - Express response object.
|
||||
* @param {NextFunction} next - Express next middleware function.
|
||||
*/
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
if (res.headersSent) {
|
||||
return;
|
||||
}
|
||||
|
||||
await checkBan(req, res);
|
||||
if (req.banned) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
req.user &&
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
res.redirect(redirectUri);
|
||||
} catch (err) {
|
||||
logger.error('Error in setting authentication tokens:', err);
|
||||
next(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createOAuthHandler,
|
||||
};
|
||||
@@ -12,7 +12,8 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
const { checkMigrations } = require('./services/start/migration');
|
||||
@@ -108,6 +109,7 @@ const startServer = async () => {
|
||||
app.use('/oauth', routes.oauth);
|
||||
/* API Endpoints */
|
||||
app.use('/api/auth', routes.auth);
|
||||
app.use('/api/admin', routes.adminAuth);
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/user', routes.user);
|
||||
@@ -126,7 +128,7 @@ const startServer = async () => {
|
||||
app.use('/api/config', routes.config);
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
@@ -154,7 +156,7 @@ const startServer = async () => {
|
||||
res.send(updatedIndexHtml);
|
||||
});
|
||||
|
||||
app.listen(port, host, () => {
|
||||
app.listen(port, host, async () => {
|
||||
if (host === '0.0.0.0') {
|
||||
logger.info(
|
||||
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
|
||||
@@ -163,7 +165,9 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCPs().then(() => checkMigrations());
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -11,18 +11,25 @@ const { getAppConfig } = require('~/server/services/Config');
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {Function} next - Next middleware function.
|
||||
*
|
||||
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the domain's email is allowed
|
||||
* @returns {Promise<void>} - Calls next middleware if the domain's email is allowed, otherwise redirects to login
|
||||
*/
|
||||
const checkDomainAllowed = async (req, res, next = () => {}) => {
|
||||
const email = req?.user?.email;
|
||||
const appConfig = await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
});
|
||||
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
|
||||
return res.redirect('/login');
|
||||
} else {
|
||||
return next();
|
||||
const checkDomainAllowed = async (req, res, next) => {
|
||||
try {
|
||||
const email = req?.user?.email;
|
||||
const appConfig = await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
});
|
||||
|
||||
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
|
||||
res.redirect('/login');
|
||||
return;
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('[checkDomainAllowed] Error checking domain:', error);
|
||||
res.redirect('/login');
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const validatePasswordReset = require('./validatePasswordReset');
|
||||
const validateRegistration = require('./validateRegistration');
|
||||
const validateImageRequest = require('./validateImageRequest');
|
||||
const buildEndpointOption = require('./buildEndpointOption');
|
||||
const validateMessageReq = require('./validateMessageReq');
|
||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||
@@ -15,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const configMiddleware = require('./config/app');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireAdmin = require('./requireAdmin');
|
||||
const moderateText = require('./moderateText');
|
||||
const logHeaders = require('./logHeaders');
|
||||
const setHeaders = require('./setHeaders');
|
||||
@@ -37,6 +37,7 @@ module.exports = {
|
||||
setHeaders,
|
||||
logHeaders,
|
||||
moderateText,
|
||||
requireAdmin,
|
||||
validateModel,
|
||||
requireJwtAuth,
|
||||
checkInviteUser,
|
||||
@@ -50,6 +51,5 @@ module.exports = {
|
||||
validateMessageReq,
|
||||
buildEndpointOption,
|
||||
validateRegistration,
|
||||
validateImageRequest,
|
||||
validatePasswordReset,
|
||||
};
|
||||
|
||||
22
api/server/middleware/requireAdmin.js
Normal file
22
api/server/middleware/requireAdmin.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Middleware to check if authenticated user has admin role
|
||||
* Should be used AFTER authentication middleware (requireJwtAuth, requireLocalAuth, etc.)
|
||||
*/
|
||||
const requireAdmin = (req, res, next) => {
|
||||
if (!req.user) {
|
||||
logger.warn('[requireAdmin] No user found in request');
|
||||
return res.status(401).json({ message: 'Authentication required' });
|
||||
}
|
||||
|
||||
if (!req.user.role || req.user.role !== SystemRoles.ADMIN) {
|
||||
logger.debug('[requireAdmin] Access denied for non-admin user:', req.user.email);
|
||||
return res.status(403).json({ message: 'Access denied: Admin privileges required' });
|
||||
}
|
||||
|
||||
next();
|
||||
};
|
||||
|
||||
module.exports = requireAdmin;
|
||||
@@ -1,6 +1,6 @@
|
||||
const passport = require('passport');
|
||||
const cookies = require('cookie');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
const jwt = require('jsonwebtoken');
|
||||
const validateImageRequest = require('~/server/middleware/validateImageRequest');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const createValidateImageRequest = require('~/server/middleware/validateImageRequest');
|
||||
|
||||
jest.mock('~/server/services/Config/app', () => ({
|
||||
getAppConfig: jest.fn(),
|
||||
jest.mock('@librechat/api', () => ({
|
||||
isEnabled: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('validateImageRequest middleware', () => {
|
||||
let req, res, next;
|
||||
let req, res, next, validateImageRequest;
|
||||
const validObjectId = '65cfb246f7ecadb8b1e8036b';
|
||||
const { getAppConfig } = require('~/server/services/Config/app');
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -22,116 +22,278 @@ describe('validateImageRequest middleware', () => {
|
||||
};
|
||||
next = jest.fn();
|
||||
process.env.JWT_REFRESH_SECRET = 'test-secret';
|
||||
process.env.OPENID_REUSE_TOKENS = 'false';
|
||||
|
||||
// Mock getAppConfig to return secureImageLinks: true by default
|
||||
getAppConfig.mockResolvedValue({
|
||||
secureImageLinks: true,
|
||||
});
|
||||
// Default: OpenID token reuse disabled
|
||||
isEnabled.mockReturnValue(false);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('should call next() if secureImageLinks is false', async () => {
|
||||
getAppConfig.mockResolvedValue({
|
||||
secureImageLinks: false,
|
||||
describe('Factory function', () => {
|
||||
test('should return a pass-through middleware if secureImageLinks is false', async () => {
|
||||
const middleware = createValidateImageRequest(false);
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return validation middleware if secureImageLinks is true', async () => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 401 if refresh token is not provided', async () => {
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
describe('Standard LibreChat token flow', () => {
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is invalid', async () => {
|
||||
req.headers.cookie = 'refreshToken=invalid-token';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
test('should return 401 if refresh token is not provided', async () => {
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is expired', async () => {
|
||||
const expiredToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${expiredToken}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should call next() for valid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 for invalid ObjectId format', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/123/example.jpg'; // Invalid ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
// File traversal tests
|
||||
test('should prevent file traversal attempts', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
|
||||
const traversalAttempts = [
|
||||
`/images/${validObjectId}/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
|
||||
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
|
||||
];
|
||||
|
||||
for (const attempt of traversalAttempts) {
|
||||
req.originalUrl = attempt;
|
||||
test('should return 403 if refresh token is invalid', async () => {
|
||||
req.headers.cookie = 'refreshToken=invalid-token';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
jest.clearAllMocks();
|
||||
}
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is expired', async () => {
|
||||
const expiredToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${expiredToken}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should call next() for valid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should allow agent avatar pattern for any valid ObjectId', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should prevent file traversal attempts', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
|
||||
const traversalAttempts = [
|
||||
`/images/${validObjectId}/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
|
||||
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
|
||||
];
|
||||
|
||||
for (const attempt of traversalAttempts) {
|
||||
req.originalUrl = attempt;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
jest.clearAllMocks();
|
||||
// Reset mocks for next iteration
|
||||
res.status = jest.fn().mockReturnThis();
|
||||
res.send = jest.fn();
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle URL encoded characters in valid paths', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle URL encoded characters in valid paths', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
describe('OpenID token flow', () => {
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
// Enable OpenID token reuse
|
||||
isEnabled.mockReturnValue(true);
|
||||
process.env.OPENID_REUSE_TOKENS = 'true';
|
||||
});
|
||||
|
||||
test('should return 403 if no OpenID user ID cookie when token_provider is openid', async () => {
|
||||
req.headers.cookie = 'refreshToken=dummy-token; token_provider=openid';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should validate JWT-signed user ID for OpenID flow', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid JWT-signed user ID', async () => {
|
||||
req.headers.cookie =
|
||||
'refreshToken=dummy-token; token_provider=openid; openid_user_id=invalid-jwt';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 for expired JWT-signed user ID', async () => {
|
||||
const expiredSignedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${expiredSignedUserId}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should validate image path against JWT-signed user ID', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
const differentObjectId = '65cfb246f7ecadb8b1e8036c';
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = `/images/${differentObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should allow agent avatars in OpenID flow', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Security edge cases', () => {
|
||||
let validToken;
|
||||
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle very long image filenames', async () => {
|
||||
const longFilename = 'a'.repeat(1000) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle URLs with maximum practical length', async () => {
|
||||
// Most browsers support URLs up to ~2000 characters
|
||||
const longFilename = 'x'.repeat(1900) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should accept URLs just under the 2048 limit', async () => {
|
||||
// Create a URL exactly 2047 characters long
|
||||
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
|
||||
const filenameLength = 2047 - baseLength;
|
||||
const filename = 'a'.repeat(filenameLength) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${filename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle malformed URL encoding gracefully', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/test%ZZinvalid.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should reject URLs with null bytes', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/test\x00.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should handle URLs with repeated slashes', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}//test.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should reject extremely long URLs as potential DoS', async () => {
|
||||
// Create a URL longer than 2048 characters
|
||||
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
|
||||
const filenameLength = 2049 - baseLength; // Ensure total length exceeds 2048
|
||||
const extremelyLongFilename = 'x'.repeat(filenameLength) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${extremelyLongFilename}`;
|
||||
// Verify our test URL is actually too long
|
||||
expect(req.originalUrl.length).toBeGreaterThan(2048);
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getAppConfig } = require('~/server/services/Config/app');
|
||||
|
||||
const OBJECT_ID_LENGTH = 24;
|
||||
const OBJECT_ID_PATTERN = /^[0-9a-f]{24}$/i;
|
||||
@@ -22,50 +22,129 @@ function isValidObjectId(id) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Middleware to validate image request.
|
||||
* Must be set by `secureImageLinks` via custom config file.
|
||||
* Validates a LibreChat refresh token
|
||||
* @param {string} refreshToken - The refresh token to validate
|
||||
* @returns {{valid: boolean, userId?: string, error?: string}} - Validation result
|
||||
*/
|
||||
async function validateImageRequest(req, res, next) {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
if (!appConfig.secureImageLinks) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Refresh token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
let payload;
|
||||
function validateToken(refreshToken) {
|
||||
try {
|
||||
payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
|
||||
if (!isValidObjectId(payload.id)) {
|
||||
return { valid: false, error: 'Invalid User ID' };
|
||||
}
|
||||
|
||||
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
|
||||
if (payload.exp < currentTimeInSeconds) {
|
||||
return { valid: false, error: 'Refresh token expired' };
|
||||
}
|
||||
|
||||
return { valid: true, userId: payload.id };
|
||||
} catch (err) {
|
||||
logger.warn('[validateImageRequest]', err);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (!isValidObjectId(payload.id)) {
|
||||
logger.warn('[validateImageRequest] Invalid User ID');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
|
||||
if (payload.exp < currentTimeInSeconds) {
|
||||
logger.warn('[validateImageRequest] Refresh token expired');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const fullPath = decodeURIComponent(req.originalUrl);
|
||||
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
logger.warn('[validateToken]', err);
|
||||
return { valid: false, error: 'Invalid token' };
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = validateImageRequest;
|
||||
/**
|
||||
* Factory to create the `validateImageRequest` middleware with configured secureImageLinks
|
||||
* @param {boolean} [secureImageLinks] - Whether secure image links are enabled
|
||||
*/
|
||||
function createValidateImageRequest(secureImageLinks) {
|
||||
if (!secureImageLinks) {
|
||||
return (_req, _res, next) => next();
|
||||
}
|
||||
/**
|
||||
* Middleware to validate image request.
|
||||
* Supports both LibreChat refresh tokens and OpenID JWT tokens.
|
||||
* Must be set by `secureImageLinks` via custom config file.
|
||||
*/
|
||||
return async function validateImageRequest(req, res, next) {
|
||||
try {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
if (!cookieHeader) {
|
||||
logger.warn('[validateImageRequest] No cookies provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const parsedCookies = cookies.parse(cookieHeader);
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const tokenProvider = parsedCookies.token_provider;
|
||||
let userIdForPath;
|
||||
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
const openidUserId = parsedCookies.openid_user_id;
|
||||
if (!openidUserId) {
|
||||
logger.warn('[validateImageRequest] No OpenID user ID cookie found');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const validationResult = validateToken(openidUserId);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
} else {
|
||||
const validationResult = validateToken(refreshToken);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
}
|
||||
|
||||
if (!userIdForPath) {
|
||||
logger.warn('[validateImageRequest] No user ID available for path validation');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const MAX_URL_LENGTH = 2048;
|
||||
if (req.originalUrl.length > MAX_URL_LENGTH) {
|
||||
logger.warn('[validateImageRequest] URL too long');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (req.originalUrl.includes('\x00')) {
|
||||
logger.warn('[validateImageRequest] URL contains null byte');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
let fullPath;
|
||||
try {
|
||||
fullPath = decodeURIComponent(req.originalUrl);
|
||||
} catch {
|
||||
logger.warn('[validateImageRequest] Invalid URL encoding');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const agentAvatarPattern = /^\/images\/[a-f0-9]{24}\/agent-[^/]*$/;
|
||||
if (agentAvatarPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
return next();
|
||||
}
|
||||
|
||||
const escapedUserId = userIdForPath.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const pathPattern = new RegExp(`^/images/${escapedUserId}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[validateImageRequest] Error:', error);
|
||||
res.status(500).send('Internal Server Error');
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = createValidateImageRequest;
|
||||
|
||||
66
api/server/routes/admin/auth.js
Normal file
66
api/server/routes/admin/auth.js
Normal file
@@ -0,0 +1,66 @@
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const middleware = require('~/server/middleware');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
const setBalanceConfig = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post(
|
||||
'/login/local',
|
||||
middleware.logHeaders,
|
||||
middleware.loginLimiter,
|
||||
middleware.checkBan,
|
||||
middleware.requireLocalAuth,
|
||||
middleware.requireAdmin,
|
||||
setBalanceConfig,
|
||||
loginController,
|
||||
);
|
||||
|
||||
router.get('/verify', middleware.requireJwtAuth, middleware.requireAdmin, (req, res) => {
|
||||
const { password: _p, totpSecret: _t, __v, ...user } = req.user;
|
||||
user.id = user._id.toString();
|
||||
res.status(200).json({ user });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid/check', (req, res) => {
|
||||
const openidConfig = getOpenIdConfig();
|
||||
if (!openidConfig) {
|
||||
return res.status(404).json({ message: 'OpenID configuration not found' });
|
||||
}
|
||||
res.status(200).json({ message: 'OpenID check successful' });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid', (req, res, next) => {
|
||||
return passport.authenticate('openidAdmin', {
|
||||
session: false,
|
||||
state: randomState(),
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/openid/callback',
|
||||
passport.authenticate('openidAdmin', {
|
||||
failureRedirect: `${process.env.DOMAIN_CLIENT}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
middleware.requireAdmin,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(
|
||||
(process.env.ADMIN_PANEL_URL || 'http://localhost:3000') + '/auth/openid/callback',
|
||||
),
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -117,9 +117,16 @@ router.get('/', async function (req, res) {
|
||||
openidReuseTokens,
|
||||
};
|
||||
|
||||
payload.mcpServers = {};
|
||||
const minPasswordLength = parseInt(process.env.MIN_PASSWORD_LENGTH, 10);
|
||||
if (minPasswordLength && !isNaN(minPasswordLength)) {
|
||||
payload.minPasswordLength = minPasswordLength;
|
||||
}
|
||||
|
||||
const getMCPServers = () => {
|
||||
try {
|
||||
if (appConfig?.mcpConfig == null) {
|
||||
return;
|
||||
}
|
||||
const mcpManager = getMCPManager();
|
||||
if (!mcpManager) {
|
||||
return;
|
||||
@@ -128,6 +135,9 @@ router.get('/', async function (req, res) {
|
||||
if (!mcpServers) return;
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
for (const serverName in mcpServers) {
|
||||
if (!payload.mcpServers) {
|
||||
payload.mcpServers = {};
|
||||
}
|
||||
const serverConfig = mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = removeNullishValues({
|
||||
startup: serverConfig?.startup,
|
||||
|
||||
@@ -4,9 +4,13 @@ const { sleep } = require('@librechat/agents');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
createImportLimiters,
|
||||
createForkLimiters,
|
||||
configMiddleware,
|
||||
} = require('~/server/middleware');
|
||||
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
||||
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
|
||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { importConversations } = require('~/server/utils/import');
|
||||
@@ -171,6 +175,7 @@ router.post(
|
||||
'/import',
|
||||
importIpLimiter,
|
||||
importUserLimiter,
|
||||
configMiddleware,
|
||||
upload.single('file'),
|
||||
async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -31,6 +31,7 @@ const { getAssistant } = require('~/models/Assistant');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
const { Readable } = require('stream');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
@@ -184,6 +185,7 @@ router.delete('/', async (req, res) => {
|
||||
role: req.user.role,
|
||||
fileIds: nonOwnedFileIds,
|
||||
agentId: req.body.agent_id,
|
||||
isDelete: true,
|
||||
});
|
||||
|
||||
for (const file of nonOwnedFiles) {
|
||||
@@ -325,11 +327,6 @@ router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
|
||||
res.setHeader('X-File-Metadata', JSON.stringify(file));
|
||||
};
|
||||
|
||||
/** @type {{ body: import('stream').PassThrough } | undefined} */
|
||||
let passThrough;
|
||||
/** @type {ReadableStream | undefined} */
|
||||
let fileStream;
|
||||
|
||||
if (checkOpenAIStorage(file.source)) {
|
||||
req.body = { model: file.model };
|
||||
const endpointMap = {
|
||||
@@ -342,12 +339,19 @@ router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
|
||||
overrideEndpoint: endpointMap[file.source],
|
||||
});
|
||||
logger.debug(`Downloading file ${file_id} from OpenAI`);
|
||||
passThrough = await getDownloadStream(file_id, openai);
|
||||
const passThrough = await getDownloadStream(file_id, openai);
|
||||
setHeaders();
|
||||
logger.debug(`File ${file_id} downloaded from OpenAI`);
|
||||
passThrough.body.pipe(res);
|
||||
|
||||
// Handle both Node.js and Web streams
|
||||
const stream =
|
||||
passThrough.body && typeof passThrough.body.getReader === 'function'
|
||||
? Readable.fromWeb(passThrough.body)
|
||||
: passThrough.body;
|
||||
|
||||
stream.pipe(res);
|
||||
} else {
|
||||
fileStream = await getDownloadStream(req, file.filepath);
|
||||
const fileStream = await getDownloadStream(req, file.filepath);
|
||||
|
||||
fileStream.on('error', (streamError) => {
|
||||
logger.error('[DOWNLOAD ROUTE] Stream error:', streamError);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const accessPermissions = require('./accessPermissions');
|
||||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const adminAuth = require('./admin/auth');
|
||||
const tokenizer = require('./tokenizer');
|
||||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
@@ -32,6 +33,7 @@ module.exports = {
|
||||
mcp,
|
||||
edit,
|
||||
auth,
|
||||
adminAuth,
|
||||
keys,
|
||||
user,
|
||||
tags,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
@@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
// clear any reconnection attempts
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
|
||||
|
||||
const tools = await userConnection.fetchTools();
|
||||
await updateMCPUserTools({
|
||||
userId: flowState.userId,
|
||||
|
||||
@@ -4,10 +4,9 @@ const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
@@ -26,28 +25,7 @@ const domains = {
|
||||
router.use(logHeaders);
|
||||
router.use(loginLimiter);
|
||||
|
||||
const oauthHandler = async (req, res) => {
|
||||
try {
|
||||
await checkDomainAllowed(req, res);
|
||||
await checkBan(req, res);
|
||||
if (req.banned) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
req.user &&
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res);
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
res.redirect(domains.client);
|
||||
} catch (err) {
|
||||
logger.error('Error in setting authentication tokens:', err);
|
||||
}
|
||||
};
|
||||
const oauthHandler = createOAuthHandler();
|
||||
|
||||
router.get('/error', (req, res) => {
|
||||
/** A single error message is pushed by passport when authentication fails. */
|
||||
@@ -79,6 +57,7 @@ router.get(
|
||||
scope: ['openid', 'profile', 'email'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
@@ -104,6 +83,7 @@ router.get(
|
||||
profileFields: ['id', 'email', 'name'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
@@ -125,6 +105,7 @@ router.get(
|
||||
session: false,
|
||||
}),
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
@@ -148,6 +129,7 @@ router.get(
|
||||
scope: ['user:email', 'read:user'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
@@ -171,6 +153,7 @@ router.get(
|
||||
scope: ['identify', 'email'],
|
||||
}),
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
@@ -192,6 +175,7 @@ router.post(
|
||||
session: false,
|
||||
}),
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
oauthHandler,
|
||||
);
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ router.get('/all', async (req, res) => {
|
||||
router.get('/groups', async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { pageSize, pageNumber, limit, cursor, name, category, ...otherFilters } = req.query;
|
||||
const { pageSize, limit, cursor, name, category, ...otherFilters } = req.query;
|
||||
|
||||
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
|
||||
name,
|
||||
@@ -171,6 +171,13 @@ router.get('/groups', async (req, res) => {
|
||||
actualLimit = parseInt(pageSize, 10);
|
||||
}
|
||||
|
||||
if (
|
||||
actualCursor &&
|
||||
(actualCursor === 'undefined' || actualCursor === 'null' || actualCursor.length === 0)
|
||||
) {
|
||||
actualCursor = null;
|
||||
}
|
||||
|
||||
let accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
@@ -190,6 +197,7 @@ router.get('/groups', async (req, res) => {
|
||||
publicPromptGroupIds: publiclyAccessibleIds,
|
||||
});
|
||||
|
||||
// Cursor-based pagination only
|
||||
const result = await getListPromptGroupsByAccess({
|
||||
accessibleIds: filteredAccessibleIds,
|
||||
otherParams: filter,
|
||||
@@ -198,19 +206,21 @@ router.get('/groups', async (req, res) => {
|
||||
});
|
||||
|
||||
if (!result) {
|
||||
const emptyResponse = createEmptyPromptGroupsResponse({ pageNumber, pageSize, actualLimit });
|
||||
const emptyResponse = createEmptyPromptGroupsResponse({
|
||||
pageNumber: '1',
|
||||
pageSize: actualLimit,
|
||||
actualLimit,
|
||||
});
|
||||
return res.status(200).send(emptyResponse);
|
||||
}
|
||||
|
||||
const { data: promptGroups = [], has_more = false, after = null } = result;
|
||||
|
||||
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
|
||||
|
||||
const response = formatPromptGroupsResponse({
|
||||
promptGroups: groupsWithPublicFlag,
|
||||
pageNumber,
|
||||
pageSize,
|
||||
actualLimit,
|
||||
pageNumber: '1', // Always 1 for cursor-based pagination
|
||||
pageSize: actualLimit.toString(),
|
||||
hasMore: has_more,
|
||||
after,
|
||||
});
|
||||
|
||||
@@ -33,22 +33,11 @@ let promptRoutes;
|
||||
let Prompt, PromptGroup, AclEntry, AccessRole, User;
|
||||
let testUsers, testRoles;
|
||||
let grantPermission;
|
||||
let currentTestUser; // Track current user for middleware
|
||||
|
||||
// Helper function to set user in middleware
|
||||
function setTestUser(app, user) {
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...(user.toObject ? user.toObject() : user),
|
||||
id: user.id || user._id.toString(),
|
||||
_id: user._id,
|
||||
name: user.name,
|
||||
role: user.role,
|
||||
};
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
console.log('Setting admin user with role:', req.user.role);
|
||||
}
|
||||
next();
|
||||
});
|
||||
currentTestUser = user;
|
||||
}
|
||||
|
||||
beforeAll(async () => {
|
||||
@@ -75,14 +64,35 @@ beforeAll(async () => {
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware - default to owner
|
||||
setTestUser(app, testUsers.owner);
|
||||
// Add user middleware before routes
|
||||
app.use((req, res, next) => {
|
||||
if (currentTestUser) {
|
||||
req.user = {
|
||||
...(currentTestUser.toObject ? currentTestUser.toObject() : currentTestUser),
|
||||
id: currentTestUser._id.toString(),
|
||||
_id: currentTestUser._id,
|
||||
name: currentTestUser.name,
|
||||
role: currentTestUser.role,
|
||||
};
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
// Import routes after mocks are set up
|
||||
// Set default user
|
||||
currentTestUser = testUsers.owner;
|
||||
|
||||
// Import routes after middleware is set up
|
||||
promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Always reset to owner user after each test for isolation
|
||||
if (currentTestUser !== testUsers.owner) {
|
||||
currentTestUser = testUsers.owner;
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
@@ -116,36 +126,26 @@ async function setupTestData() {
|
||||
// Create test users
|
||||
testUsers = {
|
||||
owner: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Owner',
|
||||
email: 'owner@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
viewer: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Viewer',
|
||||
email: 'viewer@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
editor: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Editor',
|
||||
email: 'editor@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
noAccess: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'No Access',
|
||||
email: 'noaccess@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
admin: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Admin',
|
||||
email: 'admin@example.com',
|
||||
role: SystemRoles.ADMIN,
|
||||
@@ -181,8 +181,7 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
it('should have routes loaded', async () => {
|
||||
// This should at least not crash
|
||||
const response = await request(app).get('/api/prompts/test-404');
|
||||
console.log('Test 404 response status:', response.status);
|
||||
console.log('Test 404 response body:', response.body);
|
||||
|
||||
// We expect a 401 or 404, not 500
|
||||
expect(response.status).not.toBe(500);
|
||||
});
|
||||
@@ -207,12 +206,6 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
|
||||
const response = await request(app).post('/api/prompts').send(promptData);
|
||||
|
||||
if (response.status !== 200) {
|
||||
console.log('POST /api/prompts error status:', response.status);
|
||||
console.log('POST /api/prompts error body:', response.body);
|
||||
console.log('Console errors:', consoleErrorSpy.mock.calls);
|
||||
}
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.prompt).toBeDefined();
|
||||
expect(response.body.prompt.prompt).toBe(promptData.prompt.prompt);
|
||||
@@ -318,29 +311,8 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
});
|
||||
|
||||
it('should allow admin access without explicit permissions', async () => {
|
||||
// First, reset the app to remove previous middleware
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Set admin user BEFORE adding routes
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.admin.toObject(),
|
||||
id: testUsers.admin._id.toString(),
|
||||
_id: testUsers.admin._id,
|
||||
name: testUsers.admin.name,
|
||||
role: testUsers.admin.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
|
||||
// Now add the routes
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
console.log('Admin user:', testUsers.admin);
|
||||
console.log('Admin role:', testUsers.admin.role);
|
||||
console.log('SystemRoles.ADMIN:', SystemRoles.ADMIN);
|
||||
// Set admin user
|
||||
setTestUser(app, testUsers.admin);
|
||||
|
||||
const response = await request(app).get(`/api/prompts/${testPrompt._id}`).expect(200);
|
||||
|
||||
@@ -432,21 +404,8 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
grantedBy: testUsers.editor._id,
|
||||
});
|
||||
|
||||
// Recreate app with viewer user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.viewer.toObject(),
|
||||
id: testUsers.viewer._id.toString(),
|
||||
_id: testUsers.viewer._id,
|
||||
name: testUsers.viewer.name,
|
||||
role: testUsers.viewer.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
// Set viewer user
|
||||
setTestUser(app, testUsers.viewer);
|
||||
|
||||
await request(app)
|
||||
.delete(`/api/prompts/${authorPrompt._id}`)
|
||||
@@ -499,21 +458,8 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Recreate app to ensure fresh middleware
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.owner.toObject(),
|
||||
id: testUsers.owner._id.toString(),
|
||||
_id: testUsers.owner._id,
|
||||
name: testUsers.owner.name,
|
||||
role: testUsers.owner.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
// Ensure owner user
|
||||
setTestUser(app, testUsers.owner);
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/${testPrompt._id}/tags/production`)
|
||||
@@ -537,21 +483,8 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Recreate app with viewer user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.viewer.toObject(),
|
||||
id: testUsers.viewer._id.toString(),
|
||||
_id: testUsers.viewer._id,
|
||||
name: testUsers.viewer.name,
|
||||
role: testUsers.viewer.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
// Set viewer user
|
||||
setTestUser(app, testUsers.viewer);
|
||||
|
||||
await request(app).patch(`/api/prompts/${testPrompt._id}/tags/production`).expect(403);
|
||||
|
||||
@@ -610,4 +543,305 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
expect(response.body._id).toBe(publicPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
|
||||
describe('Pagination', () => {
|
||||
beforeEach(async () => {
|
||||
// Create multiple prompt groups for pagination testing
|
||||
const groups = [];
|
||||
for (let i = 0; i < 15; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Test Group ${i + 1}`,
|
||||
category: 'pagination-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000), // Stagger updatedAt for consistent ordering
|
||||
});
|
||||
groups.push(group);
|
||||
|
||||
// Grant owner permissions on each group
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should correctly indicate hasMore when there are more pages', async () => {
|
||||
const response = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10' })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.promptGroups).toHaveLength(10);
|
||||
expect(response.body.has_more).toBe(true);
|
||||
expect(response.body.after).toBeTruthy();
|
||||
// Since has_more is true, pages should be a high number (9999 in our fix)
|
||||
expect(parseInt(response.body.pages)).toBeGreaterThan(1);
|
||||
});
|
||||
|
||||
it('should correctly indicate no more pages on the last page', async () => {
|
||||
// First get the cursor for page 2
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
// Now fetch the second page using the cursor
|
||||
const response = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10', cursor: firstPage.body.after })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.promptGroups).toHaveLength(5); // 15 total, 10 on page 1, 5 on page 2
|
||||
expect(response.body.has_more).toBe(false);
|
||||
});
|
||||
|
||||
it('should support cursor-based pagination', async () => {
|
||||
// First page
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.promptGroups).toHaveLength(5);
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
// Second page using cursor
|
||||
const secondPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', cursor: firstPage.body.after })
|
||||
.expect(200);
|
||||
|
||||
expect(secondPage.body.promptGroups).toHaveLength(5);
|
||||
expect(secondPage.body.has_more).toBe(true);
|
||||
expect(secondPage.body.after).toBeTruthy();
|
||||
|
||||
// Verify different groups
|
||||
const firstPageIds = firstPage.body.promptGroups.map((g) => g._id);
|
||||
const secondPageIds = secondPage.body.promptGroups.map((g) => g._id);
|
||||
expect(firstPageIds).not.toEqual(secondPageIds);
|
||||
});
|
||||
|
||||
it('should paginate correctly with category filtering', async () => {
|
||||
// Create groups with different categories
|
||||
await PromptGroup.deleteMany({}); // Clear existing groups
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create 8 groups with category 'test-cat-1'
|
||||
for (let i = 0; i < 8; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Category 1 Group ${i + 1}`,
|
||||
category: 'test-cat-1',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Create 7 groups with category 'test-cat-2'
|
||||
for (let i = 0; i < 7; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Category 2 Group ${i + 1}`,
|
||||
category: 'test-cat-2',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - (i + 8) * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Test pagination with category filter
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', category: 'test-cat-1' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.promptGroups).toHaveLength(5);
|
||||
expect(firstPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
const secondPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', cursor: firstPage.body.after, category: 'test-cat-1' })
|
||||
.expect(200);
|
||||
|
||||
expect(secondPage.body.promptGroups).toHaveLength(3); // 8 total, 5 on page 1, 3 on page 2
|
||||
expect(secondPage.body.promptGroups.every((g) => g.category === 'test-cat-1')).toBe(true);
|
||||
expect(secondPage.body.has_more).toBe(false);
|
||||
});
|
||||
|
||||
it('should paginate correctly with name/keyword filtering', async () => {
|
||||
// Create groups with specific names
|
||||
await PromptGroup.deleteMany({}); // Clear existing groups
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create 12 groups with 'Search' in the name
|
||||
for (let i = 0; i < 12; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Search Test Group ${i + 1}`,
|
||||
category: 'search-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Create 5 groups without 'Search' in the name
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `Other Group ${i + 1}`,
|
||||
category: 'other-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - (i + 12) * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Test pagination with name filter
|
||||
const firstPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10', name: 'Search' })
|
||||
.expect(200);
|
||||
|
||||
expect(firstPage.body.promptGroups).toHaveLength(10);
|
||||
expect(firstPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
|
||||
expect(firstPage.body.has_more).toBe(true);
|
||||
expect(firstPage.body.after).toBeTruthy();
|
||||
|
||||
const secondPage = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '10', cursor: firstPage.body.after, name: 'Search' })
|
||||
.expect(200);
|
||||
|
||||
expect(secondPage.body.promptGroups).toHaveLength(2); // 12 total, 10 on page 1, 2 on page 2
|
||||
expect(secondPage.body.promptGroups.every((g) => g.name.includes('Search'))).toBe(true);
|
||||
expect(secondPage.body.has_more).toBe(false);
|
||||
});
|
||||
|
||||
it('should paginate correctly with combined filters', async () => {
|
||||
// Create groups with various combinations
|
||||
await PromptGroup.deleteMany({}); // Clear existing groups
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create 6 groups matching both category and name filters
|
||||
for (let i = 0; i < 6; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `API Test Group ${i + 1}`,
|
||||
category: 'api-category',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - i * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Create groups that only match one filter
|
||||
for (let i = 0; i < 4; i++) {
|
||||
const group = await PromptGroup.create({
|
||||
name: `API Other Group ${i + 1}`,
|
||||
category: 'other-category',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
updatedAt: new Date(Date.now() - (i + 6) * 1000),
|
||||
});
|
||||
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: group._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
}
|
||||
|
||||
// Test pagination with both filters
|
||||
const response = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', name: 'API', category: 'api-category' })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.promptGroups).toHaveLength(5);
|
||||
expect(
|
||||
response.body.promptGroups.every(
|
||||
(g) => g.name.includes('API') && g.category === 'api-category',
|
||||
),
|
||||
).toBe(true);
|
||||
expect(response.body.has_more).toBe(true);
|
||||
expect(response.body.after).toBeTruthy();
|
||||
|
||||
// Page 2
|
||||
const page2 = await request(app)
|
||||
.get('/api/prompts/groups')
|
||||
.query({ limit: '5', cursor: response.body.after, name: 'API', category: 'api-category' })
|
||||
.expect(200);
|
||||
|
||||
expect(page2.body.promptGroups).toHaveLength(1); // 6 total, 5 on page 1, 1 on page 2
|
||||
expect(page2.body.has_more).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -49,6 +49,7 @@ const AppService = async () => {
|
||||
enabled: isEnabled(process.env.CHECK_BALANCE),
|
||||
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
|
||||
};
|
||||
const transactions = config.transactions ?? configDefaults.transactions;
|
||||
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
|
||||
|
||||
process.env.CDN_PROVIDER = fileStrategy;
|
||||
@@ -84,6 +85,7 @@ const AppService = async () => {
|
||||
memory,
|
||||
speech,
|
||||
balance,
|
||||
transactions,
|
||||
mcpConfig,
|
||||
webSearch,
|
||||
fileStrategy,
|
||||
|
||||
@@ -152,6 +152,7 @@ describe('AppService', () => {
|
||||
webSearch: expect.objectContaining({
|
||||
safeSearch: 1,
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
jinaApiUrl: '${JINA_API_URL}',
|
||||
cohereApiKey: '${COHERE_API_KEY}',
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngApiKey: '${SEARXNG_API_KEY}',
|
||||
|
||||
@@ -3,12 +3,12 @@ const jwt = require('jsonwebtoken');
|
||||
const { webcrypto } = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, checkEmailConfig } = require('@librechat/api');
|
||||
const { SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const {
|
||||
findUser,
|
||||
findToken,
|
||||
createUser,
|
||||
updateUser,
|
||||
findToken,
|
||||
countUsers,
|
||||
getUserById,
|
||||
findSession,
|
||||
@@ -181,6 +181,14 @@ const registerUser = async (user, additionalData = {}) => {
|
||||
|
||||
let newUserId;
|
||||
try {
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
const errorMessage =
|
||||
'The email address provided cannot be used. Please use a different email address.';
|
||||
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
|
||||
return { status: 403, message: errorMessage };
|
||||
}
|
||||
|
||||
const existingUser = await findUser({ email }, 'email _id');
|
||||
|
||||
if (existingUser) {
|
||||
@@ -195,14 +203,6 @@ const registerUser = async (user, additionalData = {}) => {
|
||||
return { status: 200, message: genericVerificationMessage };
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig({ role: user.role });
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
const errorMessage =
|
||||
'The email address provided cannot be used. Please use a different email address.';
|
||||
logger.error(`[registerUser] [Registration not allowed] [Email: ${user.email}]`);
|
||||
return { status: 403, message: errorMessage };
|
||||
}
|
||||
|
||||
//determine if this is the first registered user (not counting anonymous_user)
|
||||
const isFirstRegisteredUser = (await countUsers()) === 0;
|
||||
|
||||
@@ -252,6 +252,13 @@ const registerUser = async (user, additionalData = {}) => {
|
||||
*/
|
||||
const requestPasswordReset = async (req) => {
|
||||
const { email } = req.body;
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
const error = new Error(ErrorTypes.AUTH_FAILED);
|
||||
error.code = ErrorTypes.AUTH_FAILED;
|
||||
error.message = 'Email domain not allowed';
|
||||
return error;
|
||||
}
|
||||
const user = await findUser({ email }, 'email _id');
|
||||
const emailEnabled = checkEmailConfig();
|
||||
|
||||
@@ -350,23 +357,18 @@ const resetPassword = async (userId, token, password) => {
|
||||
|
||||
/**
|
||||
* Set Auth Tokens
|
||||
*
|
||||
* @param {String | ObjectId} userId
|
||||
* @param {Object} res
|
||||
* @param {String} sessionId
|
||||
* @param {ServerResponse} res
|
||||
* @param {ISession | null} [session=null]
|
||||
* @returns
|
||||
*/
|
||||
const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
const setAuthTokens = async (userId, res, _session = null) => {
|
||||
try {
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
let session;
|
||||
let session = _session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
|
||||
if (sessionId) {
|
||||
session = await findSession({ sessionId: sessionId }, { lean: false });
|
||||
if (session && session._id && session.expiration != null) {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
@@ -376,6 +378,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
httpOnly: true,
|
||||
@@ -402,9 +407,10 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
* @param {import('openid-client').TokenEndpointResponse & import('openid-client').TokenEndpointResponseHelpers} tokenset
|
||||
* - The tokenset object containing access and refresh tokens
|
||||
* @param {Object} res - response object
|
||||
* @param {string} [userId] - Optional MongoDB user ID for image path validation
|
||||
* @returns {String} - access token
|
||||
*/
|
||||
const setOpenIDAuthTokens = (tokenset, res) => {
|
||||
const setOpenIDAuthTokens = (tokenset, res, userId) => {
|
||||
try {
|
||||
if (!tokenset) {
|
||||
logger.error('[setOpenIDAuthTokens] No tokenset found in request');
|
||||
@@ -435,6 +441,18 @@ const setOpenIDAuthTokens = (tokenset, res) => {
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
if (userId && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
/** JWT-signed user ID cookie for image path validation when OPENID_REUSE_TOKENS is enabled */
|
||||
const signedUserId = jwt.sign({ id: userId }, process.env.JWT_REFRESH_SECRET, {
|
||||
expiresIn: expiryInMilliseconds / 1000,
|
||||
});
|
||||
res.cookie('openid_user_id', signedUserId, {
|
||||
expires: expirationDate,
|
||||
httpOnly: true,
|
||||
secure: isProduction,
|
||||
sameSite: 'strict',
|
||||
});
|
||||
}
|
||||
return tokenset.access_token;
|
||||
} catch (error) {
|
||||
logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error);
|
||||
@@ -452,7 +470,7 @@ const setOpenIDAuthTokens = (tokenset, res) => {
|
||||
const resendVerificationEmail = async (req) => {
|
||||
try {
|
||||
const { email } = req.body;
|
||||
await deleteTokens(email);
|
||||
await deleteTokens({ email });
|
||||
const user = await findUser({ email }, 'email _id name');
|
||||
|
||||
if (!user) {
|
||||
|
||||
@@ -4,6 +4,8 @@ const AppService = require('~/server/services/AppService');
|
||||
const { setCachedTools } = require('./getCachedTools');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
const BASE_CONFIG_KEY = '_BASE_';
|
||||
|
||||
/**
|
||||
* Get the app configuration based on user context
|
||||
* @param {Object} [options]
|
||||
@@ -14,8 +16,8 @@ const getLogStores = require('~/cache/getLogStores');
|
||||
async function getAppConfig(options = {}) {
|
||||
const { role, refresh } = options;
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cacheKey = role ? `${CacheKeys.APP_CONFIG}:${role}` : CacheKeys.APP_CONFIG;
|
||||
const cache = getLogStores(CacheKeys.APP_CONFIG);
|
||||
const cacheKey = role ? role : BASE_CONFIG_KEY;
|
||||
|
||||
if (!refresh) {
|
||||
const cached = await cache.get(cacheKey);
|
||||
@@ -24,7 +26,7 @@ async function getAppConfig(options = {}) {
|
||||
}
|
||||
}
|
||||
|
||||
let baseConfig = await cache.get(CacheKeys.APP_CONFIG);
|
||||
let baseConfig = await cache.get(BASE_CONFIG_KEY);
|
||||
if (!baseConfig) {
|
||||
logger.info('[getAppConfig] App configuration not initialized. Initializing AppService...');
|
||||
baseConfig = await AppService();
|
||||
@@ -37,7 +39,7 @@ async function getAppConfig(options = {}) {
|
||||
await setCachedTools(baseConfig.availableTools, { isGlobal: true });
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.APP_CONFIG, baseConfig);
|
||||
await cache.set(BASE_CONFIG_KEY, baseConfig);
|
||||
}
|
||||
|
||||
// For now, return the base config
|
||||
|
||||
@@ -119,10 +119,6 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
|
||||
.filter((endpoint) => endpoint.customParams)
|
||||
.forEach((endpoint) => parseCustomParams(endpoint.name, endpoint.customParams));
|
||||
|
||||
if (customConfig.cache) {
|
||||
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
|
||||
await cache.set(CacheKeys.LIBRECHAT_YAML_CONFIG, customConfig);
|
||||
}
|
||||
|
||||
if (result.data.modelSpecs) {
|
||||
customConfig.modelSpecs = result.data.modelSpecs;
|
||||
|
||||
@@ -48,16 +48,11 @@ const axios = require('axios');
|
||||
const { loadYaml } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
describe('loadCustomConfig', () => {
|
||||
const mockSet = jest.fn();
|
||||
const mockCache = { set: mockSet };
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
delete process.env.CONFIG_PATH;
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
});
|
||||
|
||||
it('should return null and log error if remote config fetch fails', async () => {
|
||||
@@ -94,7 +89,6 @@ describe('loadCustomConfig', () => {
|
||||
const result = await loadCustomConfig();
|
||||
|
||||
expect(result).toEqual(mockConfig);
|
||||
expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig);
|
||||
});
|
||||
|
||||
it('should return null and log if config schema validation fails', async () => {
|
||||
@@ -134,7 +128,6 @@ describe('loadCustomConfig', () => {
|
||||
axios.get.mockResolvedValue({ data: mockConfig });
|
||||
const result = await loadCustomConfig();
|
||||
expect(result).toEqual(mockConfig);
|
||||
expect(mockSet).toHaveBeenCalledWith(expect.anything(), mockConfig);
|
||||
});
|
||||
|
||||
it('should return null if the remote config file is not found', async () => {
|
||||
@@ -168,7 +161,6 @@ describe('loadCustomConfig', () => {
|
||||
process.env.CONFIG_PATH = 'validConfig.yaml';
|
||||
loadYaml.mockReturnValueOnce(mockConfig);
|
||||
await loadCustomConfig();
|
||||
expect(mockSet).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should log the loaded custom config', async () => {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const {
|
||||
primeResources,
|
||||
getModelMaxTokens,
|
||||
extractLibreChatParams,
|
||||
optionalChainWithEmptyCheck,
|
||||
} = require('@librechat/api');
|
||||
@@ -17,7 +18,6 @@ const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getFiles, getToolFilesByIds } = require('~/models/File');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { getLLMConfig } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||
const AnthropicClient = require('~/app/clients/AnthropicClient');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
|
||||
@@ -40,7 +40,6 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
proxy: PROXY ?? null,
|
||||
userId: req.user.id,
|
||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||
modelOptions: endpointOption?.model_parameters ?? {},
|
||||
},
|
||||
@@ -49,6 +48,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
if (overrideModel) {
|
||||
clientOptions.modelOptions.model = overrideModel;
|
||||
}
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
return getLLMConfig(anthropicApiKey, clientOptions);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
|
||||
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating an Anthropic language model (LLM) instance.
|
||||
*
|
||||
* @param {string} apiKey - The API key for authentication with Anthropic.
|
||||
* @param {Object} [options={}] - Additional options for configuring the LLM.
|
||||
* @param {Object} [options.modelOptions] - Model-specific options.
|
||||
* @param {string} [options.modelOptions.model] - The name of the model to use.
|
||||
* @param {number} [options.modelOptions.maxOutputTokens] - The maximum number of tokens to generate.
|
||||
* @param {number} [options.modelOptions.temperature] - Controls randomness in output generation.
|
||||
* @param {number} [options.modelOptions.topP] - Controls diversity of output generation.
|
||||
* @param {number} [options.modelOptions.topK] - Controls the number of top tokens to consider.
|
||||
* @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens.
|
||||
* @param {boolean} [options.modelOptions.stream] - Whether to stream the response.
|
||||
* @param {string} options.userId - The user ID for tracking and personalization.
|
||||
* @param {string} [options.proxy] - Proxy server URL.
|
||||
* @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used.
|
||||
*
|
||||
* @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed.
|
||||
*/
|
||||
function getLLMConfig(apiKey, options = {}) {
|
||||
const systemOptions = {
|
||||
thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default,
|
||||
promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default,
|
||||
thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default,
|
||||
};
|
||||
for (let key in systemOptions) {
|
||||
delete options.modelOptions[key];
|
||||
}
|
||||
const defaultOptions = {
|
||||
model: anthropicSettings.model.default,
|
||||
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
|
||||
|
||||
/** @type {AnthropicClientOptions} */
|
||||
let requestOptions = {
|
||||
apiKey,
|
||||
model: mergedOptions.model,
|
||||
stream: mergedOptions.stream,
|
||||
temperature: mergedOptions.temperature,
|
||||
stopSequences: mergedOptions.stop,
|
||||
maxTokens:
|
||||
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
|
||||
clientOptions: {},
|
||||
invocationKwargs: {
|
||||
metadata: {
|
||||
user_id: options.userId,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
requestOptions = configureReasoning(requestOptions, systemOptions);
|
||||
|
||||
if (!/claude-3[-.]7/.test(mergedOptions.model)) {
|
||||
requestOptions.topP = mergedOptions.topP;
|
||||
requestOptions.topK = mergedOptions.topK;
|
||||
} else if (requestOptions.thinking == null) {
|
||||
requestOptions.topP = mergedOptions.topP;
|
||||
requestOptions.topK = mergedOptions.topK;
|
||||
}
|
||||
|
||||
const supportsCacheControl =
|
||||
systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model);
|
||||
const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl);
|
||||
if (headers) {
|
||||
requestOptions.clientOptions.defaultHeaders = headers;
|
||||
}
|
||||
|
||||
if (options.proxy) {
|
||||
const proxyAgent = new ProxyAgent(options.proxy);
|
||||
requestOptions.clientOptions.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (options.reverseProxyUrl) {
|
||||
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
|
||||
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
const tools = [];
|
||||
|
||||
if (mergedOptions.web_search) {
|
||||
tools.push({
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
tools,
|
||||
/** @type {AnthropicClientOptions} */
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = { getLLMConfig };
|
||||
@@ -1,341 +0,0 @@
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||
|
||||
jest.mock('https-proxy-agent', () => ({
|
||||
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
|
||||
}));
|
||||
|
||||
describe('getLLMConfig', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should create a basic configuration with default values', () => {
|
||||
const result = getLLMConfig('test-api-key', { modelOptions: {} });
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
|
||||
expect(result.llmConfig).toHaveProperty('model', 'claude-3-5-sonnet-latest');
|
||||
expect(result.llmConfig).toHaveProperty('stream', true);
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens');
|
||||
});
|
||||
|
||||
it('should include proxy settings when provided', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {},
|
||||
proxy: 'http://proxy:8080',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
|
||||
'ProxyAgent',
|
||||
);
|
||||
});
|
||||
|
||||
it('should include reverse proxy URL when provided', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {},
|
||||
reverseProxyUrl: 'http://reverse-proxy',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
|
||||
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy');
|
||||
});
|
||||
|
||||
it('should include topK and topP for non-Claude-3.7 models', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should include topK and topP for Claude-3.5 models', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3-7 models with thinking enabled (hyphen notation)', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
expect(result.llmConfig).toHaveProperty('thinking');
|
||||
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
|
||||
// When thinking is enabled, it uses the default thinkingBudget of 2000
|
||||
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model', () => {
|
||||
const modelOptions = {
|
||||
model: 'claude-sonnet-4-20250514',
|
||||
promptCache: true,
|
||||
};
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model formats', () => {
|
||||
const modelVariations = [
|
||||
'claude-sonnet-4-20250514',
|
||||
'claude-sonnet-4-latest',
|
||||
'anthropic/claude-sonnet-4-20250514',
|
||||
];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const modelOptions = { model, promptCache: true };
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3.7 models with thinking enabled (decimal notation)', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3.7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
expect(result.llmConfig).toHaveProperty('thinking');
|
||||
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
|
||||
// When thinking is enabled, it uses the default thinkingBudget of 2000
|
||||
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
|
||||
});
|
||||
|
||||
it('should handle custom maxOutputTokens', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus',
|
||||
maxOutputTokens: 2048,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 2048);
|
||||
});
|
||||
|
||||
it('should handle promptCache setting', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet',
|
||||
promptCache: true,
|
||||
},
|
||||
});
|
||||
|
||||
// We're not checking specific header values since that depends on the actual helper function
|
||||
// Just verifying that the promptCache setting is processed
|
||||
expect(result.llmConfig).toBeDefined();
|
||||
});
|
||||
|
||||
it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => {
|
||||
// Test with thinking explicitly set to null/undefined
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
|
||||
// Test with thinking explicitly set to false
|
||||
const result2 = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result2.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result2.llmConfig).toHaveProperty('topP', 0.9);
|
||||
|
||||
// Test with decimal notation as well
|
||||
const result3 = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3.7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result3.llmConfig).toHaveProperty('topK', 10);
|
||||
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle missing apiKey', () => {
|
||||
const result = getLLMConfig(undefined, { modelOptions: {} });
|
||||
expect(result.llmConfig).not.toHaveProperty('apiKey');
|
||||
});
|
||||
|
||||
it('should handle empty modelOptions', () => {
|
||||
expect(() => {
|
||||
getLLMConfig('test-api-key', {});
|
||||
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
|
||||
});
|
||||
|
||||
it('should handle no options parameter', () => {
|
||||
expect(() => {
|
||||
getLLMConfig('test-api-key');
|
||||
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
|
||||
});
|
||||
|
||||
it('should handle temperature, stop sequences, and stream settings', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
temperature: 0.7,
|
||||
stop: ['\n\n', 'END'],
|
||||
stream: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
|
||||
expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']);
|
||||
expect(result.llmConfig).toHaveProperty('stream', false);
|
||||
});
|
||||
|
||||
it('should handle maxOutputTokens when explicitly set to falsy value', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus',
|
||||
maxOutputTokens: null,
|
||||
},
|
||||
});
|
||||
|
||||
// The actual anthropicSettings.maxOutputTokens.reset('claude-3-opus') returns 4096
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
|
||||
});
|
||||
|
||||
it('should handle both proxy and reverseProxyUrl', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {},
|
||||
proxy: 'http://proxy:8080',
|
||||
reverseProxyUrl: 'https://reverse-proxy.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
|
||||
'ProxyAgent',
|
||||
);
|
||||
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com');
|
||||
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com');
|
||||
});
|
||||
|
||||
it('should handle prompt cache with supported model', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet',
|
||||
promptCache: true,
|
||||
},
|
||||
});
|
||||
|
||||
// claude-3-5-sonnet supports prompt caching and should get the appropriate headers
|
||||
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
|
||||
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle thinking and thinkingBudget options', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
thinking: true,
|
||||
thinkingBudget: 10000, // This exceeds the default max_tokens of 8192
|
||||
},
|
||||
});
|
||||
|
||||
// The function should add thinking configuration for claude-3-7 models
|
||||
expect(result.llmConfig).toHaveProperty('thinking');
|
||||
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
|
||||
// With claude-3-7-sonnet, the max_tokens default is 8192
|
||||
// Budget tokens gets adjusted to 90% of max_tokens (8192 * 0.9 = 7372) when it exceeds max_tokens
|
||||
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 7372);
|
||||
|
||||
// Test with budget_tokens within max_tokens limit
|
||||
const result2 = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
thinking: true,
|
||||
thinkingBudget: 2000,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result2.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
|
||||
});
|
||||
|
||||
it('should remove system options from modelOptions', () => {
|
||||
const modelOptions = {
|
||||
model: 'claude-3-opus',
|
||||
thinking: true,
|
||||
promptCache: true,
|
||||
thinkingBudget: 1000,
|
||||
temperature: 0.5,
|
||||
};
|
||||
|
||||
getLLMConfig('test-api-key', { modelOptions });
|
||||
|
||||
expect(modelOptions).not.toHaveProperty('thinking');
|
||||
expect(modelOptions).not.toHaveProperty('promptCache');
|
||||
expect(modelOptions).not.toHaveProperty('thinkingBudget');
|
||||
expect(modelOptions).toHaveProperty('temperature', 0.5);
|
||||
});
|
||||
|
||||
it('should handle all nullish values removal', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
temperature: null,
|
||||
topP: undefined,
|
||||
topK: 0,
|
||||
stop: [],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
expect(result.llmConfig).toHaveProperty('topK', 0);
|
||||
expect(result.llmConfig).toHaveProperty('stopSequences', []);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,3 +1,4 @@
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
@@ -7,7 +8,6 @@ const {
|
||||
const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks');
|
||||
const getOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
if (!endpointOption) {
|
||||
|
||||
@@ -36,10 +36,12 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
|
||||
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
|
||||
|
||||
/** Intentionally excludes passing `body`, i.e. `req.body`, as
|
||||
* values may not be accurate until `AgentClient` is initialized
|
||||
*/
|
||||
let resolvedHeaders = resolveHeaders({
|
||||
headers: endpointConfig.headers,
|
||||
user: req.user,
|
||||
body: req.body,
|
||||
});
|
||||
|
||||
if (CUSTOM_API_KEY.match(envVarRegex)) {
|
||||
|
||||
@@ -76,7 +76,10 @@ describe('custom/initializeClient', () => {
|
||||
expect(resolveHeaders).toHaveBeenCalledWith({
|
||||
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
user: { id: 'user-123', email: 'test@example.com', role: 'user' },
|
||||
/**
|
||||
* Note: Request-based Header Resolution is deferred until right before LLM request is made
|
||||
body: { endpoint: 'test-endpoint' }, // body - supports {{LIBRECHAT_BODY_*}} placeholders
|
||||
*/
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -109,9 +109,11 @@ class STTService {
|
||||
* @throws {Error} If no STT schema is set, multiple providers are set, or no provider is set.
|
||||
*/
|
||||
async getProviderSchema(req) {
|
||||
const appConfig = await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
});
|
||||
const appConfig =
|
||||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
}));
|
||||
const sttSchema = appConfig?.speech?.stt;
|
||||
if (!sttSchema) {
|
||||
throw new Error(
|
||||
@@ -157,9 +159,11 @@ class STTService {
|
||||
* Prepares the request for the OpenAI STT provider.
|
||||
* @param {Object} sttSchema - The STT schema for OpenAI.
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed.
|
||||
* @param {Object} audioFile - The audio file object (unused in OpenAI provider).
|
||||
* @param {string} language - The language code for the transcription.
|
||||
* @returns {Array} An array containing the URL, data, and headers for the request.
|
||||
*/
|
||||
openAIProvider(sttSchema, audioReadStream) {
|
||||
openAIProvider(sttSchema, audioReadStream, audioFile, language) {
|
||||
const url = sttSchema?.url || 'https://api.openai.com/v1/audio/transcriptions';
|
||||
const apiKey = extractEnvVariable(sttSchema.apiKey) || '';
|
||||
|
||||
@@ -168,6 +172,12 @@ class STTService {
|
||||
model: sttSchema.model,
|
||||
};
|
||||
|
||||
if (language) {
|
||||
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
|
||||
const isoLanguage = language.split('-')[0];
|
||||
data.language = isoLanguage;
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
...(apiKey && { Authorization: `Bearer ${apiKey}` }),
|
||||
@@ -182,10 +192,11 @@ class STTService {
|
||||
* @param {Object} sttSchema - The STT schema for Azure OpenAI.
|
||||
* @param {Buffer} audioBuffer - The audio data to be transcribed.
|
||||
* @param {Object} audioFile - The audio file object containing originalname, mimetype, and size.
|
||||
* @param {string} language - The language code for the transcription.
|
||||
* @returns {Array} An array containing the URL, data, and headers for the request.
|
||||
* @throws {Error} If the audio file size exceeds 25MB or the audio file format is not accepted.
|
||||
*/
|
||||
azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
|
||||
azureOpenAIProvider(sttSchema, audioBuffer, audioFile, language) {
|
||||
const url = `${genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: extractEnvVariable(sttSchema?.instanceName),
|
||||
azureOpenAIApiDeploymentName: extractEnvVariable(sttSchema?.deploymentName),
|
||||
@@ -209,6 +220,12 @@ class STTService {
|
||||
contentType: audioFile.mimetype,
|
||||
});
|
||||
|
||||
if (language) {
|
||||
/** Converted locale code (e.g., "en-US") to ISO-639-1 format (e.g., "en") */
|
||||
const isoLanguage = language.split('-')[0];
|
||||
formData.append('language', isoLanguage);
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
...(apiKey && { 'api-key': apiKey }),
|
||||
@@ -227,10 +244,11 @@ class STTService {
|
||||
* @param {Object} requestData - The data required for the STT request.
|
||||
* @param {Buffer} requestData.audioBuffer - The audio data to be transcribed.
|
||||
* @param {Object} requestData.audioFile - The audio file object containing originalname, mimetype, and size.
|
||||
* @param {string} requestData.language - The language code for the transcription.
|
||||
* @returns {Promise<string>} A promise that resolves to the transcribed text.
|
||||
* @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing.
|
||||
*/
|
||||
async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) {
|
||||
async sttRequest(provider, sttSchema, { audioBuffer, audioFile, language }) {
|
||||
const strategy = this.providerStrategies[provider];
|
||||
if (!strategy) {
|
||||
throw new Error('Invalid provider');
|
||||
@@ -241,7 +259,13 @@ class STTService {
|
||||
const audioReadStream = Readable.from(audioBuffer);
|
||||
audioReadStream.path = `audio.${fileExtension}`;
|
||||
|
||||
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
|
||||
const [url, data, headers] = strategy.call(
|
||||
this,
|
||||
sttSchema,
|
||||
audioReadStream,
|
||||
audioFile,
|
||||
language,
|
||||
);
|
||||
|
||||
try {
|
||||
const response = await axios.post(url, data, { headers });
|
||||
@@ -282,7 +306,8 @@ class STTService {
|
||||
|
||||
try {
|
||||
const [provider, sttSchema] = await this.getProviderSchema(req);
|
||||
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile });
|
||||
const language = req.body?.language || '';
|
||||
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile, language });
|
||||
res.json({ text });
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while processing the audio:', error);
|
||||
|
||||
@@ -35,11 +35,12 @@ class TTSService {
|
||||
|
||||
/**
|
||||
* Retrieves the configured TTS provider.
|
||||
* @param {AppConfig | null | undefined} [appConfig] - The app configuration object.
|
||||
* @returns {string} The name of the configured provider.
|
||||
* @throws {Error} If no provider is set or multiple providers are set.
|
||||
*/
|
||||
getProvider() {
|
||||
const ttsSchema = this.customConfig.speech.tts;
|
||||
getProvider(appConfig) {
|
||||
const ttsSchema = appConfig?.speech?.tts;
|
||||
if (!ttsSchema) {
|
||||
throw new Error(
|
||||
'No TTS schema is set. Did you configure TTS in the custom config (librechat.yaml)?',
|
||||
@@ -276,8 +277,8 @@ class TTSService {
|
||||
/**
|
||||
* Processes a text-to-speech request.
|
||||
* @async
|
||||
* @param {Object} req - The request object.
|
||||
* @param {Object} res - The response object.
|
||||
* @param {ServerRequest} req - The request object.
|
||||
* @param {ServerResponse} res - The response object.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async processTextToSpeech(req, res) {
|
||||
@@ -287,12 +288,14 @@ class TTSService {
|
||||
return res.status(400).send('Missing text in request body');
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig({
|
||||
role: req.user?.role,
|
||||
});
|
||||
const appConfig =
|
||||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
}));
|
||||
try {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
const provider = this.getProvider();
|
||||
const provider = this.getProvider(appConfig);
|
||||
const ttsSchema = appConfig?.speech?.tts?.[provider];
|
||||
const voice = await this.getVoice(ttsSchema, requestVoice);
|
||||
|
||||
@@ -344,14 +347,19 @@ class TTSService {
|
||||
/**
|
||||
* Streams audio data from the TTS provider.
|
||||
* @async
|
||||
* @param {Object} req - The request object.
|
||||
* @param {Object} res - The response object.
|
||||
* @param {ServerRequest} req - The request object.
|
||||
* @param {ServerResponse} res - The response object.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async streamAudio(req, res) {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
const provider = this.getProvider();
|
||||
const ttsSchema = this.customConfig.speech.tts[provider];
|
||||
const appConfig =
|
||||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
}));
|
||||
const provider = this.getProvider(appConfig);
|
||||
const ttsSchema = appConfig?.speech?.tts?.[provider];
|
||||
const voice = await this.getVoice(ttsSchema, req.body.voice);
|
||||
|
||||
let shouldContinue = true;
|
||||
@@ -436,8 +444,8 @@ async function createTTSService() {
|
||||
/**
|
||||
* Wrapper function for text-to-speech processing.
|
||||
* @async
|
||||
* @param {Object} req - The request object.
|
||||
* @param {Object} res - The response object.
|
||||
* @param {ServerRequest} req - The request object.
|
||||
* @param {ServerResponse} res - The response object.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function textToSpeech(req, res) {
|
||||
@@ -460,11 +468,12 @@ async function streamAudio(req, res) {
|
||||
/**
|
||||
* Wrapper function to get the configured TTS provider.
|
||||
* @async
|
||||
* @param {AppConfig | null | undefined} appConfig - The app configuration object.
|
||||
* @returns {Promise<string>} A promise that resolves to the name of the configured provider.
|
||||
*/
|
||||
async function getProvider() {
|
||||
async function getProvider(appConfig) {
|
||||
const ttsService = await createTTSService();
|
||||
return ttsService.getProvider();
|
||||
return ttsService.getProvider(appConfig);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -14,16 +14,18 @@ const { getProvider } = require('./TTSService');
|
||||
*/
|
||||
async function getVoices(req, res) {
|
||||
try {
|
||||
const appConfig = await getAppConfig({
|
||||
role: req.user?.role,
|
||||
});
|
||||
const appConfig =
|
||||
req.config ??
|
||||
(await getAppConfig({
|
||||
role: req.user?.role,
|
||||
}));
|
||||
|
||||
if (!appConfig || !appConfig?.speech?.tts) {
|
||||
const ttsSchema = appConfig?.speech?.tts;
|
||||
if (!ttsSchema) {
|
||||
throw new Error('Configuration or TTS schema is missing');
|
||||
}
|
||||
|
||||
const ttsSchema = appConfig?.speech?.tts;
|
||||
const provider = await getProvider(ttsSchema);
|
||||
const provider = await getProvider(appConfig);
|
||||
let voices;
|
||||
|
||||
switch (provider) {
|
||||
|
||||
@@ -17,7 +17,7 @@ const { Files } = require('~/models');
|
||||
* @param {IUser} options.user - The user object
|
||||
* @param {AppConfig} options.appConfig - The app configuration object
|
||||
* @param {GraphRunnableConfig['configurable']} options.metadata - The metadata
|
||||
* @param {any} options.toolArtifact - The tool artifact containing structured data
|
||||
* @param {{ [Tools.file_search]: { sources: Object[]; fileCitations: boolean } }} options.toolArtifact - The tool artifact containing structured data
|
||||
* @param {string} options.toolCallId - The tool call ID
|
||||
* @returns {Promise<Object|null>} The file search attachment or null
|
||||
*/
|
||||
@@ -29,12 +29,14 @@ async function processFileCitations({ user, appConfig, toolArtifact, toolCallId,
|
||||
|
||||
if (user) {
|
||||
try {
|
||||
const hasFileCitationsAccess = await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
const hasFileCitationsAccess =
|
||||
toolArtifact?.[Tools.file_search]?.fileCitations ??
|
||||
(await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
}));
|
||||
|
||||
if (!hasFileCitationsAccess) {
|
||||
logger.debug(
|
||||
|
||||
@@ -10,9 +10,10 @@ const { getAgent } = require('~/models/Agent');
|
||||
* @param {string} [params.role] - Optional user role to avoid DB query
|
||||
* @param {string[]} params.fileIds - Array of file IDs to check
|
||||
* @param {string} params.agentId - The agent ID that might grant access
|
||||
* @param {boolean} [params.isDelete] - Whether the operation is a delete operation
|
||||
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
|
||||
*/
|
||||
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId }) => {
|
||||
const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId, isDelete }) => {
|
||||
const accessMap = new Map();
|
||||
|
||||
// Initialize all files as no access
|
||||
@@ -44,22 +45,23 @@ const hasAccessToFilesViaAgent = async ({ userId, role, fileIds, agentId }) => {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if user has EDIT permission (which would indicate collaborative access)
|
||||
const hasEditPermission = await checkPermission({
|
||||
userId,
|
||||
role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
if (isDelete) {
|
||||
// Check if user has EDIT permission (which would indicate collaborative access)
|
||||
const hasEditPermission = await checkPermission({
|
||||
userId,
|
||||
role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
// If user only has VIEW permission, they can't access files
|
||||
// Only users with EDIT permission or higher can access agent files
|
||||
if (!hasEditPermission) {
|
||||
return accessMap;
|
||||
// If user only has VIEW permission, they can't access files
|
||||
// Only users with EDIT permission or higher can access agent files
|
||||
if (!hasEditPermission) {
|
||||
return accessMap;
|
||||
}
|
||||
}
|
||||
|
||||
// User has edit permissions - check which files are actually attached
|
||||
const attachedFileIds = new Set();
|
||||
if (agent.tool_resources) {
|
||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
||||
|
||||
@@ -616,7 +616,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
|
||||
if (shouldUseSTT) {
|
||||
const sttService = await STTService.getInstance();
|
||||
const { text, bytes } = await processAudioFile({ file, sttService });
|
||||
const { text, bytes } = await processAudioFile({ req, file, sttService });
|
||||
return await createTextFile({ text, bytes });
|
||||
}
|
||||
|
||||
@@ -646,8 +646,8 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
entity_id,
|
||||
basePath,
|
||||
entity_id,
|
||||
});
|
||||
|
||||
// SECOND: Upload to Vector DB
|
||||
@@ -670,17 +670,18 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
entity_id,
|
||||
basePath,
|
||||
entity_id,
|
||||
});
|
||||
}
|
||||
|
||||
const { bytes, filename, filepath: _filepath, height, width } = storageResult;
|
||||
let { bytes, filename, filepath: _filepath, height, width } = storageResult;
|
||||
// For RAG files, use embedding result; for others, use storage result
|
||||
const embedded =
|
||||
tool_resource === EToolResources.file_search
|
||||
? embeddingResult?.embedded
|
||||
: storageResult.embedded;
|
||||
let embedded = storageResult.embedded;
|
||||
if (tool_resource === EToolResources.file_search) {
|
||||
embedded = embeddingResult?.embedded;
|
||||
filename = embeddingResult?.filename || filename;
|
||||
}
|
||||
|
||||
let filepath = _filepath;
|
||||
|
||||
@@ -929,6 +930,7 @@ async function saveBase64Image(
|
||||
url,
|
||||
{ req, file_id: _file_id, filename: _filename, endpoint, context, resolution },
|
||||
) {
|
||||
const appConfig = req.config;
|
||||
const effectiveResolution = resolution ?? appConfig.fileConfig?.imageGeneration ?? 'high';
|
||||
const file_id = _file_id ?? v4();
|
||||
let filename = `${file_id}-${_filename}`;
|
||||
@@ -943,7 +945,6 @@ async function saveBase64Image(
|
||||
}
|
||||
|
||||
const image = await resizeImageBuffer(inputBuffer, effectiveResolution, endpoint);
|
||||
const appConfig = req.config;
|
||||
const source = getFileStrategy(appConfig, { isImage: true });
|
||||
const { saveBuffer } = getStrategyFunctions(source);
|
||||
const filepath = await saveBuffer({
|
||||
|
||||
@@ -20,8 +20,8 @@ const {
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, getAppConfig } = require('./Config');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -271,6 +271,7 @@ async function createMCPTool({
|
||||
availableTools: tools,
|
||||
}) {
|
||||
const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter);
|
||||
|
||||
const availableTools =
|
||||
tools ?? (await getCachedTools({ userId: req.user?.id, includeGlobal: true }));
|
||||
/** @type {LCTool | undefined} */
|
||||
@@ -537,13 +538,20 @@ async function getServerConnectionStatus(
|
||||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
// connection state overrides specific to OAuth servers
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
// check if server is actively being reconnected
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
|
||||
finalConnectionState = 'connecting';
|
||||
} else {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getOAuthReconnectionManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
@@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
let mockGetMCPManager;
|
||||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
let mockGetOAuthReconnectionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
mockGetMCPManager = require('~/config').getMCPManager;
|
||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
@@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
@@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return failed flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return active flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return no flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => null),
|
||||
@@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
});
|
||||
});
|
||||
|
||||
it('should return connecting state when OAuth server is reconnecting', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager to return true for isReconnecting
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => true),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'connecting',
|
||||
});
|
||||
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not check OAuth flow status when server is connected', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
const axios = require('axios');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { logAxiosError, inputSchema, processModelData } = require('@librechat/api');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { inputSchema, extractBaseURL, processModelData } = require('~/utils');
|
||||
const { OllamaClient } = require('~/app/clients/OllamaClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
|
||||
/**
|
||||
* Splits a string by commas and trims each resulting value.
|
||||
|
||||
@@ -11,8 +11,8 @@ const {
|
||||
getAnthropicModels,
|
||||
} = require('./ModelService');
|
||||
|
||||
jest.mock('~/utils', () => {
|
||||
const originalUtils = jest.requireActual('~/utils');
|
||||
jest.mock('@librechat/api', () => {
|
||||
const originalUtils = jest.requireActual('@librechat/api');
|
||||
return {
|
||||
...originalUtils,
|
||||
processModelData: jest.fn((...args) => {
|
||||
@@ -108,7 +108,7 @@ describe('fetchModels with createTokenConfig true', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
// Clears the mock's history before each test
|
||||
const _utils = require('~/utils');
|
||||
const _utils = require('@librechat/api');
|
||||
axios.get.mockResolvedValue({ data });
|
||||
});
|
||||
|
||||
@@ -120,7 +120,7 @@ describe('fetchModels with createTokenConfig true', () => {
|
||||
createTokenConfig: true,
|
||||
});
|
||||
|
||||
const { processModelData } = require('~/utils');
|
||||
const { processModelData } = require('@librechat/api');
|
||||
expect(processModelData).toHaveBeenCalled();
|
||||
expect(processModelData).toHaveBeenCalledWith(data);
|
||||
});
|
||||
|
||||
@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
};
|
||||
|
||||
const userId = await createUser(userData, true, false);
|
||||
const userId = await createUser(userData, true, true);
|
||||
return userId.toString();
|
||||
}
|
||||
|
||||
|
||||
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
@@ -0,0 +1,26 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Initialize OAuth reconnect manager
|
||||
*/
|
||||
async function initializeOAuthReconnectManager() {
|
||||
try {
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
const tokenMethods = {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
};
|
||||
await createOAuthReconnectionManager(flowManager, tokenMethods);
|
||||
logger.info(`OAuth reconnect manager initialized successfully.`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize OAuth reconnect manager:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = initializeOAuthReconnectManager;
|
||||
@@ -229,7 +229,7 @@
|
||||
>
|
||||
<!--[if mso]><style>.v-button {background: transparent !important;}</style><![endif]-->
|
||||
<div align='left'>
|
||||
<!--[if mso]><v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="href="{{verificationLink}}"" style="height:37px; v-text-anchor:middle; width:114px;" arcsize="11%" stroke="f" fillcolor="#10a37f"><w:anchorlock/><center style="color:#FFFFFF;"><![endif]-->
|
||||
<!--[if mso]><v:roundrect xmlns:v="urn:schemas-microsoft-com:vml" xmlns:w="urn:schemas-microsoft-com:office:word" href="{{verificationLink}}" style="height:37px; v-text-anchor:middle; width:114px;" arcsize="11%" stroke="f" fillcolor="#10a37f"><w:anchorlock/><center style="color:#FFFFFF;"><![endif]-->
|
||||
<a
|
||||
href='{{verificationLink}}'
|
||||
target='_blank'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const fs = require('fs').promises;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getImporter } = require('./importers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Job definition for importing a conversation.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { cloneMessagesWithTimestamps } = require('./fork');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
* Returns the appropriate importer function based on the provided JSON data.
|
||||
@@ -212,6 +212,29 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to find the nearest non-system parent
|
||||
* @param {string} parentId - The ID of the parent message.
|
||||
* @returns {string} The ID of the nearest non-system parent message.
|
||||
*/
|
||||
const findNonSystemParent = (parentId) => {
|
||||
if (!parentId || !messageMap.has(parentId)) {
|
||||
return Constants.NO_PARENT;
|
||||
}
|
||||
|
||||
const parentMapping = conv.mapping[parentId];
|
||||
if (!parentMapping?.message) {
|
||||
return Constants.NO_PARENT;
|
||||
}
|
||||
|
||||
/* If parent is a system message, traverse up to find the nearest non-system parent */
|
||||
if (parentMapping.message.author?.role === 'system') {
|
||||
return findNonSystemParent(parentMapping.parent);
|
||||
}
|
||||
|
||||
return messageMap.get(parentId);
|
||||
};
|
||||
|
||||
// Create and save messages using the mapped IDs
|
||||
const messages = [];
|
||||
for (const [id, mapping] of Object.entries(conv.mapping)) {
|
||||
@@ -220,23 +243,27 @@ function processConversation(conv, importBatchBuilder, requestUserId) {
|
||||
messageMap.delete(id);
|
||||
continue;
|
||||
} else if (role === 'system') {
|
||||
messageMap.delete(id);
|
||||
// Skip system messages but keep their ID in messageMap for parent references
|
||||
continue;
|
||||
}
|
||||
|
||||
const newMessageId = messageMap.get(id);
|
||||
const parentMessageId =
|
||||
mapping.parent && messageMap.has(mapping.parent)
|
||||
? messageMap.get(mapping.parent)
|
||||
: Constants.NO_PARENT;
|
||||
const parentMessageId = findNonSystemParent(mapping.parent);
|
||||
|
||||
const messageText = formatMessageText(mapping.message);
|
||||
|
||||
const isCreatedByUser = role === 'user';
|
||||
let sender = isCreatedByUser ? 'user' : 'GPT-3.5';
|
||||
let sender = isCreatedByUser ? 'user' : 'assistant';
|
||||
const model = mapping.message.metadata.model_slug || openAISettings.model.default;
|
||||
if (model.includes('gpt-4')) {
|
||||
sender = 'GPT-4';
|
||||
|
||||
if (!isCreatedByUser) {
|
||||
/** Extracted model name from model slug */
|
||||
const gptMatch = model.match(/gpt-(.+)/i);
|
||||
if (gptMatch) {
|
||||
sender = `GPT-${gptMatch[1]}`;
|
||||
} else {
|
||||
sender = model || 'assistant';
|
||||
}
|
||||
}
|
||||
|
||||
messages.push({
|
||||
|
||||
@@ -99,6 +99,404 @@ describe('importChatGptConvo', () => {
|
||||
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle system messages without breaking parent-child relationships', async () => {
|
||||
/**
|
||||
* Test data that reproduces message graph "breaking" when it encounters a system message
|
||||
*/
|
||||
const testData = [
|
||||
{
|
||||
title: 'System Message Parent Test',
|
||||
create_time: 1714585031.148505,
|
||||
update_time: 1714585060.879308,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['user-msg-1'],
|
||||
},
|
||||
'user-msg-1': {
|
||||
id: 'user-msg-1',
|
||||
message: {
|
||||
id: 'user-msg-1',
|
||||
author: { role: 'user' },
|
||||
create_time: 1714585031.150442,
|
||||
content: { content_type: 'text', parts: ['First user message'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['assistant-msg-1'],
|
||||
},
|
||||
'assistant-msg-1': {
|
||||
id: 'assistant-msg-1',
|
||||
message: {
|
||||
id: 'assistant-msg-1',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585032.150442,
|
||||
content: { content_type: 'text', parts: ['First assistant response'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'user-msg-1',
|
||||
children: ['system-msg'],
|
||||
},
|
||||
'system-msg': {
|
||||
id: 'system-msg',
|
||||
message: {
|
||||
id: 'system-msg',
|
||||
author: { role: 'system' },
|
||||
create_time: 1714585033.150442,
|
||||
content: { content_type: 'text', parts: ['System message in middle'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'assistant-msg-1',
|
||||
children: ['user-msg-2'],
|
||||
},
|
||||
'user-msg-2': {
|
||||
id: 'user-msg-2',
|
||||
message: {
|
||||
id: 'user-msg-2',
|
||||
author: { role: 'user' },
|
||||
create_time: 1714585034.150442,
|
||||
content: { content_type: 'text', parts: ['Second user message'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'system-msg',
|
||||
children: ['assistant-msg-2'],
|
||||
},
|
||||
'assistant-msg-2': {
|
||||
id: 'assistant-msg-2',
|
||||
message: {
|
||||
id: 'assistant-msg-2',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585035.150442,
|
||||
content: { content_type: 'text', parts: ['Second assistant response'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'user-msg-2',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(testData);
|
||||
await importer(testData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
/** 2 user messages + 2 assistant messages (system message should be skipped) */
|
||||
const expectedMessages = 4;
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedMessages);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
const messageMap = new Map();
|
||||
savedMessages.forEach((msg) => {
|
||||
messageMap.set(msg.text, msg);
|
||||
});
|
||||
|
||||
const firstUser = messageMap.get('First user message');
|
||||
const firstAssistant = messageMap.get('First assistant response');
|
||||
const secondUser = messageMap.get('Second user message');
|
||||
const secondAssistant = messageMap.get('Second assistant response');
|
||||
|
||||
expect(firstUser).toBeDefined();
|
||||
expect(firstAssistant).toBeDefined();
|
||||
expect(secondUser).toBeDefined();
|
||||
expect(secondAssistant).toBeDefined();
|
||||
expect(firstUser.parentMessageId).toBe(Constants.NO_PARENT);
|
||||
expect(firstAssistant.parentMessageId).toBe(firstUser.messageId);
|
||||
|
||||
// This is the key test: second user message should have first assistant as parent
|
||||
// (not NO_PARENT which would indicate the system message broke the chain)
|
||||
expect(secondUser.parentMessageId).toBe(firstAssistant.messageId);
|
||||
expect(secondAssistant.parentMessageId).toBe(secondUser.messageId);
|
||||
});
|
||||
|
||||
it('should maintain correct sender for user messages regardless of GPT-4 model', async () => {
|
||||
/**
|
||||
* Test data with GPT-4 model to ensure user messages keep 'user' sender
|
||||
*/
|
||||
const testData = [
|
||||
{
|
||||
title: 'GPT-4 Sender Test',
|
||||
create_time: 1714585031.148505,
|
||||
update_time: 1714585060.879308,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['user-msg-1'],
|
||||
},
|
||||
'user-msg-1': {
|
||||
id: 'user-msg-1',
|
||||
message: {
|
||||
id: 'user-msg-1',
|
||||
author: { role: 'user' },
|
||||
create_time: 1714585031.150442,
|
||||
content: { content_type: 'text', parts: ['User message with GPT-4'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['assistant-msg-1'],
|
||||
},
|
||||
'assistant-msg-1': {
|
||||
id: 'assistant-msg-1',
|
||||
message: {
|
||||
id: 'assistant-msg-1',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585032.150442,
|
||||
content: { content_type: 'text', parts: ['Assistant response with GPT-4'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'user-msg-1',
|
||||
children: ['user-msg-2'],
|
||||
},
|
||||
'user-msg-2': {
|
||||
id: 'user-msg-2',
|
||||
message: {
|
||||
id: 'user-msg-2',
|
||||
author: { role: 'user' },
|
||||
create_time: 1714585033.150442,
|
||||
content: { content_type: 'text', parts: ['Another user message with GPT-4o-mini'] },
|
||||
metadata: { model_slug: 'gpt-4o-mini' },
|
||||
},
|
||||
parent: 'assistant-msg-1',
|
||||
children: ['assistant-msg-2'],
|
||||
},
|
||||
'assistant-msg-2': {
|
||||
id: 'assistant-msg-2',
|
||||
message: {
|
||||
id: 'assistant-msg-2',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585034.150442,
|
||||
content: { content_type: 'text', parts: ['Assistant response with GPT-3.5'] },
|
||||
metadata: { model_slug: 'gpt-3.5-turbo' },
|
||||
},
|
||||
parent: 'user-msg-2',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(testData);
|
||||
await importer(testData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
const userMsg1 = savedMessages.find((msg) => msg.text === 'User message with GPT-4');
|
||||
const assistantMsg1 = savedMessages.find((msg) => msg.text === 'Assistant response with GPT-4');
|
||||
const userMsg2 = savedMessages.find(
|
||||
(msg) => msg.text === 'Another user message with GPT-4o-mini',
|
||||
);
|
||||
const assistantMsg2 = savedMessages.find(
|
||||
(msg) => msg.text === 'Assistant response with GPT-3.5',
|
||||
);
|
||||
|
||||
expect(userMsg1.sender).toBe('user');
|
||||
expect(userMsg1.isCreatedByUser).toBe(true);
|
||||
expect(userMsg1.model).toBe('gpt-4');
|
||||
|
||||
expect(userMsg2.sender).toBe('user');
|
||||
expect(userMsg2.isCreatedByUser).toBe(true);
|
||||
expect(userMsg2.model).toBe('gpt-4o-mini');
|
||||
|
||||
expect(assistantMsg1.sender).toBe('GPT-4');
|
||||
expect(assistantMsg1.isCreatedByUser).toBe(false);
|
||||
expect(assistantMsg1.model).toBe('gpt-4');
|
||||
|
||||
expect(assistantMsg2.sender).toBe('GPT-3.5-turbo');
|
||||
expect(assistantMsg2.isCreatedByUser).toBe(false);
|
||||
expect(assistantMsg2.model).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
|
||||
it('should correctly extract and format model names from various model slugs', async () => {
|
||||
/**
|
||||
* Test data with various model slugs to test dynamic model identifier extraction
|
||||
*/
|
||||
const testData = [
|
||||
{
|
||||
title: 'Dynamic Model Identifier Test',
|
||||
create_time: 1714585031.148505,
|
||||
update_time: 1714585060.879308,
|
||||
mapping: {
|
||||
'root-node': {
|
||||
id: 'root-node',
|
||||
message: null,
|
||||
parent: null,
|
||||
children: ['msg-1'],
|
||||
},
|
||||
'msg-1': {
|
||||
id: 'msg-1',
|
||||
message: {
|
||||
id: 'msg-1',
|
||||
author: { role: 'user' },
|
||||
create_time: 1714585031.150442,
|
||||
content: { content_type: 'text', parts: ['Test message'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'root-node',
|
||||
children: ['msg-2', 'msg-3', 'msg-4', 'msg-5', 'msg-6', 'msg-7', 'msg-8', 'msg-9'],
|
||||
},
|
||||
'msg-2': {
|
||||
id: 'msg-2',
|
||||
message: {
|
||||
id: 'msg-2',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585032.150442,
|
||||
content: { content_type: 'text', parts: ['GPT-4 response'] },
|
||||
metadata: { model_slug: 'gpt-4' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-3': {
|
||||
id: 'msg-3',
|
||||
message: {
|
||||
id: 'msg-3',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585033.150442,
|
||||
content: { content_type: 'text', parts: ['GPT-4o response'] },
|
||||
metadata: { model_slug: 'gpt-4o' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-4': {
|
||||
id: 'msg-4',
|
||||
message: {
|
||||
id: 'msg-4',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585034.150442,
|
||||
content: { content_type: 'text', parts: ['GPT-4o-mini response'] },
|
||||
metadata: { model_slug: 'gpt-4o-mini' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-5': {
|
||||
id: 'msg-5',
|
||||
message: {
|
||||
id: 'msg-5',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585035.150442,
|
||||
content: { content_type: 'text', parts: ['GPT-3.5-turbo response'] },
|
||||
metadata: { model_slug: 'gpt-3.5-turbo' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-6': {
|
||||
id: 'msg-6',
|
||||
message: {
|
||||
id: 'msg-6',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585036.150442,
|
||||
content: { content_type: 'text', parts: ['GPT-4-turbo response'] },
|
||||
metadata: { model_slug: 'gpt-4-turbo' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-7': {
|
||||
id: 'msg-7',
|
||||
message: {
|
||||
id: 'msg-7',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585037.150442,
|
||||
content: { content_type: 'text', parts: ['GPT-4-1106-preview response'] },
|
||||
metadata: { model_slug: 'gpt-4-1106-preview' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-8': {
|
||||
id: 'msg-8',
|
||||
message: {
|
||||
id: 'msg-8',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585038.150442,
|
||||
content: { content_type: 'text', parts: ['Claude response'] },
|
||||
metadata: { model_slug: 'claude-3-opus' },
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
'msg-9': {
|
||||
id: 'msg-9',
|
||||
message: {
|
||||
id: 'msg-9',
|
||||
author: { role: 'assistant' },
|
||||
create_time: 1714585039.150442,
|
||||
content: { content_type: 'text', parts: ['No model slug response'] },
|
||||
metadata: {},
|
||||
},
|
||||
parent: 'msg-1',
|
||||
children: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
|
||||
const importer = getImporter(testData);
|
||||
await importer(testData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const savedMessages = importBatchBuilder.saveMessage.mock.calls.map((call) => call[0]);
|
||||
|
||||
// Test various GPT model slug formats
|
||||
const gpt4 = savedMessages.find((msg) => msg.text === 'GPT-4 response');
|
||||
expect(gpt4.sender).toBe('GPT-4');
|
||||
expect(gpt4.model).toBe('gpt-4');
|
||||
|
||||
const gpt4o = savedMessages.find((msg) => msg.text === 'GPT-4o response');
|
||||
expect(gpt4o.sender).toBe('GPT-4o');
|
||||
expect(gpt4o.model).toBe('gpt-4o');
|
||||
|
||||
const gpt4oMini = savedMessages.find((msg) => msg.text === 'GPT-4o-mini response');
|
||||
expect(gpt4oMini.sender).toBe('GPT-4o-mini');
|
||||
expect(gpt4oMini.model).toBe('gpt-4o-mini');
|
||||
|
||||
const gpt35Turbo = savedMessages.find((msg) => msg.text === 'GPT-3.5-turbo response');
|
||||
expect(gpt35Turbo.sender).toBe('GPT-3.5-turbo');
|
||||
expect(gpt35Turbo.model).toBe('gpt-3.5-turbo');
|
||||
|
||||
const gpt4Turbo = savedMessages.find((msg) => msg.text === 'GPT-4-turbo response');
|
||||
expect(gpt4Turbo.sender).toBe('GPT-4-turbo');
|
||||
expect(gpt4Turbo.model).toBe('gpt-4-turbo');
|
||||
|
||||
const gpt4Preview = savedMessages.find((msg) => msg.text === 'GPT-4-1106-preview response');
|
||||
expect(gpt4Preview.sender).toBe('GPT-4-1106-preview');
|
||||
expect(gpt4Preview.model).toBe('gpt-4-1106-preview');
|
||||
|
||||
// Test non-GPT model (should use the model slug as sender)
|
||||
const claude = savedMessages.find((msg) => msg.text === 'Claude response');
|
||||
expect(claude.sender).toBe('claude-3-opus');
|
||||
expect(claude.model).toBe('claude-3-opus');
|
||||
|
||||
// Test missing model slug (should default to openAISettings.model.default)
|
||||
const noModel = savedMessages.find((msg) => msg.text === 'No model slug response');
|
||||
// When no model slug is provided, it defaults to gpt-4o-mini which gets formatted to GPT-4o-mini
|
||||
expect(noModel.sender).toBe('GPT-4o-mini');
|
||||
expect(noModel.model).toBe(openAISettings.model.default);
|
||||
|
||||
// Verify user message is unaffected
|
||||
const userMsg = savedMessages.find((msg) => msg.text === 'Test message');
|
||||
expect(userMsg.sender).toBe('user');
|
||||
expect(userMsg.isCreatedByUser).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('importLibreChatConvo', () => {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
const jwt = require('jsonwebtoken');
|
||||
const mongoose = require('mongoose');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Strategy: AppleStrategy } = require('passport-apple');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const socialLogin = require('./socialLogin');
|
||||
const { findUser } = require('~/models');
|
||||
const { User } = require('~/db/models');
|
||||
@@ -17,6 +17,8 @@ jest.mock('@librechat/data-schemas', () => {
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
@@ -24,12 +26,19 @@ jest.mock('./process', () => ({
|
||||
createSocialUser: jest.fn(),
|
||||
handleExistingUser: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/utils', () => ({
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
isEnabled: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
findUser: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('Apple Login Strategy', () => {
|
||||
let mongoServer;
|
||||
@@ -288,7 +297,14 @@ describe('Apple Login Strategy', () => {
|
||||
|
||||
expect(mockVerifyCallback).toHaveBeenCalledWith(null, existingUser);
|
||||
expect(existingUser.avatarUrl).toBeNull(); // As per getProfileDetails
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(existingUser, null);
|
||||
expect(handleExistingUser).toHaveBeenCalledWith(
|
||||
existingUser,
|
||||
null,
|
||||
expect.objectContaining({
|
||||
fileStrategy: 'local',
|
||||
balance: { enabled: false },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle missing idToken gracefully', async () => {
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
const appleLogin = require('./appleStrategy');
|
||||
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
|
||||
const openIdJwtLogin = require('./openIdJwtStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const discordLogin = require('./discordStrategy');
|
||||
const passportLogin = require('./localStrategy');
|
||||
const googleLogin = require('./googleStrategy');
|
||||
const githubLogin = require('./githubStrategy');
|
||||
const discordLogin = require('./discordStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
const { setupSaml } = require('./samlStrategy');
|
||||
const openIdJwtLogin = require('./openIdJwtStrategy');
|
||||
const appleLogin = require('./appleStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
|
||||
module.exports = {
|
||||
appleLogin,
|
||||
|
||||
@@ -4,6 +4,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, getBalanceConfig } = require('@librechat/api');
|
||||
const { SystemRoles, ErrorTypes } = require('librechat-data-provider');
|
||||
const { createUser, findUser, updateUser, countUsers } = require('~/models');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
const {
|
||||
@@ -121,9 +122,18 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
||||
);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(mail, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[LDAP Strategy] Authentication blocked - email domain not allowed [Email: ${mail}]`,
|
||||
);
|
||||
return done(null, false, { message: 'Email domain not allowed' });
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
const isFirstRegisteredUser = (await countUsers()) === 0;
|
||||
const role = isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER;
|
||||
|
||||
user = {
|
||||
provider: 'ldap',
|
||||
ldapId,
|
||||
@@ -133,7 +143,6 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
||||
name: fullName,
|
||||
role,
|
||||
};
|
||||
const appConfig = await getAppConfig({ role: user?.role });
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const userId = await createUser(user, balanceConfig);
|
||||
user._id = userId;
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { updateUser, findUser } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const jwksRsa = require('jwks-rsa');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt');
|
||||
const { isEnabled, findOpenIDUser } = require('@librechat/api');
|
||||
const { updateUser, findUser } = require('~/models');
|
||||
|
||||
/**
|
||||
* @function openIdJwtLogin
|
||||
* @param {import('openid-client').Configuration} openIdConfig - Configuration object for the JWT strategy.
|
||||
@@ -13,6 +14,14 @@ const { isEnabled } = require('~/server/utils');
|
||||
* It uses the jwks-rsa library to retrieve the signing key from a JWKS endpoint.
|
||||
* The strategy extracts the JWT from the Authorization header as a Bearer token.
|
||||
* The JWT is then verified using the signing key, and the user is retrieved from the database.
|
||||
*
|
||||
* Includes email fallback mechanism:
|
||||
* 1. Primary lookup: Search user by openidId (sub claim)
|
||||
* 2. Fallback lookup: If not found, search by email claim
|
||||
* 3. User migration: If found by email without openidId, migrate the user by adding openidId
|
||||
* 4. Provider validation: Ensures users registered with other providers cannot use OpenID
|
||||
*
|
||||
* This enables seamless migration for existing users when SharePoint integration is enabled.
|
||||
*/
|
||||
const openIdJwtLogin = (openIdConfig) => {
|
||||
let jwksRsaOptions = {
|
||||
@@ -34,19 +43,41 @@ const openIdJwtLogin = (openIdConfig) => {
|
||||
},
|
||||
async (payload, done) => {
|
||||
try {
|
||||
const user = await findUser({ openidId: payload?.sub });
|
||||
const { user, error, migration } = await findOpenIDUser({
|
||||
openidId: payload?.sub,
|
||||
email: payload?.email,
|
||||
strategyName: 'openIdJwtLogin',
|
||||
findUser,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
done(null, false, { message: error });
|
||||
return;
|
||||
}
|
||||
|
||||
if (user) {
|
||||
user.id = user._id.toString();
|
||||
|
||||
const updateData = {};
|
||||
if (migration) {
|
||||
updateData.provider = 'openid';
|
||||
updateData.openidId = payload?.sub;
|
||||
}
|
||||
if (!user.role) {
|
||||
user.role = SystemRoles.USER;
|
||||
await updateUser(user.id, { role: user.role });
|
||||
updateData.role = user.role;
|
||||
}
|
||||
|
||||
if (Object.keys(updateData).length > 0) {
|
||||
await updateUser(user.id, updateData);
|
||||
}
|
||||
|
||||
done(null, user);
|
||||
} else {
|
||||
logger.warn(
|
||||
'[openIdJwtLogin] openId JwtStrategy => no user found with the sub claims: ' +
|
||||
payload?.sub,
|
||||
payload?.sub +
|
||||
(payload?.email ? ' or email: ' + payload.email : ''),
|
||||
);
|
||||
done(null, false);
|
||||
}
|
||||
|
||||
@@ -7,8 +7,15 @@ const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { hashToken, logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, ErrorTypes } = require('librechat-data-provider');
|
||||
const { Strategy: OpenIDStrategy } = require('openid-client/passport');
|
||||
const { isEnabled, logHeaders, safeStringify, getBalanceConfig } = require('@librechat/api');
|
||||
const {
|
||||
isEnabled,
|
||||
logHeaders,
|
||||
safeStringify,
|
||||
findOpenIDUser,
|
||||
getBalanceConfig,
|
||||
} = require('@librechat/api');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { findUser, createUser, updateUser } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
@@ -177,7 +184,7 @@ const getUserInfo = async (config, accessToken, sub) => {
|
||||
const exchangedAccessToken = await exchangeAccessTokenIfNeeded(config, accessToken, sub);
|
||||
return await client.fetchUserInfo(config, exchangedAccessToken, sub);
|
||||
} catch (error) {
|
||||
logger.warn(`[openidStrategy] getUserInfo: Error fetching user info: ${error}`);
|
||||
logger.error('[openidStrategy] getUserInfo: Error fetching user info:', error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
@@ -274,6 +281,221 @@ function convertToUsername(input, defaultValue = '') {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process OpenID authentication tokenset and userinfo
|
||||
* This is the core logic extracted from the passport strategy callback
|
||||
* Can be reused by both the passport strategy and proxy authentication
|
||||
*
|
||||
* @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc.
|
||||
* @param {boolean} existingUsersOnly - If true, only existing users will be processed
|
||||
* @returns {Promise<Object>} The authenticated user object with tokenset
|
||||
*/
|
||||
async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
|
||||
const claims = tokenset.claims ? tokenset.claims() : tokenset;
|
||||
const userinfo = {
|
||||
...claims,
|
||||
};
|
||||
|
||||
// Get userinfo from provider if we have access_token and haven't already
|
||||
if (tokenset.access_token) {
|
||||
const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub);
|
||||
Object.assign(userinfo, providerUserinfo);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Auth] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
|
||||
);
|
||||
throw new Error('Email domain not allowed');
|
||||
}
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
openidId: claims.sub || userinfo.sub,
|
||||
email: claims.email || userinfo.email,
|
||||
strategyName: 'openidStrategy',
|
||||
findUser,
|
||||
});
|
||||
let user = result.user;
|
||||
const error = result.error;
|
||||
|
||||
if (error) {
|
||||
throw new Error(ErrorTypes.AUTH_FAILED);
|
||||
}
|
||||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
/** Required role if configured */
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
if (requiredRole) {
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access' && tokenset.access_token) {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id' && tokenset.id_token) {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
} else if (userinfo.roles) {
|
||||
// If roles are already in userinfo, use them directly
|
||||
const roles = Array.isArray(userinfo.roles) ? userinfo.roles : [userinfo.roles];
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new Error(`You must have the "${requiredRole}" role to log in.`);
|
||||
}
|
||||
} else if (requiredRoleParameterPath) {
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[OpenID Auth] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new Error(`You must have the "${requiredRole}" role to log in.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let username = '';
|
||||
if (process.env.OPENID_USERNAME_CLAIM) {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
if (existingUsersOnly && !user) {
|
||||
throw new Error('User does not exist');
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username,
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
// Handle avatar
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
const imageUrl = userinfo.picture;
|
||||
let fileName;
|
||||
if (crypto) {
|
||||
fileName = (await hashToken(userinfo.sub)) + '.png';
|
||||
} else {
|
||||
fileName = userinfo.sub + '.png';
|
||||
}
|
||||
|
||||
const imageBuffer = await downloadImage(
|
||||
imageUrl,
|
||||
openidConfig,
|
||||
tokenset.access_token,
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
buffer: imageBuffer,
|
||||
});
|
||||
user.avatar = imagePath ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[OpenID Auth] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username}`,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return { ...user, tokenset };
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {boolean | undefined} [existingUsersOnly]
|
||||
*/
|
||||
function createOpenIDCallback(existingUsersOnly) {
|
||||
return async (tokenset, done) => {
|
||||
try {
|
||||
const user = await processOpenIDAuth(tokenset, existingUsersOnly);
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
if (err.message === 'Email domain not allowed') {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
if (err.message === ErrorTypes.AUTH_FAILED) {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
if (err.message && err.message.includes('role to log in')) {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up the OpenID strategy specifically for admin authentication.
|
||||
* @param {Configuration} openidConfig
|
||||
*/
|
||||
const setupOpenIdAdmin = (openidConfig) => {
|
||||
try {
|
||||
if (!openidConfig) {
|
||||
throw new Error('OpenID configuration not initialized');
|
||||
}
|
||||
|
||||
const openidAdminLogin = new CustomOpenIDStrategy(
|
||||
{
|
||||
config: openidConfig,
|
||||
scope: process.env.OPENID_SCOPE,
|
||||
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback',
|
||||
},
|
||||
createOpenIDCallback(true),
|
||||
);
|
||||
|
||||
passport.use('openidAdmin', openidAdminLogin);
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] setupOpenIdAdmin', err);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Sets up the OpenID strategy for authentication.
|
||||
* This function configures the OpenID client, handles proxy settings,
|
||||
@@ -311,10 +533,6 @@ async function setupOpenId() {
|
||||
},
|
||||
);
|
||||
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
const usePKCE = isEnabled(process.env.OPENID_USE_PKCE);
|
||||
logger.info(`[openidStrategy] OpenID authentication configuration`, {
|
||||
generateNonce: shouldGenerateNonce,
|
||||
reason: shouldGenerateNonce
|
||||
@@ -328,155 +546,19 @@ async function setupOpenId() {
|
||||
scope: process.env.OPENID_SCOPE,
|
||||
callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL,
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
usePKCE,
|
||||
},
|
||||
async (tokenset, done) => {
|
||||
try {
|
||||
const claims = tokenset.claims();
|
||||
let user = await findUser({ openidId: claims.sub });
|
||||
logger.info(
|
||||
`[openidStrategy] user ${user ? 'found' : 'not found'} with openidId: ${claims.sub}`,
|
||||
);
|
||||
|
||||
if (!user) {
|
||||
user = await findUser({ email: claims.email });
|
||||
logger.info(
|
||||
`[openidStrategy] user ${user ? 'found' : 'not found'} with email: ${
|
||||
claims.email
|
||||
} for openidId: ${claims.sub}`,
|
||||
);
|
||||
}
|
||||
if (user != null && user.provider !== 'openid') {
|
||||
logger.info(
|
||||
`[openidStrategy] Attempted OpenID login by user ${user.email}, was registered with "${user.provider}" provider`,
|
||||
);
|
||||
return done(null, false, {
|
||||
message: ErrorTypes.AUTH_FAILED,
|
||||
});
|
||||
}
|
||||
const userinfo = {
|
||||
...claims,
|
||||
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
|
||||
};
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
if (requiredRole) {
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access') {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id') {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
}
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!roles.includes(requiredRole)) {
|
||||
return done(null, false, {
|
||||
message: `You must have the "${requiredRole}" role to log in.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let username = '';
|
||||
if (process.env.OPENID_USERNAME_CLAIM) {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username,
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
/** @type {string | undefined} */
|
||||
const imageUrl = userinfo.picture;
|
||||
|
||||
let fileName;
|
||||
if (crypto) {
|
||||
fileName = (await hashToken(userinfo.sub)) + '.png';
|
||||
} else {
|
||||
fileName = userinfo.sub + '.png';
|
||||
}
|
||||
|
||||
const imageBuffer = await downloadImage(
|
||||
imageUrl,
|
||||
openidConfig,
|
||||
tokenset.access_token,
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
buffer: imageBuffer,
|
||||
});
|
||||
user.avatar = imagePath ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
done(null, { ...user, tokenset });
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
|
||||
},
|
||||
createOpenIDCallback(),
|
||||
);
|
||||
passport.use('openid', openidLogin);
|
||||
setupOpenIdAdmin(openidConfig);
|
||||
return openidConfig;
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy]', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @function getOpenIdConfig
|
||||
* @description Returns the OpenID client instance.
|
||||
|
||||
@@ -31,6 +31,7 @@ jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
|
||||
@@ -3,7 +3,6 @@ const { FileSources } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { updateUser, createUser, getUserById } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
/**
|
||||
* Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter
|
||||
@@ -12,14 +11,15 @@ const { getAppConfig } = require('~/server/services/Config');
|
||||
*
|
||||
* @param {IUser} oldUser - The existing user object that needs to be updated.
|
||||
* @param {string} avatarUrl - The new avatar URL to be set for the user.
|
||||
* @param {AppConfig} appConfig - The application configuration object.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
* The function updates the user's avatar and saves the user object. It does not return any value.
|
||||
*
|
||||
* @throws {Error} Throws an error if there's an issue saving the updated user object.
|
||||
*/
|
||||
const handleExistingUser = async (oldUser, avatarUrl) => {
|
||||
const fileStrategy = process.env.CDN_PROVIDER;
|
||||
const handleExistingUser = async (oldUser, avatarUrl, appConfig) => {
|
||||
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
|
||||
const isLocal = fileStrategy === FileSources.local;
|
||||
|
||||
let updatedAvatar = false;
|
||||
@@ -56,6 +56,7 @@ const handleExistingUser = async (oldUser, avatarUrl) => {
|
||||
* @param {string} params.providerId - The provider-specific ID of the user.
|
||||
* @param {string} params.username - The username of the new user.
|
||||
* @param {string} params.name - The name of the new user.
|
||||
* @param {AppConfig} appConfig - The application configuration object.
|
||||
* @param {boolean} [params.emailVerified=false] - Optional. Indicates whether the user's email is verified. Defaults to false.
|
||||
*
|
||||
* @returns {Promise<User>}
|
||||
@@ -71,6 +72,7 @@ const createSocialUser = async ({
|
||||
providerId,
|
||||
username,
|
||||
name,
|
||||
appConfig,
|
||||
emailVerified,
|
||||
}) => {
|
||||
const update = {
|
||||
@@ -83,10 +85,9 @@ const createSocialUser = async ({
|
||||
emailVerified,
|
||||
};
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const newUserId = await createUser(update, balanceConfig);
|
||||
const fileStrategy = process.env.CDN_PROVIDER;
|
||||
const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER;
|
||||
const isLocal = fileStrategy === FileSources.local;
|
||||
|
||||
if (!isLocal) {
|
||||
|
||||
@@ -7,6 +7,7 @@ const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { hashToken, logger } = require('@librechat/data-schemas');
|
||||
const { Strategy: SamlStrategy } = require('@node-saml/passport-saml');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { findUser, createUser, updateUser } = require('~/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const paths = require('~/config/paths');
|
||||
@@ -192,16 +193,25 @@ async function setupSaml() {
|
||||
logger.info(`[samlStrategy] SAML authentication received for NameID: ${profile.nameID}`);
|
||||
logger.debug('[samlStrategy] SAML profile:', profile);
|
||||
|
||||
const userEmail = getEmail(profile) || '';
|
||||
const appConfig = await getAppConfig();
|
||||
|
||||
if (!isEmailDomainAllowed(userEmail, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[SAML Strategy] Authentication blocked - email domain not allowed [Email: ${userEmail}]`,
|
||||
);
|
||||
return done(null, false, { message: 'Email domain not allowed' });
|
||||
}
|
||||
|
||||
let user = await findUser({ samlId: profile.nameID });
|
||||
logger.info(
|
||||
`[samlStrategy] User ${user ? 'found' : 'not found'} with SAML ID: ${profile.nameID}`,
|
||||
);
|
||||
|
||||
if (!user) {
|
||||
const email = getEmail(profile) || '';
|
||||
user = await findUser({ email });
|
||||
user = await findUser({ email: userEmail });
|
||||
logger.info(
|
||||
`[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${profile.email}`,
|
||||
`[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${userEmail}`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -225,11 +235,10 @@ async function setupSaml() {
|
||||
provider: 'saml',
|
||||
samlId: profile.nameID,
|
||||
username,
|
||||
email: getEmail(profile) || '',
|
||||
email: userEmail,
|
||||
emailVerified: true,
|
||||
name: fullName,
|
||||
};
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
@@ -250,7 +259,9 @@ async function setupSaml() {
|
||||
fileName = profile.nameID + '.png';
|
||||
}
|
||||
|
||||
const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER);
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
|
||||
@@ -2,6 +2,8 @@ const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { createSocialUser, handleExistingUser } = require('./process');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { findUser } = require('~/models');
|
||||
|
||||
const socialLogin =
|
||||
@@ -12,11 +14,22 @@ const socialLogin =
|
||||
profile,
|
||||
});
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
|
||||
if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[${provider}Login] Authentication blocked - email domain not allowed [Email: ${email}]`,
|
||||
);
|
||||
const error = new Error(ErrorTypes.AUTH_FAILED);
|
||||
error.code = ErrorTypes.AUTH_FAILED;
|
||||
error.message = 'Email domain not allowed';
|
||||
return cb(error);
|
||||
}
|
||||
|
||||
const existingUser = await findUser({ email: email.trim() });
|
||||
const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION);
|
||||
|
||||
if (existingUser?.provider === provider) {
|
||||
await handleExistingUser(existingUser, avatarUrl);
|
||||
await handleExistingUser(existingUser, avatarUrl, appConfig);
|
||||
return cb(null, existingUser);
|
||||
} else if (existingUser) {
|
||||
logger.info(
|
||||
@@ -28,19 +41,29 @@ const socialLogin =
|
||||
return cb(error);
|
||||
}
|
||||
|
||||
if (ALLOW_SOCIAL_REGISTRATION) {
|
||||
const newUser = await createSocialUser({
|
||||
email,
|
||||
avatarUrl,
|
||||
provider,
|
||||
providerKey: `${provider}Id`,
|
||||
providerId: id,
|
||||
username,
|
||||
name,
|
||||
emailVerified,
|
||||
});
|
||||
return cb(null, newUser);
|
||||
const ALLOW_SOCIAL_REGISTRATION = isEnabled(process.env.ALLOW_SOCIAL_REGISTRATION);
|
||||
if (!ALLOW_SOCIAL_REGISTRATION) {
|
||||
logger.error(
|
||||
`[${provider}Login] Registration blocked - social registration is disabled [Email: ${email}]`,
|
||||
);
|
||||
const error = new Error(ErrorTypes.AUTH_FAILED);
|
||||
error.code = ErrorTypes.AUTH_FAILED;
|
||||
error.message = 'Social registration is disabled';
|
||||
return cb(error);
|
||||
}
|
||||
|
||||
const newUser = await createSocialUser({
|
||||
email,
|
||||
avatarUrl,
|
||||
provider,
|
||||
providerKey: `${provider}Id`,
|
||||
providerId: id,
|
||||
username,
|
||||
name,
|
||||
emailVerified,
|
||||
appConfig,
|
||||
});
|
||||
return cb(null, newUser);
|
||||
} catch (err) {
|
||||
logger.error(`[${provider}Login]`, err);
|
||||
return cb(err);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const { z } = require('zod');
|
||||
|
||||
const MIN_PASSWORD_LENGTH = parseInt(process.env.MIN_PASSWORD_LENGTH, 10) || 8;
|
||||
|
||||
const allowedCharactersRegex = new RegExp(
|
||||
'^[' +
|
||||
'a-zA-Z0-9_.@#$%&*()' + // Basic Latin characters and symbols
|
||||
@@ -32,7 +34,7 @@ const loginSchema = z.object({
|
||||
email: z.string().email(),
|
||||
password: z
|
||||
.string()
|
||||
.min(8)
|
||||
.min(MIN_PASSWORD_LENGTH)
|
||||
.max(128)
|
||||
.refine((value) => value.trim().length > 0, {
|
||||
message: 'Password cannot be only spaces',
|
||||
@@ -50,14 +52,14 @@ const registerSchema = z
|
||||
email: z.string().email(),
|
||||
password: z
|
||||
.string()
|
||||
.min(8)
|
||||
.min(MIN_PASSWORD_LENGTH)
|
||||
.max(128)
|
||||
.refine((value) => value.trim().length > 0, {
|
||||
message: 'Password cannot be only spaces',
|
||||
}),
|
||||
confirm_password: z
|
||||
.string()
|
||||
.min(8)
|
||||
.min(MIN_PASSWORD_LENGTH)
|
||||
.max(128)
|
||||
.refine((value) => value.trim().length > 0, {
|
||||
message: 'Password cannot be only spaces',
|
||||
|
||||
@@ -258,7 +258,7 @@ describe('Zod Schemas', () => {
|
||||
email: 'john@example.com',
|
||||
password: 'password123',
|
||||
confirm_password: 'password123',
|
||||
extraField: 'I shouldn\'t be here',
|
||||
extraField: "I shouldn't be here",
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
@@ -407,7 +407,7 @@ describe('Zod Schemas', () => {
|
||||
'john{doe}', // Contains `{` and `}`
|
||||
'j', // Only one character
|
||||
'a'.repeat(81), // More than 80 characters
|
||||
'\' OR \'1\'=\'1\'; --', // SQL Injection
|
||||
"' OR '1'='1'; --", // SQL Injection
|
||||
'{$ne: null}', // MongoDB Injection
|
||||
'<script>alert("XSS")</script>', // Basic XSS
|
||||
'"><script>alert("XSS")</script>', // XSS breaking out of an attribute
|
||||
@@ -453,4 +453,64 @@ describe('Zod Schemas', () => {
|
||||
expect(result).toBe('name: String must contain at least 3 character(s)');
|
||||
});
|
||||
});
|
||||
|
||||
describe('MIN_PASSWORD_LENGTH environment variable', () => {
|
||||
// Note: These tests verify the behavior based on whatever MIN_PASSWORD_LENGTH
|
||||
// was set when the validators module was loaded
|
||||
const minLength = parseInt(process.env.MIN_PASSWORD_LENGTH, 10) || 8;
|
||||
|
||||
it('should respect the configured minimum password length for login', () => {
|
||||
// Test password exactly at minimum length
|
||||
const resultValid = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: 'a'.repeat(minLength),
|
||||
});
|
||||
expect(resultValid.success).toBe(true);
|
||||
|
||||
// Test password one character below minimum
|
||||
if (minLength > 1) {
|
||||
const resultInvalid = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: 'a'.repeat(minLength - 1),
|
||||
});
|
||||
expect(resultInvalid.success).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it('should respect the configured minimum password length for registration', () => {
|
||||
// Test password exactly at minimum length
|
||||
const resultValid = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
password: 'a'.repeat(minLength),
|
||||
confirm_password: 'a'.repeat(minLength),
|
||||
});
|
||||
expect(resultValid.success).toBe(true);
|
||||
|
||||
// Test password one character below minimum
|
||||
if (minLength > 1) {
|
||||
const resultInvalid = registerSchema.safeParse({
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
password: 'a'.repeat(minLength - 1),
|
||||
confirm_password: 'a'.repeat(minLength - 1),
|
||||
});
|
||||
expect(resultInvalid.success).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle edge case of very short minimum password length', () => {
|
||||
// This test is meaningful only if MIN_PASSWORD_LENGTH is set to a very low value
|
||||
if (minLength <= 3) {
|
||||
const result = loginSchema.safeParse({
|
||||
email: 'test@example.com',
|
||||
password: 'abc',
|
||||
});
|
||||
expect(result.success).toBe(minLength <= 3);
|
||||
} else {
|
||||
// Skip this test if minimum length is > 3
|
||||
expect(true).toBe(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -873,6 +873,13 @@
|
||||
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ISession
|
||||
* @typedef {import('@librechat/data-schemas').ISession} ISession
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports IBalance
|
||||
* @typedef {import('@librechat/data-schemas').IBalance} IBalance
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const axios = require('axios');
|
||||
const deriveBaseURL = require('./deriveBaseURL');
|
||||
jest.mock('~/utils', () => {
|
||||
const originalUtils = jest.requireActual('~/utils');
|
||||
jest.mock('@librechat/api', () => {
|
||||
const originalUtils = jest.requireActual('@librechat/api');
|
||||
return {
|
||||
...originalUtils,
|
||||
processModelData: jest.fn((...args) => {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
const tokenHelpers = require('./tokens');
|
||||
const deriveBaseURL = require('./deriveBaseURL');
|
||||
const extractBaseURL = require('./extractBaseURL');
|
||||
const findMessageContent = require('./findMessageContent');
|
||||
@@ -6,6 +5,5 @@ const findMessageContent = require('./findMessageContent');
|
||||
module.exports = {
|
||||
deriveBaseURL,
|
||||
extractBaseURL,
|
||||
...tokenHelpers,
|
||||
findMessageContent,
|
||||
};
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
maxTokensMap,
|
||||
matchModelName,
|
||||
processModelData,
|
||||
getModelMaxTokens,
|
||||
maxOutputTokensMap,
|
||||
findMatchingPattern,
|
||||
getModelMaxTokens,
|
||||
processModelData,
|
||||
matchModelName,
|
||||
maxTokensMap,
|
||||
} = require('./tokens');
|
||||
} = require('@librechat/api');
|
||||
|
||||
describe('getModelMaxTokens', () => {
|
||||
test('should return correct tokens for exact match', () => {
|
||||
@@ -394,7 +394,7 @@ describe('getModelMaxTokens', () => {
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for GPT-5 models', () => {
|
||||
const { getModelMaxOutputTokens } = require('./tokens');
|
||||
const { getModelMaxOutputTokens } = require('@librechat/api');
|
||||
['gpt-5', 'gpt-5-mini', 'gpt-5-nano'].forEach((model) => {
|
||||
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
|
||||
@@ -407,7 +407,7 @@ describe('getModelMaxTokens', () => {
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for GPT-OSS models', () => {
|
||||
const { getModelMaxOutputTokens } = require('./tokens');
|
||||
const { getModelMaxOutputTokens } = require('@librechat/api');
|
||||
['gpt-oss-20b', 'gpt-oss-120b'].forEach((model) => {
|
||||
expect(getModelMaxOutputTokens(model)).toBe(maxOutputTokensMap[EModelEndpoint.openAI][model]);
|
||||
expect(getModelMaxOutputTokens(model, EModelEndpoint.openAI)).toBe(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
/** v0.8.0-rc4 */
|
||||
module.exports = {
|
||||
roots: ['<rootDir>/src'],
|
||||
testEnvironment: 'jsdom',
|
||||
@@ -28,7 +29,8 @@ module.exports = {
|
||||
'jest-file-loader',
|
||||
'^test/(.*)$': '<rootDir>/test/$1',
|
||||
'^~/(.*)$': '<rootDir>/src/$1',
|
||||
'^librechat-data-provider/react-query$': '<rootDir>/../node_modules/librechat-data-provider/src/react-query',
|
||||
'^librechat-data-provider/react-query$':
|
||||
'<rootDir>/../node_modules/librechat-data-provider/src/react-query',
|
||||
},
|
||||
restoreMocks: true,
|
||||
testResultsProcessor: 'jest-junit',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/frontend",
|
||||
"version": "v0.8.0-rc3",
|
||||
"version": "v0.8.0-rc4",
|
||||
"description": "",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
@@ -37,6 +37,7 @@
|
||||
"@headlessui/react": "^2.1.2",
|
||||
"@librechat/client": "*",
|
||||
"@marsidev/react-turnstile": "^1.1.0",
|
||||
"@mcp-ui/client": "^5.7.0",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-alert-dialog": "^1.0.2",
|
||||
"@radix-ui/react-checkbox": "^1.0.3",
|
||||
@@ -147,8 +148,8 @@
|
||||
"tailwindcss": "^3.4.1",
|
||||
"ts-jest": "^29.2.5",
|
||||
"typescript": "^5.3.3",
|
||||
"vite": "^6.3.4",
|
||||
"vite-plugin-compression2": "^1.3.3",
|
||||
"vite": "^6.3.6",
|
||||
"vite-plugin-compression2": "^2.2.1",
|
||||
"vite-plugin-node-polyfills": "^0.23.0",
|
||||
"vite-plugin-pwa": "^0.21.2"
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import React, { createContext, useContext, useState } from 'react';
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import { Constants, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { MCP, Action, TPlugin, AgentToolType } from 'librechat-data-provider';
|
||||
import type { AgentPanelContextType } from '~/common';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider';
|
||||
import { useLocalize, useGetAgentsConfig } from '~/hooks';
|
||||
import type { AgentPanelContextType, MCPServerInfo } from '~/common';
|
||||
import { useAvailableToolsQuery, useGetActionsQuery, useGetStartupConfig } from '~/data-provider';
|
||||
import { useLocalize, useGetAgentsConfig, useMCPConnectionStatus } from '~/hooks';
|
||||
import { Panel } from '~/common';
|
||||
|
||||
type GroupedToolType = AgentToolType & { tools?: AgentToolType[] };
|
||||
type GroupedToolsRecord = Record<string, GroupedToolType>;
|
||||
|
||||
const AgentPanelContext = createContext<AgentPanelContextType | undefined>(undefined);
|
||||
|
||||
export function useAgentPanelContext() {
|
||||
@@ -33,67 +36,117 @@ export function AgentPanelProvider({ children }: { children: React.ReactNode })
|
||||
enabled: !!agent_id,
|
||||
});
|
||||
|
||||
const tools =
|
||||
pluginTools?.map((tool) => ({
|
||||
tool_id: tool.pluginKey,
|
||||
metadata: tool as TPlugin,
|
||||
agent_id: agent_id || '',
|
||||
})) || [];
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const mcpServerNames = useMemo(
|
||||
() => Object.keys(startupConfig?.mcpServers ?? {}),
|
||||
[startupConfig],
|
||||
);
|
||||
|
||||
const { connectionStatus } = useMCPConnectionStatus({
|
||||
enabled: !!agent_id && mcpServerNames.length > 0,
|
||||
});
|
||||
|
||||
const processedData = useMemo(() => {
|
||||
if (!pluginTools) {
|
||||
return {
|
||||
tools: [],
|
||||
groupedTools: {},
|
||||
mcpServersMap: new Map<string, MCPServerInfo>(),
|
||||
};
|
||||
}
|
||||
|
||||
const tools: AgentToolType[] = [];
|
||||
const groupedTools: GroupedToolsRecord = {};
|
||||
|
||||
const configuredServers = new Set(mcpServerNames);
|
||||
const mcpServersMap = new Map<string, MCPServerInfo>();
|
||||
|
||||
for (const pluginTool of pluginTools) {
|
||||
const tool: AgentToolType = {
|
||||
tool_id: pluginTool.pluginKey,
|
||||
metadata: pluginTool as TPlugin,
|
||||
};
|
||||
|
||||
tools.push(tool);
|
||||
|
||||
const groupedTools = tools?.reduce(
|
||||
(acc, tool) => {
|
||||
if (tool.tool_id.includes(Constants.mcp_delimiter)) {
|
||||
const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter);
|
||||
const groupKey = `${serverName.toLowerCase()}`;
|
||||
if (!acc[groupKey]) {
|
||||
acc[groupKey] = {
|
||||
tool_id: groupKey,
|
||||
metadata: {
|
||||
name: `${serverName}`,
|
||||
pluginKey: groupKey,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: tool.metadata.icon || '',
|
||||
} as TPlugin,
|
||||
agent_id: agent_id || '',
|
||||
|
||||
if (!mcpServersMap.has(serverName)) {
|
||||
const metadata = {
|
||||
name: serverName,
|
||||
pluginKey: serverName,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`,
|
||||
icon: pluginTool.icon || '',
|
||||
} as TPlugin;
|
||||
|
||||
mcpServersMap.set(serverName, {
|
||||
serverName,
|
||||
tools: [],
|
||||
};
|
||||
isConfigured: configuredServers.has(serverName),
|
||||
isConnected: connectionStatus?.[serverName]?.connectionState === 'connected',
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
acc[groupKey].tools?.push({
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
});
|
||||
|
||||
mcpServersMap.get(serverName)!.tools.push(tool);
|
||||
} else {
|
||||
acc[tool.tool_id] = {
|
||||
// Non-MCP tool
|
||||
groupedTools[tool.tool_id] = {
|
||||
tool_id: tool.tool_id,
|
||||
metadata: tool.metadata,
|
||||
agent_id: agent_id || '',
|
||||
};
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{} as Record<string, AgentToolType & { tools?: AgentToolType[] }>,
|
||||
);
|
||||
}
|
||||
|
||||
for (const mcpServerName of mcpServerNames) {
|
||||
if (mcpServersMap.has(mcpServerName)) {
|
||||
continue;
|
||||
}
|
||||
const metadata = {
|
||||
icon: '',
|
||||
name: mcpServerName,
|
||||
pluginKey: mcpServerName,
|
||||
description: `${localize('com_ui_tool_collection_prefix')} ${mcpServerName}`,
|
||||
} as TPlugin;
|
||||
|
||||
mcpServersMap.set(mcpServerName, {
|
||||
tools: [],
|
||||
metadata,
|
||||
isConfigured: true,
|
||||
serverName: mcpServerName,
|
||||
isConnected: connectionStatus?.[mcpServerName]?.connectionState === 'connected',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
tools,
|
||||
groupedTools,
|
||||
mcpServersMap,
|
||||
};
|
||||
}, [pluginTools, localize, mcpServerNames, connectionStatus]);
|
||||
|
||||
const { agentsConfig, endpointsConfig } = useGetAgentsConfig();
|
||||
|
||||
const value: AgentPanelContextType = {
|
||||
mcp,
|
||||
mcps,
|
||||
/** Query data for actions and tools */
|
||||
tools,
|
||||
action,
|
||||
setMcp,
|
||||
actions,
|
||||
setMcps,
|
||||
agent_id,
|
||||
setAction,
|
||||
pluginTools,
|
||||
activePanel,
|
||||
groupedTools,
|
||||
agentsConfig,
|
||||
startupConfig,
|
||||
setActivePanel,
|
||||
endpointsConfig,
|
||||
setCurrentAgentId,
|
||||
tools: processedData.tools,
|
||||
groupedTools: processedData.groupedTools,
|
||||
mcpServersMap: processedData.mcpServersMap,
|
||||
};
|
||||
|
||||
return <AgentPanelContext.Provider value={value}>{children}</AgentPanelContext.Provider>;
|
||||
|
||||
@@ -2,7 +2,14 @@ import React, { createContext, useContext, useEffect, useRef } from 'react';
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { TAgentsEndpoint } from 'librechat-data-provider';
|
||||
import { useSearchApiKeyForm, useGetAgentsConfig, useCodeApiKeyForm, useToolToggle } from '~/hooks';
|
||||
import {
|
||||
useMCPServerManager,
|
||||
useSearchApiKeyForm,
|
||||
useGetAgentsConfig,
|
||||
useCodeApiKeyForm,
|
||||
useToolToggle,
|
||||
} from '~/hooks';
|
||||
import { getTimestampedValue, setTimestamp } from '~/utils/timestamps';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
|
||||
interface BadgeRowContextType {
|
||||
@@ -14,6 +21,7 @@ interface BadgeRowContextType {
|
||||
codeInterpreter: ReturnType<typeof useToolToggle>;
|
||||
codeApiKeyForm: ReturnType<typeof useCodeApiKeyForm>;
|
||||
searchApiKeyForm: ReturnType<typeof useSearchApiKeyForm>;
|
||||
mcpServerManager: ReturnType<typeof useMCPServerManager>;
|
||||
}
|
||||
|
||||
const BadgeRowContext = createContext<BadgeRowContextType | undefined>(undefined);
|
||||
@@ -37,10 +45,11 @@ export default function BadgeRowProvider({
|
||||
isSubmitting,
|
||||
conversationId,
|
||||
}: BadgeRowProviderProps) {
|
||||
const hasInitializedRef = useRef(false);
|
||||
const lastKeyRef = useRef<string>('');
|
||||
const hasInitializedRef = useRef(false);
|
||||
const { agentsConfig } = useGetAgentsConfig();
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key));
|
||||
|
||||
/** Initialize ephemeralAgent from localStorage on mount and when conversation changes */
|
||||
@@ -53,16 +62,15 @@ export default function BadgeRowProvider({
|
||||
hasInitializedRef.current = true;
|
||||
lastKeyRef.current = key;
|
||||
|
||||
// Load all localStorage values
|
||||
const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`;
|
||||
const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`;
|
||||
const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`;
|
||||
const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`;
|
||||
|
||||
const codeToggleValue = localStorage.getItem(codeToggleKey);
|
||||
const webSearchToggleValue = localStorage.getItem(webSearchToggleKey);
|
||||
const fileSearchToggleValue = localStorage.getItem(fileSearchToggleKey);
|
||||
const artifactsToggleValue = localStorage.getItem(artifactsToggleKey);
|
||||
const codeToggleValue = getTimestampedValue(codeToggleKey);
|
||||
const webSearchToggleValue = getTimestampedValue(webSearchToggleKey);
|
||||
const fileSearchToggleValue = getTimestampedValue(fileSearchToggleKey);
|
||||
const artifactsToggleValue = getTimestampedValue(artifactsToggleKey);
|
||||
|
||||
const initialValues: Record<string, any> = {};
|
||||
|
||||
@@ -98,15 +106,37 @@ export default function BadgeRowProvider({
|
||||
}
|
||||
}
|
||||
|
||||
// Always set values for all tools (use defaults if not in localStorage)
|
||||
// If ephemeralAgent is null, create a new object with just our tool values
|
||||
setEphemeralAgent((prev) => ({
|
||||
...(prev || {}),
|
||||
/**
|
||||
* Always set values for all tools (use defaults if not in `localStorage`)
|
||||
* If `ephemeralAgent` is `null`, create a new object with just our tool values
|
||||
*/
|
||||
const finalValues = {
|
||||
[Tools.execute_code]: initialValues[Tools.execute_code] ?? false,
|
||||
[Tools.web_search]: initialValues[Tools.web_search] ?? false,
|
||||
[Tools.file_search]: initialValues[Tools.file_search] ?? false,
|
||||
[AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false,
|
||||
};
|
||||
|
||||
setEphemeralAgent((prev) => ({
|
||||
...(prev || {}),
|
||||
...finalValues,
|
||||
}));
|
||||
|
||||
Object.entries(finalValues).forEach(([toolKey, value]) => {
|
||||
if (value !== false) {
|
||||
let storageKey = artifactsToggleKey;
|
||||
if (toolKey === Tools.execute_code) {
|
||||
storageKey = codeToggleKey;
|
||||
} else if (toolKey === Tools.web_search) {
|
||||
storageKey = webSearchToggleKey;
|
||||
} else if (toolKey === Tools.file_search) {
|
||||
storageKey = fileSearchToggleKey;
|
||||
}
|
||||
// Store the value and set timestamp for existing values
|
||||
localStorage.setItem(storageKey, JSON.stringify(value));
|
||||
setTimestamp(storageKey);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [key, isSubmitting, setEphemeralAgent]);
|
||||
|
||||
@@ -156,6 +186,8 @@ export default function BadgeRowProvider({
|
||||
isAuthenticated: true,
|
||||
});
|
||||
|
||||
const mcpServerManager = useMCPServerManager({ conversationId });
|
||||
|
||||
const value: BadgeRowContextType = {
|
||||
webSearch,
|
||||
artifacts,
|
||||
@@ -165,6 +197,7 @@ export default function BadgeRowProvider({
|
||||
codeApiKeyForm,
|
||||
codeInterpreter,
|
||||
searchApiKeyForm,
|
||||
mcpServerManager,
|
||||
};
|
||||
|
||||
return <BadgeRowContext.Provider value={value}>{children}</BadgeRowContext.Provider>;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user