Compare commits
210 Commits
refactor/o
...
feat/price
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b034624690 | ||
|
|
adff605c50 | ||
|
|
465c81adee | ||
|
|
cb8e76e27e | ||
|
|
4d9e17efe1 | ||
|
|
95ebef13df | ||
|
|
4fb9d7bdff | ||
|
|
0edfecf44a | ||
|
|
ba8c09b361 | ||
|
|
794fe6fd11 | ||
|
|
97ac52fc6c | ||
|
|
1a947607a5 | ||
|
|
1745708418 | ||
|
|
14aedac1e1 | ||
|
|
a820d79bfc | ||
|
|
3b1c07ff46 | ||
|
|
c1b0f13360 | ||
|
|
637bbd2e29 | ||
|
|
30e1b421ba | ||
|
|
fb89f60470 | ||
|
|
5245aeea8f | ||
|
|
dd93db40bc | ||
|
|
136cf1d5a8 | ||
|
|
751522087a | ||
|
|
7fe830acfc | ||
|
|
cdfe686987 | ||
|
|
5b5723343c | ||
|
|
30c24a66f6 | ||
|
|
ecf9733bc1 | ||
|
|
133312fb40 | ||
|
|
b62ffb533c | ||
|
|
d75fb76338 | ||
|
|
51f2d43fed | ||
|
|
e3a645e8fb | ||
|
|
180046a3c5 | ||
|
|
916742ab9d | ||
|
|
d91f34dd42 | ||
|
|
5676976564 | ||
|
|
85aa3e7d9c | ||
|
|
a2ff6613c5 | ||
|
|
8d6cb5eee0 | ||
|
|
31445e391a | ||
|
|
04c3a5a861 | ||
|
|
5667cc9702 | ||
|
|
c0f95f971a | ||
|
|
f125f5bd32 | ||
|
|
f3eca8c7a7 | ||
|
|
f22e5f965e | ||
|
|
749f539dfc | ||
|
|
1247207afe | ||
|
|
5c0e9d8fbb | ||
|
|
957fa7a994 | ||
|
|
751c2e1d17 | ||
|
|
519645c0b0 | ||
|
|
0d0a318c3c | ||
|
|
588e0c4611 | ||
|
|
79144a6365 | ||
|
|
ca53c20370 | ||
|
|
d635503f49 | ||
|
|
920966f895 | ||
|
|
c46e0d3ecc | ||
|
|
c6ecf0095b | ||
|
|
7de6f6e44c | ||
|
|
035f85c3ba | ||
|
|
6f6a34d126 | ||
|
|
fff1f1cf27 | ||
|
|
1869854d70 | ||
|
|
4dd2998592 | ||
|
|
a4a174b3dc | ||
|
|
65c83317aa | ||
|
|
e95e0052da | ||
|
|
0ecafcd38e | ||
|
|
cadfe14abe | ||
|
|
75dd6fb28b | ||
|
|
eef93024d5 | ||
|
|
cd73cb0b3e | ||
|
|
e705b09280 | ||
|
|
23bd4dfbfd | ||
|
|
df17582103 | ||
|
|
d79b80a4bf | ||
|
|
45da421e7d | ||
|
|
122ff416ac | ||
|
|
b66bf93b31 | ||
|
|
6d791e3e12 | ||
|
|
f9b12517b0 | ||
|
|
195e1e9eb2 | ||
|
|
47aa90df1d | ||
|
|
460eac36f6 | ||
|
|
3a47deac07 | ||
|
|
49e8443ec5 | ||
|
|
d16f93b5f7 | ||
|
|
20b29bbfa6 | ||
|
|
e2a6937ca6 | ||
|
|
005a0cb84a | ||
|
|
beabe38311 | ||
|
|
62315be197 | ||
|
|
a26597a696 | ||
|
|
8772b04d1d | ||
|
|
7742b18c9c | ||
|
|
b75b799e34 | ||
|
|
43add11b05 | ||
|
|
1764de53a5 | ||
|
|
c0511b9a5f | ||
|
|
2483623c88 | ||
|
|
229d6f2dfe | ||
|
|
d5ec838218 | ||
|
|
15d7a3d221 | ||
|
|
c3e88b97c8 | ||
|
|
ba424666f8 | ||
|
|
ea3b671182 | ||
|
|
f209f616c9 | ||
|
|
961af515d5 | ||
|
|
a362963017 | ||
|
|
78d735f35c | ||
|
|
48f6f8f2f8 | ||
|
|
74bc0440f0 | ||
|
|
18d5a75cdc | ||
|
|
a820863e8b | ||
|
|
9a210971f5 | ||
|
|
e1ad235f17 | ||
|
|
4a0b329e3e | ||
|
|
a22359de5e | ||
|
|
bbfe4002eb | ||
|
|
94426a3cae | ||
|
|
e559f0f4dc | ||
|
|
15c9c7e1f4 | ||
|
|
ac641e7cba | ||
|
|
1915d7b195 | ||
|
|
c2f4b383f2 | ||
|
|
939af59950 | ||
|
|
7d08da1a8a | ||
|
|
543b617e1c | ||
|
|
c827fdd10e | ||
|
|
ac608ded46 | ||
|
|
0e00f357a6 | ||
|
|
c465d7b732 | ||
|
|
aba0a93d1d | ||
|
|
a49b2b2833 | ||
|
|
e0ebb7097e | ||
|
|
9a79635012 | ||
|
|
ce19abc968 | ||
|
|
49cd3894aa | ||
|
|
da4aa37493 | ||
|
|
5a14ee9c6a | ||
|
|
3394aa5030 | ||
|
|
cee0579e0e | ||
|
|
d6c173c94b | ||
|
|
beff848a3f | ||
|
|
b9bc3123d6 | ||
|
|
639c7ad6ad | ||
|
|
822e2310ce | ||
|
|
2a0a8f6beb | ||
|
|
a6fd32a15a | ||
|
|
80a1a57fde | ||
|
|
3576391482 | ||
|
|
55557f7cc8 | ||
|
|
d7d02766ea | ||
|
|
627f0bffe5 | ||
|
|
8d1d95371f | ||
|
|
8bcdc041b2 | ||
|
|
9b6395d955 | ||
|
|
ad1503abdc | ||
|
|
a6d7ebf22e | ||
|
|
cebf140bce | ||
|
|
cc0cf359a2 | ||
|
|
3547873bc4 | ||
|
|
50b7bd6643 | ||
|
|
81186312ef | ||
|
|
4ec7bcb60f | ||
|
|
c78fd0fc83 | ||
|
|
d711fc7852 | ||
|
|
6af7efd0f4 | ||
|
|
d57e7aec73 | ||
|
|
e4e25aaf2b | ||
|
|
e8ddd279fd | ||
|
|
b742c8c7f9 | ||
|
|
803ade8601 | ||
|
|
dcd96c29c5 | ||
|
|
53c31b85d0 | ||
|
|
d07c2b3475 | ||
|
|
a434d28579 | ||
|
|
d82a63642d | ||
|
|
9585db14ba | ||
|
|
c191af6c9b | ||
|
|
39346d6b8e | ||
|
|
28d63dab71 | ||
|
|
49d1cefe71 | ||
|
|
0262c25989 | ||
|
|
90b037a67f | ||
|
|
fc8fd489d6 | ||
|
|
81b32e400a | ||
|
|
ae732b2ebc | ||
|
|
7e7e75714e | ||
|
|
ff54cbffd9 | ||
|
|
74e029e78f | ||
|
|
75324e1c7e | ||
|
|
949682ef0f | ||
|
|
66bd419baa | ||
|
|
aa42759ffd | ||
|
|
52e59e40be | ||
|
|
a955097faf | ||
|
|
b6413b06bc | ||
|
|
e6cebdf2b6 | ||
|
|
3eb6debe6a | ||
|
|
8780a78165 | ||
|
|
9dbf153489 | ||
|
|
4799593e1a | ||
|
|
a199b87478 | ||
|
|
007570b5c6 | ||
|
|
5943d5346c |
28
.env.example
28
.env.example
@@ -40,6 +40,13 @@ NO_INDEX=true
|
||||
# Defaulted to 1.
|
||||
TRUST_PROXY=1
|
||||
|
||||
# Minimum password length for user authentication
|
||||
# Default: 8
|
||||
# Note: When using LDAP authentication, you may want to set this to 1
|
||||
# to bypass local password validation, as LDAP servers handle their own
|
||||
# password policies.
|
||||
# MIN_PASSWORD_LENGTH=8
|
||||
|
||||
#===============#
|
||||
# JSON Logging #
|
||||
#===============#
|
||||
@@ -660,6 +667,10 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
# REDIS_URI=rediss://127.0.0.1:6380
|
||||
# REDIS_CA=/path/to/ca-cert.pem
|
||||
|
||||
# Elasticache may need to use an alternate dnsLookup for TLS connections. see "Special Note: Aws Elasticache Clusters with TLS" on this webpage: https://www.npmjs.com/package/ioredis
|
||||
# Enable alternative dnsLookup for redis
|
||||
# REDIS_USE_ALTERNATIVE_DNS_LOOKUP=true
|
||||
|
||||
# Redis authentication (if required)
|
||||
# REDIS_USERNAME=your_redis_username
|
||||
# REDIS_PASSWORD=your_redis_password
|
||||
@@ -679,8 +690,8 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
# REDIS_PING_INTERVAL=300
|
||||
|
||||
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
|
||||
# Comma-separated list of CacheKeys (e.g., STATIC_CONFIG,ROLES,MESSAGES)
|
||||
# FORCED_IN_MEMORY_CACHE_NAMESPACES=STATIC_CONFIG,ROLES
|
||||
# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES)
|
||||
# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES
|
||||
|
||||
#==================================================#
|
||||
# Others #
|
||||
@@ -742,3 +753,16 @@ OPENWEATHER_API_KEY=
|
||||
# JINA_API_KEY=your_jina_api_key
|
||||
# or
|
||||
# COHERE_API_KEY=your_cohere_api_key
|
||||
|
||||
#======================#
|
||||
# MCP Configuration #
|
||||
#======================#
|
||||
|
||||
# Treat 401/403 responses as OAuth requirement when no oauth metadata found
|
||||
# MCP_OAUTH_ON_AUTH_ERROR=true
|
||||
|
||||
# Timeout for OAuth detection requests in milliseconds
|
||||
# MCP_OAUTH_DETECTION_TIMEOUT=5000
|
||||
|
||||
# Cache connection status checks for this many milliseconds to avoid expensive verification
|
||||
# MCP_CONNECTION_CHECK_TTL=60000
|
||||
|
||||
4
.github/CONTRIBUTING.md
vendored
4
.github/CONTRIBUTING.md
vendored
@@ -147,7 +147,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||
## 8. Module Import Conventions
|
||||
|
||||
- `npm` packages first,
|
||||
- from shortest line (top) to longest (bottom)
|
||||
- from longest line (top) to shortest (bottom)
|
||||
|
||||
- Followed by typescript types (pertains to data-provider and client workspaces)
|
||||
- longest line (top) to shortest (bottom)
|
||||
@@ -157,6 +157,8 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||
- longest line (top) to shortest (bottom)
|
||||
- imports with alias `~` treated the same as relative import with respect to line length
|
||||
|
||||
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
|
||||
|
||||
---
|
||||
|
||||
Please ensure that you adapt this summary to fit the specific context and nuances of your project.
|
||||
|
||||
12
.github/workflows/data-provider.yml
vendored
12
.github/workflows/data-provider.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Node.js Package
|
||||
name: Publish `librechat-data-provider` to NPM
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -6,6 +6,12 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- 'packages/data-provider/package.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
required: false
|
||||
default: 'Manual publish requested'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -14,7 +20,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 16
|
||||
node-version: 20
|
||||
- run: cd packages/data-provider && npm ci
|
||||
- run: cd packages/data-provider && npm run build
|
||||
|
||||
@@ -25,7 +31,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 16
|
||||
node-version: 20
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
- run: cd packages/data-provider && npm ci
|
||||
- run: cd packages/data-provider && npm run build
|
||||
|
||||
33
.github/workflows/i18n-unused-keys.yml
vendored
33
.github/workflows/i18n-unused-keys.yml
vendored
@@ -1,5 +1,10 @@
|
||||
name: Detect Unused i18next Strings
|
||||
|
||||
# This workflow checks for unused i18n keys in translation files.
|
||||
# It has special handling for:
|
||||
# - com_ui_special_var_* keys that are dynamically constructed
|
||||
# - com_agents_category_* keys that are stored in the database and used dynamically
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
@@ -7,6 +12,7 @@ on:
|
||||
- "api/**"
|
||||
- "packages/data-provider/src/**"
|
||||
- "packages/client/**"
|
||||
- "packages/data-schemas/src/**"
|
||||
|
||||
jobs:
|
||||
detect-unused-i18n-keys:
|
||||
@@ -24,7 +30,7 @@ jobs:
|
||||
|
||||
# Define paths
|
||||
I18N_FILE="client/src/locales/en/translation.json"
|
||||
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src" "packages/client")
|
||||
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src" "packages/client" "packages/data-schemas/src")
|
||||
|
||||
# Check if translation file exists
|
||||
if [[ ! -f "$I18N_FILE" ]]; then
|
||||
@@ -52,6 +58,31 @@ jobs:
|
||||
fi
|
||||
done
|
||||
|
||||
# Also check if the key is directly used somewhere
|
||||
if [[ "$FOUND" == false ]]; then
|
||||
for DIR in "${SOURCE_DIRS[@]}"; do
|
||||
if grep -r --include=\*.{js,jsx,ts,tsx} -q "$KEY" "$DIR"; then
|
||||
FOUND=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
fi
|
||||
# Special case for agent category keys that are dynamically used from database
|
||||
elif [[ "$KEY" == com_agents_category_* ]]; then
|
||||
# Check if agent category localization is being used
|
||||
for DIR in "${SOURCE_DIRS[@]}"; do
|
||||
# Check for dynamic category label/description usage
|
||||
if grep -r --include=\*.{js,jsx,ts,tsx} -E "category\.(label|description).*startsWith.*['\"]com_" "$DIR" > /dev/null 2>&1 || \
|
||||
# Check for the method that defines these keys
|
||||
grep -r --include=\*.{js,jsx,ts,tsx} "ensureDefaultCategories" "$DIR" > /dev/null 2>&1 || \
|
||||
# Check for direct usage in agentCategory.ts
|
||||
grep -r --include=\*.ts -E "label:.*['\"]$KEY['\"]" "$DIR" > /dev/null 2>&1 || \
|
||||
grep -r --include=\*.ts -E "description:.*['\"]$KEY['\"]" "$DIR" > /dev/null 2>&1; then
|
||||
FOUND=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Also check if the key is directly used somewhere
|
||||
if [[ "$FOUND" == false ]]; then
|
||||
for DIR in "${SOURCE_DIRS[@]}"; do
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -137,3 +137,4 @@ helm/**/.values.yaml
|
||||
/.openai/
|
||||
/.tabnine/
|
||||
/.codeium
|
||||
*.local.md
|
||||
|
||||
21
Dockerfile
21
Dockerfile
@@ -1,4 +1,4 @@
|
||||
# v0.8.0-rc2
|
||||
# v0.8.0-rc4
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
@@ -19,24 +19,31 @@ WORKDIR /app
|
||||
|
||||
USER node
|
||||
|
||||
COPY --chown=node:node . .
|
||||
COPY --chown=node:node package.json package-lock.json ./
|
||||
COPY --chown=node:node api/package.json ./api/package.json
|
||||
COPY --chown=node:node client/package.json ./client/package.json
|
||||
COPY --chown=node:node packages/data-provider/package.json ./packages/data-provider/package.json
|
||||
COPY --chown=node:node packages/data-schemas/package.json ./packages/data-schemas/package.json
|
||||
COPY --chown=node:node packages/api/package.json ./packages/api/package.json
|
||||
|
||||
RUN \
|
||||
# Allow mounting of these files, which have no default
|
||||
touch .env ; \
|
||||
# Create directories for the volumes to inherit the correct permissions
|
||||
mkdir -p /app/client/public/images /app/api/logs ; \
|
||||
mkdir -p /app/client/public/images /app/api/logs /app/uploads ; \
|
||||
npm config set fetch-retry-maxtimeout 600000 ; \
|
||||
npm config set fetch-retries 5 ; \
|
||||
npm config set fetch-retry-mintimeout 15000 ; \
|
||||
npm install --no-audit; \
|
||||
npm ci --no-audit
|
||||
|
||||
COPY --chown=node:node . .
|
||||
|
||||
RUN \
|
||||
# React client build
|
||||
NODE_OPTIONS="--max-old-space-size=2048" npm run frontend; \
|
||||
npm prune --production; \
|
||||
npm cache clean --force
|
||||
|
||||
RUN mkdir -p /app/client/public/images /app/api/logs
|
||||
|
||||
# Node API setup
|
||||
EXPOSE 3080
|
||||
ENV HOST=0.0.0.0
|
||||
@@ -47,4 +54,4 @@ CMD ["npm", "run", "backend"]
|
||||
# WORKDIR /usr/share/nginx/html
|
||||
# COPY --from=node /app/client/dist /usr/share/nginx/html
|
||||
# COPY client/nginx.conf /etc/nginx/conf.d/default.conf
|
||||
# ENTRYPOINT ["nginx", "-g", "daemon off;"]
|
||||
# ENTRYPOINT ["nginx", "-g", "daemon off;"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.8.0-rc2
|
||||
# v0.8.0-rc4
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
|
||||
16
README.md
16
README.md
@@ -65,14 +65,17 @@
|
||||
|
||||
- 🔦 **Agents & Tools Integration**:
|
||||
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
||||
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers
|
||||
- Agent Marketplace: Discover and deploy community-built agents
|
||||
- Collaborative Sharing: Share agents with specific users and groups
|
||||
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
|
||||
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
||||
|
||||
- 🔍 **Web Search**:
|
||||
- Search the internet and retrieve relevant information to enhance your AI context
|
||||
- Combines search providers, content scrapers, and result rerankers for optimal results
|
||||
- **Customizable Jina Reranking**: Configure custom Jina API URLs for reranking services
|
||||
- **[Learn More →](https://www.librechat.ai/docs/features/web_search)**
|
||||
|
||||
- 🪄 **Generative UI with Code Artifacts**:
|
||||
@@ -87,15 +90,18 @@
|
||||
- Create, Save, & Share Custom Presets
|
||||
- Switch between AI Endpoints and Presets mid-chat
|
||||
- Edit, Resubmit, and Continue Messages with Conversation branching
|
||||
- Create and share prompts with specific users and groups
|
||||
- [Fork Messages & Conversations](https://www.librechat.ai/docs/features/fork) for Advanced Context control
|
||||
|
||||
- 💬 **Multimodal & File Interactions**:
|
||||
- Upload and analyze images with Claude 3, GPT-4.5, GPT-4o, o1, Llama-Vision, and Gemini 📸
|
||||
- Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, & Google 🗃️
|
||||
|
||||
- 🌎 **Multilingual UI**:
|
||||
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro
|
||||
- Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands, עברית
|
||||
- 🌎 **Multilingual UI**:
|
||||
- English, 中文 (简体), 中文 (繁體), العربية, Deutsch, Español, Français, Italiano
|
||||
- Polski, Português (PT), Português (BR), Русский, 日本語, Svenska, 한국어, Tiếng Việt
|
||||
- Türkçe, Nederlands, עברית, Català, Čeština, Dansk, Eesti, فارسی
|
||||
- Suomi, Magyar, Հայերեն, Bahasa Indonesia, ქართული, Latviešu, ไทย, ئۇيغۇرچە
|
||||
|
||||
- 🧠 **Reasoning UI**:
|
||||
- Dynamic Reasoning UI for Chain-of-Thought/Reasoning AI models like DeepSeek-R1
|
||||
|
||||
@@ -10,7 +10,17 @@ const {
|
||||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
|
||||
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
|
||||
const {
|
||||
Tokenizer,
|
||||
createFetch,
|
||||
matchModelName,
|
||||
getClaudeHeaders,
|
||||
getModelMaxTokens,
|
||||
configureReasoning,
|
||||
checkPromptCacheSupport,
|
||||
getModelMaxOutputTokens,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
truncateText,
|
||||
formatMessage,
|
||||
@@ -19,12 +29,6 @@ const {
|
||||
parseParamFromPrompt,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const {
|
||||
getClaudeHeaders,
|
||||
configureReasoning,
|
||||
checkPromptCacheSupport,
|
||||
} = require('~/server/services/Endpoints/anthropic/helpers');
|
||||
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { sleep } = require('~/server/utils');
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const crypto = require('crypto');
|
||||
const fetch = require('node-fetch');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getBalanceConfig } = require('@librechat/api');
|
||||
const {
|
||||
supportsBalanceCheck,
|
||||
isAgentsEndpoint,
|
||||
@@ -15,7 +17,6 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
@@ -37,6 +38,8 @@ class BaseClient {
|
||||
this.conversationId;
|
||||
/** @type {string} */
|
||||
this.responseMessageId;
|
||||
/** @type {string} */
|
||||
this.parentMessageId;
|
||||
/** @type {TAttachment[]} */
|
||||
this.attachments;
|
||||
/** The key for the usage object's input tokens
|
||||
@@ -110,13 +113,15 @@ class BaseClient {
|
||||
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
|
||||
* Should only be used if `recordCollectedUsage` was not used instead.
|
||||
* @param {string} [model]
|
||||
* @param {AppConfig['balance']} [balance]
|
||||
* @param {number} promptTokens
|
||||
* @param {number} completionTokens
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ model, promptTokens, completionTokens }) {
|
||||
async recordTokenUsage({ model, balance, promptTokens, completionTokens }) {
|
||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
||||
model,
|
||||
balance,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
@@ -185,7 +190,8 @@ class BaseClient {
|
||||
this.user = user;
|
||||
const saveOptions = this.getSaveOptions();
|
||||
this.abortController = opts.abortController ?? new AbortController();
|
||||
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||
const requestConvoId = overrideConvoId ?? opts.conversationId;
|
||||
const conversationId = requestConvoId ?? crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||
const userMessageId =
|
||||
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
@@ -210,11 +216,12 @@ class BaseClient {
|
||||
...opts,
|
||||
user,
|
||||
head,
|
||||
saveOptions,
|
||||
userMessageId,
|
||||
requestConvoId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
userMessageId,
|
||||
responseMessageId,
|
||||
saveOptions,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -226,6 +233,7 @@ class BaseClient {
|
||||
sender: 'User',
|
||||
text,
|
||||
isCreatedByUser: true,
|
||||
targetModel: this.modelOptions?.model ?? this.model,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -233,11 +241,12 @@ class BaseClient {
|
||||
const {
|
||||
user,
|
||||
head,
|
||||
saveOptions,
|
||||
userMessageId,
|
||||
requestConvoId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
userMessageId,
|
||||
responseMessageId,
|
||||
saveOptions,
|
||||
} = await this.setMessageOptions(opts);
|
||||
|
||||
const userMessage = opts.isEdited
|
||||
@@ -259,7 +268,8 @@ class BaseClient {
|
||||
}
|
||||
|
||||
if (typeof opts?.onStart === 'function') {
|
||||
opts.onStart(userMessage, responseMessageId);
|
||||
const isNewConvo = !requestConvoId && parentMessageId === Constants.NO_PARENT;
|
||||
opts.onStart(userMessage, responseMessageId, isNewConvo);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -565,6 +575,7 @@ class BaseClient {
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
const appConfig = this.options.req?.config;
|
||||
/** @type {Promise<TMessage>} */
|
||||
let userMessagePromise;
|
||||
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
||||
@@ -614,15 +625,19 @@ class BaseClient {
|
||||
this.currentMessages.push(userMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
|
||||
* this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
|
||||
*/
|
||||
const parentMessageId = isEdited ? head : userMessage.messageId;
|
||||
this.parentMessageId = parentMessageId;
|
||||
let {
|
||||
prompt: payload,
|
||||
tokenCountMap,
|
||||
promptTokens,
|
||||
} = await this.buildMessages(
|
||||
this.currentMessages,
|
||||
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
|
||||
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
|
||||
isEdited ? head : userMessage.messageId,
|
||||
parentMessageId,
|
||||
this.getBuildMessagesOptions(opts),
|
||||
opts,
|
||||
);
|
||||
@@ -647,9 +662,9 @@ class BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
const balance = this.options.req?.app?.locals?.balance;
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
if (
|
||||
balance?.enabled &&
|
||||
balanceConfig?.enabled &&
|
||||
supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint]
|
||||
) {
|
||||
await checkBalance({
|
||||
@@ -748,6 +763,7 @@ class BaseClient {
|
||||
usage,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
balance: balanceConfig,
|
||||
model: responseMessage.model,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { google } = require('googleapis');
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { Tokenizer, getSafetySettings } = require('@librechat/api');
|
||||
@@ -21,7 +22,6 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const {
|
||||
|
||||
@@ -7,7 +7,9 @@ const {
|
||||
createFetch,
|
||||
resolveHeaders,
|
||||
constructAzureURL,
|
||||
getModelMaxTokens,
|
||||
genAzureChatCompletion,
|
||||
getModelMaxOutputTokens,
|
||||
createStreamEventHandlers,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
@@ -31,16 +33,16 @@ const {
|
||||
titleInstruction,
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const { createLLM, RunManager } = require('./llm');
|
||||
const { summaryBuffer } = require('./memory');
|
||||
const { runTitleChain } = require('./chains');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { tokenSplit } = require('./document');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { createLLM } = require('./llm');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class OpenAIClient extends BaseClient {
|
||||
@@ -618,10 +620,6 @@ class OpenAIClient extends BaseClient {
|
||||
temperature = 0.2,
|
||||
max_tokens,
|
||||
streaming,
|
||||
context,
|
||||
tokenBuffer,
|
||||
initialMessageCount,
|
||||
conversationId,
|
||||
}) {
|
||||
const modelOptions = {
|
||||
modelName: modelName ?? model,
|
||||
@@ -653,8 +651,10 @@ class OpenAIClient extends BaseClient {
|
||||
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
|
||||
configOptions.baseOptions = {
|
||||
headers: resolveHeaders({
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
headers: {
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
},
|
||||
}),
|
||||
};
|
||||
}
|
||||
@@ -664,22 +664,12 @@ class OpenAIClient extends BaseClient {
|
||||
configOptions.httpsAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
const { req, res, debug } = this.options;
|
||||
const runManager = new RunManager({ req, res, debug, abortController: this.abortController });
|
||||
this.runManager = runManager;
|
||||
|
||||
const llm = createLLM({
|
||||
modelOptions,
|
||||
configOptions,
|
||||
openAIApiKey: this.apiKey,
|
||||
azure: this.azure,
|
||||
streaming,
|
||||
callbacks: runManager.createCallbacks({
|
||||
context,
|
||||
tokenBuffer,
|
||||
conversationId: this.conversationId ?? conversationId,
|
||||
initialMessageCount,
|
||||
}),
|
||||
});
|
||||
|
||||
return llm;
|
||||
@@ -700,6 +690,7 @@ class OpenAIClient extends BaseClient {
|
||||
* In case of failure, it will return the default title, "New Chat".
|
||||
*/
|
||||
async titleConvo({ text, conversationId, responseText = '' }) {
|
||||
const appConfig = this.options.req?.config;
|
||||
this.conversationId = conversationId;
|
||||
|
||||
if (this.options.attachments) {
|
||||
@@ -728,8 +719,7 @@ class OpenAIClient extends BaseClient {
|
||||
max_tokens: 16,
|
||||
};
|
||||
|
||||
/** @type {TAzureConfig | undefined} */
|
||||
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
|
||||
const azureConfig = appConfig?.endpoints?.[EModelEndpoint.azureOpenAI];
|
||||
|
||||
const resetTitleOptions = !!(
|
||||
(this.azure && azureConfig) ||
|
||||
@@ -749,7 +739,7 @@ class OpenAIClient extends BaseClient {
|
||||
groupMap,
|
||||
});
|
||||
|
||||
this.options.headers = resolveHeaders(headers);
|
||||
this.options.headers = resolveHeaders({ headers });
|
||||
this.options.reverseProxyUrl = baseURL ?? null;
|
||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
@@ -1118,6 +1108,7 @@ ${convo}
|
||||
}
|
||||
|
||||
async chatCompletion({ payload, onProgress, abortController = null }) {
|
||||
const appConfig = this.options.req?.config;
|
||||
let error = null;
|
||||
let intermediateReply = [];
|
||||
const errorCallback = (err) => (error = err);
|
||||
@@ -1163,8 +1154,7 @@ ${convo}
|
||||
opts.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
/** @type {TAzureConfig | undefined} */
|
||||
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
|
||||
const azureConfig = appConfig?.endpoints?.[EModelEndpoint.azureOpenAI];
|
||||
|
||||
if (
|
||||
(this.azure && this.isVisionModel && azureConfig) ||
|
||||
@@ -1181,7 +1171,7 @@ ${convo}
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
opts.defaultHeaders = resolveHeaders(headers);
|
||||
opts.defaultHeaders = resolveHeaders({ headers });
|
||||
this.langchainProxy = extractBaseURL(baseURL);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
const { promptTokensEstimate } = require('openai-chat-tokens');
|
||||
const { EModelEndpoint, supportsBalanceCheck } = require('librechat-data-provider');
|
||||
const { formatFromLangChain } = require('~/app/clients/prompts');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const createStartHandler = ({
|
||||
context,
|
||||
conversationId,
|
||||
tokenBuffer = 0,
|
||||
initialMessageCount,
|
||||
manager,
|
||||
}) => {
|
||||
return async (_llm, _messages, runId, parentRunId, extraParams) => {
|
||||
const { invocation_params } = extraParams;
|
||||
const { model, functions, function_call } = invocation_params;
|
||||
const messages = _messages[0].map(formatFromLangChain);
|
||||
|
||||
logger.debug(`[createStartHandler] handleChatModelStart: ${context}`, {
|
||||
model,
|
||||
function_call,
|
||||
});
|
||||
|
||||
if (context !== 'title') {
|
||||
logger.debug(`[createStartHandler] handleChatModelStart: ${context}`, {
|
||||
functions,
|
||||
});
|
||||
}
|
||||
|
||||
const payload = { messages };
|
||||
let prelimPromptTokens = 1;
|
||||
|
||||
if (functions) {
|
||||
payload.functions = functions;
|
||||
prelimPromptTokens += 2;
|
||||
}
|
||||
|
||||
if (function_call) {
|
||||
payload.function_call = function_call;
|
||||
prelimPromptTokens -= 5;
|
||||
}
|
||||
|
||||
prelimPromptTokens += promptTokensEstimate(payload);
|
||||
logger.debug('[createStartHandler]', {
|
||||
prelimPromptTokens,
|
||||
tokenBuffer,
|
||||
});
|
||||
prelimPromptTokens += tokenBuffer;
|
||||
|
||||
try {
|
||||
const balance = await getBalanceConfig();
|
||||
if (balance?.enabled && supportsBalanceCheck[EModelEndpoint.openAI]) {
|
||||
const generations =
|
||||
initialMessageCount && messages.length > initialMessageCount
|
||||
? messages.slice(initialMessageCount)
|
||||
: null;
|
||||
await checkBalance({
|
||||
req: manager.req,
|
||||
res: manager.res,
|
||||
txData: {
|
||||
user: manager.user,
|
||||
tokenType: 'prompt',
|
||||
amount: prelimPromptTokens,
|
||||
debug: manager.debug,
|
||||
generations,
|
||||
model,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`[createStartHandler][${context}] checkBalance error`, err);
|
||||
manager.abortController.abort();
|
||||
if (context === 'summary' || context === 'plugins') {
|
||||
manager.addRun(runId, { conversationId, error: err.message });
|
||||
throw new Error(err);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
manager.addRun(runId, {
|
||||
model,
|
||||
messages,
|
||||
functions,
|
||||
function_call,
|
||||
runId,
|
||||
parentRunId,
|
||||
conversationId,
|
||||
prelimPromptTokens,
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = createStartHandler;
|
||||
@@ -1,5 +0,0 @@
|
||||
const createStartHandler = require('./createStartHandler');
|
||||
|
||||
module.exports = {
|
||||
createStartHandler,
|
||||
};
|
||||
@@ -1,105 +0,0 @@
|
||||
const { createStartHandler } = require('~/app/clients/callbacks');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class RunManager {
|
||||
constructor(fields) {
|
||||
const { req, res, abortController, debug } = fields;
|
||||
this.abortController = abortController;
|
||||
this.user = req.user.id;
|
||||
this.req = req;
|
||||
this.res = res;
|
||||
this.debug = debug;
|
||||
this.runs = new Map();
|
||||
this.convos = new Map();
|
||||
}
|
||||
|
||||
addRun(runId, runData) {
|
||||
if (!this.runs.has(runId)) {
|
||||
this.runs.set(runId, runData);
|
||||
if (runData.conversationId) {
|
||||
this.convos.set(runData.conversationId, runId);
|
||||
}
|
||||
return runData;
|
||||
} else {
|
||||
const existingData = this.runs.get(runId);
|
||||
const update = { ...existingData, ...runData };
|
||||
this.runs.set(runId, update);
|
||||
if (update.conversationId) {
|
||||
this.convos.set(update.conversationId, runId);
|
||||
}
|
||||
return update;
|
||||
}
|
||||
}
|
||||
|
||||
removeRun(runId) {
|
||||
if (this.runs.has(runId)) {
|
||||
this.runs.delete(runId);
|
||||
} else {
|
||||
logger.error(`[api/app/clients/llm/RunManager] Run with ID ${runId} does not exist.`);
|
||||
}
|
||||
}
|
||||
|
||||
getAllRuns() {
|
||||
return Array.from(this.runs.values());
|
||||
}
|
||||
|
||||
getRunById(runId) {
|
||||
return this.runs.get(runId);
|
||||
}
|
||||
|
||||
getRunByConversationId(conversationId) {
|
||||
const runId = this.convos.get(conversationId);
|
||||
return { run: this.runs.get(runId), runId };
|
||||
}
|
||||
|
||||
createCallbacks(metadata) {
|
||||
return [
|
||||
{
|
||||
handleChatModelStart: createStartHandler({ ...metadata, manager: this }),
|
||||
handleLLMEnd: async (output, runId, _parentRunId) => {
|
||||
const { llmOutput, ..._output } = output;
|
||||
logger.debug(`[RunManager] handleLLMEnd: ${JSON.stringify(metadata)}`, {
|
||||
runId,
|
||||
_parentRunId,
|
||||
llmOutput,
|
||||
});
|
||||
|
||||
if (metadata.context !== 'title') {
|
||||
logger.debug('[RunManager] handleLLMEnd:', {
|
||||
output: _output,
|
||||
});
|
||||
}
|
||||
|
||||
const { tokenUsage } = output.llmOutput;
|
||||
const run = this.getRunById(runId);
|
||||
this.removeRun(runId);
|
||||
|
||||
const txData = {
|
||||
user: this.user,
|
||||
model: run?.model ?? 'gpt-3.5-turbo',
|
||||
...metadata,
|
||||
};
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
},
|
||||
handleLLMError: async (err) => {
|
||||
logger.error(`[RunManager] handleLLMError: ${JSON.stringify(metadata)}`, err);
|
||||
if (metadata.context === 'title') {
|
||||
return;
|
||||
} else if (metadata.context === 'plugins') {
|
||||
throw new Error(err);
|
||||
}
|
||||
const { conversationId } = metadata;
|
||||
const { run } = this.getRunByConversationId(conversationId);
|
||||
if (run && run.error) {
|
||||
const { error } = run;
|
||||
throw new Error(error);
|
||||
}
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = RunManager;
|
||||
@@ -1,9 +1,7 @@
|
||||
const createLLM = require('./createLLM');
|
||||
const RunManager = require('./RunManager');
|
||||
const createCoherePayload = require('./createCoherePayload');
|
||||
|
||||
module.exports = {
|
||||
createLLM,
|
||||
RunManager,
|
||||
createCoherePayload,
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const axios = require('axios');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { isEnabled, generateShortLivedToken } = require('@librechat/api');
|
||||
|
||||
const footer = `Use the context as your learned knowledge to better answer the user.
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ describe('AnthropicClient', () => {
|
||||
});
|
||||
|
||||
describe('Claude 4 model headers', () => {
|
||||
it('should add "prompt-caching" beta header for claude-sonnet-4 model', () => {
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
model: 'claude-sonnet-4-20250514',
|
||||
@@ -255,10 +255,30 @@ describe('AnthropicClient', () => {
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31',
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model formats', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelVariations = [
|
||||
'claude-sonnet-4-20250514',
|
||||
'claude-sonnet-4-latest',
|
||||
'anthropic/claude-sonnet-4-20250514',
|
||||
];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const modelOptions = { model };
|
||||
client.setOptions({ modelOptions, promptCache: true });
|
||||
const anthropicClient = client.getClient(modelOptions);
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" beta header for claude-opus-4 model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
@@ -273,20 +293,6 @@ describe('AnthropicClient', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" beta header for claude-4-sonnet model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
model: 'claude-4-sonnet-20250514',
|
||||
};
|
||||
client.setOptions({ modelOptions, promptCache: true });
|
||||
const anthropicClient = client.getClient(modelOptions);
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31',
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" beta header for claude-4-opus model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
|
||||
@@ -2,6 +2,14 @@ const { Constants } = require('librechat-data-provider');
|
||||
const { initializeFakeClient } = require('./FakeClient');
|
||||
|
||||
jest.mock('~/db/connect');
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
// Default app config for tests
|
||||
paths: { uploads: '/tmp' },
|
||||
fileStrategy: 'local',
|
||||
memory: { disabled: false },
|
||||
}),
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
User: jest.fn(),
|
||||
Key: jest.fn(),
|
||||
@@ -579,6 +587,8 @@ describe('BaseClient', () => {
|
||||
expect(onStart).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ text: 'Hello, world!' }),
|
||||
expect.any(String),
|
||||
/** `isNewConvo` */
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const { getModelMaxTokens } = require('@librechat/api');
|
||||
const BaseClient = require('../BaseClient');
|
||||
const { getModelMaxTokens } = require('../../../utils');
|
||||
|
||||
class FakeClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const availableTools = require('./manifest.json');
|
||||
const manifest = require('./manifest');
|
||||
|
||||
// Structured Tools
|
||||
const DALLE3 = require('./structured/DALLE3');
|
||||
@@ -13,23 +13,8 @@ const TraversaalSearch = require('./structured/TraversaalSearch');
|
||||
const createOpenAIImageTools = require('./structured/OpenAIImageTools');
|
||||
const TavilySearchResults = require('./structured/TavilySearchResults');
|
||||
|
||||
/** @type {Record<string, TPlugin | undefined>} */
|
||||
const manifestToolMap = {};
|
||||
|
||||
/** @type {Array<TPlugin>} */
|
||||
const toolkits = [];
|
||||
|
||||
availableTools.forEach((tool) => {
|
||||
manifestToolMap[tool.pluginKey] = tool;
|
||||
if (tool.toolkit === true) {
|
||||
toolkits.push(tool);
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
toolkits,
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
...manifest,
|
||||
// Structured Tools
|
||||
DALLE3,
|
||||
FluxAPI,
|
||||
|
||||
20
api/app/clients/tools/manifest.js
Normal file
20
api/app/clients/tools/manifest.js
Normal file
@@ -0,0 +1,20 @@
|
||||
const availableTools = require('./manifest.json');
|
||||
|
||||
/** @type {Record<string, TPlugin | undefined>} */
|
||||
const manifestToolMap = {};
|
||||
|
||||
/** @type {Array<TPlugin>} */
|
||||
const toolkits = [];
|
||||
|
||||
availableTools.forEach((tool) => {
|
||||
manifestToolMap[tool.pluginKey] = tool;
|
||||
if (tool.toolkit === true) {
|
||||
toolkits.push(tool);
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
toolkits,
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
};
|
||||
@@ -49,7 +49,7 @@
|
||||
"pluginKey": "image_gen_oai",
|
||||
"toolkit": true,
|
||||
"description": "Image Generation and Editing using OpenAI's latest state-of-the-art models",
|
||||
"icon": "/assets/image_gen_oai.png",
|
||||
"icon": "assets/image_gen_oai.png",
|
||||
"authConfig": [
|
||||
{
|
||||
"authField": "IMAGE_GEN_OAI_API_KEY",
|
||||
@@ -75,7 +75,7 @@
|
||||
"name": "Browser",
|
||||
"pluginKey": "web-browser",
|
||||
"description": "Scrape and summarize webpage data",
|
||||
"icon": "/assets/web-browser.svg",
|
||||
"icon": "assets/web-browser.svg",
|
||||
"authConfig": [
|
||||
{
|
||||
"authField": "OPENAI_API_KEY",
|
||||
@@ -170,7 +170,7 @@
|
||||
"name": "OpenWeather",
|
||||
"pluginKey": "open_weather",
|
||||
"description": "Get weather forecasts and historical data from the OpenWeather API",
|
||||
"icon": "/assets/openweather.png",
|
||||
"icon": "assets/openweather.png",
|
||||
"authConfig": [
|
||||
{
|
||||
"authField": "OPENWEATHER_API_KEY",
|
||||
|
||||
@@ -5,10 +5,10 @@ const fetch = require('node-fetch');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getImageBasename } = require('@librechat/api');
|
||||
const { FileContext, ContentTypes } = require('librechat-data-provider');
|
||||
const { getImageBasename } = require('~/server/services/Files/images');
|
||||
const extractBaseURL = require('~/utils/extractBaseURL');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const displayMessage =
|
||||
"DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
|
||||
|
||||
@@ -1,69 +1,16 @@
|
||||
const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { v4 } = require('uuid');
|
||||
const OpenAI = require('openai');
|
||||
const FormData = require('form-data');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logAxiosError, oaiToolkit } = require('@librechat/api');
|
||||
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const extractBaseURL = require('~/utils/extractBaseURL');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
/** Default descriptions for image generation tool */
|
||||
const DEFAULT_IMAGE_GEN_DESCRIPTION = `
|
||||
Generates high-quality, original images based solely on text, not using any uploaded reference images.
|
||||
|
||||
When to use \`image_gen_oai\`:
|
||||
- To create entirely new images from detailed text descriptions that do NOT reference any image files.
|
||||
|
||||
When NOT to use \`image_gen_oai\`:
|
||||
- If the user has uploaded any images and requests modifications, enhancements, or remixing based on those uploads → use \`image_edit_oai\` instead.
|
||||
|
||||
Generated image IDs will be returned in the response, so you can refer to them in future requests made to \`image_edit_oai\`.
|
||||
`.trim();
|
||||
|
||||
/** Default description for image editing tool */
|
||||
const DEFAULT_IMAGE_EDIT_DESCRIPTION =
|
||||
`Generates high-quality, original images based on text and one or more uploaded/referenced images.
|
||||
|
||||
When to use \`image_edit_oai\`:
|
||||
- The user wants to modify, extend, or remix one **or more** uploaded images, either:
|
||||
- Previously generated, or in the current request (both to be included in the \`image_ids\` array).
|
||||
- Always when the user refers to uploaded images for editing, enhancement, remixing, style transfer, or combining elements.
|
||||
- Any current or existing images are to be used as visual guides.
|
||||
- If there are any files in the current request, they are more likely than not expected as references for image edit requests.
|
||||
|
||||
When NOT to use \`image_edit_oai\`:
|
||||
- Brand-new generations that do not rely on an existing image → use \`image_gen_oai\` instead.
|
||||
|
||||
Both generated and referenced image IDs will be returned in the response, so you can refer to them in future requests made to \`image_edit_oai\`.
|
||||
`.trim();
|
||||
|
||||
/** Default prompt descriptions */
|
||||
const DEFAULT_IMAGE_GEN_PROMPT_DESCRIPTION = `Describe the image you want in detail.
|
||||
Be highly specific—break your idea into layers:
|
||||
(1) main concept and subject,
|
||||
(2) composition and position,
|
||||
(3) lighting and mood,
|
||||
(4) style, medium, or camera details,
|
||||
(5) important features (age, expression, clothing, etc.),
|
||||
(6) background.
|
||||
Use positive, descriptive language and specify what should be included, not what to avoid.
|
||||
List number and characteristics of people/objects, and mention style/technical requirements (e.g., "DSLR photo, 85mm lens, golden hour").
|
||||
Do not reference any uploaded images—use for new image creation from text only.`;
|
||||
|
||||
const DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION = `Describe the changes, enhancements, or new ideas to apply to the uploaded image(s).
|
||||
Be highly specific—break your request into layers:
|
||||
(1) main concept or transformation,
|
||||
(2) specific edits/replacements or composition guidance,
|
||||
(3) desired style, mood, or technique,
|
||||
(4) features/items to keep, change, or add (such as objects, people, clothing, lighting, etc.).
|
||||
Use positive, descriptive language and clarify what should be included or changed, not what to avoid.
|
||||
Always base this prompt on the most recently uploaded reference images.`;
|
||||
|
||||
const displayMessage =
|
||||
"The tool displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
|
||||
|
||||
@@ -91,22 +38,6 @@ function returnValue(value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
const getImageGenDescription = () => {
|
||||
return process.env.IMAGE_GEN_OAI_DESCRIPTION || DEFAULT_IMAGE_GEN_DESCRIPTION;
|
||||
};
|
||||
|
||||
const getImageEditDescription = () => {
|
||||
return process.env.IMAGE_EDIT_OAI_DESCRIPTION || DEFAULT_IMAGE_EDIT_DESCRIPTION;
|
||||
};
|
||||
|
||||
const getImageGenPromptDescription = () => {
|
||||
return process.env.IMAGE_GEN_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_GEN_PROMPT_DESCRIPTION;
|
||||
};
|
||||
|
||||
const getImageEditPromptDescription = () => {
|
||||
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
|
||||
};
|
||||
|
||||
function createAbortHandler() {
|
||||
return function () {
|
||||
logger.debug('[ImageGenOAI] Image generation aborted');
|
||||
@@ -121,7 +52,9 @@ function createAbortHandler() {
|
||||
* @param {string} fields.IMAGE_GEN_OAI_API_KEY - The OpenAI API key
|
||||
* @param {boolean} [fields.override] - Whether to override the API key check, necessary for app initialization
|
||||
* @param {MongoFile[]} [fields.imageFiles] - The images to be used for editing
|
||||
* @returns {Array} - Array of image tools
|
||||
* @param {string} [fields.imageOutputType] - The image output type configuration
|
||||
* @param {string} [fields.fileStrategy] - The file storage strategy
|
||||
* @returns {Array<ReturnType<tool>>} - Array of image tools
|
||||
*/
|
||||
function createOpenAIImageTools(fields = {}) {
|
||||
/** @type {boolean} Used to initialize the Tool without necessary variables. */
|
||||
@@ -131,8 +64,8 @@ function createOpenAIImageTools(fields = {}) {
|
||||
throw new Error('This tool is only available for agents.');
|
||||
}
|
||||
const { req } = fields;
|
||||
const imageOutputType = req?.app.locals.imageOutputType || EImageOutputType.PNG;
|
||||
const appFileStrategy = req?.app.locals.fileStrategy;
|
||||
const imageOutputType = fields.imageOutputType || EImageOutputType.PNG;
|
||||
const appFileStrategy = fields.fileStrategy;
|
||||
|
||||
const getApiKey = () => {
|
||||
const apiKey = process.env.IMAGE_GEN_OAI_API_KEY ?? '';
|
||||
@@ -285,46 +218,7 @@ Error Message: ${error.message}`);
|
||||
];
|
||||
return [response, { content, file_ids }];
|
||||
},
|
||||
{
|
||||
name: 'image_gen_oai',
|
||||
description: getImageGenDescription(),
|
||||
schema: z.object({
|
||||
prompt: z.string().max(32000).describe(getImageGenPromptDescription()),
|
||||
background: z
|
||||
.enum(['transparent', 'opaque', 'auto'])
|
||||
.optional()
|
||||
.describe(
|
||||
'Sets transparency for the background. Must be one of transparent, opaque or auto (default). When transparent, the output format should be png or webp.',
|
||||
),
|
||||
/*
|
||||
n: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.max(10)
|
||||
.optional()
|
||||
.describe('The number of images to generate. Must be between 1 and 10.'),
|
||||
output_compression: z
|
||||
.number()
|
||||
.int()
|
||||
.min(0)
|
||||
.max(100)
|
||||
.optional()
|
||||
.describe('The compression level (0-100%) for webp or jpeg formats. Defaults to 100.'),
|
||||
*/
|
||||
quality: z
|
||||
.enum(['auto', 'high', 'medium', 'low'])
|
||||
.optional()
|
||||
.describe('The quality of the image. One of auto (default), high, medium, or low.'),
|
||||
size: z
|
||||
.enum(['auto', '1024x1024', '1536x1024', '1024x1536'])
|
||||
.optional()
|
||||
.describe(
|
||||
'The size of the generated image. One of 1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), or auto (default).',
|
||||
),
|
||||
}),
|
||||
responseFormat: 'content_and_artifact',
|
||||
},
|
||||
oaiToolkit.image_gen_oai,
|
||||
);
|
||||
|
||||
/**
|
||||
@@ -517,48 +411,7 @@ Error Message: ${error.message || 'Unknown error'}`);
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'image_edit_oai',
|
||||
description: getImageEditDescription(),
|
||||
schema: z.object({
|
||||
image_ids: z
|
||||
.array(z.string())
|
||||
.min(1)
|
||||
.describe(
|
||||
`
|
||||
IDs (image ID strings) of previously generated or uploaded images that should guide the edit.
|
||||
|
||||
Guidelines:
|
||||
- If the user's request depends on any prior image(s), copy their image IDs into the \`image_ids\` array (in the same order the user refers to them).
|
||||
- Never invent or hallucinate IDs; only use IDs that are still visible in the conversation context.
|
||||
- If no earlier image is relevant, omit the field entirely.
|
||||
`.trim(),
|
||||
),
|
||||
prompt: z.string().max(32000).describe(getImageEditPromptDescription()),
|
||||
/*
|
||||
n: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.max(10)
|
||||
.optional()
|
||||
.describe('The number of images to generate. Must be between 1 and 10. Defaults to 1.'),
|
||||
*/
|
||||
quality: z
|
||||
.enum(['auto', 'high', 'medium', 'low'])
|
||||
.optional()
|
||||
.describe(
|
||||
'The quality of the image. One of auto (default), high, medium, or low. High/medium/low only supported for gpt-image-1.',
|
||||
),
|
||||
size: z
|
||||
.enum(['auto', '1024x1024', '1536x1024', '1024x1536', '256x256', '512x512'])
|
||||
.optional()
|
||||
.describe(
|
||||
'The size of the generated images. For gpt-image-1: auto (default), 1024x1024, 1536x1024, 1024x1536. For dall-e-2: 256x256, 512x512, 1024x1024.',
|
||||
),
|
||||
}),
|
||||
responseFormat: 'content_and_artifact',
|
||||
},
|
||||
oaiToolkit.image_edit_oai,
|
||||
);
|
||||
|
||||
return [imageGenTool, imageEditTool];
|
||||
|
||||
@@ -11,14 +11,14 @@ const paths = require('~/config/paths');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const displayMessage =
|
||||
'Stable Diffusion displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.';
|
||||
"Stable Diffusion displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
|
||||
|
||||
class StableDiffusionAPI extends Tool {
|
||||
constructor(fields) {
|
||||
super();
|
||||
/** @type {string} User ID */
|
||||
this.userId = fields.userId;
|
||||
/** @type {Express.Request | undefined} Express Request object, only provided by ToolService */
|
||||
/** @type {ServerRequest | undefined} Express Request object, only provided by ToolService */
|
||||
this.req = fields.req;
|
||||
/** @type {boolean} Used to initialize the Tool without necessary variables. */
|
||||
this.override = fields.override ?? false;
|
||||
@@ -44,7 +44,7 @@ class StableDiffusionAPI extends Tool {
|
||||
// "negative_prompt":"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, out of frame, low quality, ugly, mutation, deformed"
|
||||
// - Generate images only once per human query unless explicitly requested by the user`;
|
||||
this.description =
|
||||
'You can generate images using text with \'stable-diffusion\'. This tool is exclusively for visual content.';
|
||||
"You can generate images using text with 'stable-diffusion'. This tool is exclusively for visual content.";
|
||||
this.schema = z.object({
|
||||
prompt: z
|
||||
.string()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
const { z } = require('zod');
|
||||
const { ytToolkit } = require('@librechat/api');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { youtube } = require('@googleapis/youtube');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { YoutubeTranscript } = require('youtube-transcript');
|
||||
const { getApiKey } = require('./credentials');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
function extractVideoId(url) {
|
||||
const rawIdRegex = /^[a-zA-Z0-9_-]{11}$/;
|
||||
@@ -29,7 +29,7 @@ function parseTranscript(transcriptResponse) {
|
||||
.map((entry) => entry.text.trim())
|
||||
.filter((text) => text)
|
||||
.join(' ')
|
||||
.replaceAll('&#39;', '\'');
|
||||
.replaceAll('&#39;', "'");
|
||||
}
|
||||
|
||||
function createYouTubeTools(fields = {}) {
|
||||
@@ -42,160 +42,94 @@ function createYouTubeTools(fields = {}) {
|
||||
auth: apiKey,
|
||||
});
|
||||
|
||||
const searchTool = tool(
|
||||
async ({ query, maxResults = 5 }) => {
|
||||
const response = await youtubeClient.search.list({
|
||||
part: 'snippet',
|
||||
q: query,
|
||||
type: 'video',
|
||||
maxResults: maxResults || 5,
|
||||
});
|
||||
const result = response.data.items.map((item) => ({
|
||||
title: item.snippet.title,
|
||||
description: item.snippet.description,
|
||||
url: `https://www.youtube.com/watch?v=${item.id.videoId}`,
|
||||
}));
|
||||
return JSON.stringify(result, null, 2);
|
||||
},
|
||||
{
|
||||
name: 'youtube_search',
|
||||
description: `Search for YouTube videos by keyword or phrase.
|
||||
- Required: query (search terms to find videos)
|
||||
- Optional: maxResults (number of videos to return, 1-50, default: 5)
|
||||
- Returns: List of videos with titles, descriptions, and URLs
|
||||
- Use for: Finding specific videos, exploring content, research
|
||||
Example: query="cooking pasta tutorials" maxResults=3`,
|
||||
schema: z.object({
|
||||
query: z.string().describe('Search query terms'),
|
||||
maxResults: z.number().int().min(1).max(50).optional().describe('Number of results (1-50)'),
|
||||
}),
|
||||
},
|
||||
);
|
||||
const searchTool = tool(async ({ query, maxResults = 5 }) => {
|
||||
const response = await youtubeClient.search.list({
|
||||
part: 'snippet',
|
||||
q: query,
|
||||
type: 'video',
|
||||
maxResults: maxResults || 5,
|
||||
});
|
||||
const result = response.data.items.map((item) => ({
|
||||
title: item.snippet.title,
|
||||
description: item.snippet.description,
|
||||
url: `https://www.youtube.com/watch?v=${item.id.videoId}`,
|
||||
}));
|
||||
return JSON.stringify(result, null, 2);
|
||||
}, ytToolkit.youtube_search);
|
||||
|
||||
const infoTool = tool(
|
||||
async ({ url }) => {
|
||||
const videoId = extractVideoId(url);
|
||||
if (!videoId) {
|
||||
throw new Error('Invalid YouTube URL or video ID');
|
||||
}
|
||||
const infoTool = tool(async ({ url }) => {
|
||||
const videoId = extractVideoId(url);
|
||||
if (!videoId) {
|
||||
throw new Error('Invalid YouTube URL or video ID');
|
||||
}
|
||||
|
||||
const response = await youtubeClient.videos.list({
|
||||
part: 'snippet,statistics',
|
||||
id: videoId,
|
||||
});
|
||||
const response = await youtubeClient.videos.list({
|
||||
part: 'snippet,statistics',
|
||||
id: videoId,
|
||||
});
|
||||
|
||||
if (!response.data.items?.length) {
|
||||
throw new Error('Video not found');
|
||||
}
|
||||
const video = response.data.items[0];
|
||||
if (!response.data.items?.length) {
|
||||
throw new Error('Video not found');
|
||||
}
|
||||
const video = response.data.items[0];
|
||||
|
||||
const result = {
|
||||
title: video.snippet.title,
|
||||
description: video.snippet.description,
|
||||
views: video.statistics.viewCount,
|
||||
likes: video.statistics.likeCount,
|
||||
comments: video.statistics.commentCount,
|
||||
};
|
||||
return JSON.stringify(result, null, 2);
|
||||
},
|
||||
{
|
||||
name: 'youtube_info',
|
||||
description: `Get detailed metadata and statistics for a specific YouTube video.
|
||||
- Required: url (full YouTube URL or video ID)
|
||||
- Returns: Video title, description, view count, like count, comment count
|
||||
- Use for: Getting video metrics and basic metadata
|
||||
- DO NOT USE FOR VIDEO SUMMARIES, USE TRANSCRIPTS FOR COMPREHENSIVE ANALYSIS
|
||||
- Accepts both full URLs and video IDs
|
||||
Example: url="https://youtube.com/watch?v=abc123" or url="abc123"`,
|
||||
schema: z.object({
|
||||
url: z.string().describe('YouTube video URL or ID'),
|
||||
}),
|
||||
},
|
||||
);
|
||||
const result = {
|
||||
title: video.snippet.title,
|
||||
description: video.snippet.description,
|
||||
views: video.statistics.viewCount,
|
||||
likes: video.statistics.likeCount,
|
||||
comments: video.statistics.commentCount,
|
||||
};
|
||||
return JSON.stringify(result, null, 2);
|
||||
}, ytToolkit.youtube_info);
|
||||
|
||||
const commentsTool = tool(
|
||||
async ({ url, maxResults = 10 }) => {
|
||||
const videoId = extractVideoId(url);
|
||||
if (!videoId) {
|
||||
throw new Error('Invalid YouTube URL or video ID');
|
||||
}
|
||||
const commentsTool = tool(async ({ url, maxResults = 10 }) => {
|
||||
const videoId = extractVideoId(url);
|
||||
if (!videoId) {
|
||||
throw new Error('Invalid YouTube URL or video ID');
|
||||
}
|
||||
|
||||
const response = await youtubeClient.commentThreads.list({
|
||||
part: 'snippet',
|
||||
videoId,
|
||||
maxResults: maxResults || 10,
|
||||
});
|
||||
const response = await youtubeClient.commentThreads.list({
|
||||
part: 'snippet',
|
||||
videoId,
|
||||
maxResults: maxResults || 10,
|
||||
});
|
||||
|
||||
const result = response.data.items.map((item) => ({
|
||||
author: item.snippet.topLevelComment.snippet.authorDisplayName,
|
||||
text: item.snippet.topLevelComment.snippet.textDisplay,
|
||||
likes: item.snippet.topLevelComment.snippet.likeCount,
|
||||
}));
|
||||
return JSON.stringify(result, null, 2);
|
||||
},
|
||||
{
|
||||
name: 'youtube_comments',
|
||||
description: `Retrieve top-level comments from a YouTube video.
|
||||
- Required: url (full YouTube URL or video ID)
|
||||
- Optional: maxResults (number of comments, 1-50, default: 10)
|
||||
- Returns: Comment text, author names, like counts
|
||||
- Use for: Sentiment analysis, audience feedback, engagement review
|
||||
Example: url="abc123" maxResults=20`,
|
||||
schema: z.object({
|
||||
url: z.string().describe('YouTube video URL or ID'),
|
||||
maxResults: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.max(50)
|
||||
.optional()
|
||||
.describe('Number of comments to retrieve'),
|
||||
}),
|
||||
},
|
||||
);
|
||||
const result = response.data.items.map((item) => ({
|
||||
author: item.snippet.topLevelComment.snippet.authorDisplayName,
|
||||
text: item.snippet.topLevelComment.snippet.textDisplay,
|
||||
likes: item.snippet.topLevelComment.snippet.likeCount,
|
||||
}));
|
||||
return JSON.stringify(result, null, 2);
|
||||
}, ytToolkit.youtube_comments);
|
||||
|
||||
const transcriptTool = tool(
|
||||
async ({ url }) => {
|
||||
const videoId = extractVideoId(url);
|
||||
if (!videoId) {
|
||||
throw new Error('Invalid YouTube URL or video ID');
|
||||
const transcriptTool = tool(async ({ url }) => {
|
||||
const videoId = extractVideoId(url);
|
||||
if (!videoId) {
|
||||
throw new Error('Invalid YouTube URL or video ID');
|
||||
}
|
||||
|
||||
try {
|
||||
try {
|
||||
const transcript = await YoutubeTranscript.fetchTranscript(videoId, { lang: 'en' });
|
||||
return parseTranscript(transcript);
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
}
|
||||
|
||||
try {
|
||||
try {
|
||||
const transcript = await YoutubeTranscript.fetchTranscript(videoId, { lang: 'en' });
|
||||
return parseTranscript(transcript);
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
}
|
||||
|
||||
try {
|
||||
const transcript = await YoutubeTranscript.fetchTranscript(videoId, { lang: 'de' });
|
||||
return parseTranscript(transcript);
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
}
|
||||
|
||||
const transcript = await YoutubeTranscript.fetchTranscript(videoId);
|
||||
const transcript = await YoutubeTranscript.fetchTranscript(videoId, { lang: 'de' });
|
||||
return parseTranscript(transcript);
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to fetch transcript: ${error.message}`);
|
||||
} catch (e) {
|
||||
logger.error(e);
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'youtube_transcript',
|
||||
description: `Fetch and parse the transcript/captions of a YouTube video.
|
||||
- Required: url (full YouTube URL or video ID)
|
||||
- Returns: Full video transcript as plain text
|
||||
- Use for: Content analysis, summarization, translation reference
|
||||
- This is the "Go-to" tool for analyzing actual video content
|
||||
- Attempts to fetch English first, then German, then any available language
|
||||
Example: url="https://youtube.com/watch?v=abc123"`,
|
||||
schema: z.object({
|
||||
url: z.string().describe('YouTube video URL or ID'),
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
const transcript = await YoutubeTranscript.fetchTranscript(videoId);
|
||||
return parseTranscript(transcript);
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to fetch transcript: ${error.message}`);
|
||||
}
|
||||
}, ytToolkit.youtube_transcript);
|
||||
|
||||
return [searchTool, infoTool, commentsTool, transcriptTool];
|
||||
}
|
||||
|
||||
@@ -1,43 +1,9 @@
|
||||
const DALLE3 = require('../DALLE3');
|
||||
const { ProxyAgent } = require('undici');
|
||||
|
||||
jest.mock('tiktoken');
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
getImageBasename: jest.fn().mockImplementation((url) => {
|
||||
const parts = url.split('/');
|
||||
const lastPart = parts.pop();
|
||||
const imageExtensionRegex = /\.(jpg|jpeg|png|gif|bmp|tiff|svg)$/i;
|
||||
if (imageExtensionRegex.test(lastPart)) {
|
||||
return lastPart;
|
||||
}
|
||||
return '';
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('fs', () => {
|
||||
return {
|
||||
existsSync: jest.fn(),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: {
|
||||
writeFile: jest.fn(),
|
||||
readFile: jest.fn(),
|
||||
unlink: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('path', () => {
|
||||
return {
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn(),
|
||||
relative: jest.fn(),
|
||||
extname: jest.fn().mockImplementation((filename) => {
|
||||
return filename.slice(filename.lastIndexOf('.'));
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe('DALLE3 Proxy Configuration', () => {
|
||||
let originalEnv;
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
const OpenAI = require('openai');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const DALLE3 = require('../DALLE3');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
jest.mock('openai');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => {
|
||||
return {
|
||||
logger: {
|
||||
@@ -26,25 +25,6 @@ jest.mock('tiktoken', () => {
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
getImageBasename: jest.fn().mockImplementation((url) => {
|
||||
// Split the URL by '/'
|
||||
const parts = url.split('/');
|
||||
|
||||
// Get the last part of the URL
|
||||
const lastPart = parts.pop();
|
||||
|
||||
// Check if the last part of the URL matches the image extension regex
|
||||
const imageExtensionRegex = /\.(jpg|jpeg|png|gif|bmp|tiff|svg)$/i;
|
||||
if (imageExtensionRegex.test(lastPart)) {
|
||||
return lastPart;
|
||||
}
|
||||
|
||||
// If the regex test fails, return an empty string
|
||||
return '';
|
||||
}),
|
||||
}));
|
||||
|
||||
const generate = jest.fn();
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
|
||||
@@ -2,9 +2,9 @@ const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateShortLivedToken } = require('@librechat/api');
|
||||
const { Tools, EToolResources } = require('librechat-data-provider');
|
||||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
/**
|
||||
@@ -71,9 +71,10 @@ const primeFiles = async (options) => {
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Array<{ file_id: string; filename: string }>} options.files
|
||||
* @param {string} [options.entity_id]
|
||||
* @param {boolean} [options.fileCitations=false] - Whether to include citation instructions
|
||||
* @returns
|
||||
*/
|
||||
const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
const createFileSearchTool = async ({ req, files, entity_id, fileCitations = false }) => {
|
||||
return tool(
|
||||
async ({ query }) => {
|
||||
if (files.length === 0) {
|
||||
@@ -142,9 +143,9 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
const formattedString = formattedResults
|
||||
.map(
|
||||
(result, index) =>
|
||||
`File: ${result.filename}\nAnchor: \\ue202turn0file${index} (${result.filename})\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${
|
||||
result.content
|
||||
}\n`,
|
||||
`File: ${result.filename}${
|
||||
fileCitations ? `\nAnchor: \\ue202turn0file${index} (${result.filename})` : ''
|
||||
}\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${result.content}\n`,
|
||||
)
|
||||
.join('\n---\n');
|
||||
|
||||
@@ -158,12 +159,14 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
pageRelevance: result.page ? { [result.page]: 1.0 - result.distance } : {},
|
||||
}));
|
||||
|
||||
return [formattedString, { [Tools.file_search]: { sources } }];
|
||||
return [formattedString, { [Tools.file_search]: { sources, fileCitations } }];
|
||||
},
|
||||
{
|
||||
name: Tools.file_search,
|
||||
responseFormat: 'content_and_artifact',
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.${
|
||||
fileCitations
|
||||
? `
|
||||
|
||||
**CITE FILE SEARCH RESULTS:**
|
||||
Use anchor markers immediately after statements derived from file content. Reference the filename in your text:
|
||||
@@ -171,7 +174,9 @@ Use anchor markers immediately after statements derived from file content. Refer
|
||||
- Page reference: "According to report.docx... \\ue202turn0file1"
|
||||
- Multi-file: "Multiple sources confirm... \\ue200\\ue202turn0file0\\ue202turn0file1\\ue201"
|
||||
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`,
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`
|
||||
: ''
|
||||
}`,
|
||||
schema: z.object({
|
||||
query: z
|
||||
.string()
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||
const { mcpToolPattern, loadWebSearchAuth, checkAccess } = require('@librechat/api');
|
||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
Permissions,
|
||||
EToolResources,
|
||||
PermissionTypes,
|
||||
replaceSpecialVars,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
@@ -24,9 +31,10 @@ const {
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { createMCPTool } = require('~/server/services/MCP');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
||||
@@ -121,27 +129,37 @@ const getAuthFields = (toolKey) => {
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {object} object
|
||||
* @param {string} object.user
|
||||
* @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent]
|
||||
* @param {string} [object.model]
|
||||
* @param {EModelEndpoint} [object.endpoint]
|
||||
* @param {LoadToolOptions} [object.options]
|
||||
* @param {boolean} [object.useSpecs]
|
||||
* @param {Array<string>} object.tools
|
||||
* @param {boolean} [object.functions]
|
||||
* @param {boolean} [object.returnMap]
|
||||
* @param {object} params
|
||||
* @param {string} params.user
|
||||
* @param {Record<string, Record<string, string>>} [object.userMCPAuthMap]
|
||||
* @param {AbortSignal} [object.signal]
|
||||
* @param {Pick<Agent, 'id' | 'provider' | 'model'>} [params.agent]
|
||||
* @param {string} [params.model]
|
||||
* @param {EModelEndpoint} [params.endpoint]
|
||||
* @param {LoadToolOptions} [params.options]
|
||||
* @param {boolean} [params.useSpecs]
|
||||
* @param {Array<string>} params.tools
|
||||
* @param {boolean} [params.functions]
|
||||
* @param {boolean} [params.returnMap]
|
||||
* @param {AppConfig['webSearch']} [params.webSearch]
|
||||
* @param {AppConfig['fileStrategy']} [params.fileStrategy]
|
||||
* @param {AppConfig['imageOutputType']} [params.imageOutputType]
|
||||
* @returns {Promise<{ loadedTools: Tool[], toolContextMap: Object<string, any> } | Record<string,Tool>>}
|
||||
*/
|
||||
const loadTools = async ({
|
||||
user,
|
||||
agent,
|
||||
model,
|
||||
signal,
|
||||
endpoint,
|
||||
userMCPAuthMap,
|
||||
tools = [],
|
||||
options = {},
|
||||
functions = true,
|
||||
returnMap = false,
|
||||
webSearch,
|
||||
fileStrategy,
|
||||
imageOutputType,
|
||||
}) => {
|
||||
const toolConstructors = {
|
||||
flux: FluxAPI,
|
||||
@@ -200,6 +218,8 @@ const loadTools = async ({
|
||||
...authValues,
|
||||
isAgent: !!agent,
|
||||
req: options.req,
|
||||
imageOutputType,
|
||||
fileStrategy,
|
||||
imageFiles,
|
||||
});
|
||||
},
|
||||
@@ -215,7 +235,7 @@ const loadTools = async ({
|
||||
const imageGenOptions = {
|
||||
isAgent: !!agent,
|
||||
req: options.req,
|
||||
fileStrategy: options.fileStrategy,
|
||||
fileStrategy,
|
||||
processFileURL: options.processFileURL,
|
||||
returnMetadata: options.returnMetadata,
|
||||
uploadImageBuffer: options.uploadImageBuffer,
|
||||
@@ -231,6 +251,7 @@ const loadTools = async ({
|
||||
/** @type {Record<string, string>} */
|
||||
const toolContextMap = {};
|
||||
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
||||
const requestedMCPTools = {};
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool === Tools.execute_code) {
|
||||
@@ -268,15 +289,36 @@ const loadTools = async ({
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
|
||||
|
||||
/** @type {boolean | undefined} Check if user has FILE_CITATIONS permission */
|
||||
let fileCitations;
|
||||
if (fileCitations == null && options.req?.user != null) {
|
||||
try {
|
||||
fileCitations = await checkAccess({
|
||||
user: options.req.user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[handleTools] FILE_CITATIONS permission check failed:', error);
|
||||
fileCitations = false;
|
||||
}
|
||||
}
|
||||
|
||||
return createFileSearchTool({
|
||||
req: options.req,
|
||||
files,
|
||||
entity_id: agent?.id,
|
||||
fileCitations,
|
||||
});
|
||||
};
|
||||
continue;
|
||||
} else if (tool === Tools.web_search) {
|
||||
const webSearchConfig = options?.req?.app?.locals?.webSearch;
|
||||
const result = await loadWebSearchAuth({
|
||||
userId: user,
|
||||
loadAuthValues,
|
||||
webSearchConfig,
|
||||
webSearchConfig: webSearch,
|
||||
});
|
||||
const { onSearchResults, onGetHighlights } = options?.[Tools.web_search] ?? {};
|
||||
requestedTools[tool] = async () => {
|
||||
@@ -299,14 +341,45 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
};
|
||||
continue;
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
requestedTools[tool] = async () =>
|
||||
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
|
||||
if (toolName === Constants.mcp_server) {
|
||||
/** Placeholder used for UI purposes */
|
||||
continue;
|
||||
}
|
||||
if (serverName && options.req?.config?.mcpConfig?.[serverName] == null) {
|
||||
logger.warn(
|
||||
`MCP server "${serverName}" for "${toolName}" tool is not configured${agent?.id != null && agent.id ? ` but attached to "${agent.id}"` : ''}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (toolName === Constants.mcp_all) {
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTools({
|
||||
req: options.req,
|
||||
res: options.res,
|
||||
index,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
});
|
||||
requestedMCPTools[serverName] = [currentMCPGenerator];
|
||||
continue;
|
||||
}
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTool({
|
||||
index,
|
||||
req: options.req,
|
||||
res: options.res,
|
||||
toolKey: tool,
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
});
|
||||
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
|
||||
requestedMCPTools[serverName].push(currentMCPGenerator);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -346,6 +419,34 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
}
|
||||
|
||||
const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []);
|
||||
const mcpToolPromises = [];
|
||||
/** MCP server tools are initialized sequentially by server */
|
||||
let index = -1;
|
||||
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
|
||||
index++;
|
||||
for (const generator of generators) {
|
||||
try {
|
||||
if (generator && generators.length === 1) {
|
||||
mcpToolPromises.push(
|
||||
generator(index).catch((error) => {
|
||||
logger.error(`Error loading ${serverName} tools:`, error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const mcpTool = await generator(index);
|
||||
if (Array.isArray(mcpTool)) {
|
||||
loadedTools.push(...mcpTool);
|
||||
} else if (mcpTool) {
|
||||
loadedTools.push(mcpTool);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error loading MCP tool for server ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
}
|
||||
loadedTools.push(...(await Promise.all(mcpToolPromises)).flatMap((plugin) => plugin || []));
|
||||
return { loadedTools, toolContextMap };
|
||||
};
|
||||
|
||||
|
||||
@@ -9,6 +9,27 @@ const mockPluginService = {
|
||||
|
||||
jest.mock('~/server/services/PluginService', () => mockPluginService);
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
// Default app config for tool tests
|
||||
paths: { uploads: '/tmp' },
|
||||
fileStrategy: 'local',
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
}),
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
// Default cached tools for tests
|
||||
dalle: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'dalle',
|
||||
description: 'DALL-E image generation',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
const { BaseLLM } = require('@langchain/openai');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
|
||||
|
||||
6
api/cache/cacheConfig.js
vendored
6
api/cache/cacheConfig.js
vendored
@@ -52,7 +52,11 @@ const cacheConfig = {
|
||||
REDIS_CONNECT_TIMEOUT: math(process.env.REDIS_CONNECT_TIMEOUT, 10000),
|
||||
/** Queue commands when disconnected */
|
||||
REDIS_ENABLE_OFFLINE_QUEUE: isEnabled(process.env.REDIS_ENABLE_OFFLINE_QUEUE ?? 'true'),
|
||||
|
||||
/** flag to modify redis connection by adding dnsLookup this is required when connecting to elasticache for ioredis
|
||||
* see "Special Note: Aws Elasticache Clusters with TLS" on this webpage: https://www.npmjs.com/package/ioredis **/
|
||||
REDIS_USE_ALTERNATIVE_DNS_LOOKUP: isEnabled(process.env.REDIS_USE_ALTERNATIVE_DNS_LOOKUP),
|
||||
/** Enable redis cluster without the need of multiple URIs */
|
||||
USE_REDIS_CLUSTER: isEnabled(process.env.USE_REDIS_CLUSTER ?? 'false'),
|
||||
CI: isEnabled(process.env.CI),
|
||||
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
|
||||
|
||||
|
||||
36
api/cache/cacheConfig.spec.js
vendored
36
api/cache/cacheConfig.spec.js
vendored
@@ -14,6 +14,7 @@ describe('cacheConfig', () => {
|
||||
delete process.env.REDIS_KEY_PREFIX_VAR;
|
||||
delete process.env.REDIS_KEY_PREFIX;
|
||||
delete process.env.USE_REDIS;
|
||||
delete process.env.USE_REDIS_CLUSTER;
|
||||
delete process.env.REDIS_PING_INTERVAL;
|
||||
delete process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES;
|
||||
|
||||
@@ -101,6 +102,38 @@ describe('cacheConfig', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('USE_REDIS_CLUSTER configuration', () => {
|
||||
test('should default to false when USE_REDIS_CLUSTER is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(false);
|
||||
});
|
||||
|
||||
test('should be false when USE_REDIS_CLUSTER is set to false', () => {
|
||||
process.env.USE_REDIS_CLUSTER = 'false';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(false);
|
||||
});
|
||||
|
||||
test('should be true when USE_REDIS_CLUSTER is set to true', () => {
|
||||
process.env.USE_REDIS_CLUSTER = 'true';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(true);
|
||||
});
|
||||
|
||||
test('should work with USE_REDIS enabled and REDIS_URI set', () => {
|
||||
process.env.USE_REDIS_CLUSTER = 'true';
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = 'redis://localhost:6379';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(true);
|
||||
expect(cacheConfig.USE_REDIS).toBe(true);
|
||||
expect(cacheConfig.REDIS_URI).toBe('redis://localhost:6379');
|
||||
});
|
||||
});
|
||||
|
||||
describe('REDIS_CA file reading', () => {
|
||||
test('should be null when REDIS_CA is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
@@ -124,12 +157,11 @@ describe('cacheConfig', () => {
|
||||
|
||||
describe('FORCED_IN_MEMORY_CACHE_NAMESPACES validation', () => {
|
||||
test('should parse comma-separated cache keys correctly', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, STATIC_CONFIG ,MESSAGES ';
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, MESSAGES ';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([
|
||||
'ROLES',
|
||||
'STATIC_CONFIG',
|
||||
'MESSAGES',
|
||||
]);
|
||||
});
|
||||
|
||||
3
api/cache/getLogStores.js
vendored
3
api/cache/getLogStores.js
vendored
@@ -31,9 +31,8 @@ const namespaces = {
|
||||
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
|
||||
|
||||
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
|
||||
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
|
||||
[CacheKeys.APP_CONFIG]: standardCache(CacheKeys.APP_CONFIG),
|
||||
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
|
||||
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
|
||||
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
|
||||
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
|
||||
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
|
||||
|
||||
48
api/cache/redisClients.js
vendored
48
api/cache/redisClients.js
vendored
@@ -38,7 +38,7 @@ if (cacheConfig.USE_REDIS) {
|
||||
const targetError = 'READONLY';
|
||||
if (err.message.includes(targetError)) {
|
||||
logger.warn('ioredis reconnecting due to READONLY error');
|
||||
return true;
|
||||
return 2; // Return retry delay instead of boolean
|
||||
}
|
||||
return false;
|
||||
},
|
||||
@@ -48,26 +48,32 @@ if (cacheConfig.USE_REDIS) {
|
||||
};
|
||||
|
||||
ioredisClient =
|
||||
urls.length === 1
|
||||
urls.length === 1 && !cacheConfig.USE_REDIS_CLUSTER
|
||||
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
|
||||
: new IoRedis.Cluster(cacheConfig.REDIS_URI, {
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
: new IoRedis.Cluster(
|
||||
urls.map((url) => ({ host: url.hostname, port: parseInt(url.port, 10) || 6379 })),
|
||||
{
|
||||
...(cacheConfig.REDIS_USE_ALTERNATIVE_DNS_LOOKUP
|
||||
? { dnsLookup: (address, callback) => callback(null, address) }
|
||||
: {}),
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
});
|
||||
);
|
||||
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error('ioredis client error:', err);
|
||||
@@ -145,10 +151,10 @@ if (cacheConfig.USE_REDIS) {
|
||||
};
|
||||
|
||||
keyvRedisClient =
|
||||
urls.length === 1
|
||||
urls.length === 1 && !cacheConfig.USE_REDIS_CLUSTER
|
||||
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
|
||||
: createCluster({
|
||||
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
|
||||
rootNodes: urls.map((url) => ({ url: url.href })),
|
||||
defaults: redisOptions,
|
||||
});
|
||||
|
||||
|
||||
@@ -1,27 +1,13 @@
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
|
||||
/** @type {MCPManager} */
|
||||
let mcpManager = null;
|
||||
let flowManager = null;
|
||||
|
||||
/**
|
||||
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
|
||||
* @returns {MCPManager}
|
||||
*/
|
||||
function getMCPManager(userId) {
|
||||
if (!mcpManager) {
|
||||
mcpManager = MCPManager.getInstance();
|
||||
} else {
|
||||
mcpManager.checkIdleConnections(userId);
|
||||
}
|
||||
return mcpManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Keyv} flowsCache
|
||||
* @returns {FlowStateManager}
|
||||
@@ -37,6 +23,7 @@ function getFlowStateManager(flowsCache) {
|
||||
|
||||
module.exports = {
|
||||
logger,
|
||||
getMCPManager,
|
||||
createMCPManager: MCPManager.createInstance,
|
||||
getMCPManager: MCPManager.getInstance,
|
||||
getFlowStateManager,
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ const mongoose = require('mongoose');
|
||||
const crypto = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
const {
|
||||
removeAgentFromAllProjects,
|
||||
@@ -78,6 +78,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const toolName of Object.keys(availableTools)) {
|
||||
if (!toolName.includes(mcp_delimiter)) {
|
||||
@@ -85,9 +86,17 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
}
|
||||
const mcpServer = toolName.split(mcp_delimiter)?.[1];
|
||||
if (mcpServer && mcpServers.has(mcpServer)) {
|
||||
addedServers.add(mcpServer);
|
||||
tools.push(toolName);
|
||||
}
|
||||
}
|
||||
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
@@ -672,7 +681,7 @@ const getListAgents = async (searchParameter) => {
|
||||
* This function also updates the corresponding projects to include or exclude the agent ID.
|
||||
*
|
||||
* @param {Object} params - Parameters for updating the agent's projects.
|
||||
* @param {MongoUser} params.user - Parameters for updating the agent's projects.
|
||||
* @param {IUser} params.user - Parameters for updating the agent's projects.
|
||||
* @param {string} params.agentId - The ID of the agent to update.
|
||||
* @param {string[]} [params.projectIds] - Array of project IDs to add to the agent.
|
||||
* @param {string[]} [params.removeProjectIds] - Array of project IDs to remove from the agent.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
@@ -102,8 +101,8 @@ module.exports = {
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||
const appConfig = req.config;
|
||||
update.expiredAt = createTempChatExpirationDate(appConfig?.interfaceConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
|
||||
@@ -113,8 +112,17 @@ module.exports = {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
|
||||
/** @type {{ $set: Partial<TConversation>; $addToSet?: Record<string, any>; $unset?: Record<keyof TConversation, number> }} */
|
||||
const updateOperation = { $set: update };
|
||||
|
||||
if (convo.model && convo.endpoint) {
|
||||
updateOperation.$addToSet = {
|
||||
modelHistory: {
|
||||
model: convo.model,
|
||||
endpoint: convo.endpoint,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
|
||||
updateOperation.$unset = metadata.unsetFields;
|
||||
}
|
||||
|
||||
@@ -13,9 +13,8 @@ const {
|
||||
saveConvo,
|
||||
getConvo,
|
||||
} = require('./Conversation');
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
jest.mock('~/server/services/Config/app');
|
||||
jest.mock('./Message');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
|
||||
const { Conversation } = require('~/db/models');
|
||||
@@ -50,6 +49,11 @@ describe('Conversation Operations', () => {
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {},
|
||||
config: {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 24, // Default 24 hours
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockConversationData = {
|
||||
@@ -118,12 +122,8 @@ describe('Conversation Operations', () => {
|
||||
|
||||
describe('isTemporary conversation handling', () => {
|
||||
it('should save a conversation with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
// Mock app config with 24 hour retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 24;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -167,12 +167,8 @@ describe('Conversation Operations', () => {
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
// Mock app config with 48 hour retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 48;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -194,12 +190,8 @@ describe('Conversation Operations', () => {
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
// Mock app config with less than minimum retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 0.5; // Half hour - should be clamped to 1 hour
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -221,12 +213,8 @@ describe('Conversation Operations', () => {
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
// Mock app config with more than maximum retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 10000; // Should be clamped to 8760 hours
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -247,22 +235,36 @@ describe('Conversation Operations', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
it('should handle missing config gracefully', async () => {
|
||||
// Simulate missing config - should use default retention period
|
||||
delete mockReq.config;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
const afterSave = new Date();
|
||||
|
||||
// Should still save the conversation but with expiredAt as null
|
||||
// Should still save the conversation with default retention period (30 days)
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 30 days in the future (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 720 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 720 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
// Mock getAppConfig to return empty config
|
||||
mockReq.config = {}; // Empty config
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -285,11 +287,7 @@ describe('Conversation Operations', () => {
|
||||
|
||||
it('should update expiredAt when saving existing temporary conversation', async () => {
|
||||
// First save a temporary conversation
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 24;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
@@ -211,7 +211,67 @@ describe('File Access Control', () => {
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when user only has VIEW permission', async () => {
|
||||
it('should deny access when user only has VIEW permission and needs access for deletion', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
isDelete: true,
|
||||
});
|
||||
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access when user has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
@@ -265,9 +325,8 @@ describe('File Access Control', () => {
|
||||
agentId,
|
||||
});
|
||||
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
@@ -11,7 +10,7 @@ const idSchema = z.string().uuid();
|
||||
*
|
||||
* @async
|
||||
* @function saveMessage
|
||||
* @param {Express.Request} req - The request object containing user information.
|
||||
* @param {ServerRequest} req - The request object containing user information.
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.iconURL - The URL of the sender's icon.
|
||||
@@ -57,8 +56,8 @@ async function saveMessage(req, params, metadata) {
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||
const appConfig = req.config;
|
||||
update.expiredAt = createTempChatExpirationDate(appConfig?.interfaceConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
|
||||
@@ -13,8 +13,7 @@ const {
|
||||
deleteMessagesSince,
|
||||
} = require('./Message');
|
||||
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
jest.mock('~/server/services/Config/app');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IMessage>}
|
||||
@@ -44,6 +43,11 @@ describe('Message Operations', () => {
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
config: {
|
||||
interfaceConfig: {
|
||||
temporaryChatRetention: 24, // Default 24 hours
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockMessageData = {
|
||||
@@ -326,12 +330,8 @@ describe('Message Operations', () => {
|
||||
});
|
||||
|
||||
it('should save a message with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
// Mock app config with 24 hour retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 24;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -375,12 +375,8 @@ describe('Message Operations', () => {
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
// Mock app config with 48 hour retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 48;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -402,12 +398,8 @@ describe('Message Operations', () => {
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
// Mock app config with less than minimum retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 0.5; // Half hour - should be clamped to 1 hour
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -429,12 +421,8 @@ describe('Message Operations', () => {
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
// Mock app config with more than maximum retention
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 10000; // Should be clamped to 8760 hours
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -455,22 +443,36 @@ describe('Message Operations', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
it('should handle missing config gracefully', async () => {
|
||||
// Simulate missing config - should use default retention period
|
||||
delete mockReq.config;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
const afterSave = new Date();
|
||||
|
||||
// Should still save the message but with expiredAt as null
|
||||
// Should still save the message with default retention period (30 days)
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 30 days in the future (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 720 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 720 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
// Mock getAppConfig to return empty config
|
||||
mockReq.config = {}; // Empty config
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
@@ -493,11 +495,7 @@ describe('Message Operations', () => {
|
||||
|
||||
it('should not update expiredAt on message update', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 24;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const savedMessage = await saveMessage(mockReq, mockMessageData);
|
||||
@@ -520,11 +518,7 @@ describe('Message Operations', () => {
|
||||
|
||||
it('should preserve expiredAt when saving existing temporary message', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
mockReq.config.interfaceConfig.temporaryChatRetention = 24;
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
@@ -269,7 +269,7 @@ async function getListPromptGroupsByAccess({
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
if (after && typeof after === 'string' && after !== 'undefined' && after !== 'null') {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { Transaction, Balance } = require('~/db/models');
|
||||
|
||||
@@ -187,20 +186,23 @@ async function createAutoRefillTransaction(txData) {
|
||||
|
||||
/**
|
||||
* Static method to create a transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createTransaction(txData) {
|
||||
async function createTransaction(_txData) {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
calculateTokenValue(transaction);
|
||||
|
||||
await transaction.save();
|
||||
|
||||
const balance = await getBalanceConfig();
|
||||
if (!balance?.enabled) {
|
||||
return;
|
||||
}
|
||||
@@ -221,9 +223,14 @@ async function createTransaction(txData) {
|
||||
|
||||
/**
|
||||
* Static method to create a structured transaction and update the balance
|
||||
* @param {txData} txData - Transaction data.
|
||||
* @param {txData} _txData - Transaction data.
|
||||
*/
|
||||
async function createStructuredTransaction(txData) {
|
||||
async function createStructuredTransaction(_txData) {
|
||||
const { balance, transactions, ...txData } = _txData;
|
||||
if (transactions?.enabled === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction({
|
||||
...txData,
|
||||
endpointTokenConfig: txData.endpointTokenConfig,
|
||||
@@ -233,7 +240,6 @@ async function createStructuredTransaction(txData) {
|
||||
|
||||
await transaction.save();
|
||||
|
||||
const balance = await getBalanceConfig();
|
||||
if (!balance?.enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
const { createTransaction } = require('./Transaction');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
// Mock the custom config module so we can control the balance flag.
|
||||
jest.mock('~/server/services/Config');
|
||||
const { createTransaction, createStructuredTransaction } = require('./Transaction');
|
||||
const { Balance, Transaction } = require('~/db/models');
|
||||
|
||||
let mongoServer;
|
||||
beforeAll(async () => {
|
||||
@@ -23,8 +19,6 @@ afterAll(async () => {
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
// Default: enable balance updates in tests.
|
||||
getBalanceConfig.mockResolvedValue({ enabled: true });
|
||||
});
|
||||
|
||||
describe('Regular Token Spending Tests', () => {
|
||||
@@ -41,6 +35,7 @@ describe('Regular Token Spending Tests', () => {
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -74,6 +69,7 @@ describe('Regular Token Spending Tests', () => {
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -104,6 +100,7 @@ describe('Regular Token Spending Tests', () => {
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {};
|
||||
@@ -128,6 +125,7 @@ describe('Regular Token Spending Tests', () => {
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = { promptTokens: 100 };
|
||||
@@ -143,8 +141,7 @@ describe('Regular Token Spending Tests', () => {
|
||||
});
|
||||
|
||||
test('spendTokens should not update balance when balance feature is disabled', async () => {
|
||||
// Arrange: Override the config to disable balance updates.
|
||||
getBalanceConfig.mockResolvedValue({ balance: { enabled: false } });
|
||||
// Arrange: Balance config is now passed directly in txData
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
@@ -156,6 +153,7 @@ describe('Regular Token Spending Tests', () => {
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -186,6 +184,7 @@ describe('Structured Token Spending Tests', () => {
|
||||
model,
|
||||
context: 'message',
|
||||
endpointTokenConfig: null,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -239,6 +238,7 @@ describe('Structured Token Spending Tests', () => {
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -271,6 +271,7 @@ describe('Structured Token Spending Tests', () => {
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -302,6 +303,7 @@ describe('Structured Token Spending Tests', () => {
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'message',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {};
|
||||
@@ -328,6 +330,7 @@ describe('Structured Token Spending Tests', () => {
|
||||
conversationId: 'test-convo',
|
||||
model,
|
||||
context: 'incomplete',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -364,6 +367,7 @@ describe('NaN Handling Tests', () => {
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: NaN,
|
||||
tokenType: 'prompt',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Act
|
||||
@@ -375,3 +379,188 @@ describe('NaN Handling Tests', () => {
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Transactions Config Tests', () => {
|
||||
test('createTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createTransaction should save when transactions.enabled is true', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created
|
||||
expect(result).toBeDefined();
|
||||
expect(result.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
});
|
||||
|
||||
test('createTransaction should save when balance.enabled is true even if transactions config is missing', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
balance: { enabled: true },
|
||||
// No transactions config provided
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created (backward compatibility)
|
||||
expect(result).toBeDefined();
|
||||
expect(result.balance).toBeLessThan(initialBalance);
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
});
|
||||
|
||||
test('createTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: -100,
|
||||
tokenType: 'prompt',
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].rawAmount).toBe(-100);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should not save when transactions.enabled is false', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: No transaction should be created
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(0);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
|
||||
test('createStructuredTransaction should save transaction but not update balance when balance is disabled but transactions enabled', async () => {
|
||||
// Arrange
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'claude-3-5-sonnet';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'message',
|
||||
tokenType: 'prompt',
|
||||
inputTokens: -10,
|
||||
writeTokens: -100,
|
||||
readTokens: -5,
|
||||
transactions: { enabled: true },
|
||||
balance: { enabled: false },
|
||||
};
|
||||
|
||||
// Act
|
||||
const result = await createStructuredTransaction(txData);
|
||||
|
||||
// Assert: Transaction should be created but balance unchanged
|
||||
expect(result).toBeUndefined();
|
||||
const transactions = await Transaction.find({ user: userId });
|
||||
expect(transactions).toHaveLength(1);
|
||||
expect(transactions[0].inputTokens).toBe(-10);
|
||||
expect(transactions[0].writeTokens).toBe(-100);
|
||||
expect(transactions[0].readTokens).toBe(-5);
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -118,7 +118,7 @@ const addIntervalToDate = (date, value, unit) => {
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {Express.Request} params.req - The Express request object.
|
||||
* @param {ServerRequest} params.req - The Express request object.
|
||||
* @param {Express.Response} params.res - The Express response object.
|
||||
* @param {Object} params.txData - The transaction data.
|
||||
* @param {string} params.txData.user - The user ID or identifier.
|
||||
|
||||
@@ -1,47 +1,9 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { buildTree } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { getMessages, bulkSaveMessages } = require('./Message');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
// Original version of buildTree function
|
||||
function buildTree({ messages, fileMap }) {
|
||||
if (messages === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const messageMap = {};
|
||||
const rootMessages = [];
|
||||
const childrenCount = {};
|
||||
|
||||
messages.forEach((message) => {
|
||||
const parentId = message.parentMessageId ?? '';
|
||||
childrenCount[parentId] = (childrenCount[parentId] || 0) + 1;
|
||||
|
||||
const extendedMessage = {
|
||||
...message,
|
||||
children: [],
|
||||
depth: 0,
|
||||
siblingIndex: childrenCount[parentId] - 1,
|
||||
};
|
||||
|
||||
if (message.files && fileMap) {
|
||||
extendedMessage.files = message.files.map((file) => fileMap[file.file_id ?? ''] ?? file);
|
||||
}
|
||||
|
||||
messageMap[message.messageId] = extendedMessage;
|
||||
|
||||
const parentMessage = messageMap[parentId];
|
||||
if (parentMessage) {
|
||||
parentMessage.children.push(extendedMessage);
|
||||
extendedMessage.depth = parentMessage.depth + 1;
|
||||
} else {
|
||||
rootMessages.push(extendedMessage);
|
||||
}
|
||||
});
|
||||
|
||||
return rootMessages;
|
||||
}
|
||||
|
||||
let mongod;
|
||||
beforeAll(async () => {
|
||||
mongod = await MongoMemoryServer.create();
|
||||
|
||||
@@ -24,8 +24,15 @@ const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversa
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
const seedDatabase = async () => {
|
||||
await methods.initializeRoles();
|
||||
await methods.seedDefaultRoles();
|
||||
await methods.ensureDefaultCategories();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
...methods,
|
||||
seedDatabase,
|
||||
comparePassword,
|
||||
findFileById,
|
||||
createFile,
|
||||
|
||||
24
api/models/interface.js
Normal file
24
api/models/interface.js
Normal file
@@ -0,0 +1,24 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { updateInterfacePermissions: updateInterfacePerms } = require('@librechat/api');
|
||||
const { getRoleByName, updateAccessPermissions } = require('./Role');
|
||||
|
||||
/**
|
||||
* Update interface permissions based on app configuration.
|
||||
* Must be done independently from loading the app config.
|
||||
* @param {AppConfig} appConfig
|
||||
*/
|
||||
async function updateInterfacePermissions(appConfig) {
|
||||
try {
|
||||
await updateInterfacePerms({
|
||||
appConfig,
|
||||
getRoleByName,
|
||||
updateAccessPermissions,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error updating interface permissions:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
updateInterfacePermissions,
|
||||
};
|
||||
@@ -5,13 +5,7 @@ const { createTransaction, createStructuredTransaction } = require('./Transactio
|
||||
*
|
||||
* @function
|
||||
* @async
|
||||
* @param {Object} txData - Transaction data.
|
||||
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
|
||||
* @param {String} txData.conversationId - The ID of the conversation.
|
||||
* @param {String} txData.model - The model name.
|
||||
* @param {String} txData.context - The context in which the transaction is made.
|
||||
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
|
||||
* @param {String} [txData.valueKey] - The value key (optional).
|
||||
* @param {txData} txData - Transaction data.
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
|
||||
@@ -69,13 +63,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
||||
*
|
||||
* @function
|
||||
* @async
|
||||
* @param {Object} txData - Transaction data.
|
||||
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
|
||||
* @param {String} txData.conversationId - The ID of the conversation.
|
||||
* @param {String} txData.model - The model name.
|
||||
* @param {String} txData.context - The context in which the transaction is made.
|
||||
* @param {EndpointTokenConfig} [txData.endpointTokenConfig] - The current endpoint token config.
|
||||
* @param {String} [txData.valueKey] - The value key (optional).
|
||||
* @param {txData} txData - Transaction data.
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Object} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens.input - The number of input tokens.
|
||||
|
||||
@@ -5,7 +5,6 @@ const { createTransaction, createAutoRefillTransaction } = require('./Transactio
|
||||
|
||||
require('~/db/models');
|
||||
|
||||
// Mock the logger to prevent console output during tests
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
@@ -13,10 +12,6 @@ jest.mock('~/config', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock the Config service
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
jest.mock('~/server/services/Config');
|
||||
|
||||
describe('spendTokens', () => {
|
||||
let mongoServer;
|
||||
let userId;
|
||||
@@ -44,8 +39,7 @@ describe('spendTokens', () => {
|
||||
// Create a new user ID for each test
|
||||
userId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Mock the balance config to be enabled by default
|
||||
getBalanceConfig.mockResolvedValue({ enabled: true });
|
||||
// Balance config is now passed directly in txData
|
||||
});
|
||||
|
||||
it('should create transactions for both prompt and completion tokens', async () => {
|
||||
@@ -60,6 +54,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
@@ -98,6 +93,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
@@ -127,6 +123,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
const tokenUsage = {};
|
||||
|
||||
@@ -138,8 +135,7 @@ describe('spendTokens', () => {
|
||||
});
|
||||
|
||||
it('should not update balance when the balance feature is disabled', async () => {
|
||||
// Override configuration: disable balance updates
|
||||
getBalanceConfig.mockResolvedValue({ enabled: false });
|
||||
// Balance is now passed directly in txData
|
||||
// Create a balance for the user
|
||||
await Balance.create({
|
||||
user: userId,
|
||||
@@ -151,6 +147,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'test',
|
||||
balance: { enabled: false },
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: 100,
|
||||
@@ -180,6 +177,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'gpt-4', // Using a more expensive model
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Spending more tokens than the user has balance for
|
||||
@@ -233,6 +231,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo-1',
|
||||
model: 'gpt-4',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage1 = {
|
||||
@@ -252,6 +251,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo-2',
|
||||
model: 'gpt-4',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage2 = {
|
||||
@@ -292,6 +292,7 @@ describe('spendTokens', () => {
|
||||
tokenType: 'completion',
|
||||
rawAmount: -100,
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
});
|
||||
|
||||
console.log('Direct Transaction.create result:', directResult);
|
||||
@@ -316,6 +317,7 @@ describe('spendTokens', () => {
|
||||
conversationId: `test-convo-${model}`,
|
||||
model,
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage = {
|
||||
@@ -352,6 +354,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo-1',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage1 = {
|
||||
@@ -375,6 +378,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo-2',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
const tokenUsage2 = {
|
||||
@@ -426,6 +430,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'claude-3-5-sonnet', // Using a model that supports structured tokens
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Spending more tokens than the user has balance for
|
||||
@@ -505,6 +510,7 @@ describe('spendTokens', () => {
|
||||
conversationId,
|
||||
user: userId,
|
||||
model: usage.model,
|
||||
balance: { enabled: true },
|
||||
};
|
||||
|
||||
// Calculate expected spend for this transaction
|
||||
@@ -617,6 +623,7 @@ describe('spendTokens', () => {
|
||||
tokenType: 'credits',
|
||||
context: 'concurrent-refill-test',
|
||||
rawAmount: refillAmount,
|
||||
balance: { enabled: true },
|
||||
}),
|
||||
);
|
||||
}
|
||||
@@ -683,6 +690,7 @@ describe('spendTokens', () => {
|
||||
conversationId: 'test-convo',
|
||||
model: 'claude-3-5-sonnet',
|
||||
context: 'test',
|
||||
balance: { enabled: true },
|
||||
};
|
||||
const tokenUsage = {
|
||||
promptTokens: {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { matchModelName } = require('../utils/tokens');
|
||||
const { matchModelName } = require('@librechat/api');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
|
||||
@@ -3,7 +3,7 @@ const bcrypt = require('bcryptjs');
|
||||
/**
|
||||
* Compares the provided password with the user's password.
|
||||
*
|
||||
* @param {MongoUser} user - The user to compare the password for.
|
||||
* @param {IUser} user - The user to compare the password for.
|
||||
* @param {string} candidatePassword - The password to test against the user's password.
|
||||
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
|
||||
*/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.0-rc2",
|
||||
"version": "v0.8.0-rc4",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -49,14 +49,14 @@
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.75",
|
||||
"@librechat/agents": "^2.4.79",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@modelcontextprotocol/sdk": "^1.17.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
"axios": "^1.12.1",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.8.1",
|
||||
"connect-redis": "^8.1.0",
|
||||
@@ -97,7 +97,6 @@
|
||||
"nodemailer": "^6.9.15",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "^5.10.1",
|
||||
"openai-chat-tokens": "^0.2.8",
|
||||
"openid-client": "^6.5.0",
|
||||
"passport": "^0.6.0",
|
||||
"passport-apple": "^2.0.2",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('~/config');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
|
||||
// WeakMap to hold temporary data associated with requests
|
||||
/** WeakMap to hold temporary data associated with requests */
|
||||
const requestDataMap = new WeakMap();
|
||||
|
||||
const FinalizationRegistry = global.FinalizationRegistry || null;
|
||||
@@ -23,7 +23,7 @@ const clientRegistry = FinalizationRegistry
|
||||
} else {
|
||||
logger.debug('[FinalizationRegistry] Cleaning up client');
|
||||
}
|
||||
} catch (e) {
|
||||
} catch {
|
||||
// Ignore errors
|
||||
}
|
||||
})
|
||||
@@ -55,6 +55,9 @@ function disposeClient(client) {
|
||||
if (client.responseMessageId) {
|
||||
client.responseMessageId = null;
|
||||
}
|
||||
if (client.parentMessageId) {
|
||||
client.parentMessageId = null;
|
||||
}
|
||||
if (client.message_file_map) {
|
||||
client.message_file_map = null;
|
||||
}
|
||||
@@ -334,7 +337,7 @@ function disposeClient(client) {
|
||||
}
|
||||
}
|
||||
client.options = null;
|
||||
} catch (e) {
|
||||
} catch {
|
||||
// Ignore errors during disposal
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ const refreshController = async (req, res) => {
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
const token = setOpenIDAuthTokens(tokenset, res);
|
||||
const token = setOpenIDAuthTokens(tokenset, res, user._id.toString());
|
||||
return res.status(200).send({ token, user });
|
||||
} catch (error) {
|
||||
logger.error('[refreshController] OpenID token refresh error', error);
|
||||
@@ -84,7 +84,7 @@ const refreshController = async (req, res) => {
|
||||
}
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret');
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret -backupCodes');
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { loadOverrideConfig } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
async function overrideController(req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
let overrideConfig = await cache.get(CacheKeys.OVERRIDE_CONFIG);
|
||||
if (overrideConfig) {
|
||||
res.send(overrideConfig);
|
||||
return;
|
||||
} else if (overrideConfig === false) {
|
||||
res.send(false);
|
||||
return;
|
||||
}
|
||||
overrideConfig = await loadOverrideConfig();
|
||||
const { endpointsConfig, modelsConfig } = overrideConfig;
|
||||
if (endpointsConfig) {
|
||||
await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
|
||||
}
|
||||
if (modelsConfig) {
|
||||
await cache.set(CacheKeys.MODELS_CONFIG, modelsConfig);
|
||||
}
|
||||
await cache.set(CacheKeys.OVERRIDE_CONFIG, overrideConfig);
|
||||
res.send(JSON.stringify(overrideConfig));
|
||||
}
|
||||
|
||||
module.exports = overrideController;
|
||||
@@ -364,7 +364,7 @@ const getUserEffectivePermissions = async (req, res) => {
|
||||
*/
|
||||
const searchPrincipals = async (req, res) => {
|
||||
try {
|
||||
const { q: query, limit = 20, type } = req.query;
|
||||
const { q: query, limit = 20, types } = req.query;
|
||||
|
||||
if (!query || query.trim().length === 0) {
|
||||
return res.status(400).json({
|
||||
@@ -379,22 +379,34 @@ const searchPrincipals = async (req, res) => {
|
||||
}
|
||||
|
||||
const searchLimit = Math.min(Math.max(1, parseInt(limit) || 10), 50);
|
||||
const typeFilter = [PrincipalType.USER, PrincipalType.GROUP, PrincipalType.ROLE].includes(type)
|
||||
? type
|
||||
: null;
|
||||
|
||||
const localResults = await searchLocalPrincipals(query.trim(), searchLimit, typeFilter);
|
||||
let typeFilters = null;
|
||||
if (types) {
|
||||
const typesArray = Array.isArray(types) ? types : types.split(',');
|
||||
const validTypes = typesArray.filter((t) =>
|
||||
[PrincipalType.USER, PrincipalType.GROUP, PrincipalType.ROLE].includes(t),
|
||||
);
|
||||
typeFilters = validTypes.length > 0 ? validTypes : null;
|
||||
}
|
||||
|
||||
const localResults = await searchLocalPrincipals(query.trim(), searchLimit, typeFilters);
|
||||
let allPrincipals = [...localResults];
|
||||
|
||||
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
|
||||
|
||||
if (useEntraId && localResults.length < searchLimit) {
|
||||
try {
|
||||
const graphTypeMap = {
|
||||
user: 'users',
|
||||
group: 'groups',
|
||||
null: 'all',
|
||||
};
|
||||
let graphType = 'all';
|
||||
if (typeFilters && typeFilters.length === 1) {
|
||||
const graphTypeMap = {
|
||||
[PrincipalType.USER]: 'users',
|
||||
[PrincipalType.GROUP]: 'groups',
|
||||
};
|
||||
const mappedType = graphTypeMap[typeFilters[0]];
|
||||
if (mappedType) {
|
||||
graphType = mappedType;
|
||||
}
|
||||
}
|
||||
|
||||
const authHeader = req.headers.authorization;
|
||||
const accessToken =
|
||||
@@ -405,7 +417,7 @@ const searchPrincipals = async (req, res) => {
|
||||
accessToken,
|
||||
req.user.openidId,
|
||||
query.trim(),
|
||||
graphTypeMap[typeFilter],
|
||||
graphType,
|
||||
searchLimit - localResults.length,
|
||||
);
|
||||
|
||||
@@ -436,21 +448,22 @@ const searchPrincipals = async (req, res) => {
|
||||
_searchScore: calculateRelevanceScore(item, query.trim()),
|
||||
}));
|
||||
|
||||
allPrincipals = sortPrincipalsByRelevance(scoredResults)
|
||||
const finalResults = sortPrincipalsByRelevance(scoredResults)
|
||||
.slice(0, searchLimit)
|
||||
.map((result) => {
|
||||
const { _searchScore, ...resultWithoutScore } = result;
|
||||
return resultWithoutScore;
|
||||
});
|
||||
|
||||
res.status(200).json({
|
||||
query: query.trim(),
|
||||
limit: searchLimit,
|
||||
type: typeFilter,
|
||||
results: allPrincipals,
|
||||
count: allPrincipals.length,
|
||||
types: typeFilters,
|
||||
results: finalResults,
|
||||
count: finalResults.length,
|
||||
sources: {
|
||||
local: allPrincipals.filter((r) => r.source === 'local').length,
|
||||
entra: allPrincipals.filter((r) => r.source === 'entra').length,
|
||||
local: finalResults.filter((r) => r.source === 'local').length,
|
||||
entra: finalResults.filter((r) => r.source === 'entra').length,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
|
||||
@@ -4,11 +4,13 @@ const {
|
||||
getToolkitKey,
|
||||
checkPluginAuth,
|
||||
filterUniquePlugins,
|
||||
convertMCPToolToPlugin,
|
||||
convertMCPToolsToPlugins,
|
||||
} = require('@librechat/api');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getCachedTools, setCachedTools, mergeUserTools } = require('~/server/services/Config');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const getAvailablePluginsController = async (req, res) => {
|
||||
@@ -20,8 +22,10 @@ const getAvailablePluginsController = async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = req.app.locals;
|
||||
const { filteredTools = [], includedTools = [] } = appConfig;
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
const pluginManifest = availableTools;
|
||||
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
@@ -47,45 +51,6 @@ const getAvailablePluginsController = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
function createServerToolsCallback() {
|
||||
/**
|
||||
* @param {string} serverName
|
||||
* @param {TPlugin[] | null} serverTools
|
||||
*/
|
||||
return async function (serverName, serverTools) {
|
||||
try {
|
||||
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
|
||||
if (!serverName || !mcpToolsCache) {
|
||||
return;
|
||||
}
|
||||
await mcpToolsCache.set(serverName, serverTools);
|
||||
logger.debug(`MCP tools for ${serverName} added to cache.`);
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving MCP tools from cache:', error);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createGetServerTools() {
|
||||
/**
|
||||
* Retrieves cached server tools
|
||||
* @param {string} serverName
|
||||
* @returns {Promise<TPlugin[] | null>}
|
||||
*/
|
||||
return async function (serverName) {
|
||||
try {
|
||||
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
|
||||
if (!mcpToolsCache) {
|
||||
return null;
|
||||
}
|
||||
return await mcpToolsCache.get(serverName);
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving MCP tools from cache:', error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file.
|
||||
*
|
||||
@@ -101,37 +66,71 @@ function createGetServerTools() {
|
||||
const getAvailableTools = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user?.id;
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!userId) {
|
||||
logger.warn('[getAvailableTools] User ID not found in request');
|
||||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
const userPlugins = convertMCPToolsToPlugins({ functionTools: cachedUserTools, customConfig });
|
||||
|
||||
if (cachedToolsArray != null && userPlugins != null) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role }));
|
||||
|
||||
/** @type {TPlugin[]} */
|
||||
let mcpPlugins;
|
||||
if (appConfig?.mcpConfig) {
|
||||
const mcpManager = getMCPManager();
|
||||
mcpPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, mcpManager })
|
||||
: undefined;
|
||||
}
|
||||
|
||||
if (
|
||||
cachedToolsArray != null &&
|
||||
(appConfig?.mcpConfig != null ? mcpPlugins != null && mcpPlugins.length > 0 : true)
|
||||
) {
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
return;
|
||||
}
|
||||
|
||||
// If not in cache, build from manifest
|
||||
/** @type {Record<string, FunctionTool> | null} Get tool definitions to filter which tools are actually available */
|
||||
let toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
let prelimCachedTools;
|
||||
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
let pluginManifest = availableTools;
|
||||
if (customConfig?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
|
||||
const serverToolsCallback = createServerToolsCallback();
|
||||
const getServerTools = createGetServerTools();
|
||||
const mcpTools = await mcpManager.loadManifestTools({
|
||||
flowManager,
|
||||
serverToolsCallback,
|
||||
getServerTools,
|
||||
});
|
||||
pluginManifest = [...mcpTools, ...pluginManifest];
|
||||
|
||||
if (appConfig?.mcpConfig != null) {
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const mcpTools = await mcpManager.getAllToolFunctions(userId);
|
||||
prelimCachedTools = prelimCachedTools ?? {};
|
||||
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
|
||||
const plugin = convertMCPToolToPlugin({
|
||||
toolKey,
|
||||
toolData,
|
||||
mcpManager,
|
||||
});
|
||||
if (plugin) {
|
||||
pluginManifest.push(plugin);
|
||||
}
|
||||
prelimCachedTools[toolKey] = toolData;
|
||||
}
|
||||
await mergeUserTools({ userId, cachedUserTools, userTools: prelimCachedTools });
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[getAvailableTools] Error loading MCP Tools, servers may still be initializing:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
} else if (prelimCachedTools != null) {
|
||||
await setCachedTools(prelimCachedTools, { isGlobal: true });
|
||||
}
|
||||
|
||||
/** @type {TPlugin[]} */
|
||||
/** @type {TPlugin[]} Deduplicate and authenticate plugins */
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
|
||||
const authenticatedPlugins = uniquePlugins.map((plugin) => {
|
||||
if (checkPluginAuth(plugin)) {
|
||||
return { ...plugin, authenticated: true };
|
||||
@@ -140,8 +139,7 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
|
||||
|
||||
/** Filter plugins based on availability and add MCP-specific auth config */
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
|
||||
@@ -157,41 +155,36 @@ const getAvailableTools = async (req, res) => {
|
||||
|
||||
const toolToAdd = { ...plugin };
|
||||
|
||||
if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) {
|
||||
toolsOutput.push(toolToAdd);
|
||||
continue;
|
||||
}
|
||||
if (plugin.pluginKey.includes(Constants.mcp_delimiter)) {
|
||||
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
const serverConfig = appConfig?.mcpConfig?.[serverName];
|
||||
|
||||
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
const serverConfig = customConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!serverConfig?.customUserVars) {
|
||||
toolsOutput.push(toolToAdd);
|
||||
continue;
|
||||
}
|
||||
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
|
||||
if (customVarKeys.length === 0) {
|
||||
toolToAdd.authConfig = [];
|
||||
toolToAdd.authenticated = true;
|
||||
} else {
|
||||
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}));
|
||||
toolToAdd.authenticated = false;
|
||||
if (serverConfig?.customUserVars) {
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
if (customVarKeys.length === 0) {
|
||||
toolToAdd.authConfig = [];
|
||||
toolToAdd.authenticated = true;
|
||||
} else {
|
||||
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(
|
||||
([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}),
|
||||
);
|
||||
toolToAdd.authenticated = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolsOutput.push(toolToAdd);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...finalTools]);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...(mcpPlugins ?? []), ...finalTools]);
|
||||
res.status(200).json(dedupedTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getCachedTools, getAppConfig } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
// Mock the dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomConfig: jest.fn(),
|
||||
getCachedTools: jest.fn(),
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
}),
|
||||
setCachedTools: jest.fn(),
|
||||
mergeUserTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ToolService', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
}));
|
||||
// loadAndFormatTools mock removed - no longer used in PluginController
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(() => ({
|
||||
loadManifestTools: jest.fn().mockResolvedValue([]),
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
})),
|
||||
getFlowStateManager: jest.fn(),
|
||||
}));
|
||||
@@ -35,71 +39,87 @@ jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
checkPluginAuth: jest.fn(),
|
||||
filterUniquePlugins: jest.fn(),
|
||||
convertMCPToolsToPlugins: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import the actual module with the function we want to test
|
||||
const { getAvailableTools, getAvailablePluginsController } = require('./PluginController');
|
||||
const {
|
||||
filterUniquePlugins,
|
||||
checkPluginAuth,
|
||||
convertMCPToolsToPlugins,
|
||||
getToolkitKey,
|
||||
} = require('@librechat/api');
|
||||
|
||||
describe('PluginController', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = { user: { id: 'test-user-id' } };
|
||||
mockReq = {
|
||||
user: { id: 'test-user-id' },
|
||||
config: {
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
},
|
||||
};
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
|
||||
// Clear availableTools and toolkits arrays before each test
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
require('~/app/clients/tools').toolkits.length = 0;
|
||||
|
||||
// Reset getCachedTools mock to ensure clean state
|
||||
getCachedTools.mockReset();
|
||||
|
||||
// Reset getAppConfig mock to ensure clean state with default values
|
||||
getAppConfig.mockReset();
|
||||
getAppConfig.mockResolvedValue({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAvailablePluginsController', () => {
|
||||
beforeEach(() => {
|
||||
mockReq.app = { locals: { filteredTools: [], includedTools: [] } };
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to remove duplicate plugins', async () => {
|
||||
// Add plugins with duplicates to availableTools
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First duplicate' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
// Configure getAppConfig to return the expected config
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).toHaveBeenCalled();
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
// The response includes authenticated: true for each plugin when checkPluginAuth returns true
|
||||
expect(mockRes.json).toHaveBeenCalledWith([
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First', authenticated: true },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second', authenticated: true },
|
||||
]);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// The real filterUniquePlugins should have removed the duplicate
|
||||
expect(responseData).toHaveLength(2);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
expect(responseData[1].pluginKey).toBe('key2');
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify plugin authentication', async () => {
|
||||
// checkPluginAuth returns false for plugins without authConfig
|
||||
// so authenticated property won't be added
|
||||
const mockPlugin = { name: 'Plugin1', pluginKey: 'key1', description: 'First' };
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValueOnce(true);
|
||||
|
||||
// Configure getAppConfig to return the expected config
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
// The real checkPluginAuth returns false for plugins without authConfig, so authenticated property is not added
|
||||
expect(responseData[0].authenticated).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return cached plugins when available', async () => {
|
||||
@@ -111,8 +131,7 @@ describe('PluginController', () => {
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).not.toHaveBeenCalled();
|
||||
expect(checkPluginAuth).not.toHaveBeenCalled();
|
||||
// When cache is hit, we return immediately without processing
|
||||
expect(mockRes.json).toHaveBeenCalledWith(cachedPlugins);
|
||||
});
|
||||
|
||||
@@ -122,10 +141,14 @@ describe('PluginController', () => {
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
mockReq.app.locals.includedTools = ['key1'];
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
// Configure getAppConfig to return config with includedTools
|
||||
getAppConfig.mockResolvedValueOnce({
|
||||
filteredTools: [],
|
||||
includedTools: ['key1'],
|
||||
});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
@@ -139,70 +162,126 @@ describe('PluginController', () => {
|
||||
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}server1`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
const mockConvertedPlugins = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
},
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockConvertedPlugins);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially (since getAllToolFunctions is called)
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// Mock second call to return tool definitions (includeGlobal: true)
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: mockUserTools,
|
||||
customConfig: null,
|
||||
});
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toBeDefined();
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
expect(responseData.length).toBeGreaterThan(0);
|
||||
const convertedTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}server1`,
|
||||
);
|
||||
expect(convertedTool).toBeDefined();
|
||||
// The real convertMCPToolsToPlugins extracts the name from the delimiter
|
||||
expect(convertedTool.name).toBe('tool1');
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
|
||||
const mockUserPlugins = [
|
||||
{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' },
|
||||
];
|
||||
const mockManifestPlugins = [
|
||||
const mockUserTools = {
|
||||
'user-tool': {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'user-tool',
|
||||
description: 'User tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockCachedPlugins = [
|
||||
{ name: 'user-tool', pluginKey: 'user-tool', description: 'Duplicate user tool' },
|
||||
{ name: 'ManifestTool', pluginKey: 'manifest-tool', description: 'Manifest tool' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(mockManifestPlugins);
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockUserPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...mockUserPlugins, ...mockManifestPlugins]);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
mockCache.get.mockResolvedValue(mockCachedPlugins);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should be called to deduplicate the combined array
|
||||
expect(filterUniquePlugins).toHaveBeenLastCalledWith([
|
||||
...mockUserPlugins,
|
||||
...mockManifestPlugins,
|
||||
]);
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
// The real filterUniquePlugins should have deduplicated tools with same pluginKey
|
||||
const userToolCount = responseData.filter((tool) => tool.pluginKey === 'user-tool').length;
|
||||
expect(userToolCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify authentication status', async () => {
|
||||
const mockPlugin = { name: 'Tool1', pluginKey: 'tool1', description: 'Tool 1' };
|
||||
// Add a plugin to availableTools that will be checked
|
||||
const mockPlugin = {
|
||||
name: 'Tool1',
|
||||
pluginKey: 'tool1',
|
||||
description: 'Tool 1',
|
||||
// No authConfig means checkPluginAuth returns false
|
||||
};
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
// First call returns null for user tools
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({ tool1: true });
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
tool1: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'tool1',
|
||||
description: 'Tool 1',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
const tool = responseData.find((t) => t.pluginKey === 'tool1');
|
||||
expect(tool).toBeDefined();
|
||||
// The real checkPluginAuth returns false for plugins without authConfig, so authenticated property is not added
|
||||
expect(tool.authenticated).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should use getToolkitKey for toolkit validation', async () => {
|
||||
@@ -213,83 +292,106 @@ describe('PluginController', () => {
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue('toolkit1');
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({
|
||||
toolkit1_function: true,
|
||||
// Mock toolkits to have a mapping
|
||||
require('~/app/clients/tools').toolkits.push({
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
tools: ['toolkit1_function'],
|
||||
});
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// First call returns null for user tools
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
toolkit1_function: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'toolkit1_function',
|
||||
description: 'Toolkit function',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(getToolkitKey).toHaveBeenCalled();
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
const toolkit = responseData.find((t) => t.pluginKey === 'toolkit1');
|
||||
expect(toolkit).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('plugin.icon behavior', () => {
|
||||
const callGetAvailableToolsWithMCPServer = async (mcpServers) => {
|
||||
const callGetAvailableToolsWithMCPServer = async (serverConfig) => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue({ mcpServers });
|
||||
|
||||
const functionTools = {
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'test-tool', description: 'A test tool' },
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockConvertedPlugin = {
|
||||
name: 'test-tool',
|
||||
pluginKey: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
icon: mcpServers['test-server']?.iconPath,
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
// Mock the MCP manager to return tools and server config
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(functionTools),
|
||||
getRawConfig: jest.fn().mockReturnValue(serverConfig),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// First call returns empty user tools
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock getAppConfig to return the mcpConfig
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
'test-server': serverConfig,
|
||||
},
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(functionTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockConvertedPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
return responseData.find((tool) => tool.name === 'test-tool');
|
||||
return responseData.find(
|
||||
(tool) => tool.pluginKey === `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
};
|
||||
|
||||
it('should set plugin.icon when iconPath is defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {
|
||||
iconPath: '/path/to/icon.png',
|
||||
},
|
||||
const serverConfig = {
|
||||
iconPath: '/path/to/icon.png',
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(mcpServers);
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
|
||||
expect(testTool.icon).toBe('/path/to/icon.png');
|
||||
});
|
||||
|
||||
it('should set plugin.icon to undefined when iconPath is not defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {},
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(mcpServers);
|
||||
const serverConfig = {};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(serverConfig);
|
||||
expect(testTool.icon).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('helper function integration', () => {
|
||||
it('should properly handle MCP tools with custom user variables', async () => {
|
||||
const customConfig = {
|
||||
mcpServers: {
|
||||
const appConfig = {
|
||||
mcpConfig: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: { title: 'API Key', description: 'Your API key' },
|
||||
@@ -298,45 +400,43 @@ describe('PluginController', () => {
|
||||
},
|
||||
};
|
||||
|
||||
// We need to test the actual flow where MCP manager tools are included
|
||||
const mcpManagerTools = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
// Mock MCP tools returned by getAllToolFunctions
|
||||
const mcpToolFunctions = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
// Mock the MCP manager to return tools
|
||||
const mockMCPManager = {
|
||||
loadManifestTools: jest.fn().mockResolvedValue(mcpManagerTools),
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(mcpToolFunctions),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: { title: 'API Key', description: 'Your API key' },
|
||||
},
|
||||
}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
mockReq.config = appConfig;
|
||||
|
||||
// First call returns user tools (empty in this case)
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock convertMCPToolsToPlugins to return empty array for user tools
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
|
||||
// Mock filterUniquePlugins to pass through
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
|
||||
// Mock checkPluginAuth
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
// Second call returns tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
// Second call (with includeGlobal: true) returns tool definitions including our MCP tool
|
||||
getCachedTools.mockResolvedValueOnce(mcpToolFunctions);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
|
||||
// Find the MCP tool in the response
|
||||
const mcpTool = responseData.find(
|
||||
@@ -372,26 +472,36 @@ describe('PluginController', () => {
|
||||
|
||||
it('should handle null cachedTools and cachedUserTools', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
// First call returns null for user tools
|
||||
getCachedTools.mockResolvedValueOnce(null);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return no tools
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// Second call (with includeGlobal: true) returns empty object instead of null
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: null,
|
||||
customConfig: null,
|
||||
});
|
||||
// Should handle null values gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle when getCachedTools returns undefined', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(undefined);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock getCachedTools to return undefined for both calls
|
||||
getCachedTools.mockReset();
|
||||
@@ -399,37 +509,72 @@ describe('PluginController', () => {
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: undefined,
|
||||
customConfig: null,
|
||||
});
|
||||
// Should handle undefined values gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
|
||||
it('should handle `cachedToolsArray` and `mcpPlugins` both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
|
||||
const userTools = {
|
||||
'user-tool': { function: { name: 'user-tool', description: 'User tool' } },
|
||||
[`user-tool${Constants.mcp_delimiter}server1`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `user-tool${Constants.mcp_delimiter}server1`,
|
||||
description: 'User tool',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
const userPlugins = [{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' }];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValue(userTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(userPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...userPlugins, ...cachedTools]);
|
||||
getCachedTools.mockResolvedValueOnce(userTools);
|
||||
mockReq.config = {
|
||||
mcpConfig: {
|
||||
server1: {},
|
||||
},
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Mock MCP manager to return empty tools initially
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
// The controller expects a second call to getCachedTools
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
'cached-tool': { type: 'function', function: { name: 'cached-tool' } },
|
||||
[`user-tool${Constants.mcp_delimiter}server1`]:
|
||||
userTools[`user-tool${Constants.mcp_delimiter}server1`],
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([...userPlugins, ...cachedTools]);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// Should have both cached and user tools
|
||||
expect(responseData.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it('should handle empty toolDefinitions object', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
// Reset getCachedTools to ensure clean state
|
||||
getCachedTools.mockReset();
|
||||
getCachedTools.mockResolvedValue({});
|
||||
mockReq.config = {}; // No mcpConfig at all
|
||||
|
||||
// Ensure no plugins are available
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
// Reset MCP manager to default state
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
@@ -438,8 +583,8 @@ describe('PluginController', () => {
|
||||
});
|
||||
|
||||
it('should handle MCP tools without customUserVars', async () => {
|
||||
const customConfig = {
|
||||
mcpServers: {
|
||||
const appConfig = {
|
||||
mcpConfig: {
|
||||
'test-server': {
|
||||
// No customUserVars defined
|
||||
},
|
||||
@@ -448,43 +593,59 @@ describe('PluginController', () => {
|
||||
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Mock the MCP manager to return the tools
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(mockUserTools),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
// No customUserVars defined
|
||||
}),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
mockReq.config = appConfig;
|
||||
// First call returns empty user tools
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Second call (with includeGlobal: true) returns the tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
const mockPlugin = {
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
};
|
||||
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
// Ensure no plugins in availableTools for clean test
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
// The actual implementation doesn't set authConfig on tools without customUserVars
|
||||
expect(responseData[0].authConfig).toEqual([]);
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
expect(responseData.length).toBeGreaterThan(0);
|
||||
|
||||
const mcpTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
|
||||
expect(mcpTool).toBeDefined();
|
||||
expect(mcpTool.authenticated).toBe(true);
|
||||
// The actual implementation sets authConfig to empty array when no customUserVars
|
||||
expect(mcpTool.authConfig).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle req.app.locals with undefined filteredTools and includedTools', async () => {
|
||||
mockReq.app = { locals: {} };
|
||||
it('should handle undefined filteredTools and includedTools', async () => {
|
||||
mockReq.config = {};
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
// Configure getAppConfig to return config with undefined properties
|
||||
// The controller will use default values [] for filteredTools and includedTools
|
||||
getAppConfig.mockResolvedValueOnce({});
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
@@ -500,16 +661,21 @@ describe('PluginController', () => {
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
// No need to mock app.locals anymore as it's not used
|
||||
|
||||
// Mock getCachedTools second call to return null
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce(null);
|
||||
// Add the toolkit to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
// First call returns empty object
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
mockReq.config = {
|
||||
mcpConfig: null,
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
};
|
||||
|
||||
// Second call (with includeGlobal: true) returns empty object to avoid null reference error
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ const verify2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
@@ -79,7 +79,7 @@ const confirm2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await getUserById(userId, '_id totpSecret');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
@@ -105,7 +105,7 @@ const disable2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA is not setup for this user' });
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
|
||||
const {
|
||||
webSearchKeys,
|
||||
extractWebSearchEnvVars,
|
||||
normalizeHttpError,
|
||||
MCPTokenStorage,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
updateUser,
|
||||
@@ -16,16 +21,30 @@ const { verifyEmail, resendVerificationEmail } = require('~/server/services/Auth
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { Tools, Constants, FileSources } = require('librechat-data-provider');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User } = require('~/db/models');
|
||||
const { Transaction, Balance, User, Token } = require('~/db/models');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { deleteAllSharedLinks } = require('~/models');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { clearMCPServerTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { findToken } = require('~/models');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
/** @type {MongoUser} */
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
/** @type {IUser} */
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
/**
|
||||
* These fields should not exist due to secure field selection, but deletion
|
||||
* is done in case of alternate database incompatibility with Mongo API
|
||||
* */
|
||||
delete userData.password;
|
||||
delete userData.totpSecret;
|
||||
if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) {
|
||||
delete userData.backupCodes;
|
||||
if (appConfig.fileStrategy === FileSources.s3 && userData.avatar) {
|
||||
const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600);
|
||||
if (!avatarNeedsRefresh) {
|
||||
return res.status(200).send(userData);
|
||||
@@ -81,6 +100,7 @@ const deleteUserFiles = async (req) => {
|
||||
};
|
||||
|
||||
const updateUserPluginsController = async (req, res) => {
|
||||
const appConfig = await getAppConfig({ role: req.user?.role });
|
||||
const { user } = req;
|
||||
const { pluginKey, action, auth, isEntityTool } = req.body;
|
||||
try {
|
||||
@@ -125,7 +145,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
|
||||
if (pluginKey === Tools.web_search) {
|
||||
/** @type {TCustomConfig['webSearch']} */
|
||||
const webSearchConfig = req.app.locals?.webSearch;
|
||||
const webSearchConfig = appConfig?.webSearch;
|
||||
keys = extractWebSearchEnvVars({
|
||||
keys: action === 'install' ? keys : webSearchKeys,
|
||||
config: webSearchConfig,
|
||||
@@ -153,6 +173,15 @@ const updateUserPluginsController = async (req, res) => {
|
||||
);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
}
|
||||
try {
|
||||
// if the MCP server uses OAuth, perform a full cleanup and token revocation
|
||||
await maybeUninstallOAuthMCP(user.id, pluginKey, appConfig);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[updateUserPluginsController] Error uninstalling OAuth MCP for ${pluginKey}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// This handles:
|
||||
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
|
||||
@@ -178,7 +207,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
// Extract server name from pluginKey (format: "mcp_<serverName>")
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
logger.info(
|
||||
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
`[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
}
|
||||
@@ -260,6 +289,97 @@ const resendVerificationController = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* OAuth MCP specific uninstall logic
|
||||
*/
|
||||
const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => {
|
||||
if (!pluginKey.startsWith(Constants.mcp_prefix)) {
|
||||
// this is not an MCP server, so nothing to do here
|
||||
return;
|
||||
}
|
||||
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!mcpManager.getOAuthServers().has(serverName)) {
|
||||
// this server does not use OAuth, so nothing to do here as well
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. get client info used for revocation (client id, secret)
|
||||
const clientTokenData = await MCPTokenStorage.getClientInfoAndMetadata({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
});
|
||||
if (clientTokenData == null) {
|
||||
return;
|
||||
}
|
||||
const { clientInfo, clientMetadata } = clientTokenData;
|
||||
|
||||
// 2. get decrypted tokens before deletion
|
||||
const tokens = await MCPTokenStorage.getTokens({
|
||||
userId,
|
||||
serverName,
|
||||
findToken,
|
||||
});
|
||||
|
||||
// 3. revoke OAuth tokens at the provider
|
||||
const revocationEndpoint =
|
||||
serverConfig.oauth?.revocation_endpoint ?? clientMetadata.revocation_endpoint;
|
||||
const revocationEndpointAuthMethodsSupported =
|
||||
serverConfig.oauth?.revocation_endpoint_auth_methods_supported ??
|
||||
clientMetadata.revocation_endpoint_auth_methods_supported;
|
||||
|
||||
if (tokens?.access_token) {
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.access_token, 'access', {
|
||||
serverUrl: serverConfig.url,
|
||||
clientId: clientInfo.client_id,
|
||||
clientSecret: clientInfo.client_secret ?? '',
|
||||
revocationEndpoint,
|
||||
revocationEndpointAuthMethodsSupported,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth access token for ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens?.refresh_token) {
|
||||
try {
|
||||
await MCPOAuthHandler.revokeOAuthToken(serverName, tokens.refresh_token, 'refresh', {
|
||||
serverUrl: serverConfig.url,
|
||||
clientId: clientInfo.client_id,
|
||||
clientSecret: clientInfo.client_secret ?? '',
|
||||
revocationEndpoint,
|
||||
revocationEndpointAuthMethodsSupported,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error revoking OAuth refresh token for ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. delete tokens from the DB after revocation attempts
|
||||
await MCPTokenStorage.deleteUserTokens({
|
||||
userId,
|
||||
serverName,
|
||||
deleteToken: async (filter) => {
|
||||
await Token.deleteOne(filter);
|
||||
},
|
||||
});
|
||||
|
||||
// 5. clear the flow state for the OAuth tokens
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
await flowManager.deleteFlow(flowId, 'mcp_get_tokens');
|
||||
await flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
|
||||
// 6. clear the tools cache for the server
|
||||
await clearMCPServerTools({ userId, serverName });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getUserController,
|
||||
getTermsStatusController,
|
||||
|
||||
342
api/server/controllers/agents/__tests__/callbacks.spec.js
Normal file
342
api/server/controllers/agents/__tests__/callbacks.spec.js
Normal file
@@ -0,0 +1,342 @@
|
||||
const { Tools } = require('librechat-data-provider');
|
||||
|
||||
// Mock all dependencies before requiring the module
|
||||
jest.mock('nanoid', () => ({
|
||||
nanoid: jest.fn(() => 'mock-id'),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
sendEvent: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
EnvVar: { CODE_API_KEY: 'CODE_API_KEY' },
|
||||
Providers: { GOOGLE: 'google' },
|
||||
GraphEvents: {},
|
||||
getMessageId: jest.fn(),
|
||||
ToolEndHandler: jest.fn(),
|
||||
handleToolCalls: jest.fn(),
|
||||
ChatModelStreamHandler: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/Citations', () => ({
|
||||
processFileCitations: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/Code/process', () => ({
|
||||
processCodeOutput: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
saveBase64Image: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('createToolEndCallback', () => {
|
||||
let req, res, artifactPromises, createToolEndCallback;
|
||||
let logger;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Get the mocked logger
|
||||
logger = require('@librechat/data-schemas').logger;
|
||||
|
||||
// Now require the module after all mocks are set up
|
||||
const callbacks = require('../callbacks');
|
||||
createToolEndCallback = callbacks.createToolEndCallback;
|
||||
|
||||
req = {
|
||||
user: { id: 'user123' },
|
||||
};
|
||||
res = {
|
||||
headersSent: false,
|
||||
write: jest.fn(),
|
||||
};
|
||||
artifactPromises = [];
|
||||
});
|
||||
|
||||
describe('ui_resources artifact handling', () => {
|
||||
it('should process ui_resources artifact and return attachment when headers not sent', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'button', label: 'Click me' },
|
||||
1: { type: 'input', placeholder: 'Enter text' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
|
||||
// Wait for all promises to resolve
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
// When headers are not sent, it returns attachment without writing
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
|
||||
const attachment = results[0];
|
||||
expect(attachment).toEqual({
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'run456',
|
||||
toolCallId: 'tool123',
|
||||
conversationId: 'thread789',
|
||||
[Tools.ui_resources]: {
|
||||
0: { type: 'button', label: 'Click me' },
|
||||
1: { type: 'input', placeholder: 'Enter text' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should write to response when headers are already sent', async () => {
|
||||
res.headersSent = true;
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'carousel', items: [] },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(res.write).toHaveBeenCalled();
|
||||
expect(results[0]).toEqual({
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'run456',
|
||||
toolCallId: 'tool123',
|
||||
conversationId: 'thread789',
|
||||
[Tools.ui_resources]: {
|
||||
0: { type: 'carousel', items: [] },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle errors when processing ui_resources', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
// Mock res.write to throw an error
|
||||
res.headersSent = true;
|
||||
res.write.mockImplementation(() => {
|
||||
throw new Error('Write failed');
|
||||
});
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'test' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'Error processing artifact content:',
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(results[0]).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle multiple artifacts including ui_resources', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {
|
||||
0: { type: 'chart', data: [] },
|
||||
},
|
||||
},
|
||||
[Tools.web_search]: {
|
||||
results: ['result1', 'result2'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
// Both ui_resources and web_search should be processed
|
||||
expect(artifactPromises).toHaveLength(2);
|
||||
expect(results).toHaveLength(2);
|
||||
|
||||
// Check ui_resources attachment
|
||||
const uiResourceAttachment = results.find((r) => r?.type === Tools.ui_resources);
|
||||
expect(uiResourceAttachment).toBeTruthy();
|
||||
expect(uiResourceAttachment[Tools.ui_resources]).toEqual({
|
||||
0: { type: 'chart', data: [] },
|
||||
});
|
||||
|
||||
// Check web_search attachment
|
||||
const webSearchAttachment = results.find((r) => r?.type === Tools.web_search);
|
||||
expect(webSearchAttachment).toBeTruthy();
|
||||
expect(webSearchAttachment[Tools.web_search]).toEqual({
|
||||
results: ['result1', 'result2'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should not process artifacts when output has no artifacts', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
content: 'Some regular content',
|
||||
// No artifact property
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
|
||||
expect(artifactPromises).toHaveLength(0);
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty ui_resources data object', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(results[0]).toEqual({
|
||||
type: Tools.ui_resources,
|
||||
messageId: 'run456',
|
||||
toolCallId: 'tool123',
|
||||
conversationId: 'thread789',
|
||||
[Tools.ui_resources]: {},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle ui_resources with complex nested data', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const complexData = {
|
||||
0: {
|
||||
type: 'form',
|
||||
fields: [
|
||||
{ name: 'field1', type: 'text', required: true },
|
||||
{ name: 'field2', type: 'select', options: ['a', 'b', 'c'] },
|
||||
],
|
||||
nested: {
|
||||
deep: {
|
||||
value: 123,
|
||||
array: [1, 2, 3],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const output = {
|
||||
tool_call_id: 'tool123',
|
||||
artifact: {
|
||||
[Tools.ui_resources]: {
|
||||
data: complexData,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output }, metadata);
|
||||
const results = await Promise.all(artifactPromises);
|
||||
|
||||
expect(results[0][Tools.ui_resources]).toEqual(complexData);
|
||||
});
|
||||
|
||||
it('should handle when output is undefined', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback({ output: undefined }, metadata);
|
||||
|
||||
expect(artifactPromises).toHaveLength(0);
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle when data parameter is undefined', async () => {
|
||||
const toolEndCallback = createToolEndCallback({ req, res, artifactPromises });
|
||||
|
||||
const metadata = {
|
||||
run_id: 'run456',
|
||||
thread_id: 'thread789',
|
||||
};
|
||||
|
||||
await toolEndCallback(undefined, metadata);
|
||||
|
||||
expect(artifactPromises).toHaveLength(0);
|
||||
expect(res.write).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -246,6 +246,7 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
const attachment = await processFileCitations({
|
||||
user,
|
||||
metadata,
|
||||
appConfig: req.config,
|
||||
toolArtifact: output.artifact,
|
||||
toolCallId: output.tool_call_id,
|
||||
});
|
||||
@@ -264,6 +265,30 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: a lot of duplicated code in createToolEndCallback
|
||||
// we should refactor this to use a helper function in a follow-up PR
|
||||
if (output.artifact[Tools.ui_resources]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const attachment = {
|
||||
type: Tools.ui_resources,
|
||||
messageId: metadata.run_id,
|
||||
toolCallId: output.tool_call_id,
|
||||
conversationId: metadata.thread_id,
|
||||
[Tools.ui_resources]: output.artifact[Tools.ui_resources].data,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing artifact content:', error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (output.artifact[Tools.web_search]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
|
||||
@@ -7,8 +7,12 @@ const {
|
||||
createRun,
|
||||
Tokenizer,
|
||||
checkAccess,
|
||||
logAxiosError,
|
||||
resolveHeaders,
|
||||
getBalanceConfig,
|
||||
memoryInstructions,
|
||||
formatContentStrings,
|
||||
getTransactionsConfig,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
@@ -33,18 +37,13 @@ const {
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
findPluginAuthsByKeys,
|
||||
getFormattedMemories,
|
||||
deleteMemory,
|
||||
setMemory,
|
||||
} = require('~/models');
|
||||
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
@@ -90,11 +89,10 @@ function createTokenCounter(encoding) {
|
||||
}
|
||||
|
||||
function logToolError(graph, error, toolId) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Tool Error',
|
||||
logAxiosError({
|
||||
error,
|
||||
toolId,
|
||||
);
|
||||
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
|
||||
});
|
||||
}
|
||||
|
||||
class AgentClient extends BaseClient {
|
||||
@@ -451,8 +449,8 @@ class AgentClient extends BaseClient {
|
||||
);
|
||||
return;
|
||||
}
|
||||
/** @type {TCustomConfig['memory']} */
|
||||
const memoryConfig = this.options.req?.app?.locals?.memory;
|
||||
const appConfig = this.options.req.config;
|
||||
const memoryConfig = appConfig.memory;
|
||||
if (!memoryConfig || memoryConfig.disabled === true) {
|
||||
return;
|
||||
}
|
||||
@@ -460,7 +458,7 @@ class AgentClient extends BaseClient {
|
||||
/** @type {Agent} */
|
||||
let prelimAgent;
|
||||
const allowedProviders = new Set(
|
||||
this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders,
|
||||
);
|
||||
try {
|
||||
if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) {
|
||||
@@ -582,8 +580,8 @@ class AgentClient extends BaseClient {
|
||||
if (this.processMemory == null) {
|
||||
return;
|
||||
}
|
||||
/** @type {TCustomConfig['memory']} */
|
||||
const memoryConfig = this.options.req?.app?.locals?.memory;
|
||||
const appConfig = this.options.req.config;
|
||||
const memoryConfig = appConfig.memory;
|
||||
const messageWindowSize = memoryConfig?.messageWindowSize ?? 5;
|
||||
|
||||
let messagesToProcess = [...messages];
|
||||
@@ -615,6 +613,7 @@ class AgentClient extends BaseClient {
|
||||
await this.chatCompletion({
|
||||
payload,
|
||||
onProgress: opts.onProgress,
|
||||
userMCPAuthMap: opts.userMCPAuthMap,
|
||||
abortController: opts.abortController,
|
||||
});
|
||||
return this.contentParts;
|
||||
@@ -624,9 +623,17 @@ class AgentClient extends BaseClient {
|
||||
* @param {Object} params
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.context='message']
|
||||
* @param {AppConfig['balance']} [params.balance]
|
||||
* @param {AppConfig['transactions']} [params.transactions]
|
||||
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
|
||||
*/
|
||||
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) {
|
||||
async recordCollectedUsage({
|
||||
model,
|
||||
balance,
|
||||
transactions,
|
||||
context = 'message',
|
||||
collectedUsage = this.collectedUsage,
|
||||
}) {
|
||||
if (!collectedUsage || !collectedUsage.length) {
|
||||
return;
|
||||
}
|
||||
@@ -648,6 +655,8 @@ class AgentClient extends BaseClient {
|
||||
|
||||
const txMetadata = {
|
||||
context,
|
||||
balance,
|
||||
transactions,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
@@ -747,7 +756,13 @@ class AgentClient extends BaseClient {
|
||||
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
||||
}
|
||||
|
||||
async chatCompletion({ payload, abortController = null }) {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {string | ChatCompletionMessageParam[]} params.payload
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @param {AbortController} [params.abortController]
|
||||
*/
|
||||
async chatCompletion({ payload, userMCPAuthMap, abortController = null }) {
|
||||
/** @type {Partial<GraphRunnableConfig>} */
|
||||
let config;
|
||||
/** @type {ReturnType<createRun>} */
|
||||
@@ -759,8 +774,9 @@ class AgentClient extends BaseClient {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
||||
/** @type {TCustomConfig['endpoints']['agents']} */
|
||||
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
|
||||
const appConfig = this.options.req.config;
|
||||
/** @type {AppConfig['endpoints']['agents']} */
|
||||
const agentsEConfig = appConfig.endpoints?.[EModelEndpoint.agents];
|
||||
|
||||
config = {
|
||||
configurable: {
|
||||
@@ -768,6 +784,11 @@ class AgentClient extends BaseClient {
|
||||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
user_id: this.user ?? this.options.req.user?.id,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
requestBody: {
|
||||
messageId: this.responseMessageId,
|
||||
conversationId: this.conversationId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
},
|
||||
user: this.options.req.user,
|
||||
},
|
||||
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
|
||||
@@ -851,11 +872,10 @@ class AgentClient extends BaseClient {
|
||||
if (agent.useLegacyContent === true) {
|
||||
messages = formatContentStrings(messages);
|
||||
}
|
||||
if (
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
|
||||
'prompt-caching',
|
||||
)
|
||||
) {
|
||||
const defaultHeaders =
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders ??
|
||||
agent.model_parameters?.configuration?.defaultHeaders;
|
||||
if (defaultHeaders?.['anthropic-beta']?.includes('prompt-caching')) {
|
||||
messages = addCacheControl(messages);
|
||||
}
|
||||
|
||||
@@ -863,6 +883,16 @@ class AgentClient extends BaseClient {
|
||||
memoryPromise = this.runMemory(messages);
|
||||
}
|
||||
|
||||
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
|
||||
* non-custom endpoints, needs consideration of varying provider header configs.
|
||||
*/
|
||||
if (agent.model_parameters?.configuration?.defaultHeaders != null) {
|
||||
agent.model_parameters.configuration.defaultHeaders = resolveHeaders({
|
||||
headers: agent.model_parameters.configuration.defaultHeaders,
|
||||
body: config.configurable.requestBody,
|
||||
});
|
||||
}
|
||||
|
||||
run = await createRun({
|
||||
agent,
|
||||
req: this.options.req,
|
||||
@@ -898,21 +928,9 @@ class AgentClient extends BaseClient {
|
||||
run.Graph.contentData = contentData;
|
||||
}
|
||||
|
||||
try {
|
||||
if (await hasCustomUserVars()) {
|
||||
config.configurable.userMCPAuthMap = await getMCPAuthMap({
|
||||
tools: agent.tools,
|
||||
userId: this.options.req.user.id,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
|
||||
err,
|
||||
);
|
||||
if (userMCPAuthMap != null) {
|
||||
config.configurable.userMCPAuthMap = userMCPAuthMap;
|
||||
}
|
||||
|
||||
await run.processStream({ messages }, config, {
|
||||
keepContent: i !== 0,
|
||||
tokenCounter: createTokenCounter(this.getEncoding()),
|
||||
@@ -1035,7 +1053,13 @@ class AgentClient extends BaseClient {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
await this.recordCollectedUsage({
|
||||
context: 'message',
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
|
||||
@@ -1076,19 +1100,21 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
||||
const { req, res, agent } = this.options;
|
||||
const appConfig = req.config;
|
||||
let endpoint = agent.endpoint;
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
let clientOptions = {
|
||||
maxTokens: 75,
|
||||
model: agent.model || agent.model_parameters.model,
|
||||
};
|
||||
|
||||
let titleProviderConfig = await getProviderConfig(endpoint);
|
||||
let titleProviderConfig = getProviderConfig({ provider: endpoint, appConfig });
|
||||
|
||||
/** @type {TEndpoint | undefined} */
|
||||
const endpointConfig =
|
||||
req.app.locals.all ?? req.app.locals[endpoint] ?? titleProviderConfig.customEndpointConfig;
|
||||
appConfig.endpoints?.all ??
|
||||
appConfig.endpoints?.[endpoint] ??
|
||||
titleProviderConfig.customEndpointConfig;
|
||||
if (!endpointConfig) {
|
||||
logger.warn(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
|
||||
@@ -1097,7 +1123,10 @@ class AgentClient extends BaseClient {
|
||||
|
||||
if (endpointConfig?.titleEndpoint && endpointConfig.titleEndpoint !== endpoint) {
|
||||
try {
|
||||
titleProviderConfig = await getProviderConfig(endpointConfig.titleEndpoint);
|
||||
titleProviderConfig = getProviderConfig({
|
||||
provider: endpointConfig.titleEndpoint,
|
||||
appConfig,
|
||||
});
|
||||
endpoint = endpointConfig.titleEndpoint;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
@@ -1106,7 +1135,7 @@ class AgentClient extends BaseClient {
|
||||
);
|
||||
// Fall back to original provider config
|
||||
endpoint = agent.endpoint;
|
||||
titleProviderConfig = await getProviderConfig(endpoint);
|
||||
titleProviderConfig = getProviderConfig({ provider: endpoint, appConfig });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1147,15 +1176,13 @@ class AgentClient extends BaseClient {
|
||||
clientOptions.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
const shouldRemoveMaxTokens = /\b(o\d|gpt-[5-9])\b/i.test(clientOptions.model);
|
||||
if (shouldRemoveMaxTokens && clientOptions.maxTokens != null) {
|
||||
if (clientOptions.maxTokens != null) {
|
||||
delete clientOptions.maxTokens;
|
||||
} else if (!shouldRemoveMaxTokens && !clientOptions.maxTokens) {
|
||||
clientOptions.maxTokens = 75;
|
||||
}
|
||||
if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_completion_tokens != null) {
|
||||
if (clientOptions?.modelKwargs?.max_completion_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_completion_tokens;
|
||||
} else if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_output_tokens != null) {
|
||||
}
|
||||
if (clientOptions?.modelKwargs?.max_output_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_output_tokens;
|
||||
}
|
||||
|
||||
@@ -1173,6 +1200,20 @@ class AgentClient extends BaseClient {
|
||||
clientOptions.json = true;
|
||||
}
|
||||
|
||||
/** Resolve request-based headers for Custom Endpoints. Note: if this is added to
|
||||
* non-custom endpoints, needs consideration of varying provider header configs.
|
||||
*/
|
||||
if (clientOptions?.configuration?.defaultHeaders != null) {
|
||||
clientOptions.configuration.defaultHeaders = resolveHeaders({
|
||||
headers: clientOptions.configuration.defaultHeaders,
|
||||
body: {
|
||||
messageId: this.responseMessageId,
|
||||
conversationId: this.conversationId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const titleResult = await this.run.generateTitle({
|
||||
provider,
|
||||
@@ -1211,10 +1252,14 @@ class AgentClient extends BaseClient {
|
||||
};
|
||||
});
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
const transactionsConfig = getTransactionsConfig(appConfig);
|
||||
await this.recordCollectedUsage({
|
||||
model: clientOptions.model,
|
||||
context: 'title',
|
||||
collectedUsage,
|
||||
context: 'title',
|
||||
model: clientOptions.model,
|
||||
balance: balanceConfig,
|
||||
transactions: transactionsConfig,
|
||||
}).catch((err) => {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage',
|
||||
@@ -1233,17 +1278,26 @@ class AgentClient extends BaseClient {
|
||||
* @param {object} params
|
||||
* @param {number} params.promptTokens
|
||||
* @param {number} params.completionTokens
|
||||
* @param {OpenAIUsageMetadata} [params.usage]
|
||||
* @param {string} [params.model]
|
||||
* @param {OpenAIUsageMetadata} [params.usage]
|
||||
* @param {AppConfig['balance']} [params.balance]
|
||||
* @param {string} [params.context='message']
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
|
||||
async recordTokenUsage({
|
||||
model,
|
||||
usage,
|
||||
balance,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
context = 'message',
|
||||
}) {
|
||||
try {
|
||||
await spendTokens(
|
||||
{
|
||||
model,
|
||||
context,
|
||||
balance,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
@@ -1260,6 +1314,7 @@ class AgentClient extends BaseClient {
|
||||
await spendTokens(
|
||||
{
|
||||
model,
|
||||
balance,
|
||||
context: 'reasoning',
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
|
||||
@@ -41,8 +41,16 @@ describe('AgentClient - titleConvo', () => {
|
||||
|
||||
// Mock request and response
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {
|
||||
user: {
|
||||
id: 'user-123',
|
||||
},
|
||||
body: {
|
||||
model: 'gpt-4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
key: null,
|
||||
},
|
||||
config: {
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
// Match the agent endpoint
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
@@ -52,14 +60,6 @@ describe('AgentClient - titleConvo', () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
user: {
|
||||
id: 'user-123',
|
||||
},
|
||||
body: {
|
||||
model: 'gpt-4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
key: null,
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {};
|
||||
@@ -143,7 +143,7 @@ describe('AgentClient - titleConvo', () => {
|
||||
|
||||
it('should handle missing endpoint config gracefully', async () => {
|
||||
// Remove endpoint config
|
||||
mockReq.app.locals[EModelEndpoint.openAI] = undefined;
|
||||
mockReq.config = { endpoints: {} };
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
@@ -161,7 +161,16 @@ describe('AgentClient - titleConvo', () => {
|
||||
|
||||
it('should use agent model when titleModel is not provided', async () => {
|
||||
// Remove titleModel from config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titleModel;
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titlePrompt: 'Custom title prompt',
|
||||
titleMethod: 'structured',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
// titleModel is omitted
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
@@ -173,7 +182,16 @@ describe('AgentClient - titleConvo', () => {
|
||||
});
|
||||
|
||||
it('should not use titleModel when it equals CURRENT_MODEL constant', async () => {
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleModel = Constants.CURRENT_MODEL;
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleModel: Constants.CURRENT_MODEL,
|
||||
titlePrompt: 'Custom title prompt',
|
||||
titleMethod: 'structured',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
@@ -216,6 +234,12 @@ describe('AgentClient - titleConvo', () => {
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'title',
|
||||
collectedUsage: expect.any(Array),
|
||||
balance: {
|
||||
enabled: false,
|
||||
},
|
||||
transactions: {
|
||||
enabled: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -245,10 +269,17 @@ describe('AgentClient - titleConvo', () => {
|
||||
process.env.ANTHROPIC_API_KEY = 'test-api-key';
|
||||
|
||||
// Add titleEndpoint to the config
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleEndpoint = EModelEndpoint.anthropic;
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleMethod = 'structured';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titlePrompt = 'Custom title prompt';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate = 'Custom template';
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titleEndpoint: EModelEndpoint.anthropic,
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Custom title prompt',
|
||||
titlePromptTemplate: 'Custom template',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
@@ -274,18 +305,16 @@ describe('AgentClient - titleConvo', () => {
|
||||
});
|
||||
|
||||
it('should use all config when endpoint config is missing', async () => {
|
||||
// Remove endpoint-specific config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titleModel;
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titlePrompt;
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titleMethod;
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate;
|
||||
|
||||
// Set 'all' config
|
||||
mockReq.app.locals.all = {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titleMethod: 'completion',
|
||||
titlePromptTemplate: 'All config template: {{content}}',
|
||||
// Set 'all' config without endpoint-specific config
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
all: {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titleMethod: 'completion',
|
||||
titlePromptTemplate: 'All config template: {{content}}',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
@@ -309,17 +338,21 @@ describe('AgentClient - titleConvo', () => {
|
||||
|
||||
it('should prioritize all config over endpoint config for title settings', async () => {
|
||||
// Set both endpoint and 'all' config
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleModel = 'gpt-3.5-turbo';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titlePrompt = 'Endpoint title prompt';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleMethod = 'structured';
|
||||
// Remove titlePromptTemplate from endpoint config to test fallback
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate;
|
||||
|
||||
mockReq.app.locals.all = {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titleMethod: 'completion',
|
||||
titlePromptTemplate: 'All config template',
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titlePrompt: 'Endpoint title prompt',
|
||||
titleMethod: 'structured',
|
||||
// titlePromptTemplate is omitted to test fallback
|
||||
},
|
||||
all: {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titleMethod: 'completion',
|
||||
titlePromptTemplate: 'All config template',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
@@ -346,17 +379,18 @@ describe('AgentClient - titleConvo', () => {
|
||||
const originalApiKey = process.env.ANTHROPIC_API_KEY;
|
||||
process.env.ANTHROPIC_API_KEY = 'test-anthropic-key';
|
||||
|
||||
// Remove endpoint-specific config to test 'all' config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI];
|
||||
|
||||
// Set comprehensive 'all' config with all new title options
|
||||
mockReq.app.locals.all = {
|
||||
titleConvo: true,
|
||||
titleModel: 'claude-3-haiku-20240307',
|
||||
titleMethod: 'completion', // Testing the new default method
|
||||
titlePrompt: 'Generate a concise, descriptive title for this conversation',
|
||||
titlePromptTemplate: 'Conversation summary: {{content}}',
|
||||
titleEndpoint: EModelEndpoint.anthropic, // Should switch provider to Anthropic
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
all: {
|
||||
titleConvo: true,
|
||||
titleModel: 'claude-3-haiku-20240307',
|
||||
titleMethod: 'completion', // Testing the new default method
|
||||
titlePrompt: 'Generate a concise, descriptive title for this conversation',
|
||||
titlePromptTemplate: 'Conversation summary: {{content}}',
|
||||
titleEndpoint: EModelEndpoint.anthropic, // Should switch provider to Anthropic
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test conversation about AI and machine learning';
|
||||
@@ -402,15 +436,16 @@ describe('AgentClient - titleConvo', () => {
|
||||
// Clear previous calls
|
||||
mockRun.generateTitle.mockClear();
|
||||
|
||||
// Remove endpoint config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI];
|
||||
|
||||
// Set 'all' config with specific titleMethod
|
||||
mockReq.app.locals.all = {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titleMethod: method,
|
||||
titlePrompt: `Testing ${method} method`,
|
||||
titlePromptTemplate: `Template for ${method}: {{content}}`,
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
all: {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titleMethod: method,
|
||||
titlePrompt: `Testing ${method} method`,
|
||||
titlePromptTemplate: `Template for ${method}: {{content}}`,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = `Test conversation for ${method} method`;
|
||||
@@ -455,29 +490,33 @@ describe('AgentClient - titleConvo', () => {
|
||||
// Set up Azure endpoint with serverless config
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: 'grok-3',
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'Azure serverless title prompt',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'grok-3': {
|
||||
group: 'Azure AI Foundry',
|
||||
deploymentName: 'grok-3',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'Azure AI Foundry': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://test.services.ai.azure.com/models',
|
||||
version: '2024-05-01-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.azureOpenAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: 'grok-3',
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'Azure serverless title prompt',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'grok-3': {
|
||||
group: 'Azure AI Foundry',
|
||||
deploymentName: 'grok-3',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'Azure AI Foundry': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://test.services.ai.azure.com/models',
|
||||
version: '2024-05-01-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'grok-3': {
|
||||
deploymentName: 'grok-3',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -503,28 +542,32 @@ describe('AgentClient - titleConvo', () => {
|
||||
// Set up Azure endpoint
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4o',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Azure instance title prompt',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o': {
|
||||
group: 'eastus',
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
eastus: {
|
||||
apiKey: '${EASTUS_API_KEY}',
|
||||
instanceName: 'region-instance',
|
||||
version: '2024-02-15-preview',
|
||||
models: {
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.azureOpenAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4o',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Azure instance title prompt',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o': {
|
||||
group: 'eastus',
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
eastus: {
|
||||
apiKey: '${EASTUS_API_KEY}',
|
||||
instanceName: 'region-instance',
|
||||
version: '2024-02-15-preview',
|
||||
models: {
|
||||
'gpt-4o': {
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -551,29 +594,33 @@ describe('AgentClient - titleConvo', () => {
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.model_parameters.model = 'gpt-4o-latest';
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: Constants.CURRENT_MODEL,
|
||||
titleMethod: 'functions',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o-latest': {
|
||||
group: 'region-eastus',
|
||||
deploymentName: 'gpt-4o-mini',
|
||||
version: '2024-02-15-preview',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'region-eastus': {
|
||||
apiKey: '${EASTUS2_API_KEY}',
|
||||
instanceName: 'test-instance',
|
||||
version: '2024-12-01-preview',
|
||||
models: {
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.azureOpenAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: Constants.CURRENT_MODEL,
|
||||
titleMethod: 'functions',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o-latest': {
|
||||
group: 'region-eastus',
|
||||
deploymentName: 'gpt-4o-mini',
|
||||
version: '2024-02-15-preview',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'region-eastus': {
|
||||
apiKey: '${EASTUS2_API_KEY}',
|
||||
instanceName: 'test-instance',
|
||||
version: '2024-12-01-preview',
|
||||
models: {
|
||||
'gpt-4o-latest': {
|
||||
deploymentName: 'gpt-4o-mini',
|
||||
version: '2024-02-15-preview',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -598,56 +645,60 @@ describe('AgentClient - titleConvo', () => {
|
||||
// Set up Azure endpoint
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: 'o1-mini',
|
||||
titleMethod: 'completion',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o': {
|
||||
group: 'eastus',
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
'o1-mini': {
|
||||
group: 'region-eastus',
|
||||
deploymentName: 'o1-mini',
|
||||
},
|
||||
'codex-mini': {
|
||||
group: 'codex-mini',
|
||||
deploymentName: 'codex-mini',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
eastus: {
|
||||
apiKey: '${EASTUS_API_KEY}',
|
||||
instanceName: 'region-eastus',
|
||||
version: '2024-02-15-preview',
|
||||
models: {
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
[EModelEndpoint.azureOpenAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: 'o1-mini',
|
||||
titleMethod: 'completion',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o': {
|
||||
group: 'eastus',
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
},
|
||||
'region-eastus': {
|
||||
apiKey: '${EASTUS2_API_KEY}',
|
||||
instanceName: 'region-eastus2',
|
||||
version: '2024-12-01-preview',
|
||||
models: {
|
||||
'o1-mini': {
|
||||
group: 'region-eastus',
|
||||
deploymentName: 'o1-mini',
|
||||
},
|
||||
},
|
||||
},
|
||||
'codex-mini': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://example.cognitiveservices.azure.com/openai/',
|
||||
version: '2025-04-01-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'codex-mini': {
|
||||
group: 'codex-mini',
|
||||
deploymentName: 'codex-mini',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
eastus: {
|
||||
apiKey: '${EASTUS_API_KEY}',
|
||||
instanceName: 'region-eastus',
|
||||
version: '2024-02-15-preview',
|
||||
models: {
|
||||
'gpt-4o': {
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
},
|
||||
'region-eastus': {
|
||||
apiKey: '${EASTUS2_API_KEY}',
|
||||
instanceName: 'region-eastus2',
|
||||
version: '2024-12-01-preview',
|
||||
models: {
|
||||
'o1-mini': {
|
||||
deploymentName: 'o1-mini',
|
||||
},
|
||||
},
|
||||
},
|
||||
'codex-mini': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://example.cognitiveservices.azure.com/openai/',
|
||||
version: '2025-04-01-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'codex-mini': {
|
||||
deploymentName: 'codex-mini',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -679,33 +730,34 @@ describe('AgentClient - titleConvo', () => {
|
||||
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.model = 'gpt-4';
|
||||
|
||||
// Remove Azure-specific config
|
||||
delete mockReq.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
// Set 'all' config as fallback with a serverless Azure config
|
||||
mockReq.app.locals.all = {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Fallback title prompt from all config',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
modelGroupMap: {
|
||||
'gpt-4': {
|
||||
group: 'default-group',
|
||||
deploymentName: 'gpt-4',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'default-group': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://default.openai.azure.com/',
|
||||
version: '2024-02-15-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
mockReq.config = {
|
||||
endpoints: {
|
||||
all: {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Fallback title prompt from all config',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
modelGroupMap: {
|
||||
'gpt-4': {
|
||||
group: 'default-group',
|
||||
deploymentName: 'gpt-4',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'default-group': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://default.openai.azure.com/',
|
||||
version: '2024-02-15-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'gpt-4': {
|
||||
deploymentName: 'gpt-4',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -982,13 +1034,6 @@ describe('AgentClient - titleConvo', () => {
|
||||
};
|
||||
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {
|
||||
memory: {
|
||||
messageWindowSize: 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
user: {
|
||||
id: 'user-123',
|
||||
personalization: {
|
||||
@@ -997,6 +1042,13 @@ describe('AgentClient - titleConvo', () => {
|
||||
},
|
||||
};
|
||||
|
||||
// Mock getAppConfig for memory tests
|
||||
mockReq.config = {
|
||||
memory: {
|
||||
messageWindowSize: 3,
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {};
|
||||
|
||||
mockOptions = {
|
||||
|
||||
@@ -21,7 +21,7 @@ const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerDependencies
|
||||
* @property {Express.Request} req - The Express request object
|
||||
* @property {ServerRequest} req - The Express request object
|
||||
* @property {Express.Response} res - The Express response object
|
||||
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
|
||||
* @property {string} [originPath] - The origin path for the error handler
|
||||
|
||||
@@ -9,6 +9,24 @@ const {
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
function createCloseHandler(abortController) {
|
||||
return function (manual) {
|
||||
if (!manual) {
|
||||
logger.debug('[AgentController] Request closed');
|
||||
}
|
||||
if (!abortController) {
|
||||
return;
|
||||
} else if (abortController.signal.aborted) {
|
||||
return;
|
||||
} else if (abortController.requestCompleted) {
|
||||
return;
|
||||
}
|
||||
|
||||
abortController.abort();
|
||||
logger.debug('[AgentController] Request aborted on close');
|
||||
};
|
||||
}
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
@@ -31,7 +49,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let userMessagePromise;
|
||||
let getAbortData;
|
||||
let client = null;
|
||||
// Initialize as an array
|
||||
let cleanupHandlers = [];
|
||||
|
||||
const newConvo = !conversationId;
|
||||
@@ -62,9 +79,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
// Create a function to handle final cleanup
|
||||
const performCleanup = () => {
|
||||
logger.debug('[AgentController] Performing cleanup');
|
||||
// Make sure cleanupHandlers is an array before iterating
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
// Execute all cleanup handlers
|
||||
for (const handler of cleanupHandlers) {
|
||||
try {
|
||||
if (typeof handler === 'function') {
|
||||
@@ -105,8 +120,33 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
};
|
||||
|
||||
try {
|
||||
/** @type {{ client: TAgentClient }} */
|
||||
const result = await initializeClient({ req, res, endpointOption });
|
||||
let prelimAbortController = new AbortController();
|
||||
const prelimCloseHandler = createCloseHandler(prelimAbortController);
|
||||
res.on('close', prelimCloseHandler);
|
||||
const removePrelimHandler = (manual) => {
|
||||
try {
|
||||
prelimCloseHandler(manual);
|
||||
res.removeListener('close', prelimCloseHandler);
|
||||
} catch (e) {
|
||||
logger.error('[AgentController] Error removing close listener', e);
|
||||
}
|
||||
};
|
||||
cleanupHandlers.push(removePrelimHandler);
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
signal: prelimAbortController.signal,
|
||||
});
|
||||
if (prelimAbortController.signal?.aborted) {
|
||||
prelimAbortController = null;
|
||||
throw new Error('Request was aborted before initialization could complete');
|
||||
} else {
|
||||
prelimAbortController = null;
|
||||
removePrelimHandler(true);
|
||||
cleanupHandlers.pop();
|
||||
}
|
||||
client = result.client;
|
||||
|
||||
// Register client with finalization registry if available
|
||||
@@ -138,22 +178,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
};
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
// Simple handler to avoid capturing scope
|
||||
const closeHandler = () => {
|
||||
logger.debug('[AgentController] Request closed');
|
||||
if (!abortController) {
|
||||
return;
|
||||
} else if (abortController.signal.aborted) {
|
||||
return;
|
||||
} else if (abortController.requestCompleted) {
|
||||
return;
|
||||
}
|
||||
|
||||
abortController.abort();
|
||||
logger.debug('[AgentController] Request aborted on close');
|
||||
};
|
||||
|
||||
const closeHandler = createCloseHandler(abortController);
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
@@ -175,6 +200,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res,
|
||||
|
||||
@@ -5,6 +5,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
SystemRoles,
|
||||
FileSources,
|
||||
ResourceType,
|
||||
@@ -69,9 +70,9 @@ const createAgentHandler = async (req, res) => {
|
||||
for (const tool of tools) {
|
||||
if (availableTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
|
||||
if (systemTools[tool]) {
|
||||
} else if (systemTools[tool]) {
|
||||
agentData.tools.push(tool);
|
||||
} else if (tool.includes(Constants.mcp_delimiter)) {
|
||||
agentData.tools.push(tool);
|
||||
}
|
||||
}
|
||||
@@ -487,6 +488,7 @@ const getListAgentsHandler = async (req, res) => {
|
||||
*/
|
||||
const uploadAgentAvatarHandler = async (req, res) => {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
filterFile({ req, file: req.file, image: true, isAvatar: true });
|
||||
const { agent_id } = req.params;
|
||||
if (!agent_id) {
|
||||
@@ -510,9 +512,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
}
|
||||
|
||||
const buffer = await fs.readFile(req.file.path);
|
||||
|
||||
const fileStrategy = getFileStrategy(req.app.locals, { isAvatar: true });
|
||||
|
||||
const fileStrategy = getFileStrategy(appConfig, { isAvatar: true });
|
||||
const resizedBuffer = await resizeAvatar({
|
||||
userId: req.user.id,
|
||||
input: buffer,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -34,7 +34,6 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
@@ -47,6 +46,7 @@ const { getOpenAIClient } = require('./helpers');
|
||||
* @returns {void}
|
||||
*/
|
||||
const chatV1 = async (req, res) => {
|
||||
const appConfig = req.config;
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
|
||||
const {
|
||||
@@ -251,8 +251,8 @@ const chatV1 = async (req, res) => {
|
||||
}
|
||||
|
||||
const checkBalanceBeforeRun = async () => {
|
||||
const balance = req.app?.locals?.balance;
|
||||
if (!balance?.enabled) {
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
if (!balanceConfig?.enabled) {
|
||||
return;
|
||||
}
|
||||
const transactions =
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -31,19 +31,19 @@ const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {Express.Request} req - The request object, containing the request data.
|
||||
* @param {ServerRequest} req - The request object, containing the request data.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
const chatV2 = async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
const appConfig = req.config;
|
||||
|
||||
/** @type {{files: MongoFile[]}} */
|
||||
const {
|
||||
@@ -126,8 +126,8 @@ const chatV2 = async (req, res) => {
|
||||
}
|
||||
|
||||
const checkBalanceBeforeRun = async () => {
|
||||
const balance = req.app?.locals?.balance;
|
||||
if (!balance?.enabled) {
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
if (!balanceConfig?.enabled) {
|
||||
return;
|
||||
}
|
||||
const transactions =
|
||||
@@ -374,9 +374,9 @@ const chatV2 = async (req, res) => {
|
||||
};
|
||||
|
||||
/** @type {undefined | TAssistantEndpoint} */
|
||||
const config = req.app.locals[endpoint] ?? {};
|
||||
const config = appConfig.endpoints?.[endpoint] ?? {};
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
const allConfig = appConfig.endpoints?.all;
|
||||
|
||||
const streamRunManager = new StreamRunManager({
|
||||
req,
|
||||
|
||||
@@ -22,7 +22,7 @@ const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerDependencies
|
||||
* @property {Express.Request} req - The Express request object
|
||||
* @property {ServerRequest} req - The Express request object
|
||||
* @property {Express.Response} res - The Express response object
|
||||
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
|
||||
* @property {string} [originPath] - The origin path for the error handler
|
||||
|
||||
@@ -11,7 +11,7 @@ const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { getEndpointsConfig } = require('~/server/services/Config');
|
||||
|
||||
/**
|
||||
* @param {Express.Request} req
|
||||
* @param {ServerRequest} req
|
||||
* @param {string} [endpoint]
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
@@ -210,6 +210,7 @@ async function getOpenAIClient({ req, res, endpointOption, initAppClient, overri
|
||||
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
|
||||
*/
|
||||
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
const appConfig = req.config;
|
||||
const {
|
||||
limit = 100,
|
||||
order = 'desc',
|
||||
@@ -230,20 +231,20 @@ const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
({ body } = await listAllAssistants({ req, res, version, query }));
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
const azureConfig = appConfig.endpoints?.[EModelEndpoint.azureOpenAI];
|
||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
} else if (!appConfig.endpoints?.[endpoint]) {
|
||||
return body;
|
||||
}
|
||||
|
||||
body.data = filterAssistants({
|
||||
userId: req.user.id,
|
||||
assistants: body.data,
|
||||
assistantsConfig: req.app.locals[endpoint],
|
||||
assistantsConfig: appConfig.endpoints?.[endpoint],
|
||||
});
|
||||
return body;
|
||||
};
|
||||
|
||||
@@ -258,8 +258,9 @@ function filterAssistantDocs({ documents, userId, assistantsConfig = {} }) {
|
||||
*/
|
||||
const getAssistantDocuments = async (req, res) => {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
const endpoint = req.query;
|
||||
const assistantsConfig = req.app.locals[endpoint];
|
||||
const assistantsConfig = appConfig.endpoints?.[endpoint];
|
||||
const documents = await getAssistants(
|
||||
{},
|
||||
{
|
||||
@@ -296,6 +297,7 @@ const getAssistantDocuments = async (req, res) => {
|
||||
*/
|
||||
const uploadAssistantAvatar = async (req, res) => {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
filterFile({ req, file: req.file, image: true, isAvatar: true });
|
||||
const { assistant_id } = req.params;
|
||||
if (!assistant_id) {
|
||||
@@ -337,7 +339,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
||||
const metadata = {
|
||||
..._metadata,
|
||||
avatar: image.filepath,
|
||||
avatar_source: req.app.locals.fileStrategy,
|
||||
avatar_source: appConfig.fileStrategy,
|
||||
};
|
||||
|
||||
const promises = [];
|
||||
@@ -347,7 +349,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
||||
{
|
||||
avatar: {
|
||||
filepath: image.filepath,
|
||||
source: req.app.locals.fileStrategy,
|
||||
source: appConfig.fileStrategy,
|
||||
},
|
||||
user: req.user.id,
|
||||
},
|
||||
|
||||
@@ -94,7 +94,7 @@ const createAssistant = async (req, res) => {
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
@@ -199,7 +199,7 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||
/**
|
||||
* Modifies an assistant with the resource file id.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {string} params.tool_resource
|
||||
@@ -227,7 +227,7 @@ const addResourceFileId = async ({ req, openai, assistant_id, tool_resource, fil
|
||||
/**
|
||||
* Deletes a file ID from an assistant's resource.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {ServerRequest} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {string} [params.tool_resource]
|
||||
|
||||
@@ -22,10 +22,11 @@ const verify2FAWithTempToken = async (req, res) => {
|
||||
try {
|
||||
payload = jwt.verify(tempToken, process.env.JWT_SECRET);
|
||||
} catch (err) {
|
||||
logger.error('Failed to verify temporary token:', err);
|
||||
return res.status(401).json({ message: 'Invalid or expired temporary token' });
|
||||
}
|
||||
|
||||
const user = await getUserById(payload.userId);
|
||||
const user = await getUserById(payload.userId, '+totpSecret +backupCodes');
|
||||
if (!user || !user.twoFactorEnabled) {
|
||||
return res.status(400).json({ message: '2FA is not enabled for this user' });
|
||||
}
|
||||
@@ -42,11 +43,11 @@ const verify2FAWithTempToken = async (req, res) => {
|
||||
return res.status(401).json({ message: 'Invalid 2FA code or backup code' });
|
||||
}
|
||||
|
||||
// Prepare user data to return (omit sensitive fields).
|
||||
const userData = user.toObject ? user.toObject() : { ...user };
|
||||
delete userData.password;
|
||||
delete userData.__v;
|
||||
delete userData.password;
|
||||
delete userData.totpSecret;
|
||||
delete userData.backupCodes;
|
||||
userData.id = user._id.toString();
|
||||
|
||||
const authToken = await setAuthTokens(user._id, res);
|
||||
|
||||
@@ -35,9 +35,10 @@ const toolAccessPermType = {
|
||||
*/
|
||||
const verifyWebSearchAuth = async (req, res) => {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
const userId = req.user.id;
|
||||
/** @type {TCustomConfig['webSearch']} */
|
||||
const webSearchConfig = req.app.locals?.webSearch || {};
|
||||
const webSearchConfig = appConfig?.webSearch || {};
|
||||
const result = await loadWebSearchAuth({
|
||||
userId,
|
||||
loadAuthValues,
|
||||
@@ -110,6 +111,7 @@ const verifyToolAuth = async (req, res) => {
|
||||
*/
|
||||
const callTool = async (req, res) => {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
const { toolId = '' } = req.params;
|
||||
if (!fieldsMap[toolId]) {
|
||||
logger.warn(`[${toolId}/call] User ${req.user.id} attempted call to invalid tool`);
|
||||
@@ -155,8 +157,10 @@ const callTool = async (req, res) => {
|
||||
returnMetadata: true,
|
||||
processFileURL,
|
||||
uploadImageBuffer,
|
||||
fileStrategy: req.app.locals.fileStrategy,
|
||||
},
|
||||
webSearch: appConfig.webSearch,
|
||||
fileStrategy: appConfig.fileStrategy,
|
||||
imageOutputType: appConfig.imageOutputType,
|
||||
});
|
||||
|
||||
const tool = loadedTools[0];
|
||||
|
||||
@@ -12,13 +12,16 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
const { checkMigrations } = require('./services/start/migration');
|
||||
const initializeMCPs = require('./services/initializeMCPs');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const AppService = require('./services/AppService');
|
||||
const { getAppConfig } = require('./services/Config');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const { seedDatabase } = require('~/models');
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {};
|
||||
@@ -44,10 +47,25 @@ const startServer = async () => {
|
||||
app.disable('x-powered-by');
|
||||
app.set('trust proxy', trusted_proxy);
|
||||
|
||||
await AppService(app);
|
||||
await seedDatabase();
|
||||
|
||||
const indexPath = path.join(app.locals.paths.dist, 'index.html');
|
||||
const indexHTML = fs.readFileSync(indexPath, 'utf8');
|
||||
const appConfig = await getAppConfig();
|
||||
await updateInterfacePermissions(appConfig);
|
||||
const indexPath = path.join(appConfig.paths.dist, 'index.html');
|
||||
let indexHTML = fs.readFileSync(indexPath, 'utf8');
|
||||
|
||||
// In order to provide support to serving the application in a sub-directory
|
||||
// We need to update the base href if the DOMAIN_CLIENT is specified and not the root path
|
||||
if (process.env.DOMAIN_CLIENT) {
|
||||
const clientUrl = new URL(process.env.DOMAIN_CLIENT);
|
||||
const baseHref = clientUrl.pathname.endsWith('/')
|
||||
? clientUrl.pathname
|
||||
: `${clientUrl.pathname}/`;
|
||||
if (baseHref !== '/') {
|
||||
logger.info(`Setting base href to ${baseHref}`);
|
||||
indexHTML = indexHTML.replace(/base href="\/"/, `base href="${baseHref}"`);
|
||||
}
|
||||
}
|
||||
|
||||
app.get('/health', (_req, res) => res.status(200).send('OK'));
|
||||
|
||||
@@ -65,10 +83,9 @@ const startServer = async () => {
|
||||
console.warn('Response compression has been disabled via DISABLE_COMPRESSION.');
|
||||
}
|
||||
|
||||
// Serve static assets with aggressive caching
|
||||
app.use(staticCache(app.locals.paths.dist));
|
||||
app.use(staticCache(app.locals.paths.fonts));
|
||||
app.use(staticCache(app.locals.paths.assets));
|
||||
app.use(staticCache(appConfig.paths.dist));
|
||||
app.use(staticCache(appConfig.paths.fonts));
|
||||
app.use(staticCache(appConfig.paths.assets));
|
||||
|
||||
if (!ALLOW_SOCIAL_LOGIN) {
|
||||
console.warn('Social logins are disabled. Set ALLOW_SOCIAL_LOGIN=true to enable them.');
|
||||
@@ -109,7 +126,7 @@ const startServer = async () => {
|
||||
app.use('/api/config', routes.config);
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||
app.use('/images/', createValidateImageRequest(appConfig.secureImageLinks), routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
@@ -131,7 +148,8 @@ const startServer = async () => {
|
||||
|
||||
const lang = req.cookies.lang || req.headers['accept-language']?.split(',')[0] || 'en-US';
|
||||
const saneLang = lang.replace(/"/g, '"');
|
||||
const updatedIndexHtml = indexHTML.replace(/lang="en-US"/g, `lang="${saneLang}"`);
|
||||
let updatedIndexHtml = indexHTML.replace(/lang="en-US"/g, `lang="${saneLang}"`);
|
||||
|
||||
res.type('html');
|
||||
res.send(updatedIndexHtml);
|
||||
});
|
||||
@@ -145,7 +163,7 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCPs(app);
|
||||
initializeMCPs().then(() => checkMigrations());
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -3,9 +3,27 @@ const request = require('supertest');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
jest.mock('~/server/services/Config/loadCustomConfig', () => {
|
||||
return jest.fn(() => Promise.resolve({}));
|
||||
});
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
loadCustomConfig: jest.fn(() => Promise.resolve({})),
|
||||
getAppConfig: jest.fn().mockResolvedValue({
|
||||
paths: {
|
||||
uploads: '/tmp',
|
||||
dist: '/tmp/dist',
|
||||
fonts: '/tmp/fonts',
|
||||
assets: '/tmp/assets',
|
||||
},
|
||||
fileStrategy: 'local',
|
||||
imageOutputType: 'PNG',
|
||||
}),
|
||||
setCachedTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/app/clients/tools', () => ({
|
||||
createOpenAIImageTools: jest.fn(() => []),
|
||||
createYouTubeTools: jest.fn(() => []),
|
||||
manifestToolMap: {},
|
||||
toolkits: [],
|
||||
}));
|
||||
|
||||
describe('Server Configuration', () => {
|
||||
// Increase the default timeout to allow for Mongo cleanup
|
||||
@@ -31,6 +49,22 @@ describe('Server Configuration', () => {
|
||||
});
|
||||
|
||||
beforeAll(async () => {
|
||||
// Create the required directories and files for the test
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const dirs = ['/tmp/dist', '/tmp/fonts', '/tmp/assets'];
|
||||
dirs.forEach((dir) => {
|
||||
if (!fs.existsSync(dir)) {
|
||||
fs.mkdirSync(dir, { recursive: true });
|
||||
}
|
||||
});
|
||||
|
||||
fs.writeFileSync(
|
||||
path.join('/tmp/dist', 'index.html'),
|
||||
'<!DOCTYPE html><html><head><title>LibreChat</title></head><body><div id="root"></div></body></html>',
|
||||
);
|
||||
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
process.env.MONGO_URI = mongoServer.getUri();
|
||||
process.env.PORT = '0'; // Use a random available port
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
@@ -11,6 +11,10 @@ const { abortRun } = require('./abortRun');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {boolean}
|
||||
*/
|
||||
function cleanupAbortController(abortKey) {
|
||||
if (!abortControllers.has(abortKey)) {
|
||||
return false;
|
||||
@@ -71,6 +75,20 @@ function cleanupAbortController(abortKey) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {function(): void}
|
||||
*/
|
||||
function createCleanUpHandler(abortKey) {
|
||||
return function () {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, endpoint } = req.body;
|
||||
|
||||
@@ -172,11 +190,15 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
* @param {boolean} [isNewConvo]
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
const onStart = (userMessage, responseMessageId, isNewConvo) => {
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const prelimAbortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const abortKey = isNewConvo
|
||||
? `${prelimAbortKey}${Constants.COMMON_DIVIDER}${Constants.NEW_CONVO}`
|
||||
: prelimAbortKey;
|
||||
getReqData({ abortKey });
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
const { overrideUserMessageId } = req?.body ?? {};
|
||||
@@ -194,16 +216,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
};
|
||||
|
||||
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
|
||||
|
||||
// Use a simple function for cleanup to avoid capturing context
|
||||
const cleanupHandler = () => {
|
||||
try {
|
||||
cleanupAbortController(addedAbortKey);
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
|
||||
const cleanupHandler = createCleanUpHandler(addedAbortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
return;
|
||||
}
|
||||
@@ -216,16 +229,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
};
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...minimalOptions });
|
||||
|
||||
// Use a simple function for cleanup to avoid capturing context
|
||||
const cleanupHandler = () => {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
|
||||
const cleanupHandler = createCleanUpHandler(abortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
};
|
||||
|
||||
@@ -364,15 +368,7 @@ const handleAbortError = async (res, req, error, data) => {
|
||||
};
|
||||
}
|
||||
|
||||
// Create a simple callback without capturing parent scope
|
||||
const callback = async () => {
|
||||
try {
|
||||
cleanupAbortController(conversationId);
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
|
||||
const callback = createCleanUpHandler(conversationId);
|
||||
await sendError(req, res, options, callback);
|
||||
};
|
||||
|
||||
|
||||
@@ -12,8 +12,9 @@ const { handleAbortError } = require('~/server/middleware/abortMiddleware');
|
||||
const validateAssistant = async (req, res, next) => {
|
||||
const { endpoint, conversationId, assistant_id, messageId } = req.body;
|
||||
|
||||
const appConfig = req.config;
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
const assistantsConfig = appConfig.endpoints?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return next();
|
||||
}
|
||||
|
||||
@@ -20,8 +20,9 @@ const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistant
|
||||
const assistant_id =
|
||||
overrideAssistantId ?? req.params.id ?? req.body.assistant_id ?? req.query.assistant_id;
|
||||
|
||||
const appConfig = req.config;
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
const assistantsConfig = appConfig.endpoints?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -40,9 +40,10 @@ async function buildEndpointOption(req, res, next) {
|
||||
return handleError(res, { text: 'Error parsing conversation' });
|
||||
}
|
||||
|
||||
if (req.app.locals.modelSpecs?.list && req.app.locals.modelSpecs?.enforce) {
|
||||
const appConfig = req.config;
|
||||
if (appConfig.modelSpecs?.list && appConfig.modelSpecs?.enforce) {
|
||||
/** @type {{ list: TModelSpec[] }}*/
|
||||
const { list } = req.app.locals.modelSpecs;
|
||||
const { list } = appConfig.modelSpecs;
|
||||
const { spec } = parsedBody;
|
||||
|
||||
if (!spec) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { logger } = require('~/config');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
/**
|
||||
* Checks the domain's social login is allowed
|
||||
@@ -10,15 +11,25 @@ const { logger } = require('~/config');
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {Function} next - Next middleware function.
|
||||
*
|
||||
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the domain's email is allowed
|
||||
* @returns {Promise<void>} - Calls next middleware if the domain's email is allowed, otherwise redirects to login
|
||||
*/
|
||||
const checkDomainAllowed = async (req, res, next = () => {}) => {
|
||||
const email = req?.user?.email;
|
||||
if (email && !(await isEmailDomainAllowed(email))) {
|
||||
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
|
||||
return res.redirect('/login');
|
||||
} else {
|
||||
return next();
|
||||
const checkDomainAllowed = async (req, res, next) => {
|
||||
try {
|
||||
const email = req?.user?.email;
|
||||
const appConfig = await getAppConfig({
|
||||
role: req?.user?.role,
|
||||
});
|
||||
|
||||
if (email && !isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(`[Social Login] [Social Login not allowed] [Email: ${email}]`);
|
||||
res.redirect('/login');
|
||||
return;
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('[checkDomainAllowed] Error checking domain:', error);
|
||||
res.redirect('/login');
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
27
api/server/middleware/config/app.js
Normal file
27
api/server/middleware/config/app.js
Normal file
@@ -0,0 +1,27 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
|
||||
const configMiddleware = async (req, res, next) => {
|
||||
try {
|
||||
const userRole = req.user?.role;
|
||||
req.config = await getAppConfig({ role: userRole });
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('Config middleware error:', {
|
||||
error: error.message,
|
||||
userRole: req.user?.role,
|
||||
path: req.path,
|
||||
});
|
||||
|
||||
try {
|
||||
req.config = await getAppConfig();
|
||||
next();
|
||||
} catch (fallbackError) {
|
||||
logger.error('Fallback config middleware error:', fallbackError);
|
||||
next(fallbackError);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = configMiddleware;
|
||||
@@ -82,7 +82,7 @@ const sendError = async (req, res, options, callback) => {
|
||||
|
||||
/**
|
||||
* Sends the response based on whether headers have been sent or not.
|
||||
* @param {Express.Request} req - The server response.
|
||||
* @param {ServerRequest} req - The server response.
|
||||
* @param {Express.Response} res - The server response.
|
||||
* @param {Object} data - The data to be sent.
|
||||
* @param {string} [errorMessage] - The error message, if any.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const validatePasswordReset = require('./validatePasswordReset');
|
||||
const validateRegistration = require('./validateRegistration');
|
||||
const validateImageRequest = require('./validateImageRequest');
|
||||
const buildEndpointOption = require('./buildEndpointOption');
|
||||
const validateMessageReq = require('./validateMessageReq');
|
||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||
@@ -9,11 +8,11 @@ const validateEndpoint = require('./validateEndpoint');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const accessResources = require('./accessResources');
|
||||
const setBalanceConfig = require('./setBalanceConfig');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
const checkInviteUser = require('./checkInviteUser');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const configMiddleware = require('./config/app');
|
||||
const validateModel = require('./validateModel');
|
||||
const moderateText = require('./moderateText');
|
||||
const logHeaders = require('./logHeaders');
|
||||
@@ -44,12 +43,11 @@ module.exports = {
|
||||
requireLocalAuth,
|
||||
canDeleteAccount,
|
||||
validateEndpoint,
|
||||
setBalanceConfig,
|
||||
configMiddleware,
|
||||
concurrentLimiter,
|
||||
checkDomainAllowed,
|
||||
validateMessageReq,
|
||||
buildEndpointOption,
|
||||
validateRegistration,
|
||||
validateImageRequest,
|
||||
validatePasswordReset,
|
||||
};
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Middleware to synchronize user balance settings with current balance configuration.
|
||||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
*/
|
||||
const setBalanceConfig = async (req, res, next) => {
|
||||
try {
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
if (!balanceConfig?.enabled) {
|
||||
return next();
|
||||
}
|
||||
if (balanceConfig.startBalance == null) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const userId = req.user._id;
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('Error setting user balance:', error);
|
||||
next(error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Build an object containing fields that need updating
|
||||
* @param {Object} config - The balance configuration
|
||||
* @param {Object|null} userRecord - The user's current balance record, if any
|
||||
* @returns {Object} Fields that need updating
|
||||
*/
|
||||
function buildUpdateFields(config, userRecord) {
|
||||
const updateFields = {};
|
||||
|
||||
// Ensure user record has the required fields
|
||||
if (!userRecord) {
|
||||
updateFields.user = userRecord?.user;
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
if (userRecord?.tokenCredits == null && config.startBalance != null) {
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
const isAutoRefillConfigValid =
|
||||
config.autoRefillEnabled &&
|
||||
config.refillIntervalValue != null &&
|
||||
config.refillIntervalUnit != null &&
|
||||
config.refillAmount != null;
|
||||
|
||||
if (!isAutoRefillConfigValid) {
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
|
||||
updateFields.autoRefillEnabled = config.autoRefillEnabled;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
|
||||
updateFields.refillIntervalValue = config.refillIntervalValue;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
|
||||
updateFields.refillIntervalUnit = config.refillIntervalUnit;
|
||||
}
|
||||
|
||||
if (userRecord?.refillAmount !== config.refillAmount) {
|
||||
updateFields.refillAmount = config.refillAmount;
|
||||
}
|
||||
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
module.exports = setBalanceConfig;
|
||||
@@ -1,13 +1,18 @@
|
||||
const jwt = require('jsonwebtoken');
|
||||
const validateImageRequest = require('~/server/middleware/validateImageRequest');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const createValidateImageRequest = require('~/server/middleware/validateImageRequest');
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
isEnabled: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('validateImageRequest middleware', () => {
|
||||
let req, res, next;
|
||||
let req, res, next, validateImageRequest;
|
||||
const validObjectId = '65cfb246f7ecadb8b1e8036b';
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
req = {
|
||||
app: { locals: { secureImageLinks: true } },
|
||||
headers: {},
|
||||
originalUrl: '',
|
||||
};
|
||||
@@ -17,109 +22,278 @@ describe('validateImageRequest middleware', () => {
|
||||
};
|
||||
next = jest.fn();
|
||||
process.env.JWT_REFRESH_SECRET = 'test-secret';
|
||||
process.env.OPENID_REUSE_TOKENS = 'false';
|
||||
|
||||
// Default: OpenID token reuse disabled
|
||||
isEnabled.mockReturnValue(false);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test('should call next() if secureImageLinks is false', () => {
|
||||
req.app.locals.secureImageLinks = false;
|
||||
validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
describe('Factory function', () => {
|
||||
test('should return a pass-through middleware if secureImageLinks is false', async () => {
|
||||
const middleware = createValidateImageRequest(false);
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 401 if refresh token is not provided', () => {
|
||||
validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is invalid', () => {
|
||||
req.headers.cookie = 'refreshToken=invalid-token';
|
||||
validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is expired', () => {
|
||||
const expiredToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${expiredToken}`;
|
||||
validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should call next() for valid image path', () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid image path', () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
|
||||
validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 for invalid ObjectId format', () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/123/example.jpg'; // Invalid ObjectId
|
||||
validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
// File traversal tests
|
||||
test('should prevent file traversal attempts', () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
|
||||
const traversalAttempts = [
|
||||
`/images/${validObjectId}/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
|
||||
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
|
||||
];
|
||||
|
||||
traversalAttempts.forEach((attempt) => {
|
||||
req.originalUrl = attempt;
|
||||
validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
jest.clearAllMocks();
|
||||
test('should return validation middleware if secureImageLinks is true', async () => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle URL encoded characters in valid paths', () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
|
||||
validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
describe('Standard LibreChat token flow', () => {
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
});
|
||||
|
||||
test('should return 401 if refresh token is not provided', async () => {
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.send).toHaveBeenCalledWith('Unauthorized');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is invalid', async () => {
|
||||
req.headers.cookie = 'refreshToken=invalid-token';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 if refresh token is expired', async () => {
|
||||
const expiredToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${expiredToken}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should call next() for valid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid image path', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/example.jpg'; // Different ObjectId
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should allow agent avatar pattern for any valid ObjectId', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should prevent file traversal attempts', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
|
||||
const traversalAttempts = [
|
||||
`/images/${validObjectId}/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/..%2F..%2F..%2Fetc%2Fpasswd`,
|
||||
`/images/${validObjectId}/image.jpg/../../../etc/passwd`,
|
||||
`/images/${validObjectId}/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd`,
|
||||
];
|
||||
|
||||
for (const attempt of traversalAttempts) {
|
||||
req.originalUrl = attempt;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
jest.clearAllMocks();
|
||||
// Reset mocks for next iteration
|
||||
res.status = jest.fn().mockReturnThis();
|
||||
res.send = jest.fn();
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle URL encoded characters in valid paths', async () => {
|
||||
const validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/image%20with%20spaces.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('OpenID token flow', () => {
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
// Enable OpenID token reuse
|
||||
isEnabled.mockReturnValue(true);
|
||||
process.env.OPENID_REUSE_TOKENS = 'true';
|
||||
});
|
||||
|
||||
test('should return 403 if no OpenID user ID cookie when token_provider is openid', async () => {
|
||||
req.headers.cookie = 'refreshToken=dummy-token; token_provider=openid';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should validate JWT-signed user ID for OpenID flow', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = `/images/${validObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 for invalid JWT-signed user ID', async () => {
|
||||
req.headers.cookie =
|
||||
'refreshToken=dummy-token; token_provider=openid; openid_user_id=invalid-jwt';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should return 403 for expired JWT-signed user ID', async () => {
|
||||
const expiredSignedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) - 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${expiredSignedUserId}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should validate image path against JWT-signed user ID', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
const differentObjectId = '65cfb246f7ecadb8b1e8036c';
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = `/images/${differentObjectId}/example.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should allow agent avatars in OpenID flow', async () => {
|
||||
const signedUserId = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
req.headers.cookie = `refreshToken=dummy-token; token_provider=openid; openid_user_id=${signedUserId}`;
|
||||
req.originalUrl = '/images/65cfb246f7ecadb8b1e8036c/agent-avatar-12345.png';
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Security edge cases', () => {
|
||||
let validToken;
|
||||
|
||||
beforeEach(() => {
|
||||
validateImageRequest = createValidateImageRequest(true);
|
||||
validToken = jwt.sign(
|
||||
{ id: validObjectId, exp: Math.floor(Date.now() / 1000) + 3600 },
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle very long image filenames', async () => {
|
||||
const longFilename = 'a'.repeat(1000) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle URLs with maximum practical length', async () => {
|
||||
// Most browsers support URLs up to ~2000 characters
|
||||
const longFilename = 'x'.repeat(1900) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${longFilename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should accept URLs just under the 2048 limit', async () => {
|
||||
// Create a URL exactly 2047 characters long
|
||||
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
|
||||
const filenameLength = 2047 - baseLength;
|
||||
const filename = 'a'.repeat(filenameLength) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${filename}`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle malformed URL encoding gracefully', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/test%ZZinvalid.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should reject URLs with null bytes', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/test\x00.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should handle URLs with repeated slashes', async () => {
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}//test.jpg`;
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
|
||||
test('should reject extremely long URLs as potential DoS', async () => {
|
||||
// Create a URL longer than 2048 characters
|
||||
const baseLength = `/images/${validObjectId}/`.length + '.jpg'.length;
|
||||
const filenameLength = 2049 - baseLength; // Ensure total length exceeds 2048
|
||||
const extremelyLongFilename = 'x'.repeat(filenameLength) + '.jpg';
|
||||
req.headers.cookie = `refreshToken=${validToken}`;
|
||||
req.originalUrl = `/images/${validObjectId}/${extremelyLongFilename}`;
|
||||
// Verify our test URL is actually too long
|
||||
expect(req.originalUrl.length).toBeGreaterThan(2048);
|
||||
await validateImageRequest(req, res, next);
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.send).toHaveBeenCalledWith('Access Denied');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,7 +15,7 @@ const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {}
|
||||
* If the `cache` store is not available, the middleware will skip its logic.
|
||||
*
|
||||
* @function
|
||||
* @param {Express.Request} req - Express request object containing user information.
|
||||
* @param {ServerRequest} req - Express request object containing user information.
|
||||
* @param {Express.Response} res - Express response object.
|
||||
* @param {function} next - Express next middleware function.
|
||||
* @throws {Error} Throws an error if the user doesn't have access to the conversation.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { logger } = require('~/config');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
|
||||
const OBJECT_ID_LENGTH = 24;
|
||||
const OBJECT_ID_PATTERN = /^[0-9a-f]{24}$/i;
|
||||
@@ -21,49 +22,129 @@ function isValidObjectId(id) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Middleware to validate image request.
|
||||
* Must be set by `secureImageLinks` via custom config file.
|
||||
* Validates a LibreChat refresh token
|
||||
* @param {string} refreshToken - The refresh token to validate
|
||||
* @returns {{valid: boolean, userId?: string, error?: string}} - Validation result
|
||||
*/
|
||||
function validateImageRequest(req, res, next) {
|
||||
if (!req.app.locals.secureImageLinks) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Refresh token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
let payload;
|
||||
function validateToken(refreshToken) {
|
||||
try {
|
||||
payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
|
||||
if (!isValidObjectId(payload.id)) {
|
||||
return { valid: false, error: 'Invalid User ID' };
|
||||
}
|
||||
|
||||
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
|
||||
if (payload.exp < currentTimeInSeconds) {
|
||||
return { valid: false, error: 'Refresh token expired' };
|
||||
}
|
||||
|
||||
return { valid: true, userId: payload.id };
|
||||
} catch (err) {
|
||||
logger.warn('[validateImageRequest]', err);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (!isValidObjectId(payload.id)) {
|
||||
logger.warn('[validateImageRequest] Invalid User ID');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const currentTimeInSeconds = Math.floor(Date.now() / 1000);
|
||||
if (payload.exp < currentTimeInSeconds) {
|
||||
logger.warn('[validateImageRequest] Refresh token expired');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const fullPath = decodeURIComponent(req.originalUrl);
|
||||
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
logger.warn('[validateToken]', err);
|
||||
return { valid: false, error: 'Invalid token' };
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = validateImageRequest;
|
||||
/**
|
||||
* Factory to create the `validateImageRequest` middleware with configured secureImageLinks
|
||||
* @param {boolean} [secureImageLinks] - Whether secure image links are enabled
|
||||
*/
|
||||
function createValidateImageRequest(secureImageLinks) {
|
||||
if (!secureImageLinks) {
|
||||
return (_req, _res, next) => next();
|
||||
}
|
||||
/**
|
||||
* Middleware to validate image request.
|
||||
* Supports both LibreChat refresh tokens and OpenID JWT tokens.
|
||||
* Must be set by `secureImageLinks` via custom config file.
|
||||
*/
|
||||
return async function validateImageRequest(req, res, next) {
|
||||
try {
|
||||
const cookieHeader = req.headers.cookie;
|
||||
if (!cookieHeader) {
|
||||
logger.warn('[validateImageRequest] No cookies provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const parsedCookies = cookies.parse(cookieHeader);
|
||||
const refreshToken = parsedCookies.refreshToken;
|
||||
|
||||
if (!refreshToken) {
|
||||
logger.warn('[validateImageRequest] Token not provided');
|
||||
return res.status(401).send('Unauthorized');
|
||||
}
|
||||
|
||||
const tokenProvider = parsedCookies.token_provider;
|
||||
let userIdForPath;
|
||||
|
||||
if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
const openidUserId = parsedCookies.openid_user_id;
|
||||
if (!openidUserId) {
|
||||
logger.warn('[validateImageRequest] No OpenID user ID cookie found');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const validationResult = validateToken(openidUserId);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
} else {
|
||||
const validationResult = validateToken(refreshToken);
|
||||
if (!validationResult.valid) {
|
||||
logger.warn(`[validateImageRequest] ${validationResult.error}`);
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
userIdForPath = validationResult.userId;
|
||||
}
|
||||
|
||||
if (!userIdForPath) {
|
||||
logger.warn('[validateImageRequest] No user ID available for path validation');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const MAX_URL_LENGTH = 2048;
|
||||
if (req.originalUrl.length > MAX_URL_LENGTH) {
|
||||
logger.warn('[validateImageRequest] URL too long');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (req.originalUrl.includes('\x00')) {
|
||||
logger.warn('[validateImageRequest] URL contains null byte');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
let fullPath;
|
||||
try {
|
||||
fullPath = decodeURIComponent(req.originalUrl);
|
||||
} catch {
|
||||
logger.warn('[validateImageRequest] Invalid URL encoding');
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
const agentAvatarPattern = /^\/images\/[a-f0-9]{24}\/agent-[^/]*$/;
|
||||
if (agentAvatarPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
return next();
|
||||
}
|
||||
|
||||
const escapedUserId = userIdForPath.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const pathPattern = new RegExp(`^/images/${escapedUserId}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[validateImageRequest] Error:', error);
|
||||
res.status(500).send('Internal Server Error');
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = createValidateImageRequest;
|
||||
|
||||
@@ -6,7 +6,7 @@ const { logViolation } = require('~/cache');
|
||||
* Validates the model of the request.
|
||||
*
|
||||
* @async
|
||||
* @param {Express.Request} req - The Express request object.
|
||||
* @param {ServerRequest} req - The Express request object.
|
||||
* @param {Express.Response} res - The Express response object.
|
||||
* @param {Function} next - The Express next function.
|
||||
*/
|
||||
|
||||
680
api/server/routes/__tests__/costs.spec.js
Normal file
680
api/server/routes/__tests__/costs.spec.js
Normal file
@@ -0,0 +1,680 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
createMethods: jest.fn(() => ({})),
|
||||
createModels: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (req, res, next) => next(),
|
||||
validateMessageReq: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getConvo: jest.fn(),
|
||||
saveConvo: jest.fn(),
|
||||
saveMessage: jest.fn(),
|
||||
getMessage: jest.fn(),
|
||||
getMessages: jest.fn(),
|
||||
updateMessage: jest.fn(),
|
||||
deleteMessages: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/db/models', () => {
|
||||
let User, Message, Transaction, Conversation;
|
||||
|
||||
return {
|
||||
get User() {
|
||||
return User;
|
||||
},
|
||||
get Message() {
|
||||
return Message;
|
||||
},
|
||||
get Transaction() {
|
||||
return Transaction;
|
||||
},
|
||||
get Conversation() {
|
||||
return Conversation;
|
||||
},
|
||||
setUser: (model) => {
|
||||
User = model;
|
||||
},
|
||||
setMessage: (model) => {
|
||||
Message = model;
|
||||
},
|
||||
setTransaction: (model) => {
|
||||
Transaction = model;
|
||||
},
|
||||
setConversation: (model) => {
|
||||
Conversation = model;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
describe('Costs Endpoint', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let messagesRouter;
|
||||
let User, Message, Transaction, Conversation;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
const userSchema = new mongoose.Schema({
|
||||
_id: String,
|
||||
name: String,
|
||||
email: String,
|
||||
});
|
||||
|
||||
const conversationSchema = new mongoose.Schema({
|
||||
conversationId: String,
|
||||
user: String,
|
||||
title: String,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
const messageSchema = new mongoose.Schema({
|
||||
messageId: String,
|
||||
conversationId: String,
|
||||
user: String,
|
||||
isCreatedByUser: Boolean,
|
||||
tokenCount: Number,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
const transactionSchema = new mongoose.Schema({
|
||||
conversationId: String,
|
||||
user: String,
|
||||
tokenType: String,
|
||||
tokenValue: Number,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
User = mongoose.model('User', userSchema);
|
||||
Conversation = mongoose.model('Conversation', conversationSchema);
|
||||
Message = mongoose.model('Message', messageSchema);
|
||||
Transaction = mongoose.model('Transaction', transactionSchema);
|
||||
|
||||
const dbModels = require('~/db/models');
|
||||
dbModels.setUser(User);
|
||||
dbModels.setMessage(Message);
|
||||
dbModels.setTransaction(Transaction);
|
||||
dbModels.setConversation(Conversation);
|
||||
|
||||
require('~/db/models');
|
||||
|
||||
try {
|
||||
messagesRouter = require('../messages');
|
||||
} catch (error) {
|
||||
console.error('Error loading messages router:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: 'test-user-id' };
|
||||
next();
|
||||
});
|
||||
app.use('/api/messages', messagesRouter);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await User.deleteMany({});
|
||||
await Conversation.deleteMany({});
|
||||
await Message.deleteMany({});
|
||||
await Transaction.deleteMany({});
|
||||
});
|
||||
|
||||
describe('GET /:conversationId/costs', () => {
|
||||
const conversationId = 'test-conversation-123';
|
||||
const userId = 'test-user-id';
|
||||
|
||||
it('should return cost data for valid conversation', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const aiMessage = new Message({
|
||||
messageId: 'ai-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: 150,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), aiMessage.save()]);
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const completionTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'completion',
|
||||
tokenValue: 750000,
|
||||
createdAt: new Date('2024-01-01T10:01:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction.save(), completionTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toMatchObject({
|
||||
conversationId,
|
||||
totals: {
|
||||
prompt: { usd: 0.5, tokenCount: 100 },
|
||||
completion: { usd: 0.75, tokenCount: 150 },
|
||||
total: { usd: 1.25, tokenCount: 250 },
|
||||
},
|
||||
perMessage: [
|
||||
{ messageId: 'user-msg-1', tokenType: 'prompt', tokenCount: 100, usd: 0.5 },
|
||||
{ messageId: 'ai-msg-1', tokenType: 'completion', tokenCount: 150, usd: 0.75 },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should return empty data for conversation with no messages', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toMatchObject({
|
||||
conversationId,
|
||||
totals: {
|
||||
prompt: { usd: 0, tokenCount: 0 },
|
||||
completion: { usd: 0, tokenCount: 0 },
|
||||
total: { usd: 0, tokenCount: 0 },
|
||||
},
|
||||
perMessage: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle messages without transactions', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const aiMessage = new Message({
|
||||
messageId: 'ai-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: 150,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), aiMessage.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0);
|
||||
expect(response.body.totals.completion.usd).toBe(0);
|
||||
expect(response.body.totals.total.usd).toBe(0);
|
||||
});
|
||||
|
||||
it('should aggregate multiple transactions correctly', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction1 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 300000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const promptTransaction2 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 200000,
|
||||
createdAt: new Date('2024-01-01T10:00:45Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.perMessage[0].usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should handle null tokenCount values', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: null,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.tokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle null tokenValue in transactions', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: null,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await promptTransaction.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle negative tokenValue using Math.abs', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: -500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await promptTransaction.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should filter by user correctly', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const otherUserId = 'other-user-id';
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const otherUserMessage = new Message({
|
||||
messageId: 'other-user-msg-1',
|
||||
conversationId,
|
||||
user: otherUserId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 200,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), otherUserMessage.save()]);
|
||||
|
||||
const userTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const otherUserTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: otherUserId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userTransaction.save(), otherUserTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.perMessage).toHaveLength(1);
|
||||
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
|
||||
});
|
||||
|
||||
it('should filter transactions by tokenType', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const otherTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'other',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction.save(), otherTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.totals.completion.usd).toBe(0);
|
||||
expect(response.body.totals.total.usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should map transactions to messages chronologically', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage1 = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const userMessage2 = new Message({
|
||||
messageId: 'user-msg-2',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 200,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage1.save(), userMessage2.save()]);
|
||||
|
||||
const promptTransaction1 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const promptTransaction2 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:01:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.perMessage).toHaveLength(2);
|
||||
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
|
||||
expect(response.body.perMessage[0].usd).toBe(0.5);
|
||||
expect(response.body.perMessage[1].messageId).toBe('user-msg-2');
|
||||
expect(response.body.perMessage[1].usd).toBe(1.0);
|
||||
});
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
await mongoose.connection.close();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.body).toHaveProperty('error');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,12 +4,14 @@ const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
@@ -36,6 +38,7 @@ jest.mock('~/models', () => ({
|
||||
updateToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
deleteTokens: jest.fn(),
|
||||
findPluginAuthsByKeys: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
@@ -44,6 +47,10 @@ jest.mock('~/server/services/Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/mcpToolsCache', () => ({
|
||||
updateMCPUserTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
getMCPSetupData: jest.fn(),
|
||||
getServerConnectionStatus: jest.fn(),
|
||||
@@ -66,6 +73,10 @@ jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/mcp', () => ({
|
||||
reinitMCPServer: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('MCP Routes', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
@@ -494,12 +505,9 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when token retrieval throws an unexpected error', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockRejectedValue(new Error('Database connection failed')),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
getLogStores.mockImplementation(() => {
|
||||
throw new Error('Database connection failed');
|
||||
});
|
||||
|
||||
const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow');
|
||||
|
||||
@@ -563,8 +571,8 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('POST /oauth/cancel/:serverName', () => {
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
it('should cancel OAuth flow successfully', async () => {
|
||||
const mockFlowManager = {
|
||||
@@ -644,15 +652,15 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('POST /:serverName/reinitialize', () => {
|
||||
const { loadCustomConfig } = require('~/server/services/Config');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
|
||||
it('should return 404 when server is not found in configuration', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'other-server': {},
|
||||
},
|
||||
});
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue(null),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
|
||||
const response = await request(app).post('/api/mcp/non-existent-server/reinitialize');
|
||||
|
||||
@@ -663,16 +671,11 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle OAuth requirement during reinitialize', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'oauth-server': {
|
||||
customUserVars: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {},
|
||||
}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
|
||||
if (oauthStart) {
|
||||
@@ -685,12 +688,19 @@ describe('MCP Routes', () => {
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
||||
serverName: 'oauth-server',
|
||||
oauthRequired: true,
|
||||
oauthUrl: 'https://oauth.example.com/auth',
|
||||
});
|
||||
|
||||
const response = await request(app).post('/api/mcp/oauth-server/reinitialize');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual({
|
||||
success: 'https://oauth.example.com/auth',
|
||||
success: true,
|
||||
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
||||
serverName: 'oauth-server',
|
||||
oauthRequired: true,
|
||||
@@ -699,14 +709,9 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'error-server': {},
|
||||
},
|
||||
});
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
|
||||
};
|
||||
@@ -714,6 +719,7 @@ describe('MCP Routes', () => {
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue(null);
|
||||
|
||||
const response = await request(app).post('/api/mcp/error-server/reinitialize');
|
||||
|
||||
@@ -724,7 +730,13 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when unexpected error occurs', async () => {
|
||||
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
|
||||
@@ -747,29 +759,17 @@ describe('MCP Routes', () => {
|
||||
expect(response.body).toEqual({ error: 'User not authenticated' });
|
||||
});
|
||||
|
||||
it('should handle errors when fetching custom user variables', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: 'test-key-var',
|
||||
SECRET_TOKEN: 'test-secret-var',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
getUserPluginAuthValue
|
||||
.mockResolvedValueOnce('test-api-key-value')
|
||||
.mockRejectedValueOnce(new Error('Database error'));
|
||||
|
||||
it('should successfully reinitialize server and cache tools', async () => {
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
fetchTools: jest.fn().mockResolvedValue([
|
||||
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
|
||||
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
|
||||
]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
};
|
||||
|
||||
@@ -778,44 +778,86 @@ describe('MCP Routes', () => {
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
updateMCPUserTools.mockResolvedValue();
|
||||
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
serverName: 'test-server',
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.success).toBe(true);
|
||||
expect(response.body).toEqual({
|
||||
success: true,
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
serverName: 'test-server',
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
expect(mockMcpManager.disconnectUserConnection).toHaveBeenCalledWith(
|
||||
'test-user-id',
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return failure message when reinitialize completely fails', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {},
|
||||
},
|
||||
});
|
||||
it('should handle server with custom user variables', async () => {
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockResolvedValue(null),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
endpoint: 'http://test-server.com',
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
require('@librechat/api').getUserMCPAuthMap.mockResolvedValue({
|
||||
'mcp:test-server': {
|
||||
API_KEY: 'api-key-value',
|
||||
},
|
||||
});
|
||||
require('~/models').findPluginAuthsByKeys.mockResolvedValue([
|
||||
{ key: 'API_KEY', value: 'api-key-value' },
|
||||
]);
|
||||
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
getCachedTools.mockResolvedValue({
|
||||
[`existing-tool${Constants.mcp_delimiter}test-server`]: { type: 'function' },
|
||||
});
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
updateMCPUserTools.mockResolvedValue();
|
||||
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
serverName: 'test-server',
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.success).toBe(false);
|
||||
expect(response.body.message).toBe("Failed to reinitialize MCP server 'test-server'");
|
||||
expect(response.body.success).toBe(true);
|
||||
expect(require('@librechat/api').getUserMCPAuthMap).toHaveBeenCalledWith({
|
||||
userId: 'test-user-id',
|
||||
servers: ['test-server'],
|
||||
findPluginAuthsByKeys: require('~/models').findPluginAuthsByKeys,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -984,21 +1026,19 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('GET /:serverName/auth-values', () => {
|
||||
const { loadCustomConfig } = require('~/server/services/Config');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
|
||||
it('should return auth value flags for server', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
SECRET_TOKEN: 'another-env-var',
|
||||
},
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
SECRET_TOKEN: 'another-env-var',
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
@@ -1017,11 +1057,11 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 404 when server is not found in configuration', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'other-server': {},
|
||||
},
|
||||
});
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue(null),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
|
||||
|
||||
@@ -1032,16 +1072,15 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle errors when checking auth values', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
@@ -1057,7 +1096,13 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when auth values check throws unexpected error', async () => {
|
||||
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
|
||||
@@ -1066,14 +1111,13 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle customUserVars that is not an object', async () => {
|
||||
const { loadCustomConfig } = require('~/server/services/Config');
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: 'not-an-object',
|
||||
},
|
||||
},
|
||||
});
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: 'not-an-object',
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
|
||||
@@ -1097,98 +1141,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /:serverName/reinitialize - Tool Deletion Coverage', () => {
|
||||
it('should handle null cached tools during reinitialize (triggers || {} fallback)', async () => {
|
||||
const { loadCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
disconnectServer: jest.fn(),
|
||||
initializeServer: jest.fn(),
|
||||
mcpConfigs: {},
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': { env: { API_KEY: 'test-key' } },
|
||||
},
|
||||
});
|
||||
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
|
||||
|
||||
expect(response.body).toEqual({
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
success: true,
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
serverName: 'test-server',
|
||||
});
|
||||
});
|
||||
|
||||
it('should delete existing cached tools during successful reinitialize', async () => {
|
||||
const {
|
||||
loadCustomConfig,
|
||||
getCachedTools,
|
||||
setCachedTools,
|
||||
} = require('~/server/services/Config');
|
||||
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
disconnectServer: jest.fn(),
|
||||
initializeServer: jest.fn(),
|
||||
mcpConfigs: {},
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': { env: { API_KEY: 'test-key' } },
|
||||
},
|
||||
});
|
||||
|
||||
const existingTools = {
|
||||
'old-tool_mcp_test-server': { type: 'function' },
|
||||
'other-tool_mcp_other-server': { type: 'function' },
|
||||
};
|
||||
getCachedTools.mockResolvedValue(existingTools);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
|
||||
|
||||
expect(response.body).toEqual({
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
success: true,
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
serverName: 'test-server',
|
||||
});
|
||||
|
||||
expect(setCachedTools).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
'new-tool_mcp_test-server': expect.any(Object),
|
||||
'other-tool_mcp_other-server': { type: 'function' },
|
||||
}),
|
||||
{ userId: 'test-user-id' },
|
||||
);
|
||||
expect(setCachedTools).toHaveBeenCalledWith(
|
||||
expect.not.objectContaining({
|
||||
'old-tool_mcp_test-server': expect.anything(),
|
||||
}),
|
||||
{ userId: 'test-user-id' },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
||||
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
|
||||
@@ -83,7 +83,11 @@ router.post(
|
||||
}
|
||||
|
||||
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
|
||||
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
|
||||
const appConfig = req.config;
|
||||
const isDomainAllowed = await isActionDomainAllowed(
|
||||
metadata.domain,
|
||||
appConfig?.actions?.allowedDomains,
|
||||
);
|
||||
if (!isDomainAllowed) {
|
||||
return res.status(400).json({ message: 'Domain not allowed' });
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ const {
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
messageIpLimiter,
|
||||
configMiddleware,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
@@ -22,6 +23,8 @@ router.use(uaParser);
|
||||
router.use('/', v1);
|
||||
|
||||
const chatRouter = express.Router();
|
||||
chatRouter.use(configMiddleware);
|
||||
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
chatRouter.use(concurrentLimiter);
|
||||
}
|
||||
@@ -37,6 +40,4 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
chatRouter.use('/', chat);
|
||||
router.use('/chat', chatRouter);
|
||||
|
||||
// Add marketplace routes
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const express = require('express');
|
||||
const { callTool, verifyToolAuth, getToolCalls } = require('~/server/controllers/tools');
|
||||
const { getAvailableTools } = require('~/server/controllers/PluginController');
|
||||
const { toolCallLimiter } = require('~/server/middleware/limiters');
|
||||
const { toolCallLimiter } = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const express = require('express');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
|
||||
const { requireJwtAuth, canAccessAgentResource } = require('~/server/middleware');
|
||||
const { requireJwtAuth, configMiddleware, canAccessAgentResource } = require('~/server/middleware');
|
||||
const v1 = require('~/server/controllers/agents/v1');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const actions = require('./actions');
|
||||
@@ -36,17 +36,17 @@ router.use(requireJwtAuth);
|
||||
* Agent actions route.
|
||||
* @route GET|POST /agents/actions
|
||||
*/
|
||||
router.use('/actions', actions);
|
||||
router.use('/actions', configMiddleware, actions);
|
||||
|
||||
/**
|
||||
* Get a list of available tools for agents.
|
||||
* @route GET /agents/tools
|
||||
*/
|
||||
router.use('/tools', tools);
|
||||
router.use('/tools', configMiddleware, tools);
|
||||
|
||||
/**
|
||||
* Get all agent categories with counts
|
||||
* @route GET /agents/marketplace/categories
|
||||
* @route GET /agents/categories
|
||||
*/
|
||||
router.get('/categories', v1.getAgentCategories);
|
||||
/**
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const express = require('express');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { actionDelimiter, EModelEndpoint, removeNullishValues } = require('librechat-data-provider');
|
||||
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
@@ -21,6 +21,7 @@ const router = express.Router();
|
||||
*/
|
||||
router.post('/:assistant_id', async (req, res) => {
|
||||
try {
|
||||
const appConfig = req.config;
|
||||
const { assistant_id } = req.params;
|
||||
|
||||
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
|
||||
@@ -30,7 +31,10 @@ router.post('/:assistant_id', async (req, res) => {
|
||||
}
|
||||
|
||||
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
|
||||
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
|
||||
const isDomainAllowed = await isActionDomainAllowed(
|
||||
metadata.domain,
|
||||
appConfig?.actions?.allowedDomains,
|
||||
);
|
||||
if (!isDomainAllowed) {
|
||||
return res.status(400).json({ message: 'Domain not allowed' });
|
||||
}
|
||||
@@ -125,7 +129,7 @@ router.post('/:assistant_id', async (req, res) => {
|
||||
}
|
||||
|
||||
/* Map Azure OpenAI model to the assistant as defined by config */
|
||||
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
if (appConfig.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
updatedAssistant = {
|
||||
...updatedAssistant,
|
||||
model: req.body.model,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user