Compare commits
136 Commits
feat/mcp-p
...
fix/mcp-va
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbf3ce7559 | ||
|
|
666e5bc7d4 | ||
|
|
f68c61d148 | ||
|
|
37701d3e9b | ||
|
|
084de9a912 | ||
|
|
49f87016a8 | ||
|
|
ef9d9b1276 | ||
|
|
a4ca4b7d9d | ||
|
|
8e6eef04ab | ||
|
|
ec3cbca6e3 | ||
|
|
4639dc3255 | ||
|
|
0ef3fefaec | ||
|
|
37aba18a96 | ||
|
|
2ce6ac74f4 | ||
|
|
9fddb0ff6a | ||
|
|
32f7dbd11f | ||
|
|
79197454f8 | ||
|
|
97e1cdd224 | ||
|
|
d6a65f5a08 | ||
|
|
f4facb7d35 | ||
|
|
545a909953 | ||
|
|
cd436dc6a8 | ||
|
|
e75beb92b3 | ||
|
|
5251246313 | ||
|
|
26f23c6aaf | ||
|
|
1636af1f27 | ||
|
|
b050a0bf1e | ||
|
|
deb928bf80 | ||
|
|
21005b66cc | ||
|
|
3dc9e85fab | ||
|
|
ec67cf2d3a | ||
|
|
1fe977e48f | ||
|
|
01470ef9fd | ||
|
|
bef5c26bed | ||
|
|
9e03fef9db | ||
|
|
283c9cff6f | ||
|
|
0aafdc0a86 | ||
|
|
365e3bca95 | ||
|
|
a01536ddb7 | ||
|
|
8a3ff62ee6 | ||
|
|
74d8a3824c | ||
|
|
62c3f135e7 | ||
|
|
baf3b4ad08 | ||
|
|
e5d08ccdf1 | ||
|
|
5178507b1c | ||
|
|
f797e90d79 | ||
|
|
259224d986 | ||
|
|
13789ab261 | ||
|
|
faaba30af1 | ||
|
|
14660d75ae | ||
|
|
aec1777a90 | ||
|
|
90c43dd451 | ||
|
|
4c754c1190 | ||
|
|
f70e0cf849 | ||
|
|
d0c958ba33 | ||
|
|
0761e65086 | ||
|
|
0bf708915b | ||
|
|
cf59f1ab45 | ||
|
|
445e9eae85 | ||
|
|
cd9c578907 | ||
|
|
ac94c73f23 | ||
|
|
dfef7c31d2 | ||
|
|
0b1b0af741 | ||
|
|
0a169a1ff6 | ||
|
|
4b12ea327a | ||
|
|
35d8ef50f4 | ||
|
|
1dabe96404 | ||
|
|
7f8c327509 | ||
|
|
52bbac3a37 | ||
|
|
62b4f3b795 | ||
|
|
01b012a8fa | ||
|
|
418b5e9070 | ||
|
|
a9f01bb86f | ||
|
|
aeeb860fe0 | ||
|
|
e11e716807 | ||
|
|
e370a87ebe | ||
|
|
170cc340d8 | ||
|
|
f1b29ffb45 | ||
|
|
6aa4bb5a4a | ||
|
|
9f44187351 | ||
|
|
d2e1ca4c4a | ||
|
|
8e869f2274 | ||
|
|
2e1874e596 | ||
|
|
929b433662 | ||
|
|
1e4f1f780c | ||
|
|
4733f10e41 | ||
|
|
110984b48f | ||
|
|
19320f2296 | ||
|
|
8523074e87 | ||
|
|
e4531d682d | ||
|
|
4bbdc4c402 | ||
|
|
8ca4cf3d2f | ||
|
|
13a9bcdd48 | ||
|
|
4b32ec42c6 | ||
|
|
4918899c8d | ||
|
|
7e37211458 | ||
|
|
e57fc83d40 | ||
|
|
550610dba9 | ||
|
|
916cd46221 | ||
|
|
12b08183ff | ||
|
|
f4d97e1672 | ||
|
|
035fa081c1 | ||
|
|
aecf8f19a6 | ||
|
|
35f548a94d | ||
|
|
e60c0cf201 | ||
|
|
5b392f9cb0 | ||
|
|
e0f468da20 | ||
|
|
91a2df4759 | ||
|
|
97a99985fa | ||
|
|
3554625a06 | ||
|
|
a37bf6719c | ||
|
|
e513f50c08 | ||
|
|
f5511e4a4e | ||
|
|
a288ad1d9c | ||
|
|
458580ec87 | ||
|
|
4285d5841c | ||
|
|
5ee55cda4f | ||
|
|
404d40cbef | ||
|
|
f4680b016c | ||
|
|
077224b351 | ||
|
|
9c70d1db96 | ||
|
|
543281da6c | ||
|
|
24800bfbeb | ||
|
|
07e08143e4 | ||
|
|
8ba61a86f4 | ||
|
|
56ad92fb1c | ||
|
|
1ceb52d2b5 | ||
|
|
5d267aa8e2 | ||
|
|
59d00e99f3 | ||
|
|
738d04fac4 | ||
|
|
8a5dbac0f9 | ||
|
|
434289fe92 | ||
|
|
a648ad3d13 | ||
|
|
55d63caaf4 | ||
|
|
313539d1ed | ||
|
|
f869d772f7 |
44
.env.example
44
.env.example
@@ -349,6 +349,11 @@ REGISTRATION_VIOLATION_SCORE=1
|
||||
CONCURRENT_VIOLATION_SCORE=1
|
||||
MESSAGE_VIOLATION_SCORE=1
|
||||
NON_BROWSER_VIOLATION_SCORE=20
|
||||
TTS_VIOLATION_SCORE=0
|
||||
STT_VIOLATION_SCORE=0
|
||||
FORK_VIOLATION_SCORE=0
|
||||
IMPORT_VIOLATION_SCORE=0
|
||||
FILE_UPLOAD_VIOLATION_SCORE=0
|
||||
|
||||
LOGIN_MAX=7
|
||||
LOGIN_WINDOW=5
|
||||
@@ -575,6 +580,10 @@ ALLOW_SHARED_LINKS_PUBLIC=true
|
||||
# If you have another service in front of your LibreChat doing compression, disable express based compression here
|
||||
# DISABLE_COMPRESSION=true
|
||||
|
||||
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
|
||||
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
|
||||
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
|
||||
|
||||
#===================================================#
|
||||
# UI #
|
||||
#===================================================#
|
||||
@@ -592,11 +601,40 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
# REDIS Options #
|
||||
#===============#
|
||||
|
||||
# REDIS_URI=10.10.10.10:6379
|
||||
# Enable Redis for caching and session storage
|
||||
# USE_REDIS=true
|
||||
|
||||
# USE_REDIS_CLUSTER=true
|
||||
# REDIS_CA=/path/to/ca.crt
|
||||
# Single Redis instance
|
||||
# REDIS_URI=redis://127.0.0.1:6379
|
||||
|
||||
# Redis cluster (multiple nodes)
|
||||
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
|
||||
|
||||
# Redis with TLS/SSL encryption and CA certificate
|
||||
# REDIS_URI=rediss://127.0.0.1:6380
|
||||
# REDIS_CA=/path/to/ca-cert.pem
|
||||
|
||||
# Redis authentication (if required)
|
||||
# REDIS_USERNAME=your_redis_username
|
||||
# REDIS_PASSWORD=your_redis_password
|
||||
|
||||
# Redis key prefix configuration
|
||||
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
|
||||
# REDIS_KEY_PREFIX_VAR=K_REVISION
|
||||
# Or use static prefix directly
|
||||
# REDIS_KEY_PREFIX=librechat
|
||||
|
||||
# Redis connection limits
|
||||
# REDIS_MAX_LISTENERS=40
|
||||
|
||||
# Redis ping interval in seconds (0 = disabled, >0 = enabled)
|
||||
# When set to a positive integer, Redis clients will ping the server at this interval to keep connections alive
|
||||
# When unset or 0, no pinging is performed (recommended for most use cases)
|
||||
# REDIS_PING_INTERVAL=300
|
||||
|
||||
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
|
||||
# Comma-separated list of CacheKeys (e.g., STATIC_CONFIG,ROLES,MESSAGES)
|
||||
# FORCED_IN_MEMORY_CACHE_NAMESPACES=STATIC_CONFIG,ROLES
|
||||
|
||||
#==================================================#
|
||||
# Others #
|
||||
|
||||
2
.github/workflows/backend-review.yml
vendored
2
.github/workflows/backend-review.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- release/*
|
||||
paths:
|
||||
- 'api/**'
|
||||
- 'packages/api/**'
|
||||
- 'packages/**'
|
||||
jobs:
|
||||
tests_Backend:
|
||||
name: Run Backend unit tests
|
||||
|
||||
58
.github/workflows/client.yml
vendored
Normal file
58
.github/workflows/client.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Publish `@librechat/client` to NPM
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'packages/client/package.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
required: false
|
||||
default: 'Manual publish requested'
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20.x'
|
||||
|
||||
- name: Install client dependencies
|
||||
run: cd packages/client && npm ci
|
||||
|
||||
- name: Build client
|
||||
run: cd packages/client && npm run build
|
||||
|
||||
- name: Set up npm authentication
|
||||
run: echo "//registry.npmjs.org/:_authToken=${{ secrets.PUBLISH_NPM_TOKEN }}" > ~/.npmrc
|
||||
|
||||
- name: Check version change
|
||||
id: check
|
||||
working-directory: packages/client
|
||||
run: |
|
||||
PACKAGE_VERSION=$(node -p "require('./package.json').version")
|
||||
PUBLISHED_VERSION=$(npm view @librechat/client version 2>/dev/null || echo "0.0.0")
|
||||
if [ "$PACKAGE_VERSION" = "$PUBLISHED_VERSION" ]; then
|
||||
echo "No version change, skipping publish"
|
||||
echo "skip=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "Version changed, proceeding with publish"
|
||||
echo "skip=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Pack package
|
||||
if: steps.check.outputs.skip != 'true'
|
||||
working-directory: packages/client
|
||||
run: npm pack
|
||||
|
||||
- name: Publish
|
||||
if: steps.check.outputs.skip != 'true'
|
||||
working-directory: packages/client
|
||||
run: npm publish *.tgz --access public
|
||||
2
.github/workflows/data-schemas.yml
vendored
2
.github/workflows/data-schemas.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18.x'
|
||||
node-version: '20.x'
|
||||
|
||||
- name: Install dependencies
|
||||
run: cd packages/data-schemas && npm ci
|
||||
|
||||
2
.github/workflows/frontend-review.yml
vendored
2
.github/workflows/frontend-review.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
- release/*
|
||||
paths:
|
||||
- 'client/**'
|
||||
- 'packages/**'
|
||||
- 'packages/data-provider/**'
|
||||
|
||||
jobs:
|
||||
tests_frontend_ubuntu:
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
name: Generate Release Changelog PR
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-release-changelog-pr:
|
||||
permissions:
|
||||
contents: write # Needed for pushing commits and creating branches.
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# 1. Checkout the repository (with full history).
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# 2. Generate the release changelog using our custom configuration.
|
||||
- name: Generate Release Changelog
|
||||
id: generate_release
|
||||
uses: mikepenz/release-changelog-builder-action@v5.1.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
configuration: ".github/configuration-release.json"
|
||||
owner: ${{ github.repository_owner }}
|
||||
repo: ${{ github.event.repository.name }}
|
||||
outputFile: CHANGELOG-release.md
|
||||
|
||||
# 3. Update the main CHANGELOG.md:
|
||||
# - If it doesn't exist, create it with a basic header.
|
||||
# - Remove the "Unreleased" section (if present).
|
||||
# - Prepend the new release changelog above previous releases.
|
||||
# - Remove all temporary files before committing.
|
||||
- name: Update CHANGELOG.md
|
||||
run: |
|
||||
# Determine the release tag, e.g. "v1.2.3"
|
||||
TAG=${GITHUB_REF##*/}
|
||||
echo "Using release tag: $TAG"
|
||||
|
||||
# Ensure CHANGELOG.md exists; if not, create a basic header.
|
||||
if [ ! -f CHANGELOG.md ]; then
|
||||
echo "# Changelog" > CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
fi
|
||||
|
||||
echo "Updating CHANGELOG.md…"
|
||||
|
||||
# Remove the "Unreleased" section (from "## [Unreleased]" until the first occurrence of '---') if it exists.
|
||||
if grep -q "^## \[Unreleased\]" CHANGELOG.md; then
|
||||
awk '/^## \[Unreleased\]/{flag=1} flag && /^---/{flag=0; next} !flag' CHANGELOG.md > CHANGELOG.cleaned
|
||||
else
|
||||
cp CHANGELOG.md CHANGELOG.cleaned
|
||||
fi
|
||||
|
||||
# Split the cleaned file into:
|
||||
# - header.md: content before the first release header ("## [v...").
|
||||
# - tail.md: content from the first release header onward.
|
||||
awk '/^## \[v/{exit} {print}' CHANGELOG.cleaned > header.md
|
||||
awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.cleaned > tail.md
|
||||
|
||||
# Combine header, the new release changelog, and the tail.
|
||||
echo "Combining updated changelog parts..."
|
||||
cat header.md CHANGELOG-release.md > CHANGELOG.md.new
|
||||
echo "" >> CHANGELOG.md.new
|
||||
cat tail.md >> CHANGELOG.md.new
|
||||
|
||||
mv CHANGELOG.md.new CHANGELOG.md
|
||||
|
||||
# Remove temporary files.
|
||||
rm -f CHANGELOG.cleaned header.md tail.md CHANGELOG-release.md
|
||||
|
||||
echo "Final CHANGELOG.md content:"
|
||||
cat CHANGELOG.md
|
||||
|
||||
# 4. Create (or update) the Pull Request with the updated CHANGELOG.md.
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
sign-commits: true
|
||||
commit-message: "chore: update CHANGELOG for release ${{ github.ref_name }}"
|
||||
base: main
|
||||
branch: "changelog/${{ github.ref_name }}"
|
||||
reviewers: danny-avila
|
||||
title: "📜 docs: Changelog for release ${{ github.ref_name }}"
|
||||
body: |
|
||||
**Description**:
|
||||
- This PR updates the CHANGELOG.md by removing the "Unreleased" section and adding new release notes for release ${{ github.ref_name }} above previous releases.
|
||||
@@ -1,107 +0,0 @@
|
||||
name: Generate Unreleased Changelog PR
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * 1" # Runs every Monday at 00:00 UTC
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-unreleased-changelog-pr:
|
||||
permissions:
|
||||
contents: write # Needed for pushing commits and creating branches.
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# 1. Checkout the repository on main.
|
||||
- name: Checkout Repository on Main
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
fetch-depth: 0
|
||||
|
||||
# 4. Get the latest version tag.
|
||||
- name: Get Latest Tag
|
||||
id: get_latest_tag
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1) || echo "none")
|
||||
echo "Latest tag: $LATEST_TAG"
|
||||
echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
# 5. Generate the Unreleased changelog.
|
||||
- name: Generate Unreleased Changelog
|
||||
id: generate_unreleased
|
||||
uses: mikepenz/release-changelog-builder-action@v5.1.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
configuration: ".github/configuration-unreleased.json"
|
||||
owner: ${{ github.repository_owner }}
|
||||
repo: ${{ github.event.repository.name }}
|
||||
outputFile: CHANGELOG-unreleased.md
|
||||
fromTag: ${{ steps.get_latest_tag.outputs.tag }}
|
||||
toTag: main
|
||||
|
||||
# 7. Update CHANGELOG.md with the new Unreleased section.
|
||||
- name: Update CHANGELOG.md
|
||||
id: update_changelog
|
||||
run: |
|
||||
# Create CHANGELOG.md if it doesn't exist.
|
||||
if [ ! -f CHANGELOG.md ]; then
|
||||
echo "# Changelog" > CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
fi
|
||||
|
||||
echo "Updating CHANGELOG.md…"
|
||||
|
||||
# Extract content before the "## [Unreleased]" (or first version header if missing).
|
||||
if grep -q "^## \[Unreleased\]" CHANGELOG.md; then
|
||||
awk '/^## \[Unreleased\]/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md
|
||||
else
|
||||
awk '/^## \[v/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md
|
||||
fi
|
||||
|
||||
# Append the generated Unreleased changelog.
|
||||
echo "" >> CHANGELOG_TMP.md
|
||||
cat CHANGELOG-unreleased.md >> CHANGELOG_TMP.md
|
||||
echo "" >> CHANGELOG_TMP.md
|
||||
|
||||
# Append the remainder of the original changelog (starting from the first version header).
|
||||
awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.md >> CHANGELOG_TMP.md
|
||||
|
||||
# Replace the old file with the updated file.
|
||||
mv CHANGELOG_TMP.md CHANGELOG.md
|
||||
|
||||
# Remove the temporary generated file.
|
||||
rm -f CHANGELOG-unreleased.md
|
||||
|
||||
echo "Final CHANGELOG.md:"
|
||||
cat CHANGELOG.md
|
||||
|
||||
# 8. Check if CHANGELOG.md has any updates.
|
||||
- name: Check for CHANGELOG.md changes
|
||||
id: changelog_changes
|
||||
run: |
|
||||
if git diff --quiet CHANGELOG.md; then
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# 9. Create (or update) the Pull Request only if there are changes.
|
||||
- name: Create Pull Request
|
||||
if: steps.changelog_changes.outputs.has_changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
base: main
|
||||
branch: "changelog/unreleased-update"
|
||||
sign-commits: true
|
||||
commit-message: "action: update Unreleased changelog"
|
||||
title: "📜 docs: Unreleased Changelog"
|
||||
body: |
|
||||
**Description**:
|
||||
- This PR updates the Unreleased section in CHANGELOG.md.
|
||||
- It compares the current main branch with the latest version tag (determined as ${{ steps.get_latest_tag.outputs.tag }}),
|
||||
regenerates the Unreleased changelog, removes any old Unreleased block, and inserts the new content.
|
||||
3
.github/workflows/i18n-unused-keys.yml
vendored
3
.github/workflows/i18n-unused-keys.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- "client/src/**"
|
||||
- "api/**"
|
||||
- "packages/data-provider/src/**"
|
||||
- "packages/client/**"
|
||||
|
||||
jobs:
|
||||
detect-unused-i18n-keys:
|
||||
@@ -23,7 +24,7 @@ jobs:
|
||||
|
||||
# Define paths
|
||||
I18N_FILE="client/src/locales/en/translation.json"
|
||||
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src")
|
||||
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src" "packages/client")
|
||||
|
||||
# Check if translation file exists
|
||||
if [[ ! -f "$I18N_FILE" ]]; then
|
||||
|
||||
101
.github/workflows/unused-packages.yml
vendored
101
.github/workflows/unused-packages.yml
vendored
@@ -7,6 +7,7 @@ on:
|
||||
- 'package-lock.json'
|
||||
- 'client/**'
|
||||
- 'api/**'
|
||||
- 'packages/client/**'
|
||||
|
||||
jobs:
|
||||
detect-unused-packages:
|
||||
@@ -28,7 +29,7 @@ jobs:
|
||||
|
||||
- name: Validate JSON files
|
||||
run: |
|
||||
for FILE in package.json client/package.json api/package.json; do
|
||||
for FILE in package.json client/package.json api/package.json packages/client/package.json; do
|
||||
if [[ -f "$FILE" ]]; then
|
||||
jq empty "$FILE" || (echo "::error title=Invalid JSON::$FILE is invalid" && exit 1)
|
||||
fi
|
||||
@@ -63,12 +64,31 @@ jobs:
|
||||
local folder=$1
|
||||
local output_file=$2
|
||||
if [[ -d "$folder" ]]; then
|
||||
grep -rEho "require\\(['\"]([a-zA-Z0-9@/._-]+)['\"]\\)" "$folder" --include=\*.{js,ts,mjs,cjs} | \
|
||||
# Extract require() statements
|
||||
grep -rEho "require\\(['\"]([a-zA-Z0-9@/._-]+)['\"]\\)" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
sed -E "s/require\\(['\"]([a-zA-Z0-9@/._-]+)['\"]\\)/\1/" > "$output_file"
|
||||
|
||||
grep -rEho "import .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,mjs,cjs} | \
|
||||
# Extract ES6 imports - various patterns
|
||||
# import x from 'module'
|
||||
grep -rEho "import .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
sed -E "s/import .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# import 'module' (side-effect imports)
|
||||
grep -rEho "import ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
sed -E "s/import ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# export { x } from 'module' or export * from 'module'
|
||||
grep -rEho "export .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
sed -E "s/export .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# import type { x } from 'module' (TypeScript)
|
||||
grep -rEho "import type .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{ts,tsx} | \
|
||||
sed -E "s/import type .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# Remove subpath imports but keep the base package
|
||||
# e.g., '@tanstack/react-query/devtools' becomes '@tanstack/react-query'
|
||||
sed -i -E 's|^(@?[a-zA-Z0-9-]+(/[a-zA-Z0-9-]+)?)/.*|\1|' "$output_file"
|
||||
|
||||
sort -u "$output_file" -o "$output_file"
|
||||
else
|
||||
touch "$output_file"
|
||||
@@ -78,13 +98,80 @@ jobs:
|
||||
extract_deps_from_code "." root_used_code.txt
|
||||
extract_deps_from_code "client" client_used_code.txt
|
||||
extract_deps_from_code "api" api_used_code.txt
|
||||
|
||||
# Extract dependencies used by @librechat/client package
|
||||
extract_deps_from_code "packages/client" packages_client_used_code.txt
|
||||
|
||||
- name: Get @librechat/client dependencies
|
||||
id: get-librechat-client-deps
|
||||
run: |
|
||||
if [[ -f "packages/client/package.json" ]]; then
|
||||
# Get all dependencies from @librechat/client (dependencies, devDependencies, and peerDependencies)
|
||||
DEPS=$(jq -r '.dependencies // {} | keys[]' packages/client/package.json 2>/dev/null || echo "")
|
||||
DEV_DEPS=$(jq -r '.devDependencies // {} | keys[]' packages/client/package.json 2>/dev/null || echo "")
|
||||
PEER_DEPS=$(jq -r '.peerDependencies // {} | keys[]' packages/client/package.json 2>/dev/null || echo "")
|
||||
|
||||
# Combine all dependencies
|
||||
echo "$DEPS" > librechat_client_deps.txt
|
||||
echo "$DEV_DEPS" >> librechat_client_deps.txt
|
||||
echo "$PEER_DEPS" >> librechat_client_deps.txt
|
||||
|
||||
# Also include dependencies that are imported in packages/client
|
||||
cat packages_client_used_code.txt >> librechat_client_deps.txt
|
||||
|
||||
# Remove empty lines and sort
|
||||
grep -v '^$' librechat_client_deps.txt | sort -u > temp_deps.txt
|
||||
mv temp_deps.txt librechat_client_deps.txt
|
||||
else
|
||||
touch librechat_client_deps.txt
|
||||
fi
|
||||
|
||||
- name: Extract Workspace Dependencies
|
||||
id: extract-workspace-deps
|
||||
run: |
|
||||
# Function to get dependencies from a workspace package that are used by another package
|
||||
get_workspace_package_deps() {
|
||||
local package_json=$1
|
||||
local output_file=$2
|
||||
|
||||
# Get all workspace dependencies (starting with @librechat/)
|
||||
if [[ -f "$package_json" ]]; then
|
||||
local workspace_deps=$(jq -r '.dependencies // {} | to_entries[] | select(.key | startswith("@librechat/")) | .key' "$package_json" 2>/dev/null || echo "")
|
||||
|
||||
# For each workspace dependency, get its dependencies
|
||||
for dep in $workspace_deps; do
|
||||
# Convert @librechat/api to packages/api
|
||||
local workspace_path=$(echo "$dep" | sed 's/@librechat\//packages\//')
|
||||
local workspace_package_json="${workspace_path}/package.json"
|
||||
|
||||
if [[ -f "$workspace_package_json" ]]; then
|
||||
# Extract all dependencies from the workspace package
|
||||
jq -r '.dependencies // {} | keys[]' "$workspace_package_json" 2>/dev/null >> "$output_file"
|
||||
# Also extract peerDependencies
|
||||
jq -r '.peerDependencies // {} | keys[]' "$workspace_package_json" 2>/dev/null >> "$output_file"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [[ -f "$output_file" ]]; then
|
||||
sort -u "$output_file" -o "$output_file"
|
||||
else
|
||||
touch "$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# Get workspace dependencies for each package
|
||||
get_workspace_package_deps "package.json" root_workspace_deps.txt
|
||||
get_workspace_package_deps "client/package.json" client_workspace_deps.txt
|
||||
get_workspace_package_deps "api/package.json" api_workspace_deps.txt
|
||||
|
||||
- name: Run depcheck for root package.json
|
||||
id: check-root
|
||||
run: |
|
||||
if [[ -f "package.json" ]]; then
|
||||
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat root_used_deps.txt root_used_code.txt | sort) || echo "")
|
||||
# Exclude dependencies used in scripts, code, and workspace packages
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat root_used_deps.txt root_used_code.txt root_workspace_deps.txt | sort) || echo "")
|
||||
echo "ROOT_UNUSED<<EOF" >> $GITHUB_ENV
|
||||
echo "$UNUSED" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
@@ -97,7 +184,8 @@ jobs:
|
||||
chmod -R 755 client
|
||||
cd client
|
||||
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "")
|
||||
# Exclude dependencies used in scripts, code, and workspace packages
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt ../client_workspace_deps.txt | sort) || echo "")
|
||||
# Filter out false positives
|
||||
UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "")
|
||||
echo "CLIENT_UNUSED<<EOF" >> $GITHUB_ENV
|
||||
@@ -113,7 +201,8 @@ jobs:
|
||||
chmod -R 755 api
|
||||
cd api
|
||||
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../api_used_deps.txt ../api_used_code.txt | sort) || echo "")
|
||||
# Exclude dependencies used in scripts, code, and workspace packages
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../api_used_deps.txt ../api_used_code.txt ../api_workspace_deps.txt | sort) || echo "")
|
||||
echo "API_UNUSED<<EOF" >> $GITHUB_ENV
|
||||
echo "$UNUSED" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -125,3 +125,12 @@ helm/**/.values.yaml
|
||||
|
||||
# SAML Idp cert
|
||||
*.cert
|
||||
|
||||
# AI Assistants
|
||||
/.claude/
|
||||
/.cursor/
|
||||
/.copilot/
|
||||
/.aider/
|
||||
/.openai/
|
||||
/.tabnine/
|
||||
/.codeium
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.7.8
|
||||
# v0.7.9
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.7.8
|
||||
# v0.7.9
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
@@ -16,6 +16,7 @@ COPY package*.json ./
|
||||
COPY packages/data-provider/package*.json ./packages/data-provider/
|
||||
COPY packages/api/package*.json ./packages/api/
|
||||
COPY packages/data-schemas/package*.json ./packages/data-schemas/
|
||||
COPY packages/client/package*.json ./packages/client/
|
||||
COPY client/package*.json ./client/
|
||||
COPY api/package*.json ./api/
|
||||
|
||||
@@ -45,11 +46,19 @@ COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/d
|
||||
COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist
|
||||
RUN npm run build
|
||||
|
||||
# Build `client` package
|
||||
FROM base AS client-package-build
|
||||
WORKDIR /app/packages/client
|
||||
COPY packages/client ./
|
||||
RUN npm run build
|
||||
|
||||
# Client build
|
||||
FROM base AS client-build
|
||||
WORKDIR /app/client
|
||||
COPY client ./
|
||||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
COPY --from=client-package-build /app/packages/client/dist /app/packages/client/dist
|
||||
COPY --from=client-package-build /app/packages/client/src /app/packages/client/src
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
RUN npm run build
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
|
||||
|
||||
- 🤖 **AI Model Selection**:
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
|
||||
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
|
||||
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
||||
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
||||
@@ -66,10 +66,9 @@
|
||||
- 🔦 **Agents & Tools Integration**:
|
||||
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
||||
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
|
||||
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
|
||||
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
||||
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
|
||||
|
||||
- 🔍 **Web Search**:
|
||||
- Search the internet and retrieve relevant information to enhance your AI context
|
||||
|
||||
@@ -13,7 +13,6 @@ const {
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { addSpaceIfNeeded } = require('~/server/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
@@ -109,12 +108,15 @@ class BaseClient {
|
||||
/**
|
||||
* Abstract method to record token usage. Subclasses must implement this method.
|
||||
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
|
||||
* Should only be used if `recordCollectedUsage` was not used instead.
|
||||
* @param {string} [model]
|
||||
* @param {number} promptTokens
|
||||
* @param {number} completionTokens
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
async recordTokenUsage({ model, promptTokens, completionTokens }) {
|
||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
||||
model,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
@@ -198,6 +200,10 @@ class BaseClient {
|
||||
this.currentMessages[this.currentMessages.length - 1].messageId = head;
|
||||
}
|
||||
|
||||
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
|
||||
responseMessageId = crypto.randomUUID();
|
||||
}
|
||||
|
||||
this.responseMessageId = responseMessageId;
|
||||
|
||||
return {
|
||||
@@ -572,7 +578,7 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
const { generation = '' } = opts;
|
||||
const { editedContent } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
// depending on subclass implementation of handling messages
|
||||
@@ -587,11 +593,21 @@ class BaseClient {
|
||||
isCreatedByUser: false,
|
||||
model: this.modelOptions?.model ?? this.model,
|
||||
sender: this.sender,
|
||||
text: generation,
|
||||
};
|
||||
this.currentMessages.push(userMessage, latestMessage);
|
||||
} else {
|
||||
latestMessage.text = generation;
|
||||
} else if (editedContent != null) {
|
||||
// Handle editedContent for content parts
|
||||
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
|
||||
const { index, text, type } = editedContent;
|
||||
if (index >= 0 && index < latestMessage.content.length) {
|
||||
const contentPart = latestMessage.content[index];
|
||||
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
|
||||
contentPart[ContentTypes.THINK] = text;
|
||||
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
|
||||
contentPart[ContentTypes.TEXT] = text;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this.continued = true;
|
||||
} else {
|
||||
@@ -672,16 +688,32 @@ class BaseClient {
|
||||
};
|
||||
|
||||
if (typeof completion === 'string') {
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion;
|
||||
responseMessage.text = completion;
|
||||
} else if (
|
||||
Array.isArray(completion) &&
|
||||
(this.clientName === EModelEndpoint.agents ||
|
||||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
||||
) {
|
||||
responseMessage.text = '';
|
||||
responseMessage.content = completion;
|
||||
|
||||
if (!opts.editedContent || this.currentMessages.length === 0) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
|
||||
if (!latestMessage?.content) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const existingContent = [...latestMessage.content];
|
||||
const { type: editedType } = opts.editedContent;
|
||||
responseMessage.content = this.mergeEditedContent(
|
||||
existingContent,
|
||||
completion,
|
||||
editedType,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (Array.isArray(completion)) {
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
|
||||
responseMessage.text = completion.join('');
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -712,9 +744,13 @@ class BaseClient {
|
||||
} else {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
await this.recordTokenUsage({
|
||||
usage,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
model: responseMessage.model,
|
||||
});
|
||||
}
|
||||
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
|
||||
}
|
||||
|
||||
if (userMessagePromise) {
|
||||
@@ -1095,6 +1131,50 @@ class BaseClient {
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges completion content with existing content when editing TEXT or THINK types
|
||||
* @param {Array} existingContent - The existing content array
|
||||
* @param {Array} newCompletion - The new completion content
|
||||
* @param {string} editedType - The type of content being edited
|
||||
* @returns {Array} The merged content array
|
||||
*/
|
||||
mergeEditedContent(existingContent, newCompletion, editedType) {
|
||||
if (!newCompletion.length) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const lastIndex = existingContent.length - 1;
|
||||
const lastExisting = existingContent[lastIndex];
|
||||
const firstNew = newCompletion[0];
|
||||
|
||||
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const mergedContent = [...existingContent];
|
||||
if (editedType === ContentTypes.TEXT) {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.TEXT]:
|
||||
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
|
||||
};
|
||||
} else {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.THINK]:
|
||||
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
|
||||
(firstNew[ContentTypes.THINK] || ''),
|
||||
};
|
||||
}
|
||||
|
||||
// Add remaining completion items
|
||||
return mergedContent.concat(newCompletion.slice(1));
|
||||
}
|
||||
|
||||
async sendPayload(payload, opts = {}) {
|
||||
if (opts && typeof opts === 'object') {
|
||||
this.setOptions(opts);
|
||||
|
||||
@@ -237,41 +237,9 @@ const formatAgentMessages = (payload) => {
|
||||
return messages;
|
||||
};
|
||||
|
||||
/**
|
||||
* Formats an array of messages for LangChain, making sure all content fields are strings
|
||||
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
|
||||
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
|
||||
*/
|
||||
const formatContentStrings = (payload) => {
|
||||
const messages = [];
|
||||
|
||||
for (const message of payload) {
|
||||
if (typeof message.content === 'string') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!Array.isArray(message.content)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reduce text types to a single string, ignore all other types
|
||||
const content = message.content.reduce((acc, curr) => {
|
||||
if (curr.type === ContentTypes.TEXT) {
|
||||
return `${acc}${curr[ContentTypes.TEXT]}\n`;
|
||||
}
|
||||
return acc;
|
||||
}, '');
|
||||
|
||||
message.content = content.trim();
|
||||
}
|
||||
|
||||
return messages;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
formatMessage,
|
||||
formatFromLangChain,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
formatLangChainMessages,
|
||||
};
|
||||
|
||||
@@ -422,6 +422,46 @@ describe('BaseClient', () => {
|
||||
expect(response).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
|
||||
const mockCrypto = require('crypto');
|
||||
const newUUID = 'new-uuid-1234';
|
||||
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
|
||||
|
||||
const opts = {
|
||||
isRegenerate: true,
|
||||
responseMessageId: 'existing-message-id_',
|
||||
};
|
||||
|
||||
await TestClient.setMessageOptions(opts);
|
||||
|
||||
expect(TestClient.responseMessageId).toBe(newUUID);
|
||||
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
|
||||
|
||||
mockCrypto.randomUUID.mockRestore();
|
||||
});
|
||||
|
||||
test('should not replace responseMessageId when isRegenerate is false', async () => {
|
||||
const opts = {
|
||||
isRegenerate: false,
|
||||
responseMessageId: 'existing-message-id_',
|
||||
};
|
||||
|
||||
await TestClient.setMessageOptions(opts);
|
||||
|
||||
expect(TestClient.responseMessageId).toBe('existing-message-id_');
|
||||
});
|
||||
|
||||
test('should not replace responseMessageId when it does not end with underscore', async () => {
|
||||
const opts = {
|
||||
isRegenerate: true,
|
||||
responseMessageId: 'existing-message-id',
|
||||
};
|
||||
|
||||
await TestClient.setMessageOptions(opts);
|
||||
|
||||
expect(TestClient.responseMessageId).toBe('existing-message-id');
|
||||
});
|
||||
|
||||
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
|
||||
const userMessage = 'Second message in the conversation';
|
||||
const opts = {
|
||||
|
||||
@@ -3,8 +3,8 @@ const path = require('path');
|
||||
const OpenAI = require('openai');
|
||||
const fetch = require('node-fetch');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { FileContext, ContentTypes } = require('librechat-data-provider');
|
||||
const { getImageBasename } = require('~/server/services/Files/images');
|
||||
const extractBaseURL = require('~/utils/extractBaseURL');
|
||||
@@ -46,7 +46,10 @@ class DALLE3 extends Tool {
|
||||
}
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
config.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
/** @type {OpenAI} */
|
||||
@@ -163,7 +166,8 @@ Error Message: ${error.message}`);
|
||||
if (this.isAgent) {
|
||||
let fetchOptions = {};
|
||||
if (process.env.PROXY) {
|
||||
fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY);
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
fetchOptions.dispatcher = proxyAgent;
|
||||
}
|
||||
const imageResponse = await fetch(theImageUrl, fetchOptions);
|
||||
const arrayBuffer = await imageResponse.arrayBuffer();
|
||||
|
||||
@@ -3,10 +3,10 @@ const axios = require('axios');
|
||||
const { v4 } = require('uuid');
|
||||
const OpenAI = require('openai');
|
||||
const FormData = require('form-data');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
@@ -189,7 +189,10 @@ function createOpenAIImageTools(fields = {}) {
|
||||
}
|
||||
const clientConfig = { ...closureConfig };
|
||||
if (process.env.PROXY) {
|
||||
clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
clientConfig.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
/** @type {OpenAI} */
|
||||
@@ -335,7 +338,10 @@ Error Message: ${error.message}`);
|
||||
|
||||
const clientConfig = { ...closureConfig };
|
||||
if (process.env.PROXY) {
|
||||
clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
clientConfig.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
@@ -447,6 +453,19 @@ Error Message: ${error.message}`);
|
||||
baseURL,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
try {
|
||||
const url = new URL(process.env.PROXY);
|
||||
axiosConfig.proxy = {
|
||||
host: url.hostname.replace(/^\[|\]$/g, ''),
|
||||
port: url.port ? parseInt(url.port, 10) : undefined,
|
||||
protocol: url.protocol.replace(':', ''),
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error parsing proxy URL:', error);
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.IMAGE_GEN_OAI_AZURE_API_VERSION && process.env.IMAGE_GEN_OAI_BASEURL) {
|
||||
axiosConfig.params = {
|
||||
'api-version': process.env.IMAGE_GEN_OAI_AZURE_API_VERSION,
|
||||
|
||||
94
api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js
Normal file
94
api/app/clients/tools/structured/specs/DALLE3-proxy.spec.js
Normal file
@@ -0,0 +1,94 @@
|
||||
const DALLE3 = require('../DALLE3');
|
||||
const { ProxyAgent } = require('undici');
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
getImageBasename: jest.fn().mockImplementation((url) => {
|
||||
const parts = url.split('/');
|
||||
const lastPart = parts.pop();
|
||||
const imageExtensionRegex = /\.(jpg|jpeg|png|gif|bmp|tiff|svg)$/i;
|
||||
if (imageExtensionRegex.test(lastPart)) {
|
||||
return lastPart;
|
||||
}
|
||||
return '';
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('fs', () => {
|
||||
return {
|
||||
existsSync: jest.fn(),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: {
|
||||
writeFile: jest.fn(),
|
||||
readFile: jest.fn(),
|
||||
unlink: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('path', () => {
|
||||
return {
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn(),
|
||||
relative: jest.fn(),
|
||||
extname: jest.fn().mockImplementation((filename) => {
|
||||
return filename.slice(filename.lastIndexOf('.'));
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe('DALLE3 Proxy Configuration', () => {
|
||||
let originalEnv;
|
||||
|
||||
beforeAll(() => {
|
||||
originalEnv = { ...process.env };
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should configure ProxyAgent in fetchOptions.dispatcher when PROXY env is set', () => {
|
||||
// Set proxy environment variable
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
process.env.DALLE_API_KEY = 'test-api-key';
|
||||
|
||||
// Create instance
|
||||
const dalleWithProxy = new DALLE3({ processFileURL });
|
||||
|
||||
// Check that the openai client exists
|
||||
expect(dalleWithProxy.openai).toBeDefined();
|
||||
|
||||
// Check that _options exists and has fetchOptions with a dispatcher
|
||||
expect(dalleWithProxy.openai._options).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
|
||||
});
|
||||
|
||||
it('should not configure ProxyAgent when PROXY env is not set', () => {
|
||||
// Ensure PROXY is not set
|
||||
delete process.env.PROXY;
|
||||
process.env.DALLE_API_KEY = 'test-api-key';
|
||||
|
||||
// Create instance
|
||||
const dalleWithoutProxy = new DALLE3({ processFileURL });
|
||||
|
||||
// Check that the openai client exists
|
||||
expect(dalleWithoutProxy.openai).toBeDefined();
|
||||
|
||||
// Check that _options exists but fetchOptions either doesn't exist or doesn't have a dispatcher
|
||||
expect(dalleWithoutProxy.openai._options).toBeDefined();
|
||||
|
||||
// fetchOptions should either not exist or not have a dispatcher
|
||||
if (dalleWithoutProxy.openai._options.fetchOptions) {
|
||||
expect(dalleWithoutProxy.openai._options.fetchOptions.dispatcher).toBeUndefined();
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -11,17 +11,25 @@ const { getFiles } = require('~/models/File');
|
||||
* @param {Object} options
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Agent['tool_resources']} options.tool_resources
|
||||
* @param {string} [options.agentId] - The agent ID for file access control
|
||||
* @returns {Promise<{
|
||||
* files: Array<{ file_id: string; filename: string }>,
|
||||
* toolContext: string
|
||||
* }>}
|
||||
*/
|
||||
const primeFiles = async (options) => {
|
||||
const { tool_resources } = options;
|
||||
const { tool_resources, req, agentId } = options;
|
||||
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
|
||||
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
|
||||
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
|
||||
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
const { mcpToolPattern } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||
const {
|
||||
Tools,
|
||||
EToolResources,
|
||||
loadWebSearchAuth,
|
||||
replaceSpecialVars,
|
||||
} = require('librechat-data-provider');
|
||||
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const {
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
@@ -235,7 +230,7 @@ const loadTools = async ({
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
const toolContextMap = {};
|
||||
const appTools = (await getCachedTools({ includeGlobal: true })) ?? {};
|
||||
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool === Tools.execute_code) {
|
||||
@@ -245,7 +240,13 @@ const loadTools = async ({
|
||||
authFields: [EnvVar.CODE_API_KEY],
|
||||
});
|
||||
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
|
||||
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
|
||||
const { files, toolContext } = await primeCodeFiles(
|
||||
{
|
||||
...options,
|
||||
agentId: agent?.id,
|
||||
},
|
||||
codeApiKey,
|
||||
);
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
@@ -260,7 +261,10 @@ const loadTools = async ({
|
||||
continue;
|
||||
} else if (tool === Tools.file_search) {
|
||||
requestedTools[tool] = async () => {
|
||||
const { files, toolContext } = await primeSearchFiles(options);
|
||||
const { files, toolContext } = await primeSearchFiles({
|
||||
...options,
|
||||
agentId: agent?.id,
|
||||
});
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
@@ -294,7 +298,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
});
|
||||
};
|
||||
continue;
|
||||
} else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
requestedTools[tool] = async () =>
|
||||
createMCPTool({
|
||||
req: options.req,
|
||||
|
||||
62
api/cache/cacheConfig.js
vendored
Normal file
62
api/cache/cacheConfig.js
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
const fs = require('fs');
|
||||
const { math, isEnabled } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
|
||||
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
|
||||
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
|
||||
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
|
||||
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
|
||||
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
|
||||
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
|
||||
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
|
||||
}
|
||||
|
||||
const USE_REDIS = isEnabled(process.env.USE_REDIS);
|
||||
if (USE_REDIS && !process.env.REDIS_URI) {
|
||||
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
|
||||
}
|
||||
|
||||
// Comma-separated list of cache namespaces that should be forced to use in-memory storage
|
||||
// even when Redis is enabled. This allows selective performance optimization for specific caches.
|
||||
const FORCED_IN_MEMORY_CACHE_NAMESPACES = process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES
|
||||
? process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES.split(',').map((key) => key.trim())
|
||||
: [];
|
||||
|
||||
// Validate against CacheKeys enum
|
||||
if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {
|
||||
const validKeys = Object.values(CacheKeys);
|
||||
const invalidKeys = FORCED_IN_MEMORY_CACHE_NAMESPACES.filter((key) => !validKeys.includes(key));
|
||||
|
||||
if (invalidKeys.length > 0) {
|
||||
throw new Error(
|
||||
`Invalid cache keys in FORCED_IN_MEMORY_CACHE_NAMESPACES: ${invalidKeys.join(', ')}. Valid keys: ${validKeys.join(', ')}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const cacheConfig = {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES,
|
||||
USE_REDIS,
|
||||
REDIS_URI: process.env.REDIS_URI,
|
||||
REDIS_USERNAME: process.env.REDIS_USERNAME,
|
||||
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
|
||||
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
|
||||
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
|
||||
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
|
||||
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),
|
||||
/** Max delay between reconnection attempts in ms */
|
||||
REDIS_RETRY_MAX_DELAY: math(process.env.REDIS_RETRY_MAX_DELAY, 3000),
|
||||
/** Max number of reconnection attempts (0 = infinite) */
|
||||
REDIS_RETRY_MAX_ATTEMPTS: math(process.env.REDIS_RETRY_MAX_ATTEMPTS, 10),
|
||||
/** Connection timeout in ms */
|
||||
REDIS_CONNECT_TIMEOUT: math(process.env.REDIS_CONNECT_TIMEOUT, 10000),
|
||||
/** Queue commands when disconnected */
|
||||
REDIS_ENABLE_OFFLINE_QUEUE: isEnabled(process.env.REDIS_ENABLE_OFFLINE_QUEUE ?? 'true'),
|
||||
|
||||
CI: isEnabled(process.env.CI),
|
||||
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
|
||||
|
||||
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
|
||||
};
|
||||
|
||||
module.exports = { cacheConfig };
|
||||
157
api/cache/cacheConfig.spec.js
vendored
Normal file
157
api/cache/cacheConfig.spec.js
vendored
Normal file
@@ -0,0 +1,157 @@
|
||||
const fs = require('fs');
|
||||
|
||||
describe('cacheConfig', () => {
|
||||
let originalEnv;
|
||||
let originalReadFileSync;
|
||||
|
||||
beforeEach(() => {
|
||||
originalEnv = { ...process.env };
|
||||
originalReadFileSync = fs.readFileSync;
|
||||
|
||||
// Clear all related env vars first
|
||||
delete process.env.REDIS_URI;
|
||||
delete process.env.REDIS_CA;
|
||||
delete process.env.REDIS_KEY_PREFIX_VAR;
|
||||
delete process.env.REDIS_KEY_PREFIX;
|
||||
delete process.env.USE_REDIS;
|
||||
delete process.env.REDIS_PING_INTERVAL;
|
||||
delete process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES;
|
||||
|
||||
// Clear require cache
|
||||
jest.resetModules();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
fs.readFileSync = originalReadFileSync;
|
||||
jest.resetModules();
|
||||
});
|
||||
|
||||
describe('REDIS_KEY_PREFIX validation and resolution', () => {
|
||||
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
|
||||
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
|
||||
});
|
||||
|
||||
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
|
||||
process.env.DEPLOYMENT_ID = 'test-deployment-123';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
|
||||
});
|
||||
|
||||
test('should use direct REDIS_KEY_PREFIX value', () => {
|
||||
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
|
||||
});
|
||||
|
||||
test('should default to empty string when no prefix is configured', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
|
||||
});
|
||||
|
||||
test('should handle empty variable reference', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
|
||||
process.env.EMPTY_VAR = '';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
|
||||
});
|
||||
|
||||
test('should handle undefined variable reference', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('USE_REDIS and REDIS_URI validation', () => {
|
||||
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
|
||||
process.env.USE_REDIS = 'true';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
|
||||
});
|
||||
|
||||
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = 'redis://localhost:6379';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = '';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('REDIS_CA file reading', () => {
|
||||
test('should be null when REDIS_CA is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_CA).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('REDIS_PING_INTERVAL configuration', () => {
|
||||
test('should default to 0 when REDIS_PING_INTERVAL is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_PING_INTERVAL).toBe(0);
|
||||
});
|
||||
|
||||
test('should use provided REDIS_PING_INTERVAL value', () => {
|
||||
process.env.REDIS_PING_INTERVAL = '300';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_PING_INTERVAL).toBe(300);
|
||||
});
|
||||
});
|
||||
|
||||
describe('FORCED_IN_MEMORY_CACHE_NAMESPACES validation', () => {
|
||||
test('should parse comma-separated cache keys correctly', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, STATIC_CONFIG ,MESSAGES ';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([
|
||||
'ROLES',
|
||||
'STATIC_CONFIG',
|
||||
'MESSAGES',
|
||||
]);
|
||||
});
|
||||
|
||||
test('should throw error for invalid cache keys', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = 'INVALID_KEY,ROLES';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('Invalid cache keys in FORCED_IN_MEMORY_CACHE_NAMESPACES: INVALID_KEY');
|
||||
});
|
||||
|
||||
test('should handle empty string gracefully', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = '';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
|
||||
});
|
||||
|
||||
test('should handle undefined env var gracefully', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
108
api/cache/cacheFactory.js
vendored
Normal file
108
api/cache/cacheFactory.js
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
const { Keyv } = require('keyv');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { RedisStore: ConnectRedis } = require('connect-redis');
|
||||
const MemoryStore = require('memorystore')(require('express-session'));
|
||||
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
const { violationFile } = require('./keyvFiles');
|
||||
|
||||
/**
|
||||
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
|
||||
* @param {string} namespace - The cache namespace.
|
||||
* @param {number} [ttl] - Time to live for cache entries.
|
||||
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
|
||||
* @returns {Keyv} Cache instance.
|
||||
*/
|
||||
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
|
||||
if (
|
||||
cacheConfig.USE_REDIS &&
|
||||
!cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES?.includes(namespace)
|
||||
) {
|
||||
try {
|
||||
const keyvRedis = new KeyvRedis(keyvRedisClient);
|
||||
const cache = new Keyv(keyvRedis, { namespace, ttl });
|
||||
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
|
||||
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
|
||||
|
||||
cache.on('error', (err) => {
|
||||
logger.error(`Cache error in namespace ${namespace}:`, err);
|
||||
});
|
||||
|
||||
return cache;
|
||||
} catch (err) {
|
||||
logger.error(`Failed to create Redis cache for namespace ${namespace}:`, err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
|
||||
return new Keyv({ namespace, ttl });
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a cache instance for storing violation data.
|
||||
* Uses a file-based fallback store if Redis is not enabled.
|
||||
* @param {string} namespace - The cache namespace for violations.
|
||||
* @param {number} [ttl] - Time to live for cache entries.
|
||||
* @returns {Keyv} Cache instance for violations.
|
||||
*/
|
||||
const violationCache = (namespace, ttl = undefined) => {
|
||||
return standardCache(`violations:${namespace}`, ttl, violationFile);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a session cache instance using Redis or in-memory store.
|
||||
* @param {string} namespace - The session namespace.
|
||||
* @param {number} [ttl] - Time to live for session entries.
|
||||
* @returns {MemoryStore | ConnectRedis} Session store instance.
|
||||
*/
|
||||
const sessionCache = (namespace, ttl = undefined) => {
|
||||
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
|
||||
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
|
||||
const store = new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
|
||||
if (ioredisClient) {
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error(`Session store Redis error for namespace ${namespace}:`, err);
|
||||
});
|
||||
}
|
||||
return store;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a rate limiter cache using Redis.
|
||||
* @param {string} prefix - The key prefix for rate limiting.
|
||||
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
|
||||
*/
|
||||
const limiterCache = (prefix) => {
|
||||
if (!prefix) throw new Error('prefix is required');
|
||||
if (!cacheConfig.USE_REDIS) return undefined;
|
||||
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
|
||||
|
||||
try {
|
||||
if (!ioredisClient) {
|
||||
logger.warn(`Redis client not available for rate limiter with prefix ${prefix}`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return new RedisStore({ sendCommand, prefix });
|
||||
} catch (err) {
|
||||
logger.error(`Failed to create Redis rate limiter for prefix ${prefix}:`, err);
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const sendCommand = (...args) => {
|
||||
if (!ioredisClient) {
|
||||
logger.warn('Redis client not available for command execution');
|
||||
return Promise.reject(new Error('Redis client not available'));
|
||||
}
|
||||
|
||||
return ioredisClient.call(...args).catch((err) => {
|
||||
logger.error('Redis command execution failed:', err);
|
||||
throw err;
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = { standardCache, sessionCache, violationCache, limiterCache };
|
||||
432
api/cache/cacheFactory.spec.js
vendored
Normal file
432
api/cache/cacheFactory.spec.js
vendored
Normal file
@@ -0,0 +1,432 @@
|
||||
const { Time } = require('librechat-data-provider');
|
||||
|
||||
// Mock dependencies first
|
||||
const mockKeyvRedis = {
|
||||
namespace: '',
|
||||
keyPrefixSeparator: '',
|
||||
};
|
||||
|
||||
const mockKeyv = jest.fn().mockReturnValue({
|
||||
mock: 'keyv',
|
||||
on: jest.fn(),
|
||||
});
|
||||
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
|
||||
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
|
||||
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
|
||||
|
||||
const mockIoredisClient = {
|
||||
call: jest.fn(),
|
||||
on: jest.fn(),
|
||||
};
|
||||
|
||||
const mockKeyvRedisClient = {};
|
||||
const mockViolationFile = {};
|
||||
|
||||
// Mock modules before requiring the main module
|
||||
jest.mock('@keyv/redis', () => ({
|
||||
default: jest.fn().mockImplementation(() => mockKeyvRedis),
|
||||
}));
|
||||
|
||||
jest.mock('keyv', () => ({
|
||||
Keyv: mockKeyv,
|
||||
}));
|
||||
|
||||
jest.mock('./cacheConfig', () => ({
|
||||
cacheConfig: {
|
||||
USE_REDIS: false,
|
||||
REDIS_KEY_PREFIX: 'test',
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES: [],
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('./redisClients', () => ({
|
||||
keyvRedisClient: mockKeyvRedisClient,
|
||||
ioredisClient: mockIoredisClient,
|
||||
GLOBAL_PREFIX_SEPARATOR: '::',
|
||||
}));
|
||||
|
||||
jest.mock('./keyvFiles', () => ({
|
||||
violationFile: mockViolationFile,
|
||||
}));
|
||||
|
||||
jest.mock('connect-redis', () => ({ RedisStore: mockConnectRedis }));
|
||||
|
||||
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
|
||||
|
||||
jest.mock('rate-limit-redis', () => ({
|
||||
RedisStore: mockRedisStore,
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocking
|
||||
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
|
||||
describe('cacheFactory', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Reset cache config mock
|
||||
cacheConfig.USE_REDIS = false;
|
||||
cacheConfig.REDIS_KEY_PREFIX = 'test';
|
||||
cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES = [];
|
||||
});
|
||||
|
||||
describe('redisCache', () => {
|
||||
it('should create Redis cache when USE_REDIS is true', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
|
||||
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
|
||||
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
|
||||
});
|
||||
|
||||
it('should create Redis cache with undefined ttl when not provided', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
|
||||
standardCache(namespace);
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
|
||||
});
|
||||
|
||||
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
const fallbackStore = { some: 'store' };
|
||||
|
||||
standardCache(namespace, ttl, fallbackStore);
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
|
||||
});
|
||||
|
||||
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
|
||||
});
|
||||
|
||||
it('should handle namespace and ttl as undefined', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
|
||||
standardCache();
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
|
||||
});
|
||||
|
||||
it('should use fallback when namespace is in FORCED_IN_MEMORY_CACHE_NAMESPACES', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES = ['forced-memory'];
|
||||
const namespace = 'forced-memory';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(require('@keyv/redis').default).not.toHaveBeenCalled();
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
|
||||
});
|
||||
|
||||
it('should use Redis when namespace is not in FORCED_IN_MEMORY_CACHE_NAMESPACES', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES = ['other-namespace'];
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
|
||||
});
|
||||
|
||||
it('should throw error when Redis cache creation fails', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
const testError = new Error('Redis connection failed');
|
||||
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
KeyvRedis.mockImplementationOnce(() => {
|
||||
throw testError;
|
||||
});
|
||||
|
||||
expect(() => standardCache(namespace, ttl)).toThrow('Redis connection failed');
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
`Failed to create Redis cache for namespace ${namespace}:`,
|
||||
testError,
|
||||
);
|
||||
|
||||
expect(mockKeyv).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('violationCache', () => {
|
||||
it('should create violation cache with prefixed namespace', () => {
|
||||
const namespace = 'test-violations';
|
||||
const ttl = 7200;
|
||||
|
||||
// We can't easily mock the internal redisCache call since it's in the same module
|
||||
// But we can test that the function executes without throwing
|
||||
expect(() => violationCache(namespace, ttl)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should create violation cache with undefined ttl', () => {
|
||||
const namespace = 'test-violations';
|
||||
|
||||
violationCache(namespace);
|
||||
|
||||
// The function should call redisCache with violations: prefixed namespace
|
||||
// Since we can't easily mock the internal redisCache call, we test the behavior
|
||||
expect(() => violationCache(namespace)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle undefined namespace', () => {
|
||||
expect(() => violationCache(undefined)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sessionCache', () => {
|
||||
it('should return MemoryStore when USE_REDIS is false', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
const result = sessionCache(namespace, ttl);
|
||||
|
||||
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
|
||||
expect(result).toBe(mockMemoryStore());
|
||||
});
|
||||
|
||||
it('should return ConnectRedis when USE_REDIS is true', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
const result = sessionCache(namespace, ttl);
|
||||
|
||||
expect(mockConnectRedis).toHaveBeenCalledWith({
|
||||
client: mockIoredisClient,
|
||||
ttl,
|
||||
prefix: `${namespace}:`,
|
||||
});
|
||||
expect(result).toBe(mockConnectRedis());
|
||||
});
|
||||
|
||||
it('should add colon to namespace if not present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
sessionCache(namespace);
|
||||
|
||||
expect(mockConnectRedis).toHaveBeenCalledWith({
|
||||
client: mockIoredisClient,
|
||||
ttl: undefined,
|
||||
prefix: 'sessions:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add colon to namespace if already present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions:';
|
||||
|
||||
sessionCache(namespace);
|
||||
|
||||
expect(mockConnectRedis).toHaveBeenCalledWith({
|
||||
client: mockIoredisClient,
|
||||
ttl: undefined,
|
||||
prefix: 'sessions:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle undefined ttl', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'sessions';
|
||||
|
||||
sessionCache(namespace);
|
||||
|
||||
expect(mockMemoryStore).toHaveBeenCalledWith({
|
||||
ttl: undefined,
|
||||
checkPeriod: Time.ONE_DAY,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error when ConnectRedis constructor fails', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
// Mock ConnectRedis to throw an error during construction
|
||||
const redisError = new Error('Redis connection failed');
|
||||
mockConnectRedis.mockImplementationOnce(() => {
|
||||
throw redisError;
|
||||
});
|
||||
|
||||
// The error should propagate up, not be caught
|
||||
expect(() => sessionCache(namespace, ttl)).toThrow('Redis connection failed');
|
||||
|
||||
// Verify that MemoryStore was NOT used as fallback
|
||||
expect(mockMemoryStore).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should register error handler but let errors propagate to Express', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
// Create a mock session store with middleware methods
|
||||
const mockSessionStore = {
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
mockConnectRedis.mockReturnValue(mockSessionStore);
|
||||
|
||||
const store = sessionCache(namespace);
|
||||
|
||||
// Verify error handler was registered
|
||||
expect(mockIoredisClient.on).toHaveBeenCalledWith('error', expect.any(Function));
|
||||
|
||||
// Get the error handler
|
||||
const errorHandler = mockIoredisClient.on.mock.calls.find((call) => call[0] === 'error')[1];
|
||||
|
||||
// Simulate an error from Redis during a session operation
|
||||
const redisError = new Error('Socket closed unexpectedly');
|
||||
|
||||
// The error handler should log but not swallow the error
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
errorHandler(redisError);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
`Session store Redis error for namespace ${namespace}::`,
|
||||
redisError,
|
||||
);
|
||||
|
||||
// Now simulate what happens when session middleware tries to use the store
|
||||
const callback = jest.fn();
|
||||
mockSessionStore.get.mockImplementation((sid, cb) => {
|
||||
cb(new Error('Redis connection lost'));
|
||||
});
|
||||
|
||||
// Call the store's get method (as Express session would)
|
||||
store.get('test-session-id', callback);
|
||||
|
||||
// The error should be passed to the callback, not swallowed
|
||||
expect(callback).toHaveBeenCalledWith(new Error('Redis connection lost'));
|
||||
});
|
||||
|
||||
it('should handle null ioredisClient gracefully', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
// Temporarily set ioredisClient to null (simulating connection not established)
|
||||
const originalClient = require('./redisClients').ioredisClient;
|
||||
require('./redisClients').ioredisClient = null;
|
||||
|
||||
// ConnectRedis might accept null client but would fail on first use
|
||||
// The important thing is it doesn't throw uncaught exceptions during construction
|
||||
const store = sessionCache(namespace);
|
||||
expect(store).toBeDefined();
|
||||
|
||||
// Restore original client
|
||||
require('./redisClients').ioredisClient = originalClient;
|
||||
});
|
||||
});
|
||||
|
||||
describe('limiterCache', () => {
|
||||
it('should return undefined when USE_REDIS is false', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const result = limiterCache('prefix');
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return RedisStore when USE_REDIS is true', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const result = limiterCache('rate-limit');
|
||||
|
||||
expect(mockRedisStore).toHaveBeenCalledWith({
|
||||
sendCommand: expect.any(Function),
|
||||
prefix: `rate-limit:`,
|
||||
});
|
||||
expect(result).toBe(mockRedisStore());
|
||||
});
|
||||
|
||||
it('should add colon to prefix if not present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
limiterCache('rate-limit');
|
||||
|
||||
expect(mockRedisStore).toHaveBeenCalledWith({
|
||||
sendCommand: expect.any(Function),
|
||||
prefix: 'rate-limit:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add colon to prefix if already present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
limiterCache('rate-limit:');
|
||||
|
||||
expect(mockRedisStore).toHaveBeenCalledWith({
|
||||
sendCommand: expect.any(Function),
|
||||
prefix: 'rate-limit:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass sendCommand function that calls ioredisClient.call', async () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
mockIoredisClient.call.mockResolvedValue('test-value');
|
||||
|
||||
limiterCache('rate-limit');
|
||||
|
||||
const sendCommandCall = mockRedisStore.mock.calls[0][0];
|
||||
const sendCommand = sendCommandCall.sendCommand;
|
||||
|
||||
// Test that sendCommand properly delegates to ioredisClient.call
|
||||
const args = ['GET', 'test-key'];
|
||||
const result = await sendCommand(...args);
|
||||
|
||||
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
|
||||
expect(result).toBe('test-value');
|
||||
});
|
||||
|
||||
it('should handle sendCommand errors properly', async () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Mock the call method to reject with an error
|
||||
const testError = new Error('Redis error');
|
||||
mockIoredisClient.call.mockRejectedValue(testError);
|
||||
|
||||
limiterCache('rate-limit');
|
||||
|
||||
const sendCommandCall = mockRedisStore.mock.calls[0][0];
|
||||
const sendCommand = sendCommandCall.sendCommand;
|
||||
|
||||
// Test that sendCommand properly handles errors
|
||||
const args = ['GET', 'test-key'];
|
||||
|
||||
await expect(sendCommand(...args)).rejects.toThrow('Redis error');
|
||||
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
|
||||
});
|
||||
|
||||
it('should handle undefined prefix', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
expect(() => limiterCache()).toThrow('prefix is required');
|
||||
});
|
||||
});
|
||||
});
|
||||
165
api/cache/getLogStores.js
vendored
165
api/cache/getLogStores.js
vendored
@@ -1,113 +1,53 @@
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
const { Keyv } = require('keyv');
|
||||
const { isEnabled, math } = require('@librechat/api');
|
||||
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const { logFile } = require('./keyvFiles');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
const isRedisEnabled = isEnabled(USE_REDIS);
|
||||
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
|
||||
|
||||
const createViolationInstance = (namespace) => {
|
||||
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
|
||||
return new Keyv(config);
|
||||
};
|
||||
|
||||
// Serve cache from memory so no need to clear it on startup/exit
|
||||
const pending_req = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
|
||||
|
||||
const config = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const roles = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||
|
||||
const mcpTools = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MCP_TOOLS });
|
||||
|
||||
const audioRuns = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const messages = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
|
||||
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
|
||||
|
||||
const flows = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
|
||||
|
||||
const tokenConfig = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const genTitle = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
const s3ExpiryInterval = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
|
||||
|
||||
const abortKeys = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const openIdExchangedTokensCache = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.OPENID_EXCHANGED_TOKENS, ttl: Time.TEN_MINUTES });
|
||||
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
|
||||
|
||||
const namespaces = {
|
||||
[CacheKeys.ROLES]: roles,
|
||||
[CacheKeys.MCP_TOOLS]: mcpTools,
|
||||
[CacheKeys.CONFIG_STORE]: config,
|
||||
[CacheKeys.PENDING_REQ]: pending_req,
|
||||
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
||||
[CacheKeys.ENCODED_DOMAINS]: new Keyv({
|
||||
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }),
|
||||
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS),
|
||||
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT),
|
||||
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER),
|
||||
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT),
|
||||
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS),
|
||||
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
|
||||
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
|
||||
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
|
||||
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
|
||||
[ViolationTypes.BAN]: new Keyv({
|
||||
store: keyvMongo,
|
||||
namespace: CacheKeys.ENCODED_DOMAINS,
|
||||
ttl: 0,
|
||||
namespace: CacheKeys.BANS,
|
||||
ttl: cacheConfig.BAN_DURATION,
|
||||
}),
|
||||
general: new Keyv({ store: logFile, namespace: 'violations' }),
|
||||
concurrent: createViolationInstance('concurrent'),
|
||||
non_browser: createViolationInstance('non_browser'),
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
|
||||
[ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
|
||||
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
|
||||
ViolationTypes.RESET_PASSWORD_LIMIT,
|
||||
|
||||
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION),
|
||||
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
|
||||
|
||||
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
|
||||
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
|
||||
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
|
||||
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
|
||||
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
|
||||
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
|
||||
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
|
||||
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES),
|
||||
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES),
|
||||
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES),
|
||||
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
|
||||
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
|
||||
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
|
||||
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
|
||||
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
|
||||
CacheKeys.OPENID_EXCHANGED_TOKENS,
|
||||
Time.TEN_MINUTES,
|
||||
),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
),
|
||||
logins: createViolationInstance('logins'),
|
||||
[CacheKeys.ABORT_KEYS]: abortKeys,
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
[CacheKeys.FLOWS]: flows,
|
||||
[CacheKeys.OPENID_EXCHANGED_TOKENS]: openIdExchangedTokensCache,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -116,7 +56,10 @@ const namespaces = {
|
||||
*/
|
||||
function getTTLStores() {
|
||||
return Object.values(namespaces).filter(
|
||||
(store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
|
||||
(store) =>
|
||||
store instanceof Keyv &&
|
||||
parseInt(store.opts?.ttl ?? '0') > 0 &&
|
||||
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
|
||||
);
|
||||
}
|
||||
|
||||
@@ -152,18 +95,18 @@ async function clearExpiredFromCache(cache) {
|
||||
if (data?.expires && data.expires <= expiryTime) {
|
||||
const deleted = await cache.opts.store.delete(key);
|
||||
if (!deleted) {
|
||||
debugMemoryCache &&
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
|
||||
continue;
|
||||
}
|
||||
cleared++;
|
||||
}
|
||||
} catch (error) {
|
||||
debugMemoryCache &&
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
|
||||
const deleted = await cache.opts.store.delete(key);
|
||||
if (!deleted) {
|
||||
debugMemoryCache &&
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
|
||||
continue;
|
||||
}
|
||||
@@ -172,7 +115,7 @@ async function clearExpiredFromCache(cache) {
|
||||
}
|
||||
|
||||
if (cleared > 0) {
|
||||
debugMemoryCache &&
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
console.log(
|
||||
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
|
||||
);
|
||||
@@ -213,7 +156,7 @@ async function clearAllExpiredFromCache() {
|
||||
}
|
||||
}
|
||||
|
||||
if (!isRedisEnabled && !isEnabled(CI)) {
|
||||
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
|
||||
/** @type {Set<NodeJS.Timeout>} */
|
||||
const cleanupIntervals = new Set();
|
||||
|
||||
@@ -224,7 +167,7 @@ if (!isRedisEnabled && !isEnabled(CI)) {
|
||||
|
||||
cleanupIntervals.add(cleanup);
|
||||
|
||||
if (debugMemoryCache) {
|
||||
if (cacheConfig.DEBUG_MEMORY_CACHE) {
|
||||
const monitor = setInterval(() => {
|
||||
const ttlStores = getTTLStores();
|
||||
const memory = process.memoryUsage();
|
||||
@@ -245,13 +188,13 @@ if (!isRedisEnabled && !isEnabled(CI)) {
|
||||
}
|
||||
|
||||
const dispose = () => {
|
||||
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
|
||||
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...');
|
||||
cleanupIntervals.forEach((interval) => clearInterval(interval));
|
||||
cleanupIntervals.clear();
|
||||
|
||||
// One final cleanup before exit
|
||||
clearAllExpiredFromCache().then(() => {
|
||||
debugMemoryCache && console.log('[Cache] Final cleanup completed');
|
||||
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed');
|
||||
process.exit(0);
|
||||
});
|
||||
};
|
||||
|
||||
92
api/cache/ioredisClient.js
vendored
92
api/cache/ioredisClient.js
vendored
@@ -1,92 +0,0 @@
|
||||
const fs = require('fs');
|
||||
const Redis = require('ioredis');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
|
||||
|
||||
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
|
||||
let ioredisClient;
|
||||
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
|
||||
|
||||
function mapURI(uri) {
|
||||
const regex =
|
||||
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
|
||||
const match = uri.match(regex);
|
||||
|
||||
if (match) {
|
||||
const { scheme, user, password, host, port } = match.groups;
|
||||
|
||||
return {
|
||||
scheme: scheme || 'none',
|
||||
user: user || null,
|
||||
password: password || null,
|
||||
host: host || null,
|
||||
port: port || null,
|
||||
};
|
||||
} else {
|
||||
const parts = uri.split(':');
|
||||
if (parts.length === 2) {
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: parts[0],
|
||||
port: parts[1],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: uri,
|
||||
port: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (REDIS_URI && isEnabled(USE_REDIS)) {
|
||||
let redisOptions = null;
|
||||
|
||||
if (REDIS_CA) {
|
||||
const ca = fs.readFileSync(REDIS_CA);
|
||||
redisOptions = { tls: { ca } };
|
||||
}
|
||||
|
||||
if (isEnabled(USE_REDIS_CLUSTER)) {
|
||||
const hosts = REDIS_URI.split(',').map((item) => {
|
||||
var value = mapURI(item);
|
||||
|
||||
return {
|
||||
host: value.host,
|
||||
port: value.port,
|
||||
};
|
||||
});
|
||||
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
|
||||
} else {
|
||||
ioredisClient = new Redis(REDIS_URI, redisOptions);
|
||||
}
|
||||
|
||||
ioredisClient.on('ready', () => {
|
||||
logger.info('IoRedis connection ready');
|
||||
});
|
||||
ioredisClient.on('reconnecting', () => {
|
||||
logger.info('IoRedis connection reconnecting');
|
||||
});
|
||||
ioredisClient.on('end', () => {
|
||||
logger.info('IoRedis connection ended');
|
||||
});
|
||||
ioredisClient.on('close', () => {
|
||||
logger.info('IoRedis connection closed');
|
||||
});
|
||||
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
|
||||
ioredisClient.setMaxListeners(redis_max_listeners);
|
||||
logger.info(
|
||||
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
|
||||
);
|
||||
} else {
|
||||
logger.info('[Optional] IoRedis not initialized for rate limiters.');
|
||||
}
|
||||
|
||||
module.exports = ioredisClient;
|
||||
109
api/cache/keyvRedis.js
vendored
109
api/cache/keyvRedis.js
vendored
@@ -1,109 +0,0 @@
|
||||
const fs = require('fs');
|
||||
const ioredis = require('ioredis');
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
|
||||
process.env;
|
||||
|
||||
let keyvRedis;
|
||||
const redis_prefix = REDIS_KEY_PREFIX || '';
|
||||
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
|
||||
|
||||
function mapURI(uri) {
|
||||
const regex =
|
||||
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
|
||||
const match = uri.match(regex);
|
||||
|
||||
if (match) {
|
||||
const { scheme, user, password, host, port } = match.groups;
|
||||
|
||||
return {
|
||||
scheme: scheme || 'none',
|
||||
user: user || null,
|
||||
password: password || null,
|
||||
host: host || null,
|
||||
port: port || null,
|
||||
};
|
||||
} else {
|
||||
const parts = uri.split(':');
|
||||
if (parts.length === 2) {
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: parts[0],
|
||||
port: parts[1],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: uri,
|
||||
port: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (REDIS_URI && isEnabled(USE_REDIS)) {
|
||||
let redisOptions = null;
|
||||
/** @type {import('@keyv/redis').KeyvRedisOptions} */
|
||||
let keyvOpts = {
|
||||
useRedisSets: false,
|
||||
keyPrefix: redis_prefix,
|
||||
};
|
||||
|
||||
if (REDIS_CA) {
|
||||
const ca = fs.readFileSync(REDIS_CA);
|
||||
redisOptions = { tls: { ca } };
|
||||
}
|
||||
|
||||
if (isEnabled(USE_REDIS_CLUSTER)) {
|
||||
const hosts = REDIS_URI.split(',').map((item) => {
|
||||
var value = mapURI(item);
|
||||
|
||||
return {
|
||||
host: value.host,
|
||||
port: value.port,
|
||||
};
|
||||
});
|
||||
const cluster = new ioredis.Cluster(hosts, { redisOptions });
|
||||
keyvRedis = new KeyvRedis(cluster, keyvOpts);
|
||||
} else {
|
||||
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
|
||||
}
|
||||
|
||||
const pingInterval = setInterval(
|
||||
() => {
|
||||
logger.debug('KeyvRedis ping');
|
||||
keyvRedis.client.ping().catch((err) => logger.error('Redis keep-alive ping failed:', err));
|
||||
},
|
||||
5 * 60 * 1000,
|
||||
);
|
||||
|
||||
keyvRedis.on('ready', () => {
|
||||
logger.info('KeyvRedis connection ready');
|
||||
});
|
||||
keyvRedis.on('reconnecting', () => {
|
||||
logger.info('KeyvRedis connection reconnecting');
|
||||
});
|
||||
keyvRedis.on('end', () => {
|
||||
logger.info('KeyvRedis connection ended');
|
||||
});
|
||||
keyvRedis.on('close', () => {
|
||||
clearInterval(pingInterval);
|
||||
logger.info('KeyvRedis connection closed');
|
||||
});
|
||||
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
|
||||
keyvRedis.setMaxListeners(redis_max_listeners);
|
||||
logger.info(
|
||||
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
|
||||
);
|
||||
} else {
|
||||
logger.info('[Optional] Redis not initialized.');
|
||||
}
|
||||
|
||||
module.exports = keyvRedis;
|
||||
5
api/cache/logViolation.js
vendored
5
api/cache/logViolation.js
vendored
@@ -1,4 +1,5 @@
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const getLogStores = require('./getLogStores');
|
||||
const banViolation = require('./banViolation');
|
||||
|
||||
@@ -9,14 +10,14 @@ const banViolation = require('./banViolation');
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {string} type - The type of violation.
|
||||
* @param {Object} errorMessage - The error message to log.
|
||||
* @param {number} [score=1] - The severity of the violation. Defaults to 1
|
||||
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
|
||||
*/
|
||||
const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||
const userId = req.user?.id ?? req.user?._id;
|
||||
if (!userId) {
|
||||
return;
|
||||
}
|
||||
const logs = getLogStores('general');
|
||||
const logs = getLogStores(ViolationTypes.GENERAL);
|
||||
const violationLogs = getLogStores(type);
|
||||
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;
|
||||
|
||||
|
||||
204
api/cache/redisClients.js
vendored
Normal file
204
api/cache/redisClients.js
vendored
Normal file
@@ -0,0 +1,204 @@
|
||||
const IoRedis = require('ioredis');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createClient, createCluster } = require('@keyv/redis');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
|
||||
const GLOBAL_PREFIX_SEPARATOR = '::';
|
||||
|
||||
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
|
||||
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
|
||||
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
|
||||
const ca = cacheConfig.REDIS_CA;
|
||||
|
||||
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
|
||||
let ioredisClient = null;
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
/** @type {import('ioredis').RedisOptions | import('ioredis').ClusterOptions} */
|
||||
const redisOptions = {
|
||||
username: username,
|
||||
password: password,
|
||||
tls: ca ? { ca } : undefined,
|
||||
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
|
||||
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
|
||||
retryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
reconnectOnError: (err) => {
|
||||
const targetError = 'READONLY';
|
||||
if (err.message.includes(targetError)) {
|
||||
logger.warn('ioredis reconnecting due to READONLY error');
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
connectTimeout: cacheConfig.REDIS_CONNECT_TIMEOUT,
|
||||
maxRetriesPerRequest: 3,
|
||||
};
|
||||
|
||||
ioredisClient =
|
||||
urls.length === 1
|
||||
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
|
||||
: new IoRedis.Cluster(cacheConfig.REDIS_URI, {
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
});
|
||||
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error('ioredis client error:', err);
|
||||
});
|
||||
|
||||
ioredisClient.on('connect', () => {
|
||||
logger.info('ioredis client connected');
|
||||
});
|
||||
|
||||
ioredisClient.on('ready', () => {
|
||||
logger.info('ioredis client ready');
|
||||
});
|
||||
|
||||
ioredisClient.on('reconnecting', (delay) => {
|
||||
logger.info(`ioredis client reconnecting in ${delay}ms`);
|
||||
});
|
||||
|
||||
ioredisClient.on('close', () => {
|
||||
logger.warn('ioredis client connection closed');
|
||||
});
|
||||
|
||||
/** Ping Interval to keep the Redis server connection alive (if enabled) */
|
||||
let pingInterval = null;
|
||||
const clearPingInterval = () => {
|
||||
if (pingInterval) {
|
||||
clearInterval(pingInterval);
|
||||
pingInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
if (cacheConfig.REDIS_PING_INTERVAL > 0) {
|
||||
pingInterval = setInterval(() => {
|
||||
if (ioredisClient && ioredisClient.status === 'ready') {
|
||||
ioredisClient.ping().catch((err) => {
|
||||
logger.error('ioredis ping failed:', err);
|
||||
});
|
||||
}
|
||||
}, cacheConfig.REDIS_PING_INTERVAL * 1000);
|
||||
ioredisClient.on('close', clearPingInterval);
|
||||
ioredisClient.on('end', clearPingInterval);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
|
||||
let keyvRedisClient = null;
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
/**
|
||||
* ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
|
||||
* The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
|
||||
* @type {import('@keyv/redis').RedisClientOptions | import('@keyv/redis').RedisClusterOptions}
|
||||
*/
|
||||
const redisOptions = {
|
||||
username,
|
||||
password,
|
||||
socket: {
|
||||
tls: ca != null,
|
||||
ca,
|
||||
connectTimeout: cacheConfig.REDIS_CONNECT_TIMEOUT,
|
||||
reconnectStrategy: (retries) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
retries > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`@keyv/redis client giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return new Error('Max reconnection attempts reached');
|
||||
}
|
||||
const delay = Math.min(retries * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`@keyv/redis reconnecting... attempt ${retries}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
},
|
||||
disableOfflineQueue: !cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
};
|
||||
|
||||
keyvRedisClient =
|
||||
urls.length === 1
|
||||
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
|
||||
: createCluster({
|
||||
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
|
||||
defaults: redisOptions,
|
||||
});
|
||||
|
||||
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
|
||||
|
||||
keyvRedisClient.on('error', (err) => {
|
||||
logger.error('@keyv/redis client error:', err);
|
||||
});
|
||||
|
||||
keyvRedisClient.on('connect', () => {
|
||||
logger.info('@keyv/redis client connected');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('ready', () => {
|
||||
logger.info('@keyv/redis client ready');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('reconnecting', () => {
|
||||
logger.info('@keyv/redis client reconnecting...');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('disconnect', () => {
|
||||
logger.warn('@keyv/redis client disconnected');
|
||||
});
|
||||
|
||||
keyvRedisClient.connect().catch((err) => {
|
||||
logger.error('@keyv/redis initial connection failed:', err);
|
||||
throw err;
|
||||
});
|
||||
|
||||
/** Ping Interval to keep the Redis server connection alive (if enabled) */
|
||||
let pingInterval = null;
|
||||
const clearPingInterval = () => {
|
||||
if (pingInterval) {
|
||||
clearInterval(pingInterval);
|
||||
pingInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
if (cacheConfig.REDIS_PING_INTERVAL > 0) {
|
||||
pingInterval = setInterval(() => {
|
||||
if (keyvRedisClient && keyvRedisClient.isReady) {
|
||||
keyvRedisClient.ping().catch((err) => {
|
||||
logger.error('@keyv/redis ping failed:', err);
|
||||
});
|
||||
}
|
||||
}, cacheConfig.REDIS_PING_INTERVAL * 1000);
|
||||
keyvRedisClient.on('disconnect', clearPingInterval);
|
||||
keyvRedisClient.on('end', clearPingInterval);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };
|
||||
@@ -61,7 +61,7 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
|
||||
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
/** @type {Record<string, FunctionTool>} */
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
|
||||
/** @type {TEphemeralAgent | null} */
|
||||
const ephemeralAgent = req.body.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
@@ -90,7 +90,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
return {
|
||||
const result = {
|
||||
id: agent_id,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
@@ -98,6 +98,11 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
|
||||
572
api/models/Conversation.spec.js
Normal file
572
api/models/Conversation.spec.js
Normal file
@@ -0,0 +1,572 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
deleteNullOrEmptyConversations,
|
||||
searchConversation,
|
||||
getConvosByCursor,
|
||||
getConvosQueried,
|
||||
getConvoFiles,
|
||||
getConvoTitle,
|
||||
deleteConvos,
|
||||
saveConvo,
|
||||
getConvo,
|
||||
} = require('./Conversation');
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
jest.mock('./Message');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
describe('Conversation Operations', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockConversationData;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear database
|
||||
await Conversation.deleteMany({});
|
||||
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Default mock implementations
|
||||
getMessages.mockResolvedValue([]);
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {},
|
||||
};
|
||||
|
||||
mockConversationData = {
|
||||
conversationId: uuidv4(),
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
};
|
||||
});
|
||||
|
||||
describe('saveConvo', () => {
|
||||
it('should save a conversation for an authenticated user', async () => {
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBe('Test Conversation');
|
||||
expect(result.endpoint).toBe(EModelEndpoint.openAI);
|
||||
|
||||
// Verify the conversation was actually saved to the database
|
||||
const savedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
expect(savedConvo).toBeTruthy();
|
||||
expect(savedConvo.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should query messages when saving a conversation', async () => {
|
||||
// Mock messages as ObjectIds
|
||||
const mongoose = require('mongoose');
|
||||
const mockMessages = [new mongoose.Types.ObjectId(), new mongoose.Types.ObjectId()];
|
||||
getMessages.mockResolvedValue(mockMessages);
|
||||
|
||||
await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
// Verify that getMessages was called with correct parameters
|
||||
expect(getMessages).toHaveBeenCalledWith(
|
||||
{ conversationId: mockConversationData.conversationId },
|
||||
'_id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle newConversationId when provided', async () => {
|
||||
const newConversationId = uuidv4();
|
||||
const result = await saveConvo(mockReq, {
|
||||
...mockConversationData,
|
||||
newConversationId,
|
||||
});
|
||||
|
||||
expect(result.conversationId).toBe(newConversationId);
|
||||
});
|
||||
|
||||
it('should handle unsetFields metadata', async () => {
|
||||
const metadata = {
|
||||
unsetFields: { someField: 1 },
|
||||
};
|
||||
|
||||
await saveConvo(mockReq, mockConversationData, metadata);
|
||||
|
||||
const savedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(savedConvo.someField).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary conversation handling', () => {
|
||||
it('should save a conversation with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is false', async () => {
|
||||
mockReq.body = { isTemporary: false };
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary in body
|
||||
mockReq.body = {};
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
// Should still save the conversation but with expiredAt as null
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should update expiredAt when saving existing temporary conversation', async () => {
|
||||
// First save a temporary conversation
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
const originalExpiredAt = firstSave.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same conversationId but different title
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockReq, updatedData);
|
||||
|
||||
// Should update title and create new expiredAt
|
||||
expect(secondSave.title).toBe('Updated Title');
|
||||
expect(secondSave.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave.expiredAt).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not set expiredAt when updating non-temporary conversation', async () => {
|
||||
// First save a non-temporary conversation
|
||||
mockReq.body = { isTemporary: false };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
expect(firstSave.expiredAt).toBeNull();
|
||||
|
||||
// Update without isTemporary flag
|
||||
mockReq.body = {};
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockReq, updatedData);
|
||||
|
||||
expect(secondSave.title).toBe('Updated Title');
|
||||
expect(secondSave.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosByCursor', async () => {
|
||||
// Create some test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Future expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours from now
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
// Mock Meili search
|
||||
Conversation.meiliSearch = jest.fn().mockResolvedValue({ hits: [] });
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// Should only return conversations with null or non-existent expiredAt
|
||||
expect(result.conversations).toHaveLength(1);
|
||||
expect(result.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosQueried', async () => {
|
||||
// Create test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
});
|
||||
|
||||
const expiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
});
|
||||
|
||||
const convoIds = [
|
||||
{ conversationId: nonExpiredConvo.conversationId },
|
||||
{ conversationId: expiredConvo.conversationId },
|
||||
];
|
||||
|
||||
const result = await getConvosQueried('user123', convoIds);
|
||||
|
||||
// Should only return the non-expired conversation
|
||||
expect(result.conversations).toHaveLength(1);
|
||||
expect(result.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
expect(result.convoMap[nonExpiredConvo.conversationId]).toBeDefined();
|
||||
expect(result.convoMap[expiredConvo.conversationId]).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchConversation', () => {
|
||||
it('should find a conversation by conversationId', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await searchConversation(mockConversationData.conversationId);
|
||||
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBeUndefined(); // Only returns conversationId and user
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await searchConversation('non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvo', () => {
|
||||
it('should retrieve a conversation for a user', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvo('user123', mockConversationData.conversationId);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await getConvo('user123', 'non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoTitle', () => {
|
||||
it('should return the conversation title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Title',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBe('Test Title');
|
||||
});
|
||||
|
||||
it('should return null if conversation has no title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: null,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return "New Chat" if conversation not found', async () => {
|
||||
const result = await getConvoTitle('user123', 'non-existent-id');
|
||||
expect(result).toBe('New Chat');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoFiles', () => {
|
||||
it('should return conversation files', async () => {
|
||||
const files = ['file1', 'file2'];
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
files,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual(files);
|
||||
});
|
||||
|
||||
it('should return empty array if no files', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return empty array if conversation not found', async () => {
|
||||
const result = await getConvoFiles('non-existent-id');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteConvos', () => {
|
||||
it('should delete conversations and associated messages', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'To Delete',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 5 });
|
||||
|
||||
const result = await deleteConvos('user123', {
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
expect(result.messages.deletedCount).toBe(5);
|
||||
expect(deleteMessages).toHaveBeenCalledWith({
|
||||
conversationId: { $in: [mockConversationData.conversationId] },
|
||||
});
|
||||
|
||||
// Verify conversation was deleted
|
||||
const deletedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(deletedConvo).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw error if no conversations found', async () => {
|
||||
await expect(deleteConvos('user123', { conversationId: 'non-existent' })).rejects.toThrow(
|
||||
'Conversation not found or already deleted.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteNullOrEmptyConversations', () => {
|
||||
it('should delete conversations with null, empty, or missing conversationIds', async () => {
|
||||
// Since conversationId is required by the schema, we can't create documents with null/missing IDs
|
||||
// This test should verify the function works when such documents exist (e.g., from data corruption)
|
||||
|
||||
// For this test, let's create a valid conversation and verify the function doesn't delete it
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
const result = await deleteNullOrEmptyConversations();
|
||||
|
||||
expect(result.conversations.deletedCount).toBe(0); // No invalid conversations to delete
|
||||
expect(result.messages.deletedCount).toBe(0);
|
||||
|
||||
// Verify valid conversation remains
|
||||
const remainingConvos = await Conversation.find({});
|
||||
expect(remainingConvos).toHaveLength(1);
|
||||
expect(remainingConvos[0].conversationId).toBe(mockConversationData.conversationId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle database errors in saveConvo', async () => {
|
||||
// Force a database error by disconnecting
|
||||
await mongoose.disconnect();
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result).toEqual({ message: 'Error saving conversation' });
|
||||
|
||||
// Reconnect for other tests
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,7 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EToolResources, FileContext } = require('librechat-data-provider');
|
||||
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { getAgent } = require('./Agent');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
/**
|
||||
@@ -12,17 +14,124 @@ const findFileById = async (file_id, options = {}) => {
|
||||
return await File.findOne({ file_id, ...options }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a user has access to multiple files through a shared agent (batch operation)
|
||||
* @param {string} userId - The user ID to check access for
|
||||
* @param {string[]} fileIds - Array of file IDs to check
|
||||
* @param {string} agentId - The agent ID that might grant access
|
||||
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
|
||||
*/
|
||||
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
|
||||
const accessMap = new Map();
|
||||
|
||||
// Initialize all files as no access
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, false));
|
||||
|
||||
try {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
|
||||
if (!agent) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if user is the author - if so, grant access to all files
|
||||
if (agent.author.toString() === userId) {
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, true));
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is shared with the user via projects
|
||||
if (!agent.projectIds || agent.projectIds.length === 0) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is in global project
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
|
||||
) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Agent is globally shared - check if it's collaborative
|
||||
if (checkCollaborative && !agent.isCollaborative) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check which files are actually attached
|
||||
const attachedFileIds = new Set();
|
||||
if (agent.tool_resources) {
|
||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grant access only to files that are attached to this agent
|
||||
fileIds.forEach((fileId) => {
|
||||
if (attachedFileIds.has(fileId)) {
|
||||
accessMap.set(fileId, true);
|
||||
}
|
||||
});
|
||||
|
||||
return accessMap;
|
||||
} catch (error) {
|
||||
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
|
||||
return accessMap;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves files matching a given filter, sorted by the most recently updated.
|
||||
* @param {Object} filter - The filter criteria to apply.
|
||||
* @param {Object} [_sortOptions] - Optional sort parameters.
|
||||
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
|
||||
* Default excludes the 'text' field.
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {string} [options.userId] - User ID for access control
|
||||
* @param {string} [options.agentId] - Agent ID that might grant access to files
|
||||
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
|
||||
*/
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
|
||||
// If userId and agentId are provided, filter files based on access
|
||||
if (options.userId && options.agentId) {
|
||||
// Collect file IDs that need access check
|
||||
const filesToCheck = [];
|
||||
const ownedFiles = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (file.user && file.user.toString() === options.userId) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
filesToCheck.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
if (filesToCheck.length === 0) {
|
||||
return ownedFiles;
|
||||
}
|
||||
|
||||
// Batch check access for all non-owned files
|
||||
const fileIds = filesToCheck.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(
|
||||
options.userId,
|
||||
fileIds,
|
||||
options.agentId,
|
||||
false,
|
||||
);
|
||||
|
||||
// Filter files based on access
|
||||
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
|
||||
|
||||
return [...ownedFiles, ...accessibleFiles];
|
||||
}
|
||||
|
||||
return files;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -176,4 +285,5 @@ module.exports = {
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
hasAccessToFilesViaAgent,
|
||||
};
|
||||
|
||||
264
api/models/File.spec.js
Normal file
264
api/models/File.spec.js
Normal file
@@ -0,0 +1,264 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { fileSchema } = require('@librechat/data-schemas');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { projectSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { getFiles, createFile } = require('./File');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { createAgent } = require('./Agent');
|
||||
|
||||
let File;
|
||||
let Agent;
|
||||
let Project;
|
||||
|
||||
describe('File Access Control', () => {
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
File = mongoose.models.File || mongoose.model('File', fileSchema);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await Project.deleteMany({});
|
||||
});
|
||||
|
||||
describe('hasAccessToFilesViaAgent', () => {
|
||||
it('should efficiently check access for multiple files at once', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: `file-${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with only first two files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0], fileIds[1]],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
|
||||
// Check access for all files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
|
||||
// Should have access only to the first two files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(false);
|
||||
expect(accessMap.get(fileIds[3])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access to all files when user is the agent author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0]], // Only one file attached
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Check access as the author
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
|
||||
|
||||
// Author should have access to all files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle non-existent agent gracefully', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
|
||||
|
||||
// Should have no access to any files
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when agent is not collaborative', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent with files but isCollaborative: false
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
|
||||
// Should have no access to any files when isCollaborative is false
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFiles with agent access control', () => {
|
||||
test('should return files owned by user and files accessible through agent', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const ownedFileId = `file_${uuidv4()}`;
|
||||
const sharedFileId = `file_${uuidv4()}`;
|
||||
const inaccessibleFileId = `file_${uuidv4()}`;
|
||||
|
||||
// Create/get global project using getProjectByName which will upsert
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
|
||||
|
||||
// Create agent with shared file
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Shared Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
projectIds: [globalProject._id],
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [sharedFileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
file_id: ownedFileId,
|
||||
user: userId,
|
||||
filename: 'owned.txt',
|
||||
filepath: '/uploads/owned.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: sharedFileId,
|
||||
user: authorId,
|
||||
filename: 'shared.txt',
|
||||
filepath: '/uploads/shared.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
embedded: true,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: inaccessibleFileId,
|
||||
user: authorId,
|
||||
filename: 'inaccessible.txt',
|
||||
filepath: '/uploads/inaccessible.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 300,
|
||||
});
|
||||
|
||||
// Get files with access control
|
||||
const files = await getFiles(
|
||||
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: userId.toString(), agentId },
|
||||
);
|
||||
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
|
||||
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
|
||||
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
|
||||
});
|
||||
|
||||
test('should return all files when no userId/agentId provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileId1 = `file_${uuidv4()}`;
|
||||
const fileId2 = `file_${uuidv4()}`;
|
||||
|
||||
await createFile({
|
||||
file_id: fileId1,
|
||||
user: userId,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: fileId2,
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
});
|
||||
|
||||
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { messageSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
updateMessage,
|
||||
deleteMessages,
|
||||
bulkSaveMessages,
|
||||
updateMessageText,
|
||||
deleteMessagesSince,
|
||||
} = require('./Message');
|
||||
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IMessage>}
|
||||
*/
|
||||
@@ -117,21 +121,21 @@ describe('Message Operations', () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Create multiple messages in the same conversation
|
||||
const message1 = await saveMessage(mockReq, {
|
||||
await saveMessage(mockReq, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'First message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
const message2 = await saveMessage(mockReq, {
|
||||
await saveMessage(mockReq, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
text: 'Second message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
const message3 = await saveMessage(mockReq, {
|
||||
await saveMessage(mockReq, {
|
||||
messageId: 'msg3',
|
||||
conversationId,
|
||||
text: 'Third message',
|
||||
@@ -314,4 +318,265 @@ describe('Message Operations', () => {
|
||||
expect(messages[0].text).toBe('Victim message');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary message handling', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should save a message with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is false', async () => {
|
||||
mockReq.body = { isTemporary: false };
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary in body
|
||||
mockReq.body = {};
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
// Should still save the message but with expiredAt as null
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not update expiredAt on message update', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const savedMessage = await saveMessage(mockReq, mockMessageData);
|
||||
const originalExpiredAt = savedMessage.expiredAt;
|
||||
|
||||
// Now update the message without isTemporary flag
|
||||
mockReq.body = {};
|
||||
const updatedMessage = await updateMessage(mockReq, {
|
||||
messageId: 'msg123',
|
||||
text: 'Updated text',
|
||||
});
|
||||
|
||||
// expiredAt should not be in the returned updated message object
|
||||
expect(updatedMessage.expiredAt).toBeUndefined();
|
||||
|
||||
// Verify in database that expiredAt wasn't changed
|
||||
const dbMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(dbMessage.expiredAt).toEqual(originalExpiredAt);
|
||||
});
|
||||
|
||||
it('should preserve expiredAt when saving existing temporary message', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveMessage(mockReq, mockMessageData);
|
||||
const originalExpiredAt = firstSave.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same messageId but different text
|
||||
const updatedData = { ...mockMessageData, text: 'Updated text' };
|
||||
const secondSave = await saveMessage(mockReq, updatedData);
|
||||
|
||||
// Should update text but create new expiredAt
|
||||
expect(secondSave.text).toBe('Updated text');
|
||||
expect(secondSave.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave.expiredAt).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle bulk operations with temporary messages', async () => {
|
||||
// This test verifies bulkSaveMessages doesn't interfere with expiredAt
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'bulk1',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 1',
|
||||
user: 'user123',
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
},
|
||||
{
|
||||
messageId: 'bulk2',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 2',
|
||||
user: 'user123',
|
||||
expiredAt: null,
|
||||
},
|
||||
];
|
||||
|
||||
await bulkSaveMessages(messages);
|
||||
|
||||
const savedMessages = await Message.find({
|
||||
messageId: { $in: ['bulk1', 'bulk2'] },
|
||||
}).lean();
|
||||
|
||||
expect(savedMessages).toHaveLength(2);
|
||||
|
||||
const bulk1 = savedMessages.find((m) => m.messageId === 'bulk1');
|
||||
const bulk2 = savedMessages.find((m) => m.messageId === 'bulk2');
|
||||
|
||||
expect(bulk1.expiredAt).toBeDefined();
|
||||
expect(bulk2.expiredAt).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -135,10 +135,11 @@ const tokenValues = Object.assign(
|
||||
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
|
||||
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
|
||||
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
|
||||
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
|
||||
'grok-3': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-4': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'mistral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'pixtral-large': { prompt: 2.0, completion: 6.0 },
|
||||
|
||||
@@ -636,6 +636,15 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 model', () => {
|
||||
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-3'].prompt,
|
||||
@@ -662,6 +671,15 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
tokenValues['grok-3-mini-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.7.8",
|
||||
"version": "v0.7.9",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -44,19 +44,21 @@
|
||||
"@googleapis/youtube": "^20.0.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/community": "^0.3.47",
|
||||
"@langchain/core": "^0.3.60",
|
||||
"@langchain/core": "^0.3.62",
|
||||
"@langchain/google-genai": "^0.2.13",
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.46",
|
||||
"@librechat/agents": "^2.4.68",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.7.4",
|
||||
"connect-redis": "^7.1.0",
|
||||
"compression": "^1.8.1",
|
||||
"connect-redis": "^8.1.0",
|
||||
"cookie": "^0.7.2",
|
||||
"cookie-parser": "^1.4.7",
|
||||
"cors": "^2.8.5",
|
||||
@@ -66,10 +68,11 @@
|
||||
"express": "^4.21.2",
|
||||
"express-mongo-sanitize": "^2.2.0",
|
||||
"express-rate-limit": "^7.4.1",
|
||||
"express-session": "^1.18.1",
|
||||
"express-session": "^1.18.2",
|
||||
"express-static-gzip": "^2.2.0",
|
||||
"file-type": "^18.7.0",
|
||||
"firebase": "^11.0.2",
|
||||
"form-data": "^4.0.4",
|
||||
"googleapis": "^126.0.1",
|
||||
"handlebars": "^4.7.7",
|
||||
"https-proxy-agent": "^7.0.6",
|
||||
@@ -87,12 +90,12 @@
|
||||
"mime": "^3.0.0",
|
||||
"module-alias": "^2.2.3",
|
||||
"mongoose": "^8.12.1",
|
||||
"multer": "^2.0.1",
|
||||
"multer": "^2.0.2",
|
||||
"nanoid": "^3.3.7",
|
||||
"node-fetch": "^2.7.0",
|
||||
"nodemailer": "^6.9.15",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "^4.96.2",
|
||||
"openai": "^5.10.1",
|
||||
"openai-chat-tokens": "^0.2.8",
|
||||
"openid-client": "^6.5.0",
|
||||
"passport": "^0.6.0",
|
||||
@@ -117,7 +120,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"jest": "^29.7.0",
|
||||
"mongodb-memory-server": "^10.1.3",
|
||||
"mongodb-memory-server": "^10.1.4",
|
||||
"nodemon": "^3.0.3",
|
||||
"supertest": "^7.1.0"
|
||||
}
|
||||
|
||||
@@ -24,17 +24,23 @@ const handleValidationError = (err, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
module.exports = (err, req, res, next) => {
|
||||
module.exports = (err, _req, res, _next) => {
|
||||
try {
|
||||
if (err.name === 'ValidationError') {
|
||||
return (err = handleValidationError(err, res));
|
||||
return handleValidationError(err, res);
|
||||
}
|
||||
if (err.code && err.code == 11000) {
|
||||
return (err = handleDuplicateKeyError(err, res));
|
||||
return handleDuplicateKeyError(err, res);
|
||||
}
|
||||
} catch (err) {
|
||||
// Special handling for errors like SyntaxError
|
||||
if (err.statusCode && err.body) {
|
||||
return res.status(err.statusCode).send(err.body);
|
||||
}
|
||||
|
||||
logger.error('ErrorController => error', err);
|
||||
res.status(500).send('An unknown error occurred.');
|
||||
return res.status(500).send('An unknown error occurred.');
|
||||
} catch (err) {
|
||||
logger.error('ErrorController => processing error', err);
|
||||
return res.status(500).send('Processing error in ErrorController.');
|
||||
}
|
||||
};
|
||||
|
||||
241
api/server/controllers/ErrorController.spec.js
Normal file
241
api/server/controllers/ErrorController.spec.js
Normal file
@@ -0,0 +1,241 @@
|
||||
const errorController = require('./ErrorController');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// Mock the logger
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ErrorController', () => {
|
||||
let mockReq, mockRes, mockNext;
|
||||
|
||||
beforeEach(() => {
|
||||
mockReq = {};
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
mockNext = jest.fn();
|
||||
logger.error.mockClear();
|
||||
});
|
||||
|
||||
describe('ValidationError handling', () => {
|
||||
it('should handle ValidationError with single error', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '["Email is required"]',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with multiple errors', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
password: { message: 'Password is required', path: 'password' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '"Email is required Password is required"',
|
||||
fields: '["email","password"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with empty errors object', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '[]',
|
||||
fields: '[]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Duplicate key error handling', () => {
|
||||
it('should handle duplicate key error (code 11000)', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle duplicate key error with multiple fields', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com', username: 'testuser' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email","username"] already exists.',
|
||||
fields: '["email","username"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle error with code 11000 as string', () => {
|
||||
const duplicateKeyError = {
|
||||
code: '11000',
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('SyntaxError handling', () => {
|
||||
it('should handle errors with statusCode and body', () => {
|
||||
const syntaxError = {
|
||||
statusCode: 400,
|
||||
body: 'Invalid JSON syntax',
|
||||
};
|
||||
|
||||
errorController(syntaxError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
||||
});
|
||||
|
||||
it('should handle errors with different statusCode and body', () => {
|
||||
const customError = {
|
||||
statusCode: 422,
|
||||
body: { error: 'Unprocessable entity' },
|
||||
};
|
||||
|
||||
errorController(customError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(422);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
||||
});
|
||||
|
||||
it('should handle error with statusCode but no body', () => {
|
||||
const partialError = {
|
||||
statusCode: 400,
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
|
||||
it('should handle error with body but no statusCode', () => {
|
||||
const partialError = {
|
||||
body: 'Some error message',
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Unknown error handling', () => {
|
||||
it('should handle unknown errors', () => {
|
||||
const unknownError = new Error('Some unknown error');
|
||||
|
||||
errorController(unknownError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
|
||||
});
|
||||
|
||||
it('should handle errors with code other than 11000', () => {
|
||||
const mongoError = {
|
||||
code: 11100,
|
||||
message: 'Some MongoDB error',
|
||||
};
|
||||
|
||||
errorController(mongoError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
|
||||
});
|
||||
|
||||
it('should handle null/undefined errors', () => {
|
||||
errorController(null, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'ErrorController => processing error',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Catch block handling', () => {
|
||||
beforeEach(() => {
|
||||
// Restore logger mock to normal behavior for these tests
|
||||
logger.error.mockRestore();
|
||||
logger.error = jest.fn();
|
||||
});
|
||||
|
||||
it('should handle errors when logger.error throws', () => {
|
||||
// Create fresh mocks for this test
|
||||
const freshMockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
|
||||
// Mock logger to throw on the first call, succeed on the second
|
||||
logger.error
|
||||
.mockImplementationOnce(() => {
|
||||
throw new Error('Logger error');
|
||||
})
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const testError = new Error('Test error');
|
||||
|
||||
errorController(testError, mockReq, freshMockRes, mockNext);
|
||||
|
||||
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,10 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, AuthType } = require('librechat-data-provider');
|
||||
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getToolkitKey } = require('~/server/services/ToolService');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { availableTools } = require('~/app/clients/tools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Filters out duplicate plugins from the list of plugins.
|
||||
@@ -139,15 +138,21 @@ function createGetServerTools() {
|
||||
*/
|
||||
const getAvailableTools = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user?.id;
|
||||
const customConfig = await getCustomConfig();
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedTools = await cache.get(CacheKeys.TOOLS);
|
||||
if (cachedTools) {
|
||||
res.status(200).json(cachedTools);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
const userPlugins = convertMCPToolsToPlugins(cachedUserTools, customConfig);
|
||||
|
||||
if (cachedToolsArray && userPlugins) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
return;
|
||||
}
|
||||
|
||||
// If not in cache, build from manifest
|
||||
let pluginManifest = availableTools;
|
||||
const customConfig = await getCustomConfig();
|
||||
if (customConfig?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
@@ -173,7 +178,7 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
|
||||
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
@@ -218,16 +223,70 @@ const getAvailableTools = async (req, res) => {
|
||||
|
||||
toolsOutput.push(toolToAdd);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
res.status(200).json(finalTools);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...finalTools]);
|
||||
|
||||
res.status(200).json(dedupedTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Converts MCP function format tools to plugin format
|
||||
* @param {Object} functionTools - Object with function format tools
|
||||
* @param {Object} customConfig - Custom configuration for MCP servers
|
||||
* @returns {Array} Array of plugin objects
|
||||
*/
|
||||
function convertMCPToolsToPlugins(functionTools, customConfig) {
|
||||
const plugins = [];
|
||||
|
||||
for (const [toolKey, toolData] of Object.entries(functionTools)) {
|
||||
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const functionData = toolData.function;
|
||||
const parts = toolKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
|
||||
const serverConfig = customConfig?.mcpServers?.[serverName];
|
||||
|
||||
const plugin = {
|
||||
name: parts[0], // Use the tool name without server suffix
|
||||
pluginKey: toolKey,
|
||||
description: functionData.description || '',
|
||||
authenticated: true,
|
||||
icon: serverConfig?.iconPath,
|
||||
};
|
||||
|
||||
// Build authConfig for MCP tools
|
||||
if (!serverConfig?.customUserVars) {
|
||||
plugin.authConfig = [];
|
||||
plugins.push(plugin);
|
||||
continue;
|
||||
}
|
||||
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
if (customVarKeys.length === 0) {
|
||||
plugin.authConfig = [];
|
||||
} else {
|
||||
plugin.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}));
|
||||
}
|
||||
|
||||
plugins.push(plugin);
|
||||
}
|
||||
|
||||
return plugins;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getAvailableTools,
|
||||
getAvailablePluginsController,
|
||||
|
||||
89
api/server/controllers/PluginController.spec.js
Normal file
89
api/server/controllers/PluginController.spec.js
Normal file
@@ -0,0 +1,89 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
// Mock the dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomConfig: jest.fn(),
|
||||
getCachedTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ToolService', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(() => ({
|
||||
loadManifestTools: jest.fn().mockResolvedValue([]),
|
||||
})),
|
||||
getFlowStateManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/app/clients/tools', () => ({
|
||||
availableTools: [],
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import the actual module with the function we want to test
|
||||
const { getAvailableTools } = require('./PluginController');
|
||||
|
||||
describe('PluginController', () => {
|
||||
describe('plugin.icon behavior', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
|
||||
const callGetAvailableToolsWithMCPServer = async (mcpServers) => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue({ mcpServers });
|
||||
|
||||
const functionTools = {
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'test-tool', description: 'A test tool' },
|
||||
},
|
||||
};
|
||||
getCachedTools.mockResolvedValueOnce(functionTools);
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
return responseData.find((tool) => tool.name === 'test-tool');
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = { user: { id: 'test-user-id' } };
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
});
|
||||
|
||||
it('should set plugin.icon when iconPath is defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {
|
||||
iconPath: '/path/to/icon.png',
|
||||
},
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(mcpServers);
|
||||
expect(testTool.icon).toBe('/path/to/icon.png');
|
||||
});
|
||||
|
||||
it('should set plugin.icon to undefined when iconPath is not defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {},
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(mcpServers);
|
||||
expect(testTool.icon).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,5 @@
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
FileSources,
|
||||
webSearchKeys,
|
||||
extractWebSearchEnvVars,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { webSearchKeys, extractWebSearchEnvVars } = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
updateUser,
|
||||
@@ -20,6 +14,7 @@ const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/service
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { Tools, Constants, FileSources } = require('librechat-data-provider');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User } = require('~/db/models');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
@@ -180,14 +175,16 @@ const updateUserPluginsController = async (req, res) => {
|
||||
try {
|
||||
const mcpManager = getMCPManager(user.id);
|
||||
if (mcpManager) {
|
||||
// Extract server name from pluginKey (format: "mcp_<serverName>")
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
logger.info(
|
||||
`[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnections(user.id);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
}
|
||||
} catch (disconnectError) {
|
||||
logger.error(
|
||||
`[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`,
|
||||
`[updateUserPluginsController] Error disconnecting MCP connection for user ${user.id} after plugin auth update:`,
|
||||
disconnectError,
|
||||
);
|
||||
// Do not fail the request for this, but log it.
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
sendEvent,
|
||||
createRun,
|
||||
Tokenizer,
|
||||
checkAccess,
|
||||
memoryInstructions,
|
||||
formatContentStrings,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Callback,
|
||||
Providers,
|
||||
GraphEvents,
|
||||
TitleMethod,
|
||||
formatMessage,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
getTokenCountForMessage,
|
||||
createMetadataAggregator,
|
||||
} = require('@librechat/agents');
|
||||
@@ -24,20 +27,22 @@ const {
|
||||
VisionModes,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
KnownEndpoints,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config');
|
||||
const {
|
||||
findPluginAuthsByKeys,
|
||||
getFormattedMemories,
|
||||
deleteMemory,
|
||||
setMemory,
|
||||
} = require('~/models');
|
||||
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
@@ -54,6 +59,7 @@ const omitTitleOptions = new Set([
|
||||
'thinkingBudget',
|
||||
'includeThoughts',
|
||||
'maxOutputTokens',
|
||||
'additionalModelRequestFields',
|
||||
]);
|
||||
|
||||
/**
|
||||
@@ -65,13 +71,15 @@ const payloadParser = ({ req, agent, endpoint }) => {
|
||||
if (isAgentsEndpoint(endpoint)) {
|
||||
return { model: undefined };
|
||||
} else if (endpoint === EModelEndpoint.bedrock) {
|
||||
return bedrockInputSchema.parse(agent.model_parameters);
|
||||
const parsedValues = bedrockInputSchema.parse(agent.model_parameters);
|
||||
if (parsedValues.thinking == null) {
|
||||
parsedValues.thinking = false;
|
||||
}
|
||||
return parsedValues;
|
||||
}
|
||||
return req.body.endpointOption.model_parameters;
|
||||
};
|
||||
|
||||
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
|
||||
|
||||
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
|
||||
|
||||
function createTokenCounter(encoding) {
|
||||
@@ -452,6 +460,12 @@ class AgentClient extends BaseClient {
|
||||
res: this.options.res,
|
||||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
endpointOption: {
|
||||
endpoint:
|
||||
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
},
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
@@ -525,7 +539,10 @@ class AgentClient extends BaseClient {
|
||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||
}
|
||||
}
|
||||
return await this.processMemory(messagesToProcess);
|
||||
|
||||
const bufferString = getBufferString(messagesToProcess);
|
||||
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
|
||||
return await this.processMemory([bufferMessage]);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
@@ -697,17 +714,12 @@ class AgentClient extends BaseClient {
|
||||
version: 'v2',
|
||||
};
|
||||
|
||||
const getUserMCPAuthMap = await createGetMCPAuthMap();
|
||||
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
this.indexTokenCountMap,
|
||||
toolSet,
|
||||
);
|
||||
if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) {
|
||||
initialMessages = formatContentStrings(initialMessages);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -723,6 +735,9 @@ class AgentClient extends BaseClient {
|
||||
if (i > 0) {
|
||||
this.model = agent.model_parameters.model;
|
||||
}
|
||||
if (i > 0 && config.signal == null) {
|
||||
config.signal = abortController.signal;
|
||||
}
|
||||
if (agent.recursion_limit && typeof agent.recursion_limit === 'number') {
|
||||
config.recursionLimit = agent.recursion_limit;
|
||||
}
|
||||
@@ -771,6 +786,9 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
let messages = _messages;
|
||||
if (agent.useLegacyContent === true) {
|
||||
messages = formatContentStrings(messages);
|
||||
}
|
||||
if (
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
|
||||
'prompt-caching',
|
||||
@@ -819,10 +837,11 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
try {
|
||||
if (getUserMCPAuthMap) {
|
||||
config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
|
||||
if (await hasCustomUserVars()) {
|
||||
config.configurable.userMCPAuthMap = await getMCPAuthMap({
|
||||
tools: agent.tools,
|
||||
userId: this.options.req.user.id,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -998,25 +1017,40 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
||||
const { req, res, agent } = this.options;
|
||||
const endpoint = agent.endpoint;
|
||||
let endpoint = agent.endpoint;
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
let clientOptions = {
|
||||
maxTokens: 75,
|
||||
model: agent.model_parameters.model,
|
||||
model: agent.model || agent.model_parameters.model,
|
||||
};
|
||||
|
||||
const { getOptions, overrideProvider, customEndpointConfig } =
|
||||
await getProviderConfig(endpoint);
|
||||
let titleProviderConfig = await getProviderConfig(endpoint);
|
||||
|
||||
/** @type {TEndpoint | undefined} */
|
||||
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
|
||||
const endpointConfig =
|
||||
req.app.locals.all ?? req.app.locals[endpoint] ?? titleProviderConfig.customEndpointConfig;
|
||||
if (!endpointConfig) {
|
||||
logger.warn(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
|
||||
);
|
||||
}
|
||||
|
||||
if (endpointConfig?.titleEndpoint && endpointConfig.titleEndpoint !== endpoint) {
|
||||
try {
|
||||
titleProviderConfig = await getProviderConfig(endpointConfig.titleEndpoint);
|
||||
endpoint = endpointConfig.titleEndpoint;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[api/server/controllers/agents/client.js #titleConvo] Error getting title endpoint config for ${endpointConfig.titleEndpoint}, falling back to default`,
|
||||
error,
|
||||
);
|
||||
// Fall back to original provider config
|
||||
endpoint = agent.endpoint;
|
||||
titleProviderConfig = await getProviderConfig(endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
endpointConfig &&
|
||||
endpointConfig.titleModel &&
|
||||
@@ -1025,7 +1059,7 @@ class AgentClient extends BaseClient {
|
||||
clientOptions.model = endpointConfig.titleModel;
|
||||
}
|
||||
|
||||
const options = await getOptions({
|
||||
const options = await titleProviderConfig.getOptions({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
@@ -1034,12 +1068,18 @@ class AgentClient extends BaseClient {
|
||||
endpointOption: { model_parameters: clientOptions },
|
||||
});
|
||||
|
||||
let provider = options.provider ?? overrideProvider ?? agent.provider;
|
||||
let provider = options.provider ?? titleProviderConfig.overrideProvider ?? agent.provider;
|
||||
if (
|
||||
endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
) {
|
||||
provider = Providers.OPENAI;
|
||||
} else if (
|
||||
endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName != null &&
|
||||
provider !== Providers.AZURE
|
||||
) {
|
||||
provider = Providers.AZURE;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
@@ -1061,16 +1101,23 @@ class AgentClient extends BaseClient {
|
||||
),
|
||||
);
|
||||
|
||||
if (provider === Providers.GOOGLE) {
|
||||
if (
|
||||
provider === Providers.GOOGLE &&
|
||||
(endpointConfig?.titleMethod === TitleMethod.FUNCTIONS ||
|
||||
endpointConfig?.titleMethod === TitleMethod.STRUCTURED)
|
||||
) {
|
||||
clientOptions.json = true;
|
||||
}
|
||||
|
||||
try {
|
||||
const titleResult = await this.run.generateTitle({
|
||||
provider,
|
||||
clientOptions,
|
||||
inputText: text,
|
||||
contentParts: this.contentParts,
|
||||
clientOptions,
|
||||
titleMethod: endpointConfig?.titleMethod,
|
||||
titlePrompt: endpointConfig?.titlePrompt,
|
||||
titlePromptTemplate: endpointConfig?.titlePromptTemplate,
|
||||
chainOptions: {
|
||||
signal: abortController.signal,
|
||||
callbacks: [
|
||||
@@ -1118,8 +1165,52 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
/** Silent method, as `recordCollectedUsage` is used instead */
|
||||
async recordTokenUsage() {}
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {number} params.promptTokens
|
||||
* @param {number} params.completionTokens
|
||||
* @param {OpenAIUsageMetadata} [params.usage]
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.context='message']
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
|
||||
try {
|
||||
await spendTokens(
|
||||
{
|
||||
model,
|
||||
context,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
if (
|
||||
usage &&
|
||||
typeof usage === 'object' &&
|
||||
'reasoning_tokens' in usage &&
|
||||
typeof usage.reasoning_tokens === 'number'
|
||||
) {
|
||||
await spendTokens(
|
||||
{
|
||||
model,
|
||||
context: 'reasoning',
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ completionTokens: usage.reasoning_tokens },
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #recordTokenUsage] Error recording token usage',
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
getEncoding() {
|
||||
return 'o200k_base';
|
||||
|
||||
730
api/server/controllers/agents/client.test.js
Normal file
730
api/server/controllers/agents/client.test.js
Normal file
@@ -0,0 +1,730 @@
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { Constants, EModelEndpoint } = require('librechat-data-provider');
|
||||
const AgentClient = require('./client');
|
||||
|
||||
jest.mock('@librechat/agents', () => ({
|
||||
...jest.requireActual('@librechat/agents'),
|
||||
createMetadataAggregator: () => ({
|
||||
handleLLMEnd: jest.fn(),
|
||||
collected: [],
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('AgentClient - titleConvo', () => {
|
||||
let client;
|
||||
let mockRun;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
let mockAgent;
|
||||
let mockOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Mock run object
|
||||
mockRun = {
|
||||
generateTitle: jest.fn().mockResolvedValue({
|
||||
title: 'Generated Title',
|
||||
}),
|
||||
};
|
||||
|
||||
// Mock agent - with both endpoint and provider
|
||||
mockAgent = {
|
||||
id: 'agent-123',
|
||||
endpoint: EModelEndpoint.openAI, // Use a valid provider as endpoint for getProviderConfig
|
||||
provider: EModelEndpoint.openAI, // Add provider property
|
||||
model_parameters: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
};
|
||||
|
||||
// Mock request and response
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
// Match the agent endpoint
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titlePrompt: 'Custom title prompt',
|
||||
titleMethod: 'structured',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
},
|
||||
},
|
||||
},
|
||||
user: {
|
||||
id: 'user-123',
|
||||
},
|
||||
body: {
|
||||
model: 'gpt-4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
key: null,
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {};
|
||||
|
||||
// Mock options
|
||||
mockOptions = {
|
||||
req: mockReq,
|
||||
res: mockRes,
|
||||
agent: mockAgent,
|
||||
endpointTokenConfig: {},
|
||||
};
|
||||
|
||||
// Create client instance
|
||||
client = new AgentClient(mockOptions);
|
||||
client.run = mockRun;
|
||||
client.responseMessageId = 'response-123';
|
||||
client.conversationId = 'convo-123';
|
||||
client.contentParts = [{ type: 'text', text: 'Test content' }];
|
||||
client.recordCollectedUsage = jest.fn().mockResolvedValue(); // Mock as async function that resolves
|
||||
});
|
||||
|
||||
describe('titleConvo method', () => {
|
||||
it('should throw error if run is not initialized', async () => {
|
||||
client.run = null;
|
||||
|
||||
await expect(
|
||||
client.titleConvo({ text: 'Test', abortController: new AbortController() }),
|
||||
).rejects.toThrow('Run not initialized');
|
||||
});
|
||||
|
||||
it('should use titlePrompt from endpoint config', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titlePrompt: 'Custom title prompt',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use titlePromptTemplate from endpoint config', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use titleMethod from endpoint config', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: Providers.OPENAI,
|
||||
titleMethod: 'structured',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use titleModel from endpoint config when provided', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Check that generateTitle was called with correct clientOptions
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
|
||||
it('should handle missing endpoint config gracefully', async () => {
|
||||
// Remove endpoint config
|
||||
mockReq.app.locals[EModelEndpoint.openAI] = undefined;
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titlePrompt: undefined,
|
||||
titlePromptTemplate: undefined,
|
||||
titleMethod: undefined,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use agent model when titleModel is not provided', async () => {
|
||||
// Remove titleModel from config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titleModel;
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('gpt-4'); // Should use agent's model
|
||||
});
|
||||
|
||||
it('should not use titleModel when it equals CURRENT_MODEL constant', async () => {
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleModel = Constants.CURRENT_MODEL;
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('gpt-4'); // Should use agent's model
|
||||
});
|
||||
|
||||
it('should pass all required parameters to generateTitle', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith({
|
||||
provider: expect.any(String),
|
||||
inputText: text,
|
||||
contentParts: client.contentParts,
|
||||
clientOptions: expect.objectContaining({
|
||||
model: 'gpt-3.5-turbo',
|
||||
}),
|
||||
titlePrompt: 'Custom title prompt',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
titleMethod: 'structured',
|
||||
chainOptions: expect.objectContaining({
|
||||
signal: abortController.signal,
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('should record collected usage after title generation', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(client.recordCollectedUsage).toHaveBeenCalledWith({
|
||||
model: 'gpt-3.5-turbo',
|
||||
context: 'title',
|
||||
collectedUsage: expect.any(Array),
|
||||
});
|
||||
});
|
||||
|
||||
it('should return the generated title', async () => {
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
const result = await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(result).toBe('Generated Title');
|
||||
});
|
||||
|
||||
it('should handle errors gracefully and return undefined', async () => {
|
||||
mockRun.generateTitle.mockRejectedValue(new Error('Title generation failed'));
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
const result = await client.titleConvo({ text, abortController });
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should pass titleEndpoint configuration to generateTitle', async () => {
|
||||
// Mock the API key just for this test
|
||||
const originalApiKey = process.env.ANTHROPIC_API_KEY;
|
||||
process.env.ANTHROPIC_API_KEY = 'test-api-key';
|
||||
|
||||
// Add titleEndpoint to the config
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleEndpoint = EModelEndpoint.anthropic;
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleMethod = 'structured';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titlePrompt = 'Custom title prompt';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate = 'Custom template';
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify generateTitle was called with the custom configuration
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titleMethod: 'structured',
|
||||
provider: Providers.ANTHROPIC,
|
||||
titlePrompt: 'Custom title prompt',
|
||||
titlePromptTemplate: 'Custom template',
|
||||
}),
|
||||
);
|
||||
|
||||
// Restore the original API key
|
||||
if (originalApiKey) {
|
||||
process.env.ANTHROPIC_API_KEY = originalApiKey;
|
||||
} else {
|
||||
delete process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
});
|
||||
|
||||
it('should use all config when endpoint config is missing', async () => {
|
||||
// Remove endpoint-specific config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titleModel;
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titlePrompt;
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titleMethod;
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate;
|
||||
|
||||
// Set 'all' config
|
||||
mockReq.app.locals.all = {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titleMethod: 'completion',
|
||||
titlePromptTemplate: 'All config template: {{content}}',
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify generateTitle was called with 'all' config values
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titlePromptTemplate: 'All config template: {{content}}',
|
||||
}),
|
||||
);
|
||||
|
||||
// Check that the model was set from 'all' config
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('gpt-4o-mini');
|
||||
});
|
||||
|
||||
it('should prioritize all config over endpoint config for title settings', async () => {
|
||||
// Set both endpoint and 'all' config
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleModel = 'gpt-3.5-turbo';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titlePrompt = 'Endpoint title prompt';
|
||||
mockReq.app.locals[EModelEndpoint.openAI].titleMethod = 'structured';
|
||||
// Remove titlePromptTemplate from endpoint config to test fallback
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI].titlePromptTemplate;
|
||||
|
||||
mockReq.app.locals.all = {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titleMethod: 'completion',
|
||||
titlePromptTemplate: 'All config template',
|
||||
};
|
||||
|
||||
const text = 'Test conversation text';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify 'all' config takes precedence over endpoint config
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'All config title prompt',
|
||||
titlePromptTemplate: 'All config template',
|
||||
}),
|
||||
);
|
||||
|
||||
// Check that the model was set from 'all' config
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('gpt-4o-mini');
|
||||
});
|
||||
|
||||
it('should use all config with titleEndpoint and verify provider switch', async () => {
|
||||
// Mock the API key for the titleEndpoint provider
|
||||
const originalApiKey = process.env.ANTHROPIC_API_KEY;
|
||||
process.env.ANTHROPIC_API_KEY = 'test-anthropic-key';
|
||||
|
||||
// Remove endpoint-specific config to test 'all' config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI];
|
||||
|
||||
// Set comprehensive 'all' config with all new title options
|
||||
mockReq.app.locals.all = {
|
||||
titleConvo: true,
|
||||
titleModel: 'claude-3-haiku-20240307',
|
||||
titleMethod: 'completion', // Testing the new default method
|
||||
titlePrompt: 'Generate a concise, descriptive title for this conversation',
|
||||
titlePromptTemplate: 'Conversation summary: {{content}}',
|
||||
titleEndpoint: EModelEndpoint.anthropic, // Should switch provider to Anthropic
|
||||
};
|
||||
|
||||
const text = 'Test conversation about AI and machine learning';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify all config values were used
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: Providers.ANTHROPIC, // Critical: Verify provider switched to Anthropic
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'Generate a concise, descriptive title for this conversation',
|
||||
titlePromptTemplate: 'Conversation summary: {{content}}',
|
||||
inputText: text,
|
||||
contentParts: client.contentParts,
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify the model was set from 'all' config
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('claude-3-haiku-20240307');
|
||||
|
||||
// Verify other client options are set correctly
|
||||
expect(generateTitleCall.clientOptions).toMatchObject({
|
||||
model: 'claude-3-haiku-20240307',
|
||||
// Note: Anthropic's getOptions may set its own maxTokens value
|
||||
});
|
||||
|
||||
// Restore the original API key
|
||||
if (originalApiKey) {
|
||||
process.env.ANTHROPIC_API_KEY = originalApiKey;
|
||||
} else {
|
||||
delete process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
});
|
||||
|
||||
it('should test all titleMethod options from all config', async () => {
|
||||
// Test each titleMethod: 'completion', 'functions', 'structured'
|
||||
const titleMethods = ['completion', 'functions', 'structured'];
|
||||
|
||||
for (const method of titleMethods) {
|
||||
// Clear previous calls
|
||||
mockRun.generateTitle.mockClear();
|
||||
|
||||
// Remove endpoint config
|
||||
delete mockReq.app.locals[EModelEndpoint.openAI];
|
||||
|
||||
// Set 'all' config with specific titleMethod
|
||||
mockReq.app.locals.all = {
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titleMethod: method,
|
||||
titlePrompt: `Testing ${method} method`,
|
||||
titlePromptTemplate: `Template for ${method}: {{content}}`,
|
||||
};
|
||||
|
||||
const text = `Test conversation for ${method} method`;
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify the correct titleMethod was used
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
titleMethod: method,
|
||||
titlePrompt: `Testing ${method} method`,
|
||||
titlePromptTemplate: `Template for ${method}: {{content}}`,
|
||||
}),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
describe('Azure-specific title generation', () => {
|
||||
let originalEnv;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Save original environment variables
|
||||
originalEnv = { ...process.env };
|
||||
|
||||
// Mock Azure API keys
|
||||
process.env.AZURE_OPENAI_API_KEY = 'test-azure-key';
|
||||
process.env.AZURE_API_KEY = 'test-azure-key';
|
||||
process.env.EASTUS_API_KEY = 'test-eastus-key';
|
||||
process.env.EASTUS2_API_KEY = 'test-eastus2-key';
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore environment variables
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should use OPENAI provider for Azure serverless endpoints', async () => {
|
||||
// Set up Azure endpoint with serverless config
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: 'grok-3',
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'Azure serverless title prompt',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'grok-3': {
|
||||
group: 'Azure AI Foundry',
|
||||
deploymentName: 'grok-3',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'Azure AI Foundry': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://test.services.ai.azure.com/models',
|
||||
version: '2024-05-01-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'grok-3': {
|
||||
deploymentName: 'grok-3',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.model = 'grok-3';
|
||||
|
||||
const text = 'Test Azure serverless conversation';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify provider was switched to OPENAI for serverless
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: Providers.OPENAI, // Should be OPENAI for serverless
|
||||
titleMethod: 'completion',
|
||||
titlePrompt: 'Azure serverless title prompt',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use AZURE provider for Azure endpoints with instanceName', async () => {
|
||||
// Set up Azure endpoint
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4o',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Azure instance title prompt',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o': {
|
||||
group: 'eastus',
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
eastus: {
|
||||
apiKey: '${EASTUS_API_KEY}',
|
||||
instanceName: 'region-instance',
|
||||
version: '2024-02-15-preview',
|
||||
models: {
|
||||
'gpt-4o': {
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.model = 'gpt-4o';
|
||||
|
||||
const text = 'Test Azure instance conversation';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify provider remains AZURE with instanceName
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: Providers.AZURE,
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Azure instance title prompt',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle Azure titleModel with CURRENT_MODEL constant', async () => {
|
||||
// Set up Azure endpoint
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.model_parameters.model = 'gpt-4o-latest';
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: Constants.CURRENT_MODEL,
|
||||
titleMethod: 'functions',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o-latest': {
|
||||
group: 'region-eastus',
|
||||
deploymentName: 'gpt-4o-mini',
|
||||
version: '2024-02-15-preview',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'region-eastus': {
|
||||
apiKey: '${EASTUS2_API_KEY}',
|
||||
instanceName: 'test-instance',
|
||||
version: '2024-12-01-preview',
|
||||
models: {
|
||||
'gpt-4o-latest': {
|
||||
deploymentName: 'gpt-4o-mini',
|
||||
version: '2024-02-15-preview',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.model = 'gpt-4o-latest';
|
||||
|
||||
const text = 'Test Azure current model';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify it uses the correct model when titleModel is CURRENT_MODEL
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
// When CURRENT_MODEL is used with Azure, the model gets mapped to the deployment name
|
||||
// In this case, 'gpt-4o-latest' is mapped to 'gpt-4o-mini' deployment
|
||||
expect(generateTitleCall.clientOptions.model).toBe('gpt-4o-mini');
|
||||
// Also verify that CURRENT_MODEL constant was not passed as the model
|
||||
expect(generateTitleCall.clientOptions.model).not.toBe(Constants.CURRENT_MODEL);
|
||||
});
|
||||
|
||||
it('should handle Azure with multiple model groups', async () => {
|
||||
// Set up Azure endpoint
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.app.locals[EModelEndpoint.azureOpenAI] = {
|
||||
titleConvo: true,
|
||||
titleModel: 'o1-mini',
|
||||
titleMethod: 'completion',
|
||||
streamRate: 35,
|
||||
modelGroupMap: {
|
||||
'gpt-4o': {
|
||||
group: 'eastus',
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
'o1-mini': {
|
||||
group: 'region-eastus',
|
||||
deploymentName: 'o1-mini',
|
||||
},
|
||||
'codex-mini': {
|
||||
group: 'codex-mini',
|
||||
deploymentName: 'codex-mini',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
eastus: {
|
||||
apiKey: '${EASTUS_API_KEY}',
|
||||
instanceName: 'region-eastus',
|
||||
version: '2024-02-15-preview',
|
||||
models: {
|
||||
'gpt-4o': {
|
||||
deploymentName: 'gpt-4o',
|
||||
},
|
||||
},
|
||||
},
|
||||
'region-eastus': {
|
||||
apiKey: '${EASTUS2_API_KEY}',
|
||||
instanceName: 'region-eastus2',
|
||||
version: '2024-12-01-preview',
|
||||
models: {
|
||||
'o1-mini': {
|
||||
deploymentName: 'o1-mini',
|
||||
},
|
||||
},
|
||||
},
|
||||
'codex-mini': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://example.cognitiveservices.azure.com/openai/',
|
||||
version: '2025-04-01-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'codex-mini': {
|
||||
deploymentName: 'codex-mini',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.model = 'o1-mini';
|
||||
|
||||
const text = 'Test Azure multi-group conversation';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify correct model and provider are used
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: Providers.AZURE,
|
||||
titleMethod: 'completion',
|
||||
}),
|
||||
);
|
||||
|
||||
const generateTitleCall = mockRun.generateTitle.mock.calls[0][0];
|
||||
expect(generateTitleCall.clientOptions.model).toBe('o1-mini');
|
||||
expect(generateTitleCall.clientOptions.maxTokens).toBeUndefined(); // o1 models shouldn't have maxTokens
|
||||
});
|
||||
|
||||
it('should use all config as fallback for Azure endpoints', async () => {
|
||||
// Set up Azure endpoint with minimal config
|
||||
mockAgent.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockAgent.provider = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.endpoint = EModelEndpoint.azureOpenAI;
|
||||
mockReq.body.model = 'gpt-4';
|
||||
|
||||
// Remove Azure-specific config
|
||||
delete mockReq.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
// Set 'all' config as fallback with a serverless Azure config
|
||||
mockReq.app.locals.all = {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Fallback title prompt from all config',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
modelGroupMap: {
|
||||
'gpt-4': {
|
||||
group: 'default-group',
|
||||
deploymentName: 'gpt-4',
|
||||
},
|
||||
},
|
||||
groupMap: {
|
||||
'default-group': {
|
||||
apiKey: '${AZURE_API_KEY}',
|
||||
baseURL: 'https://default.openai.azure.com/',
|
||||
version: '2024-02-15-preview',
|
||||
serverless: true,
|
||||
models: {
|
||||
'gpt-4': {
|
||||
deploymentName: 'gpt-4',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const text = 'Test Azure with all config fallback';
|
||||
const abortController = new AbortController();
|
||||
|
||||
await client.titleConvo({ text, abortController });
|
||||
|
||||
// Verify all config is used
|
||||
expect(mockRun.generateTitle).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: Providers.OPENAI, // Should be OPENAI when no instanceName
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Fallback title prompt from all config',
|
||||
titlePromptTemplate: 'Template: {{content}}',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -12,10 +12,14 @@ const { saveMessage } = require('~/models');
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
@@ -67,7 +71,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
handler();
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
logger.error('[AgentController] Error in cleanup handler', e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -155,7 +159,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
try {
|
||||
res.removeListener('close', closeHandler);
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
logger.error('[AgentController] Error removing close listener', e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -163,10 +167,15 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
isRegenerate,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res,
|
||||
},
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
const { z } = require('zod');
|
||||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
@@ -8,6 +10,7 @@ const {
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getAgent,
|
||||
@@ -30,6 +33,7 @@ const { deleteFileByFilter } = require('~/models/File');
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
[Tools.file_search]: true,
|
||||
[Tools.web_search]: true,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -42,9 +46,13 @@ const systemTools = {
|
||||
*/
|
||||
const createAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
|
||||
const validatedData = agentCreateSchema.parse(req.body);
|
||||
const { tools = [], ...agentData } = removeNullishValues(validatedData);
|
||||
|
||||
const { id: userId } = req.user;
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
@@ -58,19 +66,13 @@ const createAgentHandler = async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
Object.assign(agentData, {
|
||||
author: userId,
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
const agent = await createAgent(agentData);
|
||||
res.status(201).json(agent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
logger.error('[/Agents] Error creating agent', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
@@ -154,14 +156,16 @@ const getAgentHandler = async (req, res) => {
|
||||
const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const { projectIds, removeProjectIds, ...updateData } = req.body;
|
||||
const validatedData = agentUpdateSchema.parse(req.body);
|
||||
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id });
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
@@ -200,6 +204,11 @@ const updateAgentHandler = async (req, res) => {
|
||||
|
||||
return res.json(updatedAgent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents/:id] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
|
||||
logger.error('[/Agents/:id] Error updating Agent', error);
|
||||
|
||||
if (error.statusCode === 409) {
|
||||
@@ -382,6 +391,22 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
return res.status(400).json({ message: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id: agent_id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
}
|
||||
|
||||
const buffer = await fs.readFile(req.file.path);
|
||||
|
||||
const fileStrategy = req.app.locals.fileStrategy;
|
||||
@@ -404,14 +429,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
source: fileStrategy,
|
||||
};
|
||||
|
||||
let _avatar;
|
||||
try {
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
_avatar = agent.avatar;
|
||||
} catch (error) {
|
||||
logger.error('[/:agent_id/avatar] Error fetching agent', error);
|
||||
_avatar = {};
|
||||
}
|
||||
let _avatar = existingAgent.avatar;
|
||||
|
||||
if (_avatar && _avatar.source) {
|
||||
const { deleteFile } = getStrategyFunctions(_avatar.source);
|
||||
@@ -433,7 +451,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
};
|
||||
|
||||
promises.push(
|
||||
await updateAgent({ id: agent_id, author: req.user.id }, data, {
|
||||
await updateAgent({ id: agent_id }, data, {
|
||||
updatingUserId: req.user.id,
|
||||
}),
|
||||
);
|
||||
|
||||
659
api/server/controllers/agents/v1.spec.js
Normal file
659
api/server/controllers/agents/v1.spec.js
Normal file
@@ -0,0 +1,659 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
|
||||
// Only mock the dependencies that are not database-related
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
web_search: true,
|
||||
execute_code: true,
|
||||
file_search: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Project', () => ({
|
||||
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||
resizeAvatar: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3Url: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
filterFile: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Action', () => ({
|
||||
updateAction: jest.fn(),
|
||||
getActions: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/File', () => ({
|
||||
deleteFileByFilter: jest.fn(),
|
||||
}));
|
||||
|
||||
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
*/
|
||||
let Agent;
|
||||
|
||||
describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
|
||||
// Reset all mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Setup mock request and response objects
|
||||
mockReq = {
|
||||
user: {
|
||||
id: new mongoose.Types.ObjectId().toString(),
|
||||
role: 'USER',
|
||||
},
|
||||
body: {},
|
||||
params: {},
|
||||
app: {
|
||||
locals: {
|
||||
fileStrategy: 'local',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('createAgentHandler', () => {
|
||||
test('should create agent with allowed fields only', async () => {
|
||||
const validData = {
|
||||
name: 'Test Agent',
|
||||
description: 'A test agent',
|
||||
instructions: 'Be helpful',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search'],
|
||||
model_parameters: { temperature: 0.7 },
|
||||
tool_resources: {
|
||||
file_search: { file_ids: ['file1', 'file2'] },
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = validData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.name).toBe('Test Agent');
|
||||
expect(createdAgent.description).toBe('A test agent');
|
||||
expect(createdAgent.provider).toBe('openai');
|
||||
expect(createdAgent.model).toBe('gpt-4');
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
|
||||
expect(createdAgent.tools).toContain('web_search');
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb).toBeDefined();
|
||||
expect(agentInDb.name).toBe('Test Agent');
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
});
|
||||
|
||||
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
|
||||
const maliciousData = {
|
||||
// Required fields
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Malicious Agent',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
isCollaborative: true, // Should be stripped on creation
|
||||
versions: [], // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
id: 'custom_agent_id', // Should be overridden
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
mockReq.body = maliciousData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not set
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
|
||||
expect(createdAgent.authorName).toBeUndefined();
|
||||
expect(createdAgent.isCollaborative).toBeFalsy();
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
|
||||
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
|
||||
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
|
||||
|
||||
// Verify timestamps are recent (not the malicious dates)
|
||||
const createdTime = new Date(createdAgent.createdAt).getTime();
|
||||
const now = Date.now();
|
||||
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
expect(agentInDb.authorName).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should validate required fields', async () => {
|
||||
const invalidData = {
|
||||
name: 'Missing Required Fields',
|
||||
// Missing provider and model
|
||||
};
|
||||
|
||||
mockReq.body = invalidData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify nothing was created in database
|
||||
const count = await Agent.countDocuments();
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
test('should handle tool_resources validation', async () => {
|
||||
const dataWithInvalidToolResources = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Tool Resources',
|
||||
tool_resources: {
|
||||
// Valid resources
|
||||
file_search: {
|
||||
file_ids: ['file1', 'file2'],
|
||||
vector_store_ids: ['vs1'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['file3'],
|
||||
},
|
||||
// Invalid resource (should be stripped by schema)
|
||||
invalid_resource: {
|
||||
file_ids: ['file4'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidToolResources;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.tool_resources).toBeDefined();
|
||||
expect(createdAgent.tool_resources.file_search).toBeDefined();
|
||||
expect(createdAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle avatar validation', async () => {
|
||||
const dataWithAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Avatar',
|
||||
avatar: {
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.avatar).toEqual({
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle invalid avatar format', async () => {
|
||||
const dataWithInvalidAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Invalid Avatar',
|
||||
avatar: 'just-a-string', // Invalid format
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateAgentHandler', () => {
|
||||
let existingAgentId;
|
||||
let existingAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create an existing agent for update tests
|
||||
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
author: existingAgentAuthorId,
|
||||
description: 'Original description',
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
description: 'Original description',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
existingAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should update agent with allowed fields only', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Agent',
|
||||
description: 'Updated description',
|
||||
model: 'gpt-4',
|
||||
isCollaborative: true, // This IS allowed in updates
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(400);
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Updated Agent');
|
||||
expect(updatedAgent.description).toBe('Updated description');
|
||||
expect(updatedAgent.model).toBe('gpt-4');
|
||||
expect(updatedAgent.isCollaborative).toBe(true);
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Updated Agent');
|
||||
expect(agentInDb.isCollaborative).toBe(true);
|
||||
});
|
||||
|
||||
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Name',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
id: 'different_agent_id', // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
versions: [], // Should be stripped
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not changed
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
|
||||
expect(updatedAgent.authorName).toBeUndefined();
|
||||
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
|
||||
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
|
||||
expect(agentInDb.id).toBe(existingAgentId);
|
||||
});
|
||||
|
||||
test('should reject update from non-author when not collaborative', async () => {
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Unauthorized Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
|
||||
// Verify agent was not modified in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Original Agent');
|
||||
});
|
||||
|
||||
test('should allow update from non-author when collaborative', async () => {
|
||||
// First make the agent collaborative
|
||||
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
|
||||
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Collaborative Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Collaborative Update');
|
||||
// Author field should be removed for non-author
|
||||
expect(updatedAgent.author).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Collaborative Update');
|
||||
});
|
||||
|
||||
test('should allow admin to update any agent', async () => {
|
||||
const adminUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = adminUserId;
|
||||
mockReq.user.role = 'ADMIN'; // Set as admin
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Admin Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Admin Update');
|
||||
});
|
||||
|
||||
test('should handle projectIds updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
||||
const projectId1 = new mongoose.Types.ObjectId().toString();
|
||||
const projectId2 = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
mockReq.body = {
|
||||
projectIds: [projectId1, projectId2],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent).toBeDefined();
|
||||
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
|
||||
});
|
||||
|
||||
test('should validate tool_resources in updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tool_resources: {
|
||||
ocr: {
|
||||
file_ids: ['ocr1', 'ocr2'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['img1'],
|
||||
},
|
||||
// Invalid tool resource
|
||||
invalid_tool: {
|
||||
file_ids: ['invalid'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tool_resources).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.ocr).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return 404 for non-existent agent', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
|
||||
mockReq.body = {
|
||||
name: 'Update Non-existent',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(404);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
|
||||
});
|
||||
|
||||
test('should handle validation errors properly', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
model_parameters: 'invalid-not-an-object', // Should be an object
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mass Assignment Attack Scenarios', () => {
|
||||
test('should prevent setting system fields during creation', async () => {
|
||||
const systemFields = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'System Fields Test',
|
||||
|
||||
// System fields that should never be settable by users
|
||||
__v: 99,
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
versions: [
|
||||
{
|
||||
name: 'Fake Version',
|
||||
provider: 'fake',
|
||||
model: 'fake-model',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockReq.body = systemFields;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify system fields were not affected
|
||||
expect(createdAgent.__v).not.toBe(99);
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
|
||||
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
|
||||
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.__v).not.toBe(99);
|
||||
});
|
||||
|
||||
test('should prevent privilege escalation through isCollaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Try to make it collaborative as a different user
|
||||
const attackerId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = attackerId;
|
||||
mockReq.params.id = agent.id;
|
||||
mockReq.body = {
|
||||
isCollaborative: true, // Trying to escalate privileges
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
// Should be rejected
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
|
||||
// Verify in database that it's still not collaborative
|
||||
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||
expect(agentInDb.isCollaborative).toBe(false);
|
||||
});
|
||||
|
||||
test('should prevent author hijacking', async () => {
|
||||
const originalAuthorId = new mongoose.Types.ObjectId();
|
||||
const attackerId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Admin creates an agent
|
||||
mockReq.user.id = originalAuthorId.toString();
|
||||
mockReq.user.role = 'ADMIN';
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Admin Agent',
|
||||
author: attackerId.toString(), // Trying to set different author
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Author should be the actual user, not the attempted value
|
||||
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
|
||||
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
|
||||
});
|
||||
|
||||
test('should strip unknown fields to prevent future vulnerabilities', async () => {
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Future Proof Test',
|
||||
|
||||
// Unknown fields that might be added in future
|
||||
superAdminAccess: true,
|
||||
bypassAllChecks: true,
|
||||
internalFlag: 'secret',
|
||||
futureFeature: 'exploit',
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unknown fields were stripped
|
||||
expect(createdAgent.superAdminAccess).toBeUndefined();
|
||||
expect(createdAgent.bypassAllChecks).toBeUndefined();
|
||||
expect(createdAgent.internalFlag).toBeUndefined();
|
||||
expect(createdAgent.futureFeature).toBeUndefined();
|
||||
|
||||
// Also check in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
|
||||
expect(agentInDb.superAdminAccess).toBeUndefined();
|
||||
expect(agentInDb.bypassAllChecks).toBeUndefined();
|
||||
expect(agentInDb.internalFlag).toBeUndefined();
|
||||
expect(agentInDb.futureFeature).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,14 +1,13 @@
|
||||
const { nanoid } = require('nanoid');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const { checkAccess } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { checkAccess, loadWebSearchAuth } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
AuthType,
|
||||
Permissions,
|
||||
ToolCallTypes,
|
||||
PermissionTypes,
|
||||
loadWebSearchAuth,
|
||||
} = require('librechat-data-provider');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
|
||||
@@ -16,7 +16,7 @@ const { connectDb, indexSync } = require('~/db');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const errorController = require('./controllers/ErrorController');
|
||||
const initializeMCP = require('./services/initializeMCP');
|
||||
const initializeMCPs = require('./services/initializeMCPs');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const AppService = require('./services/AppService');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
@@ -55,7 +55,6 @@ const startServer = async () => {
|
||||
|
||||
/* Middleware */
|
||||
app.use(noIndex);
|
||||
app.use(errorController);
|
||||
app.use(express.json({ limit: '3mb' }));
|
||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||
app.use(mongoSanitize());
|
||||
@@ -121,6 +120,9 @@ const startServer = async () => {
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
// Add the error controller one more time after all routes
|
||||
app.use(errorController);
|
||||
|
||||
app.use((req, res) => {
|
||||
res.set({
|
||||
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
||||
@@ -144,7 +146,7 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCP(app);
|
||||
initializeMCPs(app);
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const request = require('supertest');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const mongoose = require('mongoose');
|
||||
@@ -59,6 +58,30 @@ describe('Server Configuration', () => {
|
||||
expect(response.headers['pragma']).toBe('no-cache');
|
||||
expect(response.headers['expires']).toBe('0');
|
||||
});
|
||||
|
||||
it('should return 500 for unknown errors via ErrorController', async () => {
|
||||
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
||||
|
||||
// Mock MongoDB operations to fail
|
||||
const originalFindOne = mongoose.models.User.findOne;
|
||||
const mockError = new Error('MongoDB operation failed');
|
||||
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
|
||||
throw mockError;
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await request(app).post('/api/auth/login').send({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.text).toBe('An unknown error occurred.');
|
||||
} finally {
|
||||
// Restore original function
|
||||
mongoose.models.User.findOne = originalFindOne;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
EndpointURLs,
|
||||
@@ -14,7 +15,6 @@ const openAI = require('~/server/services/Endpoints/openAI');
|
||||
const agents = require('~/server/services/Endpoints/agents');
|
||||
const custom = require('~/server/services/Endpoints/custom');
|
||||
const google = require('~/server/services/Endpoints/google');
|
||||
const { handleError } = require('~/server/utils');
|
||||
|
||||
const buildFunction = {
|
||||
[EModelEndpoint.openAI]: openAI.buildOptions,
|
||||
|
||||
@@ -18,7 +18,6 @@ const message = 'Your account has been temporarily banned due to violations of o
|
||||
* @function
|
||||
* @param {Object} req - Express Request object.
|
||||
* @param {Object} res - Express Response object.
|
||||
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
|
||||
*
|
||||
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
|
||||
*/
|
||||
@@ -135,6 +134,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
return await banResponse(req, res);
|
||||
} catch (error) {
|
||||
logger.error('Error in checkBan middleware:', error);
|
||||
return next(error);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { Time, CacheKeys } = require('librechat-data-provider');
|
||||
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
@@ -37,7 +37,7 @@ const concurrentLimiter = async (req, res, next) => {
|
||||
|
||||
const userId = req.user?.id ?? req.user?._id ?? '';
|
||||
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
|
||||
const type = 'concurrent';
|
||||
const type = ViolationTypes.CONCURRENT;
|
||||
|
||||
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
|
||||
const pendingRequests = +((await cache.get(key)) ?? 0);
|
||||
|
||||
79
api/server/middleware/limiters/forkLimiters.js
Normal file
79
api/server/middleware/limiters/forkLimiters.js
Normal file
@@ -0,0 +1,79 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
|
||||
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
|
||||
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
|
||||
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
|
||||
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
|
||||
|
||||
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
|
||||
const forkIpMax = FORK_IP_MAX;
|
||||
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
|
||||
|
||||
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
|
||||
const forkUserMax = FORK_USER_MAX;
|
||||
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
forkIpWindowMs,
|
||||
forkIpMax,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowMs,
|
||||
forkUserMax,
|
||||
forkUserWindowInMinutes,
|
||||
forkViolationScore: FORK_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createForkHandler = (ip = true) => {
|
||||
const {
|
||||
forkIpMax,
|
||||
forkUserMax,
|
||||
forkViolationScore,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? forkIpMax : forkUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createForkLimiters = () => {
|
||||
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ipLimiterOptions = {
|
||||
windowMs: forkIpWindowMs,
|
||||
max: forkIpMax,
|
||||
handler: createForkHandler(),
|
||||
store: limiterCache('fork_ip_limiter'),
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: forkUserWindowMs,
|
||||
max: forkUserMax,
|
||||
handler: createForkHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
store: limiterCache('fork_user_limiter'),
|
||||
};
|
||||
|
||||
const forkIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const forkUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { forkIpLimiter, forkUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = { createForkLimiters };
|
||||
@@ -1,16 +1,14 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
||||
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
|
||||
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
|
||||
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
|
||||
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
|
||||
|
||||
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
|
||||
const importIpMax = IMPORT_IP_MAX;
|
||||
@@ -27,12 +25,18 @@ const getEnvironmentVariables = () => {
|
||||
importUserWindowMs,
|
||||
importUserMax,
|
||||
importUserWindowInMinutes,
|
||||
importViolationScore: IMPORT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createImportHandler = (ip = true) => {
|
||||
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
const {
|
||||
importIpMax,
|
||||
importUserMax,
|
||||
importViolationScore,
|
||||
importIpWindowInMinutes,
|
||||
importUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
@@ -43,7 +47,7 @@ const createImportHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, importViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -56,6 +60,7 @@ const createImportLimiters = () => {
|
||||
windowMs: importIpWindowMs,
|
||||
max: importIpMax,
|
||||
handler: createImportHandler(),
|
||||
store: limiterCache('import_ip_limiter'),
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: importUserWindowMs,
|
||||
@@ -64,23 +69,9 @@ const createImportLimiters = () => {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('import_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for import rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'import_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'import_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const importIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const importUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { importIpLimiter, importUserLimiter };
|
||||
|
||||
@@ -4,6 +4,7 @@ const createSTTLimiters = require('./sttLimiters');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const importLimiters = require('./importLimiters');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const forkLimiters = require('./forkLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const toolCallLimiter = require('./toolCallLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
@@ -14,6 +15,7 @@ module.exports = {
|
||||
...uploadLimiters,
|
||||
...importLimiters,
|
||||
...messageLimiters,
|
||||
...forkLimiters,
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
toolCallLimiter,
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = LOGIN_WINDOW * 60 * 1000;
|
||||
@@ -12,7 +11,7 @@ const windowInMinutes = windowMs / 60000;
|
||||
const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = 'logins';
|
||||
const type = ViolationTypes.LOGINS;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max,
|
||||
@@ -28,17 +27,9 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('login_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for login rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'login_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const loginLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = loginLimiter;
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const denyRequest = require('~/server/middleware/denyRequest');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
MESSAGE_IP_MAX = 40,
|
||||
MESSAGE_IP_WINDOW = 1,
|
||||
MESSAGE_USER_MAX = 40,
|
||||
MESSAGE_USER_WINDOW = 1,
|
||||
MESSAGE_VIOLATION_SCORE: score,
|
||||
} = process.env;
|
||||
|
||||
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
|
||||
@@ -31,7 +30,7 @@ const userWindowInMinutes = userWindowMs / 60000;
|
||||
*/
|
||||
const createHandler = (ip = true) => {
|
||||
return async (req, res) => {
|
||||
const type = 'message_limit';
|
||||
const type = ViolationTypes.MESSAGE_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? ipMax : userMax,
|
||||
@@ -39,7 +38,7 @@ const createHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
};
|
||||
};
|
||||
@@ -51,6 +50,7 @@ const ipLimiterOptions = {
|
||||
windowMs: ipWindowMs,
|
||||
max: ipMax,
|
||||
handler: createHandler(),
|
||||
store: limiterCache('message_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
@@ -60,23 +60,9 @@ const userLimiterOptions = {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('message_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for message rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'message_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'message_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Message request rate limiter by IP
|
||||
*/
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = REGISTER_WINDOW * 60 * 1000;
|
||||
@@ -12,7 +11,7 @@ const windowInMinutes = windowMs / 60000;
|
||||
const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = 'registrations';
|
||||
const type = ViolationTypes.REGISTRATIONS;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max,
|
||||
@@ -28,17 +27,9 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('register_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for register rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'register_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const registerLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = registerLimiter;
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
RESET_PASSWORD_WINDOW = 2,
|
||||
@@ -33,17 +31,9 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('reset_password_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for reset password rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'reset_password_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const resetPasswordLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = resetPasswordLimiter;
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
@@ -27,11 +25,12 @@ const getEnvironmentVariables = () => {
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
sttViolationScore: STT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -43,7 +42,7 @@ const createSTTHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, sttViolationScore);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -55,6 +54,7 @@ const createSTTLimiters = () => {
|
||||
windowMs: sttIpWindowMs,
|
||||
max: sttIpMax,
|
||||
handler: createSTTHandler(),
|
||||
store: limiterCache('stt_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
@@ -64,23 +64,9 @@ const createSTTLimiters = () => {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('stt_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for STT rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'stt_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'stt_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const sttIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const sttUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||
@@ -15,7 +14,7 @@ const handler = async (req, res) => {
|
||||
windowInMinutes: 1,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||
};
|
||||
|
||||
@@ -26,17 +25,9 @@ const limiterOptions = {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
store: limiterCache('tool_call_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for tool call rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'tool_call_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const toolCallLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = toolCallLimiter;
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
@@ -27,11 +25,12 @@ const getEnvironmentVariables = () => {
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
ttsViolationScore: TTS_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -43,7 +42,7 @@ const createTTSHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, ttsViolationScore);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -55,32 +54,19 @@ const createTTSLimiters = () => {
|
||||
windowMs: ttsIpWindowMs,
|
||||
max: ttsIpMax,
|
||||
handler: createTTSHandler(),
|
||||
store: limiterCache('tts_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
windowMs: ttsUserWindowMs,
|
||||
max: ttsUserMax,
|
||||
handler: createTTSHandler(false),
|
||||
store: limiterCache('tts_user_limiter'),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for TTS rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tts_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tts_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const ttsIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const ttsUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
|
||||
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
|
||||
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
|
||||
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
|
||||
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
|
||||
|
||||
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
|
||||
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
|
||||
@@ -27,6 +25,7 @@ const getEnvironmentVariables = () => {
|
||||
fileUploadUserWindowMs,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -36,6 +35,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
fileUploadIpWindowInMinutes,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -47,7 +47,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
|
||||
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -60,6 +60,7 @@ const createFileLimiters = () => {
|
||||
windowMs: fileUploadIpWindowMs,
|
||||
max: fileUploadIpMax,
|
||||
handler: createFileUploadHandler(),
|
||||
store: limiterCache('file_upload_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
@@ -69,23 +70,9 @@ const createFileLimiters = () => {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('file_upload_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for file upload rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'file_upload_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'file_upload_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const fileUploadUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
VERIFY_EMAIL_WINDOW = 2,
|
||||
@@ -33,17 +31,9 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('verify_email_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for verify email rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'verify_email_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const verifyEmailLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = verifyEmailLimiter;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const uap = require('ua-parser-js');
|
||||
const { handleError } = require('../utils');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { logViolation } = require('../../cache');
|
||||
|
||||
/**
|
||||
@@ -21,7 +22,7 @@ async function uaParser(req, res, next) {
|
||||
const ua = uap(req.headers['user-agent']);
|
||||
|
||||
if (!ua.browser.name) {
|
||||
const type = 'non_browser';
|
||||
const type = ViolationTypes.NON_BROWSER;
|
||||
await logViolation(req, res, type, { type }, score);
|
||||
return handleError(res, { message: 'Illegal request' });
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { handleError } = require('../utils');
|
||||
const { handleError } = require('@librechat/api');
|
||||
|
||||
function validateEndpoint(req, res, next) {
|
||||
const { endpoint: _endpoint, endpointType } = req.body;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const { handleError } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
/**
|
||||
* Validates the model of the request.
|
||||
|
||||
162
api/server/routes/__tests__/static.spec.js
Normal file
162
api/server/routes/__tests__/static.spec.js
Normal file
@@ -0,0 +1,162 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const zlib = require('zlib');
|
||||
|
||||
// Create test setup
|
||||
const mockTestDir = path.join(__dirname, 'test-static-route');
|
||||
|
||||
// Mock the paths module to point to our test directory
|
||||
jest.mock('~/config/paths', () => ({
|
||||
imageOutput: mockTestDir,
|
||||
}));
|
||||
|
||||
describe('Static Route Integration', () => {
|
||||
let app;
|
||||
let staticRoute;
|
||||
let testDir;
|
||||
let testImagePath;
|
||||
|
||||
beforeAll(() => {
|
||||
// Create a test directory and files
|
||||
testDir = mockTestDir;
|
||||
testImagePath = path.join(testDir, 'test-image.jpg');
|
||||
|
||||
if (!fs.existsSync(testDir)) {
|
||||
fs.mkdirSync(testDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Create a test image file
|
||||
fs.writeFileSync(testImagePath, 'fake-image-data');
|
||||
|
||||
// Create a gzipped version of the test image (for gzip scanning tests)
|
||||
fs.writeFileSync(testImagePath + '.gz', zlib.gzipSync('fake-image-data'));
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Clean up test files
|
||||
if (fs.existsSync(testDir)) {
|
||||
fs.rmSync(testDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
// Helper function to set up static route with specific config
|
||||
const setupStaticRoute = (skipGzipScan = false) => {
|
||||
if (skipGzipScan) {
|
||||
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
|
||||
} else {
|
||||
process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN = 'true';
|
||||
}
|
||||
|
||||
staticRoute = require('../static');
|
||||
app.use('/images', staticRoute);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear the module cache to get fresh imports
|
||||
jest.resetModules();
|
||||
|
||||
app = express();
|
||||
|
||||
// Clear environment variables
|
||||
delete process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN;
|
||||
delete process.env.NODE_ENV;
|
||||
});
|
||||
|
||||
describe('route functionality', () => {
|
||||
it('should serve static image files', async () => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
|
||||
it('should return 404 for non-existent files', async () => {
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/nonexistent.jpg');
|
||||
expect(response.status).toBe(404);
|
||||
});
|
||||
});
|
||||
|
||||
describe('cache behavior', () => {
|
||||
it('should set cache headers for images in production', async () => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
expect(response.headers['cache-control']).toBe('public, max-age=172800, s-maxage=86400');
|
||||
});
|
||||
|
||||
it('should not set cache headers in development', async () => {
|
||||
process.env.NODE_ENV = 'development';
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
// Our middleware should not set the production cache-control header in development
|
||||
expect(response.headers['cache-control']).not.toBe('public, max-age=172800, s-maxage=86400');
|
||||
});
|
||||
});
|
||||
|
||||
describe('gzip compression behavior', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should serve gzipped files when gzip scanning is enabled', async () => {
|
||||
setupStaticRoute(false); // Enable gzip scanning
|
||||
|
||||
const response = await request(app)
|
||||
.get('/images/test-image.jpg')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBe('gzip');
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
|
||||
it('should not serve gzipped files when gzip scanning is disabled', async () => {
|
||||
setupStaticRoute(true); // Disable gzip scanning
|
||||
|
||||
const response = await request(app)
|
||||
.get('/images/test-image.jpg')
|
||||
.set('Accept-Encoding', 'gzip')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-encoding']).toBeUndefined();
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
});
|
||||
|
||||
describe('path configuration', () => {
|
||||
it('should use the configured imageOutput path', async () => {
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/test-image.jpg').expect(200);
|
||||
|
||||
expect(response.body.toString()).toBe('fake-image-data');
|
||||
});
|
||||
|
||||
it('should serve from subdirectories', async () => {
|
||||
// Create a subdirectory with a file
|
||||
const subDir = path.join(testDir, 'thumbs');
|
||||
fs.mkdirSync(subDir, { recursive: true });
|
||||
const thumbPath = path.join(subDir, 'thumb.jpg');
|
||||
fs.writeFileSync(thumbPath, 'thumbnail-data');
|
||||
|
||||
setupStaticRoute();
|
||||
|
||||
const response = await request(app).get('/images/thumbs/thumb.jpg').expect(200);
|
||||
|
||||
expect(response.body.toString()).toBe('thumbnail-data');
|
||||
|
||||
// Clean up
|
||||
fs.rmSync(subDir, { recursive: true, force: true });
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,6 +3,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -102,10 +103,16 @@ router.get('/', async function (req, res) {
|
||||
payload.mcpServers = {};
|
||||
const config = await getCustomConfig();
|
||||
if (config?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
|
||||
for (const serverName in config.mcpServers) {
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = {
|
||||
customUserVars: serverConfig?.customUserVars || {},
|
||||
chatMenu: serverConfig?.chatMenu,
|
||||
isOAuth: oauthServers.has(serverName),
|
||||
startup: serverConfig?.startup,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||
const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork');
|
||||
const { createImportLimiters, createForkLimiters } = require('~/server/middleware');
|
||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { importConversations } = require('~/server/utils/import');
|
||||
const { createImportLimiters } = require('~/server/middleware');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const assistantClients = {
|
||||
[EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'),
|
||||
@@ -43,6 +44,7 @@ router.get('/', async (req, res) => {
|
||||
});
|
||||
res.status(200).json(result);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching conversations', error);
|
||||
res.status(500).json({ error: 'Error fetching conversations' });
|
||||
}
|
||||
});
|
||||
@@ -156,6 +158,7 @@ router.post('/update', async (req, res) => {
|
||||
});
|
||||
|
||||
const { importIpLimiter, importUserLimiter } = createImportLimiters();
|
||||
const { forkIpLimiter, forkUserLimiter } = createForkLimiters();
|
||||
const upload = multer({ storage: storage, fileFilter: importFileFilter });
|
||||
|
||||
/**
|
||||
@@ -189,7 +192,7 @@ router.post(
|
||||
* @param {express.Response<TForkConvoResponse>} res - Express response object.
|
||||
* @returns {Promise<void>} - The response after forking the conversation.
|
||||
*/
|
||||
router.post('/fork', async (req, res) => {
|
||||
router.post('/fork', forkIpLimiter, forkUserLimiter, async (req, res) => {
|
||||
try {
|
||||
/** @type {TForkConvoRequest} */
|
||||
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;
|
||||
|
||||
282
api/server/routes/files/files.agents.test.js
Normal file
282
api/server/routes/files/files.agents.test.js
Normal file
@@ -0,0 +1,282 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
processFileUpload: jest.fn(),
|
||||
processAgentFileUpload: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/assistants/helpers', () => ({
|
||||
getOpenAIClient: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3FileUrls: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const { createFile } = require('~/models/File');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
|
||||
// Import the router after mocks
|
||||
const router = require('./files');
|
||||
|
||||
describe('File Routes - Agent Files Endpoint', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let authorId;
|
||||
let otherUserId;
|
||||
let agentId;
|
||||
let fileId1;
|
||||
let fileId2;
|
||||
let fileId3;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
// Initialize models
|
||||
require('~/db/models');
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: otherUserId || 'default-user' };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
|
||||
app.use('/files', router);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Clear database
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
otherUserId = new mongoose.Types.ObjectId().toString();
|
||||
agentId = uuidv4();
|
||||
fileId1 = uuidv4();
|
||||
fileId2 = uuidv4();
|
||||
fileId3 = uuidv4();
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId1,
|
||||
filename: 'agent-file1.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId1}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId2,
|
||||
filename: 'agent-file2.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId2}`,
|
||||
bytes: 2048,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: fileId3,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${fileId3}`,
|
||||
bytes: 512,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Share the agent globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
});
|
||||
|
||||
describe('GET /files/agent/:agent_id', () => {
|
||||
it('should return files accessible through the agent for non-author', async () => {
|
||||
const response = await request(app).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(2); // Only agent files, not user-owned files
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).not.toContain(fileId3); // User's own file not included
|
||||
});
|
||||
|
||||
it('should return 400 when agent_id is not provided', async () => {
|
||||
const response = await request(app).get('/files/agent/');
|
||||
|
||||
expect(response.status).toBe(404); // Express returns 404 for missing route parameter
|
||||
});
|
||||
|
||||
it('should return empty array for non-existent agent', async () => {
|
||||
const response = await request(app).get('/files/agent/non-existent-agent');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual([]); // Empty array for non-existent agent
|
||||
});
|
||||
|
||||
it('should return empty array when agent is not collaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const nonCollabAgentId = uuidv4();
|
||||
await createAgent({
|
||||
id: nonCollabAgentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Share it globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: nonCollabAgentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
|
||||
const response = await request(app).get(`/files/agent/${nonCollabAgentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual([]); // Empty array when not collaborative
|
||||
});
|
||||
|
||||
it('should return agent files for agent author', async () => {
|
||||
// Create a new app instance with author authentication
|
||||
const authorApp = express();
|
||||
authorApp.use(express.json());
|
||||
authorApp.use((req, res, next) => {
|
||||
req.user = { id: authorId };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
authorApp.use('/files', router);
|
||||
|
||||
const response = await request(authorApp).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(2); // Agent files for author
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).not.toContain(fileId3); // User's own file not included
|
||||
});
|
||||
|
||||
it('should return files uploaded by other users to shared agent for author', async () => {
|
||||
// Create a file uploaded by another user
|
||||
const otherUserFileId = uuidv4();
|
||||
const anotherUserId = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
await createFile({
|
||||
user: anotherUserId,
|
||||
file_id: otherUserFileId,
|
||||
filename: 'other-user-file.txt',
|
||||
filepath: `/uploads/${anotherUserId}/${otherUserFileId}`,
|
||||
bytes: 4096,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Update agent to include the file uploaded by another user
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2, otherUserFileId],
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Create app instance with author authentication
|
||||
const authorApp = express();
|
||||
authorApp.use(express.json());
|
||||
authorApp.use((req, res, next) => {
|
||||
req.user = { id: authorId };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
authorApp.use('/files', router);
|
||||
|
||||
const response = await request(authorApp).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(3); // Including file from another user
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).toContain(otherUserFileId); // File uploaded by another user
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,6 +5,7 @@ const {
|
||||
Time,
|
||||
isUUID,
|
||||
CacheKeys,
|
||||
Constants,
|
||||
FileSources,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
@@ -16,11 +17,12 @@ const {
|
||||
processDeleteRequest,
|
||||
processAgentFileUpload,
|
||||
} = require('~/server/services/Files/process');
|
||||
const { getFiles, batchUpdateFiles, hasAccessToFilesViaAgent } = require('~/models/File');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
|
||||
const { getFiles, batchUpdateFiles } = require('~/models/File');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -50,6 +52,68 @@ router.get('/', async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Get files specific to an agent
|
||||
* @route GET /files/agent/:agent_id
|
||||
* @param {string} agent_id - The agent ID to get files for
|
||||
* @returns {Promise<TFile[]>} Array of files attached to the agent
|
||||
*/
|
||||
router.get('/agent/:agent_id', async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
const userId = req.user.id;
|
||||
|
||||
if (!agent_id) {
|
||||
return res.status(400).json({ error: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
// Get the agent to check ownership and attached files
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
|
||||
if (!agent) {
|
||||
// No agent found, return empty array
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
|
||||
// Check if user has access to the agent
|
||||
if (agent.author.toString() !== userId) {
|
||||
// Non-authors need the agent to be globally shared and collaborative
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString()) ||
|
||||
!agent.isCollaborative
|
||||
) {
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all file IDs from agent's tool resources
|
||||
const agentFileIds = [];
|
||||
if (agent.tool_resources) {
|
||||
for (const [, resource] of Object.entries(agent.tool_resources)) {
|
||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||
agentFileIds.push(...resource.file_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no files attached to agent, return empty array
|
||||
if (agentFileIds.length === 0) {
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
|
||||
// Get only the files attached to this agent
|
||||
const files = await getFiles({ file_id: { $in: agentFileIds } }, null, { text: 0 });
|
||||
|
||||
res.status(200).json(files);
|
||||
} catch (error) {
|
||||
logger.error('[/files/agent/:agent_id] Error fetching agent files:', error);
|
||||
res.status(500).json({ error: 'Failed to fetch agent files' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/config', async (req, res) => {
|
||||
try {
|
||||
res.status(200).json(req.app.locals.fileConfig);
|
||||
@@ -86,11 +150,62 @@ router.delete('/', async (req, res) => {
|
||||
|
||||
const fileIds = files.map((file) => file.file_id);
|
||||
const dbFiles = await getFiles({ file_id: { $in: fileIds } });
|
||||
const unauthorizedFiles = dbFiles.filter((file) => file.user.toString() !== req.user.id);
|
||||
|
||||
const ownedFiles = [];
|
||||
const nonOwnedFiles = [];
|
||||
const fileMap = new Map();
|
||||
|
||||
for (const file of dbFiles) {
|
||||
fileMap.set(file.file_id, file);
|
||||
if (file.user.toString() === req.user.id) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
nonOwnedFiles.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
// If all files are owned by the user, no need for further checks
|
||||
if (nonOwnedFiles.length === 0) {
|
||||
await processDeleteRequest({ req, files: ownedFiles });
|
||||
logger.debug(
|
||||
`[/files] Files deleted successfully: ${ownedFiles
|
||||
.filter((f) => f.file_id)
|
||||
.map((f) => f.file_id)
|
||||
.join(', ')}`,
|
||||
);
|
||||
res.status(200).json({ message: 'Files deleted successfully' });
|
||||
return;
|
||||
}
|
||||
|
||||
// Check access for non-owned files
|
||||
let authorizedFiles = [...ownedFiles];
|
||||
let unauthorizedFiles = [];
|
||||
|
||||
if (req.body.agent_id && nonOwnedFiles.length > 0) {
|
||||
// Batch check access for all non-owned files
|
||||
const nonOwnedFileIds = nonOwnedFiles.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(
|
||||
req.user.id,
|
||||
nonOwnedFileIds,
|
||||
req.body.agent_id,
|
||||
);
|
||||
|
||||
// Separate authorized and unauthorized files
|
||||
for (const file of nonOwnedFiles) {
|
||||
if (accessMap.get(file.file_id)) {
|
||||
authorizedFiles.push(file);
|
||||
} else {
|
||||
unauthorizedFiles.push(file);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No agent context, all non-owned files are unauthorized
|
||||
unauthorizedFiles = nonOwnedFiles;
|
||||
}
|
||||
|
||||
if (unauthorizedFiles.length > 0) {
|
||||
return res.status(403).json({
|
||||
message: 'You can only delete your own files',
|
||||
message: 'You can only delete files you have access to',
|
||||
unauthorizedFiles: unauthorizedFiles.map((f) => f.file_id),
|
||||
});
|
||||
}
|
||||
@@ -131,10 +246,10 @@ router.delete('/', async (req, res) => {
|
||||
.json({ message: 'File associations removed successfully from Azure Assistant' });
|
||||
}
|
||||
|
||||
await processDeleteRequest({ req, files: dbFiles });
|
||||
await processDeleteRequest({ req, files: authorizedFiles });
|
||||
|
||||
logger.debug(
|
||||
`[/files] Files deleted successfully: ${files
|
||||
`[/files] Files deleted successfully: ${authorizedFiles
|
||||
.filter((f) => f.file_id)
|
||||
.map((f) => f.file_id)
|
||||
.join(', ')}`,
|
||||
|
||||
302
api/server/routes/files/files.test.js
Normal file
302
api/server/routes/files/files.test.js
Normal file
@@ -0,0 +1,302 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
processFileUpload: jest.fn(),
|
||||
processAgentFileUpload: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/controllers/assistants/helpers', () => ({
|
||||
getOpenAIClient: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3FileUrls: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const { createFile } = require('~/models/File');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
|
||||
// Import the router after mocks
|
||||
const router = require('./files');
|
||||
|
||||
describe('File Routes - Delete with Agent Access', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let authorId;
|
||||
let otherUserId;
|
||||
let agentId;
|
||||
let fileId;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
// Initialize models
|
||||
require('~/db/models');
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: otherUserId || 'default-user' };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
|
||||
app.use('/files', router);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Clear database
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
otherUserId = new mongoose.Types.ObjectId().toString();
|
||||
fileId = uuidv4();
|
||||
|
||||
// Create a file owned by the author
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: 'test.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with the file attached
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
agentId = agent.id;
|
||||
|
||||
// Share the agent globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
});
|
||||
|
||||
describe('DELETE /files', () => {
|
||||
it('should allow deleting files owned by the user', async () => {
|
||||
// Create a file owned by the current user
|
||||
const userFileId = uuidv4();
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: userFileId,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
files: [
|
||||
{
|
||||
file_id: userFileId,
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.message).toBe('Files deleted successfully');
|
||||
expect(processDeleteRequest).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should prevent deleting files not owned by user without agent context', async () => {
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(fileId);
|
||||
expect(processDeleteRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow deleting files accessible through shared agent', async () => {
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.message).toBe('Files deleted successfully');
|
||||
expect(processDeleteRequest).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should prevent deleting files not attached to the specified agent', async () => {
|
||||
// Create another file not attached to the agent
|
||||
const unattachedFileId = uuidv4();
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: unattachedFileId,
|
||||
filename: 'unattached.txt',
|
||||
filepath: `/uploads/${authorId}/${unattachedFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: unattachedFileId,
|
||||
filepath: `/uploads/${authorId}/${unattachedFileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(unattachedFileId);
|
||||
});
|
||||
|
||||
it('should handle mixed authorized and unauthorized files', async () => {
|
||||
// Create a file owned by the current user
|
||||
const userFileId = uuidv4();
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: userFileId,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an unauthorized file
|
||||
const unauthorizedFileId = uuidv4();
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: unauthorizedFileId,
|
||||
filename: 'unauthorized.txt',
|
||||
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
|
||||
bytes: 1024,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId, // Authorized through agent
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
{
|
||||
file_id: userFileId, // Owned by user
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
},
|
||||
{
|
||||
file_id: unauthorizedFileId, // Not authorized
|
||||
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(unauthorizedFileId);
|
||||
expect(response.body.unauthorizedFiles).not.toContain(fileId);
|
||||
expect(response.body.unauthorizedFiles).not.toContain(userFileId);
|
||||
});
|
||||
|
||||
it('should prevent deleting files when agent is not collaborative', async () => {
|
||||
// Update the agent to be non-collaborative
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { isCollaborative: false });
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(fileId);
|
||||
expect(processDeleteRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -477,7 +477,9 @@ describe('Multer Configuration', () => {
|
||||
done(new Error('Expected mkdirSync to throw an error but no error was thrown'));
|
||||
} catch (error) {
|
||||
// This is the expected behavior - mkdirSync throws synchronously for invalid paths
|
||||
expect(error.code).toBe('EACCES');
|
||||
// On Linux, this typically returns EACCES (permission denied)
|
||||
// On macOS/Darwin, this returns ENOENT (no such file or directory)
|
||||
expect(['EACCES', 'ENOENT']).toContain(error.code);
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
const { Router } = require('express');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getFlowStateManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const {
|
||||
getMCPServers,
|
||||
getMCPServer,
|
||||
createMCPServer,
|
||||
updateMCPServer,
|
||||
deleteMCPServer,
|
||||
} = require('@librechat/api');
|
||||
|
||||
const router = Router();
|
||||
|
||||
@@ -124,9 +121,73 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager);
|
||||
logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route');
|
||||
|
||||
// For system-level OAuth, we need to store the tokens and retry the connection
|
||||
if (flowState.userId === 'system') {
|
||||
logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`);
|
||||
// Try to establish the MCP connection with the new tokens
|
||||
try {
|
||||
const mcpManager = getMCPManager(flowState.userId);
|
||||
logger.debug(`[MCP OAuth] Attempting to reconnect ${serverName} with new OAuth tokens`);
|
||||
|
||||
// For user-level OAuth, try to establish the connection
|
||||
if (flowState.userId !== 'system') {
|
||||
// We need to get the user object - in this case we'll need to reconstruct it
|
||||
const user = { id: flowState.userId };
|
||||
|
||||
// Try to establish connection with the new tokens
|
||||
const userConnection = await mcpManager.getUserConnection({
|
||||
user,
|
||||
serverName,
|
||||
flowManager,
|
||||
tokenMethods: {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
},
|
||||
});
|
||||
|
||||
logger.info(
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
// Fetch and cache tools now that we have a successful connection
|
||||
const userTools = (await getCachedTools({ userId: flowState.userId })) || {};
|
||||
|
||||
// Remove any old tools from this server in the user's cache
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new tools from this server
|
||||
const tools = await userConnection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
userTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Save the updated user tool cache
|
||||
await setCachedTools(userTools, { userId: flowState.userId });
|
||||
|
||||
logger.debug(
|
||||
`[MCP OAuth] Cached ${tools.length} tools for ${serverName} user ${flowState.userId}`,
|
||||
);
|
||||
} else {
|
||||
logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`);
|
||||
}
|
||||
} catch (error) {
|
||||
// Don't fail the OAuth callback if reconnection fails - the tokens are still saved
|
||||
logger.warn(
|
||||
`[MCP OAuth] Failed to reconnect ${serverName} after OAuth, but tokens are saved:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
/** ID of the flow that the tool/connection is waiting for */
|
||||
@@ -210,43 +271,358 @@ router.get('/oauth/status/:flowId', async (req, res) => {
|
||||
});
|
||||
|
||||
/**
|
||||
* Get all MCP servers for the authenticated user
|
||||
* @route GET /api/mcp
|
||||
* @returns {Array} Array of MCP servers
|
||||
* Cancel OAuth flow
|
||||
* This endpoint cancels a pending OAuth flow
|
||||
*/
|
||||
router.get('/', requireJwtAuth, getMCPServers);
|
||||
router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
logger.info(`[MCP OAuth Cancel] Cancelling OAuth flow for ${serverName} by user ${user.id}`);
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
|
||||
// Generate the flow ID for this user/server combination
|
||||
const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName);
|
||||
|
||||
// Check if flow exists
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (!flowState) {
|
||||
logger.debug(`[MCP OAuth Cancel] No active flow found for ${serverName}`);
|
||||
return res.json({
|
||||
success: true,
|
||||
message: 'No active OAuth flow to cancel',
|
||||
});
|
||||
}
|
||||
|
||||
// Cancel the flow by marking it as failed
|
||||
await flowManager.completeFlow(flowId, 'mcp_oauth', null, 'User cancelled OAuth flow');
|
||||
|
||||
logger.info(`[MCP OAuth Cancel] Successfully cancelled OAuth flow for ${serverName}`);
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
message: `OAuth flow for ${serverName} cancelled successfully`,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[MCP OAuth Cancel] Failed to cancel OAuth flow', error);
|
||||
res.status(500).json({ error: 'Failed to cancel OAuth flow' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Get a single MCP server by ID
|
||||
* @route GET /api/mcp/:mcp_id
|
||||
* @param {string} mcp_id - The ID of the MCP server to fetch
|
||||
* @returns {object} MCP server data
|
||||
* Reinitialize MCP server
|
||||
* This endpoint allows reinitializing a specific MCP server
|
||||
*/
|
||||
router.get('/:mcp_id', requireJwtAuth, getMCPServer);
|
||||
router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
||||
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
});
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const mcpManager = getMCPManager();
|
||||
|
||||
await mcpManager.disconnectServer(serverName);
|
||||
logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`);
|
||||
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
mcpManager.mcpConfigs[serverName] = serverConfig;
|
||||
let customUserVars = {};
|
||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
||||
try {
|
||||
const value = await getUserPluginAuthValue(user.id, varName, false);
|
||||
if (value) {
|
||||
customUserVars[varName] = value;
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let userConnection = null;
|
||||
let oauthRequired = false;
|
||||
let oauthUrl = null;
|
||||
|
||||
try {
|
||||
userConnection = await mcpManager.getUserConnection({
|
||||
user,
|
||||
serverName,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
tokenMethods: {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
},
|
||||
returnOnOAuth: true, // Return immediately when OAuth is initiated
|
||||
// Add OAuth handlers to capture the OAuth URL when needed
|
||||
oauthStart: async (authURL) => {
|
||||
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
|
||||
oauthUrl = authURL;
|
||||
oauthRequired = true;
|
||||
},
|
||||
});
|
||||
|
||||
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
|
||||
} catch (err) {
|
||||
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
|
||||
logger.info(
|
||||
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||
);
|
||||
|
||||
// Check if this is an OAuth error - if so, the flow state should be set up now
|
||||
const isOAuthError =
|
||||
err.message?.includes('OAuth') ||
|
||||
err.message?.includes('authentication') ||
|
||||
err.message?.includes('401');
|
||||
|
||||
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
|
||||
|
||||
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
|
||||
logger.info(
|
||||
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
|
||||
);
|
||||
oauthRequired = true;
|
||||
// Don't return error - continue so frontend can handle OAuth
|
||||
} else {
|
||||
logger.error(
|
||||
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
|
||||
err,
|
||||
);
|
||||
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
|
||||
}
|
||||
}
|
||||
|
||||
// Only fetch and cache tools if we successfully connected (no OAuth required)
|
||||
if (userConnection && !oauthRequired) {
|
||||
const userTools = (await getCachedTools({ userId: user.id })) || {};
|
||||
|
||||
// Remove any old tools from this server in the user's cache
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new tools from this server
|
||||
const tools = await userConnection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
userTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Save the updated user tool cache
|
||||
await setCachedTools(userTools, { userId: user.id });
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||
);
|
||||
|
||||
const getResponseMessage = () => {
|
||||
if (oauthRequired) {
|
||||
return `MCP server '${serverName}' ready for OAuth authentication`;
|
||||
}
|
||||
if (userConnection) {
|
||||
return `MCP server '${serverName}' reinitialized successfully`;
|
||||
}
|
||||
return `Failed to reinitialize MCP server '${serverName}'`;
|
||||
};
|
||||
|
||||
res.json({
|
||||
success: userConnection && !oauthRequired,
|
||||
message: getResponseMessage(),
|
||||
serverName,
|
||||
oauthRequired,
|
||||
oauthUrl,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[MCP Reinitialize] Unexpected error', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Create a new MCP server
|
||||
* @route POST /api/mcp/add
|
||||
* @param {object} req.body - MCP server data
|
||||
* @returns {object} Created MCP server with populated tools
|
||||
* Get connection status for all MCP servers
|
||||
* This endpoint returns all app level and user-scoped connection statuses from MCPManager without disconnecting idle connections
|
||||
*/
|
||||
router.post('/add', requireJwtAuth, createMCPServer);
|
||||
router.get('/connection/status', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
);
|
||||
const connectionStatus = {};
|
||||
|
||||
for (const [serverName] of Object.entries(mcpConfig)) {
|
||||
connectionStatus[serverName] = await getServerConnectionStatus(
|
||||
user.id,
|
||||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
}
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
connectionStatus,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error('[MCP Connection Status] Failed to get connection status', error);
|
||||
res.status(500).json({ error: 'Failed to get connection status' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Update an existing MCP server
|
||||
* @route PUT /api/mcp/:mcp_id
|
||||
* @param {string} mcp_id - The ID of the MCP server to update
|
||||
* @param {object} req.body - Updated MCP server data
|
||||
* @returns {object} Updated MCP server with populated tools
|
||||
* Get connection status for a single MCP server
|
||||
* This endpoint returns the connection status for a specific server for a given user
|
||||
*/
|
||||
router.put('/:mcp_id', requireJwtAuth, updateMCPServer);
|
||||
router.get('/connection/status/:serverName', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const user = req.user;
|
||||
const { serverName } = req.params;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
if (!serverName) {
|
||||
return res.status(400).json({ error: 'Server name is required' });
|
||||
}
|
||||
|
||||
const { mcpConfig, appConnections, userConnections, oauthServers } = await getMCPSetupData(
|
||||
user.id,
|
||||
);
|
||||
|
||||
if (!mcpConfig[serverName]) {
|
||||
return res
|
||||
.status(404)
|
||||
.json({ error: `MCP server '${serverName}' not found in configuration` });
|
||||
}
|
||||
|
||||
const serverStatus = await getServerConnectionStatus(
|
||||
user.id,
|
||||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
serverName,
|
||||
connectionStatus: serverStatus.connectionState,
|
||||
requiresOAuth: serverStatus.requiresOAuth,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.message === 'MCP config not found') {
|
||||
return res.status(404).json({ error: error.message });
|
||||
}
|
||||
logger.error(
|
||||
`[MCP Per-Server Status] Failed to get connection status for ${req.params.serverName}`,
|
||||
error,
|
||||
);
|
||||
res.status(500).json({ error: 'Failed to get connection status' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Delete an MCP server
|
||||
* @route DELETE /api/mcp/:mcp_id
|
||||
* @param {string} mcp_id - The ID of the MCP server to delete
|
||||
* @returns {object} Deletion confirmation
|
||||
* Check which authentication values exist for a specific MCP server
|
||||
* This endpoint returns only boolean flags indicating if values are set, not the actual values
|
||||
*/
|
||||
router.delete('/:mcp_id', requireJwtAuth, deleteMCPServer);
|
||||
router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { serverName } = req.params;
|
||||
const user = req.user;
|
||||
|
||||
if (!user?.id) {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
});
|
||||
}
|
||||
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
|
||||
const authValueFlags = {};
|
||||
|
||||
// Check existence of saved values for each custom user variable (don't fetch actual values)
|
||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
||||
try {
|
||||
const value = await getUserPluginAuthValue(user.id, varName, false, pluginKey);
|
||||
// Only store boolean flag indicating if value exists
|
||||
authValueFlags[varName] = !!(value && value.length > 0);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[MCP Auth Value Flags] Error checking ${varName} for user ${user.id}:`,
|
||||
err,
|
||||
);
|
||||
// Default to false if we can't check
|
||||
authValueFlags[varName] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
serverName,
|
||||
authValueFlags,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[MCP Auth Value Flags] Failed to check auth value flags for ${req.params.serverName}`,
|
||||
error,
|
||||
);
|
||||
res.status(500).json({ error: 'Failed to check auth value flags' });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -172,40 +172,68 @@ router.patch('/preferences', checkMemoryOptOut, async (req, res) => {
|
||||
/**
|
||||
* PATCH /memories/:key
|
||||
* Updates the value of an existing memory entry for the authenticated user.
|
||||
* Body: { value: string }
|
||||
* Body: { key?: string, value: string }
|
||||
* Returns 200 and { updated: true, memory: <updatedDoc> } when successful.
|
||||
*/
|
||||
router.patch('/:key', checkMemoryUpdate, async (req, res) => {
|
||||
const { key } = req.params;
|
||||
const { value } = req.body || {};
|
||||
const { key: urlKey } = req.params;
|
||||
const { key: bodyKey, value } = req.body || {};
|
||||
|
||||
if (typeof value !== 'string' || value.trim() === '') {
|
||||
return res.status(400).json({ error: 'Value is required and must be a non-empty string.' });
|
||||
}
|
||||
|
||||
// Use the key from the body if provided, otherwise use the key from the URL
|
||||
const newKey = bodyKey || urlKey;
|
||||
|
||||
try {
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
const memories = await getAllUserMemories(req.user.id);
|
||||
const existingMemory = memories.find((m) => m.key === key);
|
||||
const existingMemory = memories.find((m) => m.key === urlKey);
|
||||
|
||||
if (!existingMemory) {
|
||||
return res.status(404).json({ error: 'Memory not found.' });
|
||||
}
|
||||
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
// If the key is changing, we need to handle it specially
|
||||
if (newKey !== urlKey) {
|
||||
const keyExists = memories.find((m) => m.key === newKey);
|
||||
if (keyExists) {
|
||||
return res.status(409).json({ error: 'Memory with this key already exists.' });
|
||||
}
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
const createResult = await createMemory({
|
||||
userId: req.user.id,
|
||||
key: newKey,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!createResult.ok) {
|
||||
return res.status(500).json({ error: 'Failed to create new memory.' });
|
||||
}
|
||||
|
||||
const deleteResult = await deleteMemory({ userId: req.user.id, key: urlKey });
|
||||
if (!deleteResult.ok) {
|
||||
return res.status(500).json({ error: 'Failed to delete old memory.' });
|
||||
}
|
||||
} else {
|
||||
// Key is not changing, just update the value
|
||||
const result = await setMemory({
|
||||
userId: req.user.id,
|
||||
key: newKey,
|
||||
value,
|
||||
tokenCount,
|
||||
});
|
||||
|
||||
if (!result.ok) {
|
||||
return res.status(500).json({ error: 'Failed to update memory.' });
|
||||
}
|
||||
}
|
||||
|
||||
const updatedMemories = await getAllUserMemories(req.user.id);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === key);
|
||||
const updatedMemory = updatedMemories.find((m) => m.key === newKey);
|
||||
|
||||
res.json({ updated: true, memory: updatedMemory });
|
||||
} catch (error) {
|
||||
|
||||
@@ -235,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) =
|
||||
return res.status(400).json({ error: 'Content part not found' });
|
||||
}
|
||||
|
||||
if (updatedContent[index].type !== ContentTypes.TEXT) {
|
||||
const currentPartType = updatedContent[index].type;
|
||||
if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) {
|
||||
return res.status(400).json({ error: 'Cannot update non-text content' });
|
||||
}
|
||||
|
||||
const oldText = updatedContent[index].text;
|
||||
updatedContent[index] = { type: ContentTypes.TEXT, text };
|
||||
const oldText = updatedContent[index][currentPartType];
|
||||
updatedContent[index] = { type: currentPartType, [currentPartType]: text };
|
||||
|
||||
let tokenCount = message.tokenCount;
|
||||
if (tokenCount !== undefined) {
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
const express = require('express');
|
||||
const staticCache = require('../utils/staticCache');
|
||||
const paths = require('~/config/paths');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
const skipGzipScan = !isEnabled(process.env.ENABLE_IMAGE_OUTPUT_GZIP_SCAN);
|
||||
|
||||
const router = express.Router();
|
||||
router.use(staticCache(paths.imageOutput));
|
||||
router.use(staticCache(paths.imageOutput, { skipGzipScan }));
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
const { agentsConfigSetup, loadWebSearchConfig } = require('@librechat/api');
|
||||
const {
|
||||
FileSources,
|
||||
loadOCRConfig,
|
||||
EModelEndpoint,
|
||||
loadMemoryConfig,
|
||||
getConfigDefaults,
|
||||
loadWebSearchConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { agentsConfigSetup } = require('@librechat/api');
|
||||
const {
|
||||
checkHealth,
|
||||
checkConfig,
|
||||
@@ -158,6 +157,10 @@ const AppService = async (app) => {
|
||||
}
|
||||
});
|
||||
|
||||
if (endpoints?.all) {
|
||||
endpointLocals.all = endpoints.all;
|
||||
}
|
||||
|
||||
app.locals = {
|
||||
...defaultLocals,
|
||||
fileConfig: config?.fileConfig,
|
||||
|
||||
@@ -152,12 +152,14 @@ describe('AppService', () => {
|
||||
filteredTools: undefined,
|
||||
includedTools: undefined,
|
||||
webSearch: {
|
||||
safeSearch: 1,
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
cohereApiKey: '${COHERE_API_KEY}',
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngApiKey: '${SEARXNG_API_KEY}',
|
||||
firecrawlApiKey: '${FIRECRAWL_API_KEY}',
|
||||
firecrawlApiUrl: '${FIRECRAWL_API_URL}',
|
||||
jinaApiKey: '${JINA_API_KEY}',
|
||||
safeSearch: 1,
|
||||
serperApiKey: '${SERPER_API_KEY}',
|
||||
searxngInstanceUrl: '${SEARXNG_INSTANCE_URL}',
|
||||
},
|
||||
memory: undefined,
|
||||
agents: {
|
||||
@@ -541,6 +543,206 @@ describe('AppService', () => {
|
||||
expect(process.env.IMPORT_USER_MAX).toEqual('initialUserMax');
|
||||
expect(process.env.IMPORT_USER_WINDOW).toEqual('initialUserWindow');
|
||||
});
|
||||
|
||||
it('should correctly configure endpoint with titlePrompt, titleMethod, and titlePromptTemplate', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Custom title prompt for conversation',
|
||||
titlePromptTemplate: 'Summarize this conversation: {{conversation}}',
|
||||
},
|
||||
[EModelEndpoint.assistants]: {
|
||||
titleMethod: 'functions',
|
||||
titlePrompt: 'Generate a title for this assistant conversation',
|
||||
titlePromptTemplate: 'Assistant conversation template: {{messages}}',
|
||||
},
|
||||
[EModelEndpoint.azureOpenAI]: {
|
||||
groups: azureGroups,
|
||||
titleConvo: true,
|
||||
titleMethod: 'completion',
|
||||
titleModel: 'gpt-4',
|
||||
titlePrompt: 'Azure title prompt',
|
||||
titlePromptTemplate: 'Azure conversation: {{context}}',
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
// Check OpenAI endpoint configuration
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
|
||||
expect(app.locals[EModelEndpoint.openAI]).toEqual(
|
||||
expect.objectContaining({
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Custom title prompt for conversation',
|
||||
titlePromptTemplate: 'Summarize this conversation: {{conversation}}',
|
||||
}),
|
||||
);
|
||||
|
||||
// Check Assistants endpoint configuration
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.assistants);
|
||||
expect(app.locals[EModelEndpoint.assistants]).toMatchObject({
|
||||
titleMethod: 'functions',
|
||||
titlePrompt: 'Generate a title for this assistant conversation',
|
||||
titlePromptTemplate: 'Assistant conversation template: {{messages}}',
|
||||
});
|
||||
|
||||
// Check Azure OpenAI endpoint configuration
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.azureOpenAI);
|
||||
expect(app.locals[EModelEndpoint.azureOpenAI]).toEqual(
|
||||
expect.objectContaining({
|
||||
titleConvo: true,
|
||||
titleMethod: 'completion',
|
||||
titleModel: 'gpt-4',
|
||||
titlePrompt: 'Azure title prompt',
|
||||
titlePromptTemplate: 'Azure conversation: {{context}}',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should configure Agent endpoint with title generation settings', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.agents]: {
|
||||
disableBuilder: false,
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Generate a descriptive title for this agent conversation',
|
||||
titlePromptTemplate: 'Agent conversation summary: {{content}}',
|
||||
recursionLimit: 15,
|
||||
capabilities: [AgentCapabilities.tools, AgentCapabilities.actions],
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toMatchObject({
|
||||
disableBuilder: false,
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Generate a descriptive title for this agent conversation',
|
||||
titlePromptTemplate: 'Agent conversation summary: {{content}}',
|
||||
recursionLimit: 15,
|
||||
capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]),
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle missing title configuration options with defaults', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleConvo: true,
|
||||
// titlePrompt and titlePromptTemplate are not provided
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
|
||||
expect(app.locals[EModelEndpoint.openAI]).toMatchObject({
|
||||
titleConvo: true,
|
||||
});
|
||||
// Check that the optional fields are undefined when not provided
|
||||
expect(app.locals[EModelEndpoint.openAI].titlePrompt).toBeUndefined();
|
||||
expect(app.locals[EModelEndpoint.openAI].titlePromptTemplate).toBeUndefined();
|
||||
expect(app.locals[EModelEndpoint.openAI].titleMethod).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should correctly configure titleEndpoint when specified', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titleEndpoint: EModelEndpoint.anthropic,
|
||||
titlePrompt: 'Generate a concise title',
|
||||
},
|
||||
[EModelEndpoint.agents]: {
|
||||
titleEndpoint: 'custom-provider',
|
||||
titleMethod: 'structured',
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
// Check OpenAI endpoint has titleEndpoint
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
|
||||
expect(app.locals[EModelEndpoint.openAI]).toMatchObject({
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
titleEndpoint: EModelEndpoint.anthropic,
|
||||
titlePrompt: 'Generate a concise title',
|
||||
});
|
||||
|
||||
// Check Agents endpoint has titleEndpoint
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.agents);
|
||||
expect(app.locals[EModelEndpoint.agents]).toMatchObject({
|
||||
titleEndpoint: 'custom-provider',
|
||||
titleMethod: 'structured',
|
||||
});
|
||||
});
|
||||
|
||||
it('should correctly configure all endpoint when specified', async () => {
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
endpoints: {
|
||||
all: {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Default title prompt for all endpoints',
|
||||
titlePromptTemplate: 'Default template: {{conversation}}',
|
||||
titleEndpoint: EModelEndpoint.anthropic,
|
||||
streamRate: 50,
|
||||
},
|
||||
[EModelEndpoint.openAI]: {
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await AppService(app);
|
||||
|
||||
// Check that 'all' endpoint config is loaded
|
||||
expect(app.locals).toHaveProperty('all');
|
||||
expect(app.locals.all).toMatchObject({
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-4o-mini',
|
||||
titleMethod: 'structured',
|
||||
titlePrompt: 'Default title prompt for all endpoints',
|
||||
titlePromptTemplate: 'Default template: {{conversation}}',
|
||||
titleEndpoint: EModelEndpoint.anthropic,
|
||||
streamRate: 50,
|
||||
});
|
||||
|
||||
// Check that OpenAI endpoint has its own config
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.openAI);
|
||||
expect(app.locals[EModelEndpoint.openAI]).toMatchObject({
|
||||
titleConvo: true,
|
||||
titleModel: 'gpt-3.5-turbo',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('AppService updating app.locals and issuing warnings', () => {
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
|
||||
const { normalizeEndpointName } = require('~/server/utils');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const { getCachedTools } = require('./getCachedTools');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
@@ -13,8 +11,8 @@ const getLogStores = require('~/cache/getLogStores');
|
||||
* @returns {Promise<TCustomConfig | null>}
|
||||
* */
|
||||
async function getCustomConfig() {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
return (await cache.get(CacheKeys.CUSTOM_CONFIG)) || (await loadCustomConfig());
|
||||
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
|
||||
return (await cache.get(CacheKeys.LIBRECHAT_YAML_CONFIG)) || (await loadCustomConfig());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -55,46 +53,44 @@ const getCustomEndpointConfig = async (endpoint) => {
|
||||
);
|
||||
};
|
||||
|
||||
async function createGetMCPAuthMap() {
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId
|
||||
* @param {GenericTool[]} [params.tools]
|
||||
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
|
||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
||||
*/
|
||||
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
|
||||
try {
|
||||
if (!tools || tools.length === 0) {
|
||||
return;
|
||||
}
|
||||
return await getUserMCPAuthMap({
|
||||
tools,
|
||||
userId,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
async function hasCustomUserVars() {
|
||||
const customConfig = await getCustomConfig();
|
||||
const mcpServers = customConfig?.mcpServers;
|
||||
const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
|
||||
if (!hasCustomUserVars) {
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {GenericTool[]} [params.tools]
|
||||
* @param {string} params.userId
|
||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
||||
*/
|
||||
return async function ({ tools, userId }) {
|
||||
try {
|
||||
if (!tools || tools.length === 0) {
|
||||
return;
|
||||
}
|
||||
const appTools = await getCachedTools({
|
||||
userId,
|
||||
});
|
||||
return await getUserMCPAuthMap({
|
||||
tools,
|
||||
userId,
|
||||
appTools,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
};
|
||||
return Object.values(mcpServers ?? {}).some((server) => server.customUserVars);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getMCPAuthMap,
|
||||
getCustomConfig,
|
||||
getBalanceConfig,
|
||||
createGetMCPAuthMap,
|
||||
hasCustomUserVars,
|
||||
getCustomEndpointConfig,
|
||||
};
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
|
||||
const {
|
||||
CacheKeys,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
orderEndpointsConfig,
|
||||
defaultAgentCapabilities,
|
||||
} = require('librechat-data-provider');
|
||||
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
|
||||
const loadConfigEndpoints = require('./loadConfigEndpoints');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
@@ -80,8 +86,12 @@ async function getEndpointsConfig(req) {
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
const checkCapability = async (req, capability) => {
|
||||
const isAgents = isAgentsEndpoint(req.body?.original_endpoint || req.body?.endpoint);
|
||||
const endpointsConfig = await getEndpointsConfig(req);
|
||||
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
|
||||
const capabilities =
|
||||
isAgents || endpointsConfig?.[EModelEndpoint.agents]?.capabilities != null
|
||||
? (endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [])
|
||||
: defaultAgentCapabilities;
|
||||
return capabilities.includes(capability);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { loadServiceKey, isUserProvided } = require('@librechat/api');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { config } = require('./EndpointService');
|
||||
|
||||
const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, googleKey } = config;
|
||||
@@ -11,36 +11,28 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go
|
||||
* @param {Express.Request} req - The request object
|
||||
*/
|
||||
async function loadAsyncEndpoints(req) {
|
||||
let i = 0;
|
||||
let serviceKey, googleUserProvides;
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE_PATH ||
|
||||
path.join(__dirname, '../../..', 'data', 'auth.json');
|
||||
|
||||
try {
|
||||
if (process.env.GOOGLE_SERVICE_KEY_FILE_PATH) {
|
||||
const absolutePath = path.isAbsolute(serviceKeyPath)
|
||||
? serviceKeyPath
|
||||
: path.resolve(serviceKeyPath);
|
||||
const fileContent = fs.readFileSync(absolutePath, 'utf8');
|
||||
serviceKey = JSON.parse(fileContent);
|
||||
} else {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
}
|
||||
} catch {
|
||||
if (i === 0) {
|
||||
i++;
|
||||
/** Check if GOOGLE_KEY is provided at all(including 'user_provided') */
|
||||
const isGoogleKeyProvided = googleKey && googleKey.trim() !== '';
|
||||
|
||||
if (isGoogleKeyProvided) {
|
||||
/** If GOOGLE_KEY is provided, check if it's user_provided */
|
||||
googleUserProvides = isUserProvided(googleKey);
|
||||
} else {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE || path.join(__dirname, '../../..', 'data', 'auth.json');
|
||||
|
||||
try {
|
||||
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||
} catch (error) {
|
||||
logger.error('Error loading service key', error);
|
||||
serviceKey = null;
|
||||
}
|
||||
}
|
||||
|
||||
if (isUserProvided(googleKey)) {
|
||||
googleUserProvides = true;
|
||||
if (i <= 1) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
const google = serviceKey || googleKey ? { userProvide: googleUserProvides } : false;
|
||||
const google = serviceKey || isGoogleKeyProvided ? { userProvide: googleUserProvides } : false;
|
||||
|
||||
const useAzure = req.app.locals[EModelEndpoint.azureOpenAI]?.plugins;
|
||||
const gptPlugins =
|
||||
|
||||
@@ -25,7 +25,7 @@ let i = 0;
|
||||
* @function loadCustomConfig
|
||||
* @returns {Promise<TCustomConfig | null>} A promise that resolves to null or the custom config object.
|
||||
* */
|
||||
async function loadCustomConfig() {
|
||||
async function loadCustomConfig(printConfig = true) {
|
||||
// Use CONFIG_PATH if set, otherwise fallback to defaultConfigPath
|
||||
const configPath = process.env.CONFIG_PATH || defaultConfigPath;
|
||||
|
||||
@@ -108,9 +108,11 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
|
||||
|
||||
return null;
|
||||
} else {
|
||||
logger.info('Custom config file loaded:');
|
||||
logger.info(JSON.stringify(customConfig, null, 2));
|
||||
logger.debug('Custom config:', customConfig);
|
||||
if (printConfig) {
|
||||
logger.info('Custom config file loaded:');
|
||||
logger.info(JSON.stringify(customConfig, null, 2));
|
||||
logger.debug('Custom config:', customConfig);
|
||||
}
|
||||
}
|
||||
|
||||
(customConfig.endpoints?.custom ?? [])
|
||||
@@ -118,8 +120,8 @@ https://www.librechat.ai/docs/configuration/stt_tts`);
|
||||
.forEach((endpoint) => parseCustomParams(endpoint.name, endpoint.customParams));
|
||||
|
||||
if (customConfig.cache) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.set(CacheKeys.CUSTOM_CONFIG, customConfig);
|
||||
const cache = getLogStores(CacheKeys.STATIC_CONFIG);
|
||||
await cache.set(CacheKeys.LIBRECHAT_YAML_CONFIG, customConfig);
|
||||
}
|
||||
|
||||
if (result.data.modelSpecs) {
|
||||
|
||||
@@ -8,11 +8,12 @@ const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
EToolResources,
|
||||
isAgentsEndpoint,
|
||||
replaceSpecialVars,
|
||||
providerEndpointMap,
|
||||
} = require('librechat-data-provider');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getFiles, getToolFilesByIds } = require('~/models/File');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
@@ -42,7 +43,11 @@ const initializeAgent = async ({
|
||||
allowedProviders,
|
||||
isInitialAgent = false,
|
||||
}) => {
|
||||
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
|
||||
if (
|
||||
isAgentsEndpoint(endpointOption?.endpoint) &&
|
||||
allowedProviders.size > 0 &&
|
||||
!allowedProviders.has(agent.provider)
|
||||
) {
|
||||
throw new Error(
|
||||
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
|
||||
);
|
||||
@@ -82,10 +87,11 @@ const initializeAgent = async ({
|
||||
attachments: currentFiles,
|
||||
tool_resources: agent.tool_resources,
|
||||
requestFileSet: new Set(requestFiles?.map((file) => file.file_id)),
|
||||
agentId: agent.id,
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools, toolContextMap } =
|
||||
const { tools: structuredTools, toolContextMap } =
|
||||
(await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
@@ -98,7 +104,7 @@ const initializeAgent = async ({
|
||||
|
||||
agent.endpoint = provider;
|
||||
const { getOptions, overrideProvider } = await getProviderConfig(provider);
|
||||
if (overrideProvider) {
|
||||
if (overrideProvider !== agent.provider) {
|
||||
agent.provider = overrideProvider;
|
||||
}
|
||||
|
||||
@@ -125,7 +131,7 @@ const initializeAgent = async ({
|
||||
);
|
||||
const agentMaxContextTokens = optionalChainWithEmptyCheck(
|
||||
maxContextTokens,
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider], options.endpointTokenConfig),
|
||||
4096,
|
||||
);
|
||||
|
||||
@@ -140,6 +146,24 @@ const initializeAgent = async ({
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').GenericTool[]} */
|
||||
let tools = options.tools?.length ? options.tools : structuredTools;
|
||||
if (
|
||||
(agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) &&
|
||||
options.tools?.length &&
|
||||
structuredTools?.length
|
||||
) {
|
||||
throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`);
|
||||
} else if (
|
||||
(agent.provider === Providers.OPENAI ||
|
||||
agent.provider === Providers.AZURE ||
|
||||
agent.provider === Providers.ANTHROPIC) &&
|
||||
options.tools?.length &&
|
||||
structuredTools?.length
|
||||
) {
|
||||
tools = structuredTools.concat(options.tools);
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
agent.model_parameters = { ...options.llmConfig };
|
||||
if (options.configOptions) {
|
||||
@@ -166,7 +190,8 @@ const initializeAgent = async ({
|
||||
attachments,
|
||||
resendFiles,
|
||||
toolContextMap,
|
||||
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
|
||||
useLegacyContent: !!options.useLegacyContent,
|
||||
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -78,7 +78,17 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
const tools = [];
|
||||
|
||||
if (mergedOptions.web_search) {
|
||||
tools.push({
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
tools,
|
||||
/** @type {AnthropicClientOptions} */
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { ErrorTypes, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
getUserKeyValues,
|
||||
@@ -59,7 +59,10 @@ const initializeClient = async ({ req, res, endpointOption, version, initAppClie
|
||||
}
|
||||
|
||||
if (PROXY) {
|
||||
opts.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
const proxyAgent = new ProxyAgent(PROXY);
|
||||
opts.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (OPENAI_ORGANIZATION) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { getUserKey, getUserKeyExpiry, getUserKeyValues } = require('~/server/services/UserService');
|
||||
const initializeClient = require('./initalize');
|
||||
@@ -107,6 +107,7 @@ describe('initializeClient', () => {
|
||||
const res = {};
|
||||
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent);
|
||||
expect(openai.fetchOptions).toBeDefined();
|
||||
expect(openai.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api');
|
||||
const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider');
|
||||
const {
|
||||
@@ -158,7 +158,10 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
|
||||
}
|
||||
|
||||
if (PROXY) {
|
||||
opts.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
const proxyAgent = new ProxyAgent(PROXY);
|
||||
opts.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (OPENAI_ORGANIZATION) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { getUserKey, getUserKeyExpiry, getUserKeyValues } = require('~/server/services/UserService');
|
||||
const initializeClient = require('./initialize');
|
||||
@@ -107,6 +107,7 @@ describe('initializeClient', () => {
|
||||
const res = {};
|
||||
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent);
|
||||
expect(openai.fetchOptions).toBeDefined();
|
||||
expect(openai.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -139,7 +139,11 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getOpenAIConfig(apiKey, clientOptions, endpoint);
|
||||
if (!customOptions.streamRate) {
|
||||
if (options != null) {
|
||||
options.useLegacyContent = true;
|
||||
options.endpointTokenConfig = endpointTokenConfig;
|
||||
}
|
||||
if (!clientOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
options.llmConfig.callbacks = [
|
||||
@@ -156,6 +160,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
}
|
||||
|
||||
return {
|
||||
useLegacyContent: true,
|
||||
llmConfig: modelOptions,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { getGoogleConfig, isEnabled } = require('@librechat/api');
|
||||
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
|
||||
const { getGoogleConfig, isEnabled, loadServiceKey } = require('@librechat/api');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { GoogleClient } = require('~/app');
|
||||
|
||||
@@ -18,21 +17,24 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
|
||||
let serviceKey = {};
|
||||
|
||||
try {
|
||||
if (process.env.GOOGLE_SERVICE_KEY_FILE_PATH) {
|
||||
/** Check if GOOGLE_KEY is provided at all (including 'user_provided') */
|
||||
const isGoogleKeyProvided =
|
||||
(GOOGLE_KEY && GOOGLE_KEY.trim() !== '') || (isUserProvided && userKey != null);
|
||||
|
||||
if (!isGoogleKeyProvided) {
|
||||
/** Only attempt to load service key if GOOGLE_KEY is not provided */
|
||||
try {
|
||||
const serviceKeyPath =
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE_PATH ||
|
||||
path.join(__dirname, '../../../../..', 'data', 'auth.json');
|
||||
const absolutePath = path.isAbsolute(serviceKeyPath)
|
||||
? serviceKeyPath
|
||||
: path.resolve(serviceKeyPath);
|
||||
const fileContent = fs.readFileSync(absolutePath, 'utf8');
|
||||
serviceKey = JSON.parse(fileContent);
|
||||
} else {
|
||||
serviceKey = require('~/data/auth.json');
|
||||
process.env.GOOGLE_SERVICE_KEY_FILE ||
|
||||
path.join(__dirname, '../../../..', 'data', 'auth.json');
|
||||
serviceKey = await loadServiceKey(serviceKeyPath);
|
||||
if (!serviceKey) {
|
||||
serviceKey = {};
|
||||
}
|
||||
} catch (_e) {
|
||||
// Service key loading failed, but that's okay if not required
|
||||
serviceKey = {};
|
||||
}
|
||||
} catch (_e) {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
const credentials = isUserProvided
|
||||
|
||||
@@ -7,6 +7,16 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
|
||||
/** Check if the provider is a known custom provider
|
||||
* @param {string | undefined} [provider] - The provider string
|
||||
* @returns {boolean} - True if the provider is a known custom provider, false otherwise
|
||||
*/
|
||||
function isKnownCustomProvider(provider) {
|
||||
return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes(
|
||||
provider?.toLowerCase() || '',
|
||||
);
|
||||
}
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
@@ -24,13 +34,13 @@ const providerConfigMap = {
|
||||
* @param {string} provider - The provider string
|
||||
* @returns {Promise<{
|
||||
* getOptions: Function,
|
||||
* overrideProvider?: string,
|
||||
* overrideProvider: string,
|
||||
* customEndpointConfig?: TEndpoint
|
||||
* }>}
|
||||
*/
|
||||
async function getProviderConfig(provider) {
|
||||
let getOptions = providerConfigMap[provider];
|
||||
let overrideProvider;
|
||||
let overrideProvider = provider;
|
||||
/** @type {TEndpoint | undefined} */
|
||||
let customEndpointConfig;
|
||||
|
||||
@@ -46,6 +56,13 @@ async function getProviderConfig(provider) {
|
||||
overrideProvider = Providers.OPENAI;
|
||||
}
|
||||
|
||||
if (isKnownCustomProvider(overrideProvider) && !customEndpointConfig) {
|
||||
customEndpointConfig = await getCustomEndpointConfig(provider);
|
||||
if (!customEndpointConfig) {
|
||||
throw new Error(`Provider ${provider} not supported`);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
getOptions,
|
||||
overrideProvider,
|
||||
|
||||
@@ -65,19 +65,20 @@ const initializeClient = async ({
|
||||
const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI;
|
||||
/** @type {false | TAzureConfig} */
|
||||
const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
let serverless = false;
|
||||
if (isAzureOpenAI && azureConfig) {
|
||||
const { modelGroupMap, groupMap } = azureConfig;
|
||||
const {
|
||||
azureOptions,
|
||||
baseURL,
|
||||
headers = {},
|
||||
serverless,
|
||||
serverless: _serverless,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName,
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
serverless = _serverless;
|
||||
|
||||
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = resolveHeaders(
|
||||
@@ -143,6 +144,9 @@ const initializeClient = async ({
|
||||
clientOptions = Object.assign({ modelOptions }, clientOptions);
|
||||
clientOptions.modelOptions.user = req.user.id;
|
||||
const options = getOpenAIConfig(apiKey, clientOptions);
|
||||
if (options != null && serverless === true) {
|
||||
options.useLegacyContent = true;
|
||||
}
|
||||
const streamRate = clientOptions.streamRate;
|
||||
if (!streamRate) {
|
||||
return options;
|
||||
|
||||
@@ -152,6 +152,7 @@ async function getSessionInfo(fileIdentifier, apiKey) {
|
||||
* @param {Object} options
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Agent['tool_resources']} options.tool_resources
|
||||
* @param {string} [options.agentId] - The agent ID for file access control
|
||||
* @param {string} apiKey
|
||||
* @returns {Promise<{
|
||||
* files: Array<{ id: string; session_id: string; name: string }>,
|
||||
@@ -159,11 +160,18 @@ async function getSessionInfo(fileIdentifier, apiKey) {
|
||||
* }>}
|
||||
*/
|
||||
const primeFiles = async (options, apiKey) => {
|
||||
const { tool_resources } = options;
|
||||
const { tool_resources, req, agentId } = options;
|
||||
const file_ids = tool_resources?.[EToolResources.execute_code]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.execute_code]?.files ?? [];
|
||||
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
|
||||
const files = [];
|
||||
const sessions = new Map();
|
||||
|
||||
@@ -2,17 +2,17 @@ const { z } = require('zod');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Time, CacheKeys, StepTypes } = require('librechat-data-provider');
|
||||
const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api');
|
||||
const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents');
|
||||
const { Constants, ContentTypes, isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
Constants,
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
convertJsonSchemaToZod,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
sendEvent,
|
||||
MCPOAuthHandler,
|
||||
normalizeServerName,
|
||||
convertWithResolvedRefs,
|
||||
} = require('@librechat/api');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getCachedTools } = require('./Config');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, loadCustomConfig } = require('./Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
@@ -104,7 +104,7 @@ function createAbortHandler({ userId, serverName, toolName, flowManager }) {
|
||||
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
|
||||
*/
|
||||
async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
const availableTools = await getCachedTools({ userId: req.user?.id, includeGlobal: true });
|
||||
const toolDefinition = availableTools?.[toolKey]?.function;
|
||||
if (!toolDefinition) {
|
||||
logger.error(`Tool ${toolKey} not found in available tools`);
|
||||
@@ -113,7 +113,7 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
/** @type {LCTool} */
|
||||
const { description, parameters } = toolDefinition;
|
||||
const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE;
|
||||
let schema = convertJsonSchemaToZod(parameters, {
|
||||
let schema = convertWithResolvedRefs(parameters, {
|
||||
allowEmptyObject: !isGoogle,
|
||||
transformOneOfAnyOf: true,
|
||||
});
|
||||
@@ -235,9 +235,127 @@ async function createMCPTool({ req, res, toolKey, provider: _provider }) {
|
||||
responseFormat: AgentConstants.CONTENT_AND_ARTIFACT,
|
||||
});
|
||||
toolInstance.mcp = true;
|
||||
toolInstance.mcpRawServerName = serverName;
|
||||
return toolInstance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get MCP setup data including config, connections, and OAuth servers
|
||||
* @param {string} userId - The user ID
|
||||
* @returns {Object} Object containing mcpConfig, appConnections, userConnections, and oauthServers
|
||||
*/
|
||||
async function getMCPSetupData(userId) {
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
const mcpConfig = config?.mcpServers;
|
||||
|
||||
if (!mcpConfig) {
|
||||
throw new Error('MCP config not found');
|
||||
}
|
||||
|
||||
const mcpManager = getMCPManager(userId);
|
||||
const appConnections = mcpManager.getAllConnections() || new Map();
|
||||
const userConnections = mcpManager.getUserConnections(userId) || new Map();
|
||||
const oauthServers = mcpManager.getOAuthServers() || new Set();
|
||||
|
||||
return {
|
||||
mcpConfig,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check OAuth flow status for a user and server
|
||||
* @param {string} userId - The user ID
|
||||
* @param {string} serverName - The server name
|
||||
* @returns {Object} Object containing hasActiveFlow and hasFailedFlow flags
|
||||
*/
|
||||
async function checkOAuthFlowStatus(userId, serverName) {
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const flowId = MCPOAuthHandler.generateFlowId(userId, serverName);
|
||||
|
||||
try {
|
||||
const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
if (!flowState) {
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
}
|
||||
|
||||
const flowAge = Date.now() - flowState.createdAt;
|
||||
const flowTTL = flowState.ttl || 180000; // Default 3 minutes
|
||||
|
||||
if (flowState.status === 'FAILED' || flowAge > flowTTL) {
|
||||
logger.debug(`[MCP Connection Status] Found failed OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
status: flowState.status,
|
||||
flowAge,
|
||||
flowTTL,
|
||||
timedOut: flowAge > flowTTL,
|
||||
});
|
||||
return { hasActiveFlow: false, hasFailedFlow: true };
|
||||
}
|
||||
|
||||
if (flowState.status === 'PENDING') {
|
||||
logger.debug(`[MCP Connection Status] Found active OAuth flow for ${serverName}`, {
|
||||
flowId,
|
||||
flowAge,
|
||||
flowTTL,
|
||||
});
|
||||
return { hasActiveFlow: true, hasFailedFlow: false };
|
||||
}
|
||||
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Connection Status] Error checking OAuth flows for ${serverName}:`, error);
|
||||
return { hasActiveFlow: false, hasFailedFlow: false };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection status for a specific MCP server
|
||||
* @param {string} userId - The user ID
|
||||
* @param {string} serverName - The server name
|
||||
* @param {Map} appConnections - App-level connections
|
||||
* @param {Map} userConnections - User-level connections
|
||||
* @param {Set} oauthServers - Set of OAuth servers
|
||||
* @returns {Object} Object containing requiresOAuth and connectionState
|
||||
*/
|
||||
async function getServerConnectionStatus(
|
||||
userId,
|
||||
serverName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
) {
|
||||
const getConnectionState = () =>
|
||||
appConnections.get(serverName)?.connectionState ??
|
||||
userConnections.get(serverName)?.connectionState ??
|
||||
'disconnected';
|
||||
|
||||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
requiresOAuth: oauthServers.has(serverName),
|
||||
connectionState: finalConnectionState,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createMCPTool,
|
||||
getMCPSetupData,
|
||||
checkOAuthFlowStatus,
|
||||
getServerConnectionStatus,
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user