Compare commits
5 Commits
feat/add-b
...
feature/cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3fb97b7c4a | ||
|
|
0417a38a6e | ||
|
|
b6de2a8557 | ||
|
|
a77ad276f1 | ||
|
|
8703b9c2d8 |
58
.env.example
58
.env.example
@@ -58,7 +58,7 @@ DEBUG_CONSOLE=false
|
||||
# Endpoints #
|
||||
#===================================================#
|
||||
|
||||
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic
|
||||
# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
|
||||
|
||||
PROXY=
|
||||
|
||||
@@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
|
||||
# GOOGLE_AUTH_HEADER=true
|
||||
|
||||
# Gemini API (AI Studio)
|
||||
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
|
||||
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
|
||||
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
|
||||
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
|
||||
|
||||
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
|
||||
|
||||
@@ -349,11 +349,6 @@ REGISTRATION_VIOLATION_SCORE=1
|
||||
CONCURRENT_VIOLATION_SCORE=1
|
||||
MESSAGE_VIOLATION_SCORE=1
|
||||
NON_BROWSER_VIOLATION_SCORE=20
|
||||
TTS_VIOLATION_SCORE=0
|
||||
STT_VIOLATION_SCORE=0
|
||||
FORK_VIOLATION_SCORE=0
|
||||
IMPORT_VIOLATION_SCORE=0
|
||||
FILE_UPLOAD_VIOLATION_SCORE=0
|
||||
|
||||
LOGIN_MAX=7
|
||||
LOGIN_WINDOW=5
|
||||
@@ -442,8 +437,6 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
|
||||
OPENID_USERNAME_CLAIM=
|
||||
# Set to determine which user info property returned from OpenID Provider to store as the User's name
|
||||
OPENID_NAME_CLAIM=
|
||||
# Optional audience parameter for OpenID authorization requests
|
||||
OPENID_AUDIENCE=
|
||||
|
||||
OPENID_BUTTON_LABEL=
|
||||
OPENID_IMAGE_URL=
|
||||
@@ -460,8 +453,8 @@ OPENID_REUSE_TOKENS=
|
||||
OPENID_JWKS_URL_CACHE_ENABLED=
|
||||
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
|
||||
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
|
||||
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
|
||||
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
|
||||
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
|
||||
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
|
||||
# Set to true to use the OpenID Connect end session endpoint for logout
|
||||
OPENID_USE_END_SESSION_ENDPOINT=
|
||||
|
||||
@@ -582,10 +575,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
|
||||
# If you have another service in front of your LibreChat doing compression, disable express based compression here
|
||||
# DISABLE_COMPRESSION=true
|
||||
|
||||
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
|
||||
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
|
||||
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
|
||||
|
||||
#===================================================#
|
||||
# UI #
|
||||
#===================================================#
|
||||
@@ -603,40 +592,11 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
# REDIS Options #
|
||||
#===============#
|
||||
|
||||
# Enable Redis for caching and session storage
|
||||
# REDIS_URI=10.10.10.10:6379
|
||||
# USE_REDIS=true
|
||||
|
||||
# Single Redis instance
|
||||
# REDIS_URI=redis://127.0.0.1:6379
|
||||
|
||||
# Redis cluster (multiple nodes)
|
||||
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
|
||||
|
||||
# Redis with TLS/SSL encryption and CA certificate
|
||||
# REDIS_URI=rediss://127.0.0.1:6380
|
||||
# REDIS_CA=/path/to/ca-cert.pem
|
||||
|
||||
# Redis authentication (if required)
|
||||
# REDIS_USERNAME=your_redis_username
|
||||
# REDIS_PASSWORD=your_redis_password
|
||||
|
||||
# Redis key prefix configuration
|
||||
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
|
||||
# REDIS_KEY_PREFIX_VAR=K_REVISION
|
||||
# Or use static prefix directly
|
||||
# REDIS_KEY_PREFIX=librechat
|
||||
|
||||
# Redis connection limits
|
||||
# REDIS_MAX_LISTENERS=40
|
||||
|
||||
# Redis ping interval in seconds (0 = disabled, >0 = enabled)
|
||||
# When set to a positive integer, Redis clients will ping the server at this interval to keep connections alive
|
||||
# When unset or 0, no pinging is performed (recommended for most use cases)
|
||||
# REDIS_PING_INTERVAL=300
|
||||
|
||||
# Force specific cache namespaces to use in-memory storage even when Redis is enabled
|
||||
# Comma-separated list of CacheKeys (e.g., STATIC_CONFIG,ROLES,MESSAGES)
|
||||
# FORCED_IN_MEMORY_CACHE_NAMESPACES=STATIC_CONFIG,ROLES
|
||||
# USE_REDIS_CLUSTER=true
|
||||
# REDIS_CA=/path/to/ca.crt
|
||||
|
||||
#==================================================#
|
||||
# Others #
|
||||
@@ -697,4 +657,4 @@ OPENWEATHER_API_KEY=
|
||||
# Reranker (Required)
|
||||
# JINA_API_KEY=your_jina_api_key
|
||||
# or
|
||||
# COHERE_API_KEY=your_cohere_api_key
|
||||
# COHERE_API_KEY=your_cohere_api_key
|
||||
2
.github/workflows/backend-review.yml
vendored
2
.github/workflows/backend-review.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- release/*
|
||||
paths:
|
||||
- 'api/**'
|
||||
- 'packages/**'
|
||||
- 'packages/api/**'
|
||||
jobs:
|
||||
tests_Backend:
|
||||
name: Run Backend unit tests
|
||||
|
||||
58
.github/workflows/client.yml
vendored
58
.github/workflows/client.yml
vendored
@@ -1,58 +0,0 @@
|
||||
name: Publish `@librechat/client` to NPM
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'packages/client/package.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
required: false
|
||||
default: 'Manual publish requested'
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20.x'
|
||||
|
||||
- name: Install client dependencies
|
||||
run: cd packages/client && npm ci
|
||||
|
||||
- name: Build client
|
||||
run: cd packages/client && npm run build
|
||||
|
||||
- name: Set up npm authentication
|
||||
run: echo "//registry.npmjs.org/:_authToken=${{ secrets.PUBLISH_NPM_TOKEN }}" > ~/.npmrc
|
||||
|
||||
- name: Check version change
|
||||
id: check
|
||||
working-directory: packages/client
|
||||
run: |
|
||||
PACKAGE_VERSION=$(node -p "require('./package.json').version")
|
||||
PUBLISHED_VERSION=$(npm view @librechat/client version 2>/dev/null || echo "0.0.0")
|
||||
if [ "$PACKAGE_VERSION" = "$PUBLISHED_VERSION" ]; then
|
||||
echo "No version change, skipping publish"
|
||||
echo "skip=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "Version changed, proceeding with publish"
|
||||
echo "skip=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Pack package
|
||||
if: steps.check.outputs.skip != 'true'
|
||||
working-directory: packages/client
|
||||
run: npm pack
|
||||
|
||||
- name: Publish
|
||||
if: steps.check.outputs.skip != 'true'
|
||||
working-directory: packages/client
|
||||
run: npm publish *.tgz --access public
|
||||
2
.github/workflows/data-schemas.yml
vendored
2
.github/workflows/data-schemas.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20.x'
|
||||
node-version: '18.x'
|
||||
|
||||
- name: Install dependencies
|
||||
run: cd packages/data-schemas && npm ci
|
||||
|
||||
2
.github/workflows/frontend-review.yml
vendored
2
.github/workflows/frontend-review.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
- release/*
|
||||
paths:
|
||||
- 'client/**'
|
||||
- 'packages/data-provider/**'
|
||||
- 'packages/**'
|
||||
|
||||
jobs:
|
||||
tests_frontend_ubuntu:
|
||||
|
||||
95
.github/workflows/generate-release-changelog-pr.yml
vendored
Normal file
95
.github/workflows/generate-release-changelog-pr.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
name: Generate Release Changelog PR
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-release-changelog-pr:
|
||||
permissions:
|
||||
contents: write # Needed for pushing commits and creating branches.
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# 1. Checkout the repository (with full history).
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# 2. Generate the release changelog using our custom configuration.
|
||||
- name: Generate Release Changelog
|
||||
id: generate_release
|
||||
uses: mikepenz/release-changelog-builder-action@v5.1.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
configuration: ".github/configuration-release.json"
|
||||
owner: ${{ github.repository_owner }}
|
||||
repo: ${{ github.event.repository.name }}
|
||||
outputFile: CHANGELOG-release.md
|
||||
|
||||
# 3. Update the main CHANGELOG.md:
|
||||
# - If it doesn't exist, create it with a basic header.
|
||||
# - Remove the "Unreleased" section (if present).
|
||||
# - Prepend the new release changelog above previous releases.
|
||||
# - Remove all temporary files before committing.
|
||||
- name: Update CHANGELOG.md
|
||||
run: |
|
||||
# Determine the release tag, e.g. "v1.2.3"
|
||||
TAG=${GITHUB_REF##*/}
|
||||
echo "Using release tag: $TAG"
|
||||
|
||||
# Ensure CHANGELOG.md exists; if not, create a basic header.
|
||||
if [ ! -f CHANGELOG.md ]; then
|
||||
echo "# Changelog" > CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
fi
|
||||
|
||||
echo "Updating CHANGELOG.md…"
|
||||
|
||||
# Remove the "Unreleased" section (from "## [Unreleased]" until the first occurrence of '---') if it exists.
|
||||
if grep -q "^## \[Unreleased\]" CHANGELOG.md; then
|
||||
awk '/^## \[Unreleased\]/{flag=1} flag && /^---/{flag=0; next} !flag' CHANGELOG.md > CHANGELOG.cleaned
|
||||
else
|
||||
cp CHANGELOG.md CHANGELOG.cleaned
|
||||
fi
|
||||
|
||||
# Split the cleaned file into:
|
||||
# - header.md: content before the first release header ("## [v...").
|
||||
# - tail.md: content from the first release header onward.
|
||||
awk '/^## \[v/{exit} {print}' CHANGELOG.cleaned > header.md
|
||||
awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.cleaned > tail.md
|
||||
|
||||
# Combine header, the new release changelog, and the tail.
|
||||
echo "Combining updated changelog parts..."
|
||||
cat header.md CHANGELOG-release.md > CHANGELOG.md.new
|
||||
echo "" >> CHANGELOG.md.new
|
||||
cat tail.md >> CHANGELOG.md.new
|
||||
|
||||
mv CHANGELOG.md.new CHANGELOG.md
|
||||
|
||||
# Remove temporary files.
|
||||
rm -f CHANGELOG.cleaned header.md tail.md CHANGELOG-release.md
|
||||
|
||||
echo "Final CHANGELOG.md content:"
|
||||
cat CHANGELOG.md
|
||||
|
||||
# 4. Create (or update) the Pull Request with the updated CHANGELOG.md.
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
sign-commits: true
|
||||
commit-message: "chore: update CHANGELOG for release ${{ github.ref_name }}"
|
||||
base: main
|
||||
branch: "changelog/${{ github.ref_name }}"
|
||||
reviewers: danny-avila
|
||||
title: "📜 docs: Changelog for release ${{ github.ref_name }}"
|
||||
body: |
|
||||
**Description**:
|
||||
- This PR updates the CHANGELOG.md by removing the "Unreleased" section and adding new release notes for release ${{ github.ref_name }} above previous releases.
|
||||
107
.github/workflows/generate-unreleased-changelog-pr.yml
vendored
Normal file
107
.github/workflows/generate-unreleased-changelog-pr.yml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
name: Generate Unreleased Changelog PR
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * 1" # Runs every Monday at 00:00 UTC
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-unreleased-changelog-pr:
|
||||
permissions:
|
||||
contents: write # Needed for pushing commits and creating branches.
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# 1. Checkout the repository on main.
|
||||
- name: Checkout Repository on Main
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
fetch-depth: 0
|
||||
|
||||
# 4. Get the latest version tag.
|
||||
- name: Get Latest Tag
|
||||
id: get_latest_tag
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1) || echo "none")
|
||||
echo "Latest tag: $LATEST_TAG"
|
||||
echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
# 5. Generate the Unreleased changelog.
|
||||
- name: Generate Unreleased Changelog
|
||||
id: generate_unreleased
|
||||
uses: mikepenz/release-changelog-builder-action@v5.1.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
configuration: ".github/configuration-unreleased.json"
|
||||
owner: ${{ github.repository_owner }}
|
||||
repo: ${{ github.event.repository.name }}
|
||||
outputFile: CHANGELOG-unreleased.md
|
||||
fromTag: ${{ steps.get_latest_tag.outputs.tag }}
|
||||
toTag: main
|
||||
|
||||
# 7. Update CHANGELOG.md with the new Unreleased section.
|
||||
- name: Update CHANGELOG.md
|
||||
id: update_changelog
|
||||
run: |
|
||||
# Create CHANGELOG.md if it doesn't exist.
|
||||
if [ ! -f CHANGELOG.md ]; then
|
||||
echo "# Changelog" > CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md
|
||||
echo "" >> CHANGELOG.md
|
||||
fi
|
||||
|
||||
echo "Updating CHANGELOG.md…"
|
||||
|
||||
# Extract content before the "## [Unreleased]" (or first version header if missing).
|
||||
if grep -q "^## \[Unreleased\]" CHANGELOG.md; then
|
||||
awk '/^## \[Unreleased\]/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md
|
||||
else
|
||||
awk '/^## \[v/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md
|
||||
fi
|
||||
|
||||
# Append the generated Unreleased changelog.
|
||||
echo "" >> CHANGELOG_TMP.md
|
||||
cat CHANGELOG-unreleased.md >> CHANGELOG_TMP.md
|
||||
echo "" >> CHANGELOG_TMP.md
|
||||
|
||||
# Append the remainder of the original changelog (starting from the first version header).
|
||||
awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.md >> CHANGELOG_TMP.md
|
||||
|
||||
# Replace the old file with the updated file.
|
||||
mv CHANGELOG_TMP.md CHANGELOG.md
|
||||
|
||||
# Remove the temporary generated file.
|
||||
rm -f CHANGELOG-unreleased.md
|
||||
|
||||
echo "Final CHANGELOG.md:"
|
||||
cat CHANGELOG.md
|
||||
|
||||
# 8. Check if CHANGELOG.md has any updates.
|
||||
- name: Check for CHANGELOG.md changes
|
||||
id: changelog_changes
|
||||
run: |
|
||||
if git diff --quiet CHANGELOG.md; then
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# 9. Create (or update) the Pull Request only if there are changes.
|
||||
- name: Create Pull Request
|
||||
if: steps.changelog_changes.outputs.has_changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
base: main
|
||||
branch: "changelog/unreleased-update"
|
||||
sign-commits: true
|
||||
commit-message: "action: update Unreleased changelog"
|
||||
title: "📜 docs: Unreleased Changelog"
|
||||
body: |
|
||||
**Description**:
|
||||
- This PR updates the Unreleased section in CHANGELOG.md.
|
||||
- It compares the current main branch with the latest version tag (determined as ${{ steps.get_latest_tag.outputs.tag }}),
|
||||
regenerates the Unreleased changelog, removes any old Unreleased block, and inserts the new content.
|
||||
3
.github/workflows/i18n-unused-keys.yml
vendored
3
.github/workflows/i18n-unused-keys.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
- "client/src/**"
|
||||
- "api/**"
|
||||
- "packages/data-provider/src/**"
|
||||
- "packages/client/**"
|
||||
|
||||
jobs:
|
||||
detect-unused-i18n-keys:
|
||||
@@ -24,7 +23,7 @@ jobs:
|
||||
|
||||
# Define paths
|
||||
I18N_FILE="client/src/locales/en/translation.json"
|
||||
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src" "packages/client")
|
||||
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src")
|
||||
|
||||
# Check if translation file exists
|
||||
if [[ ! -f "$I18N_FILE" ]]; then
|
||||
|
||||
2
.github/workflows/locize-i18n-sync.yml
vendored
2
.github/workflows/locize-i18n-sync.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
|
||||
# 2. Download translation files from locize.
|
||||
- name: Download Translations from locize
|
||||
uses: locize/download@v2
|
||||
uses: locize/download@v1
|
||||
with:
|
||||
project-id: ${{ secrets.LOCIZE_PROJECT_ID }}
|
||||
path: "client/src/locales"
|
||||
|
||||
101
.github/workflows/unused-packages.yml
vendored
101
.github/workflows/unused-packages.yml
vendored
@@ -7,7 +7,6 @@ on:
|
||||
- 'package-lock.json'
|
||||
- 'client/**'
|
||||
- 'api/**'
|
||||
- 'packages/client/**'
|
||||
|
||||
jobs:
|
||||
detect-unused-packages:
|
||||
@@ -29,7 +28,7 @@ jobs:
|
||||
|
||||
- name: Validate JSON files
|
||||
run: |
|
||||
for FILE in package.json client/package.json api/package.json packages/client/package.json; do
|
||||
for FILE in package.json client/package.json api/package.json; do
|
||||
if [[ -f "$FILE" ]]; then
|
||||
jq empty "$FILE" || (echo "::error title=Invalid JSON::$FILE is invalid" && exit 1)
|
||||
fi
|
||||
@@ -64,31 +63,12 @@ jobs:
|
||||
local folder=$1
|
||||
local output_file=$2
|
||||
if [[ -d "$folder" ]]; then
|
||||
# Extract require() statements
|
||||
grep -rEho "require\\(['\"]([a-zA-Z0-9@/._-]+)['\"]\\)" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
grep -rEho "require\\(['\"]([a-zA-Z0-9@/._-]+)['\"]\\)" "$folder" --include=\*.{js,ts,mjs,cjs} | \
|
||||
sed -E "s/require\\(['\"]([a-zA-Z0-9@/._-]+)['\"]\\)/\1/" > "$output_file"
|
||||
|
||||
# Extract ES6 imports - various patterns
|
||||
# import x from 'module'
|
||||
grep -rEho "import .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
grep -rEho "import .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,mjs,cjs} | \
|
||||
sed -E "s/import .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# import 'module' (side-effect imports)
|
||||
grep -rEho "import ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
sed -E "s/import ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# export { x } from 'module' or export * from 'module'
|
||||
grep -rEho "export .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{js,ts,tsx,jsx,mjs,cjs} | \
|
||||
sed -E "s/export .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# import type { x } from 'module' (TypeScript)
|
||||
grep -rEho "import type .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]" "$folder" --include=\*.{ts,tsx} | \
|
||||
sed -E "s/import type .* from ['\"]([a-zA-Z0-9@/._-]+)['\"]/\1/" >> "$output_file"
|
||||
|
||||
# Remove subpath imports but keep the base package
|
||||
# e.g., '@tanstack/react-query/devtools' becomes '@tanstack/react-query'
|
||||
sed -i -E 's|^(@?[a-zA-Z0-9-]+(/[a-zA-Z0-9-]+)?)/.*|\1|' "$output_file"
|
||||
|
||||
sort -u "$output_file" -o "$output_file"
|
||||
else
|
||||
touch "$output_file"
|
||||
@@ -98,80 +78,13 @@ jobs:
|
||||
extract_deps_from_code "." root_used_code.txt
|
||||
extract_deps_from_code "client" client_used_code.txt
|
||||
extract_deps_from_code "api" api_used_code.txt
|
||||
|
||||
# Extract dependencies used by @librechat/client package
|
||||
extract_deps_from_code "packages/client" packages_client_used_code.txt
|
||||
|
||||
- name: Get @librechat/client dependencies
|
||||
id: get-librechat-client-deps
|
||||
run: |
|
||||
if [[ -f "packages/client/package.json" ]]; then
|
||||
# Get all dependencies from @librechat/client (dependencies, devDependencies, and peerDependencies)
|
||||
DEPS=$(jq -r '.dependencies // {} | keys[]' packages/client/package.json 2>/dev/null || echo "")
|
||||
DEV_DEPS=$(jq -r '.devDependencies // {} | keys[]' packages/client/package.json 2>/dev/null || echo "")
|
||||
PEER_DEPS=$(jq -r '.peerDependencies // {} | keys[]' packages/client/package.json 2>/dev/null || echo "")
|
||||
|
||||
# Combine all dependencies
|
||||
echo "$DEPS" > librechat_client_deps.txt
|
||||
echo "$DEV_DEPS" >> librechat_client_deps.txt
|
||||
echo "$PEER_DEPS" >> librechat_client_deps.txt
|
||||
|
||||
# Also include dependencies that are imported in packages/client
|
||||
cat packages_client_used_code.txt >> librechat_client_deps.txt
|
||||
|
||||
# Remove empty lines and sort
|
||||
grep -v '^$' librechat_client_deps.txt | sort -u > temp_deps.txt
|
||||
mv temp_deps.txt librechat_client_deps.txt
|
||||
else
|
||||
touch librechat_client_deps.txt
|
||||
fi
|
||||
|
||||
- name: Extract Workspace Dependencies
|
||||
id: extract-workspace-deps
|
||||
run: |
|
||||
# Function to get dependencies from a workspace package that are used by another package
|
||||
get_workspace_package_deps() {
|
||||
local package_json=$1
|
||||
local output_file=$2
|
||||
|
||||
# Get all workspace dependencies (starting with @librechat/)
|
||||
if [[ -f "$package_json" ]]; then
|
||||
local workspace_deps=$(jq -r '.dependencies // {} | to_entries[] | select(.key | startswith("@librechat/")) | .key' "$package_json" 2>/dev/null || echo "")
|
||||
|
||||
# For each workspace dependency, get its dependencies
|
||||
for dep in $workspace_deps; do
|
||||
# Convert @librechat/api to packages/api
|
||||
local workspace_path=$(echo "$dep" | sed 's/@librechat\//packages\//')
|
||||
local workspace_package_json="${workspace_path}/package.json"
|
||||
|
||||
if [[ -f "$workspace_package_json" ]]; then
|
||||
# Extract all dependencies from the workspace package
|
||||
jq -r '.dependencies // {} | keys[]' "$workspace_package_json" 2>/dev/null >> "$output_file"
|
||||
# Also extract peerDependencies
|
||||
jq -r '.peerDependencies // {} | keys[]' "$workspace_package_json" 2>/dev/null >> "$output_file"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [[ -f "$output_file" ]]; then
|
||||
sort -u "$output_file" -o "$output_file"
|
||||
else
|
||||
touch "$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# Get workspace dependencies for each package
|
||||
get_workspace_package_deps "package.json" root_workspace_deps.txt
|
||||
get_workspace_package_deps "client/package.json" client_workspace_deps.txt
|
||||
get_workspace_package_deps "api/package.json" api_workspace_deps.txt
|
||||
|
||||
- name: Run depcheck for root package.json
|
||||
id: check-root
|
||||
run: |
|
||||
if [[ -f "package.json" ]]; then
|
||||
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
|
||||
# Exclude dependencies used in scripts, code, and workspace packages
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat root_used_deps.txt root_used_code.txt root_workspace_deps.txt | sort) || echo "")
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat root_used_deps.txt root_used_code.txt | sort) || echo "")
|
||||
echo "ROOT_UNUSED<<EOF" >> $GITHUB_ENV
|
||||
echo "$UNUSED" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
@@ -184,8 +97,7 @@ jobs:
|
||||
chmod -R 755 client
|
||||
cd client
|
||||
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
|
||||
# Exclude dependencies used in scripts, code, and workspace packages
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt ../client_workspace_deps.txt | sort) || echo "")
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "")
|
||||
# Filter out false positives
|
||||
UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "")
|
||||
echo "CLIENT_UNUSED<<EOF" >> $GITHUB_ENV
|
||||
@@ -201,8 +113,7 @@ jobs:
|
||||
chmod -R 755 api
|
||||
cd api
|
||||
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
|
||||
# Exclude dependencies used in scripts, code, and workspace packages
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../api_used_deps.txt ../api_used_code.txt ../api_workspace_deps.txt | sort) || echo "")
|
||||
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../api_used_deps.txt ../api_used_code.txt | sort) || echo "")
|
||||
echo "API_UNUSED<<EOF" >> $GITHUB_ENV
|
||||
echo "$UNUSED" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -13,9 +13,6 @@ pids
|
||||
*.seed
|
||||
.git
|
||||
|
||||
# CI/CD data
|
||||
test-image*
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
|
||||
@@ -128,12 +125,3 @@ helm/**/.values.yaml
|
||||
|
||||
# SAML Idp cert
|
||||
*.cert
|
||||
|
||||
# AI Assistants
|
||||
/.claude/
|
||||
/.cursor/
|
||||
/.copilot/
|
||||
/.aider/
|
||||
/.openai/
|
||||
/.tabnine/
|
||||
/.codeium
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.8.0-rc1
|
||||
# v0.7.8
|
||||
|
||||
# Base node image
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Dockerfile.multi
|
||||
# v0.8.0-rc1
|
||||
# v0.7.8
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base-min
|
||||
@@ -16,7 +16,6 @@ COPY package*.json ./
|
||||
COPY packages/data-provider/package*.json ./packages/data-provider/
|
||||
COPY packages/api/package*.json ./packages/api/
|
||||
COPY packages/data-schemas/package*.json ./packages/data-schemas/
|
||||
COPY packages/client/package*.json ./packages/client/
|
||||
COPY client/package*.json ./client/
|
||||
COPY api/package*.json ./api/
|
||||
|
||||
@@ -46,19 +45,11 @@ COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/d
|
||||
COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist
|
||||
RUN npm run build
|
||||
|
||||
# Build `client` package
|
||||
FROM base AS client-package-build
|
||||
WORKDIR /app/packages/client
|
||||
COPY packages/client ./
|
||||
RUN npm run build
|
||||
|
||||
# Client build
|
||||
FROM base AS client-build
|
||||
WORKDIR /app/client
|
||||
COPY client ./
|
||||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
COPY --from=client-package-build /app/packages/client/dist /app/packages/client/dist
|
||||
COPY --from=client-package-build /app/packages/client/src /app/packages/client/src
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
RUN npm run build
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
|
||||
|
||||
- 🤖 **AI Model Selection**:
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
|
||||
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
|
||||
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
|
||||
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
|
||||
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
|
||||
@@ -66,9 +66,10 @@
|
||||
- 🔦 **Agents & Tools Integration**:
|
||||
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
|
||||
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
|
||||
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
|
||||
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
|
||||
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
|
||||
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
|
||||
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
|
||||
|
||||
- 🔍 **Web Search**:
|
||||
- Search the internet and retrieve relevant information to enhance your AI context
|
||||
|
||||
@@ -13,6 +13,7 @@ const {
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const { addSpaceIfNeeded } = require('~/server/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
@@ -108,15 +109,12 @@ class BaseClient {
|
||||
/**
|
||||
* Abstract method to record token usage. Subclasses must implement this method.
|
||||
* If a correction to the token usage is needed, the method should return an object with the corrected token counts.
|
||||
* Should only be used if `recordCollectedUsage` was not used instead.
|
||||
* @param {string} [model]
|
||||
* @param {number} promptTokens
|
||||
* @param {number} completionTokens
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ model, promptTokens, completionTokens }) {
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
||||
model,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
@@ -200,10 +198,6 @@ class BaseClient {
|
||||
this.currentMessages[this.currentMessages.length - 1].messageId = head;
|
||||
}
|
||||
|
||||
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
|
||||
responseMessageId = crypto.randomUUID();
|
||||
}
|
||||
|
||||
this.responseMessageId = responseMessageId;
|
||||
|
||||
return {
|
||||
@@ -578,7 +572,7 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
const { editedContent } = opts;
|
||||
const { generation = '' } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
// depending on subclass implementation of handling messages
|
||||
@@ -593,21 +587,11 @@ class BaseClient {
|
||||
isCreatedByUser: false,
|
||||
model: this.modelOptions?.model ?? this.model,
|
||||
sender: this.sender,
|
||||
text: generation,
|
||||
};
|
||||
this.currentMessages.push(userMessage, latestMessage);
|
||||
} else if (editedContent != null) {
|
||||
// Handle editedContent for content parts
|
||||
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
|
||||
const { index, text, type } = editedContent;
|
||||
if (index >= 0 && index < latestMessage.content.length) {
|
||||
const contentPart = latestMessage.content[index];
|
||||
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
|
||||
contentPart[ContentTypes.THINK] = text;
|
||||
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
|
||||
contentPart[ContentTypes.TEXT] = text;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
latestMessage.text = generation;
|
||||
}
|
||||
this.continued = true;
|
||||
} else {
|
||||
@@ -688,32 +672,16 @@ class BaseClient {
|
||||
};
|
||||
|
||||
if (typeof completion === 'string') {
|
||||
responseMessage.text = completion;
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion;
|
||||
} else if (
|
||||
Array.isArray(completion) &&
|
||||
(this.clientName === EModelEndpoint.agents ||
|
||||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
|
||||
) {
|
||||
responseMessage.text = '';
|
||||
|
||||
if (!opts.editedContent || this.currentMessages.length === 0) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
|
||||
if (!latestMessage?.content) {
|
||||
responseMessage.content = completion;
|
||||
} else {
|
||||
const existingContent = [...latestMessage.content];
|
||||
const { type: editedType } = opts.editedContent;
|
||||
responseMessage.content = this.mergeEditedContent(
|
||||
existingContent,
|
||||
completion,
|
||||
editedType,
|
||||
);
|
||||
}
|
||||
}
|
||||
responseMessage.content = completion;
|
||||
} else if (Array.isArray(completion)) {
|
||||
responseMessage.text = completion.join('');
|
||||
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -744,13 +712,9 @@ class BaseClient {
|
||||
} else {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
await this.recordTokenUsage({
|
||||
usage,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
model: responseMessage.model,
|
||||
});
|
||||
}
|
||||
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
|
||||
}
|
||||
|
||||
if (userMessagePromise) {
|
||||
@@ -828,8 +792,7 @@ class BaseClient {
|
||||
|
||||
userMessage.tokenCount = userMessageTokenCount;
|
||||
/*
|
||||
Note: `AgentController` saves the user message if not saved here
|
||||
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
|
||||
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
|
||||
*/
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
@@ -838,8 +801,7 @@ class BaseClient {
|
||||
}
|
||||
/*
|
||||
Note: we update the user message to be sure it gets the calculated token count;
|
||||
though `AgentController` saves the user message if not saved here
|
||||
(noted by `savedMessageIds`), EditController does not
|
||||
though `AskController` saves the user message, EditController does not
|
||||
*/
|
||||
await userMessagePromise;
|
||||
await this.updateMessageInDatabase({
|
||||
@@ -1131,50 +1093,6 @@ class BaseClient {
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges completion content with existing content when editing TEXT or THINK types
|
||||
* @param {Array} existingContent - The existing content array
|
||||
* @param {Array} newCompletion - The new completion content
|
||||
* @param {string} editedType - The type of content being edited
|
||||
* @returns {Array} The merged content array
|
||||
*/
|
||||
mergeEditedContent(existingContent, newCompletion, editedType) {
|
||||
if (!newCompletion.length) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const lastIndex = existingContent.length - 1;
|
||||
const lastExisting = existingContent[lastIndex];
|
||||
const firstNew = newCompletion[0];
|
||||
|
||||
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
|
||||
return existingContent.concat(newCompletion);
|
||||
}
|
||||
|
||||
const mergedContent = [...existingContent];
|
||||
if (editedType === ContentTypes.TEXT) {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.TEXT]:
|
||||
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
|
||||
};
|
||||
} else {
|
||||
mergedContent[lastIndex] = {
|
||||
...mergedContent[lastIndex],
|
||||
[ContentTypes.THINK]:
|
||||
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
|
||||
(firstNew[ContentTypes.THINK] || ''),
|
||||
};
|
||||
}
|
||||
|
||||
// Add remaining completion items
|
||||
return mergedContent.concat(newCompletion.slice(1));
|
||||
}
|
||||
|
||||
async sendPayload(payload, opts = {}) {
|
||||
if (opts && typeof opts === 'object') {
|
||||
this.setOptions(opts);
|
||||
|
||||
804
api/app/clients/ChatGPTClient.js
Normal file
804
api/app/clients/ChatGPTClient.js
Normal file
@@ -0,0 +1,804 @@
|
||||
const { Keyv } = require('keyv');
|
||||
const crypto = require('crypto');
|
||||
const { CohereClient } = require('cohere-ai');
|
||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||
const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
ImageDetail,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
CohereConstants,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { createContextHandlers } = require('./prompts');
|
||||
const { createCoherePayload } = require('./llm');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const CHATGPT_MODEL = 'gpt-3.5-turbo';
|
||||
const tokenizersCache = {};
|
||||
|
||||
class ChatGPTClient extends BaseClient {
|
||||
constructor(apiKey, options = {}, cacheOptions = {}) {
|
||||
super(apiKey, options, cacheOptions);
|
||||
|
||||
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
|
||||
this.conversationsCache = new Keyv(cacheOptions);
|
||||
this.setOptions(options);
|
||||
}
|
||||
|
||||
setOptions(options) {
|
||||
if (this.options && !this.options.replaceOptions) {
|
||||
// nested options aren't spread properly, so we need to do this manually
|
||||
this.options.modelOptions = {
|
||||
...this.options.modelOptions,
|
||||
...options.modelOptions,
|
||||
};
|
||||
delete options.modelOptions;
|
||||
// now we can merge options
|
||||
this.options = {
|
||||
...this.options,
|
||||
...options,
|
||||
};
|
||||
} else {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
if (this.options.openaiApiKey) {
|
||||
this.apiKey = this.options.openaiApiKey;
|
||||
}
|
||||
|
||||
const modelOptions = this.options.modelOptions || {};
|
||||
this.modelOptions = {
|
||||
...modelOptions,
|
||||
// set some good defaults (check for undefined in some cases because they may be 0)
|
||||
model: modelOptions.model || CHATGPT_MODEL,
|
||||
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
|
||||
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
|
||||
presence_penalty:
|
||||
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
|
||||
stop: modelOptions.stop,
|
||||
};
|
||||
|
||||
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
|
||||
const { isChatGptModel } = this;
|
||||
this.isUnofficialChatGptModel =
|
||||
this.modelOptions.model.startsWith('text-chat') ||
|
||||
this.modelOptions.model.startsWith('text-davinci-002-render');
|
||||
const { isUnofficialChatGptModel } = this;
|
||||
|
||||
// Davinci models have a max context length of 4097 tokens.
|
||||
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
|
||||
// I decided to reserve 1024 tokens for the response.
|
||||
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
||||
// Earlier messages will be dropped until the prompt is within the limit.
|
||||
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
|
||||
this.maxPromptTokens =
|
||||
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
||||
|
||||
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
|
||||
throw new Error(
|
||||
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
|
||||
this.maxPromptTokens + this.maxResponseTokens
|
||||
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
|
||||
);
|
||||
}
|
||||
|
||||
this.userLabel = this.options.userLabel || 'User';
|
||||
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
|
||||
|
||||
if (isChatGptModel) {
|
||||
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
|
||||
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
|
||||
// without tripping the stop sequences, so I'm using "||>" instead.
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
|
||||
} else if (isUnofficialChatGptModel) {
|
||||
this.startToken = '<|im_start|>';
|
||||
this.endToken = '<|im_end|>';
|
||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
});
|
||||
} else {
|
||||
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
|
||||
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
|
||||
// as a single token. So we're using this instead.
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
try {
|
||||
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
|
||||
} catch {
|
||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
|
||||
}
|
||||
}
|
||||
|
||||
if (!this.modelOptions.stop) {
|
||||
const stopTokens = [this.startToken];
|
||||
if (this.endToken && this.endToken !== this.startToken) {
|
||||
stopTokens.push(this.endToken);
|
||||
}
|
||||
stopTokens.push(`\n${this.userLabel}:`);
|
||||
stopTokens.push('<|diff_marker|>');
|
||||
// I chose not to do one for `chatGptLabel` because I've never seen it happen
|
||||
this.modelOptions.stop = stopTokens;
|
||||
}
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
this.completionsUrl = this.options.reverseProxyUrl;
|
||||
} else if (isChatGptModel) {
|
||||
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
|
||||
} else {
|
||||
this.completionsUrl = 'https://api.openai.com/v1/completions';
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
if (tokenizersCache[encoding]) {
|
||||
return tokenizersCache[encoding];
|
||||
}
|
||||
let tokenizer;
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||
}
|
||||
tokenizersCache[encoding] = tokenizer;
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
/** @type {getCompletion} */
|
||||
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
|
||||
let modelOptions = { ...this.modelOptions };
|
||||
if (typeof onProgress === 'function') {
|
||||
modelOptions.stream = true;
|
||||
}
|
||||
if (this.isChatGptModel) {
|
||||
modelOptions.messages = input;
|
||||
} else {
|
||||
modelOptions.prompt = input;
|
||||
}
|
||||
|
||||
if (this.useOpenRouter && modelOptions.prompt) {
|
||||
delete modelOptions.stop;
|
||||
}
|
||||
|
||||
const { debug } = this.options;
|
||||
let baseURL = this.completionsUrl;
|
||||
if (debug) {
|
||||
console.debug();
|
||||
console.debug(baseURL);
|
||||
console.debug(modelOptions);
|
||||
console.debug();
|
||||
}
|
||||
|
||||
const opts = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
};
|
||||
|
||||
if (this.isVisionModel) {
|
||||
modelOptions.max_tokens = 4000;
|
||||
}
|
||||
|
||||
/** @type {TAzureConfig | undefined} */
|
||||
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
|
||||
|
||||
const isAzure = this.azure || this.options.azure;
|
||||
if (
|
||||
(isAzure && this.isVisionModel && azureConfig) ||
|
||||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
|
||||
) {
|
||||
const { modelGroupMap, groupMap } = azureConfig;
|
||||
const {
|
||||
azureOptions,
|
||||
baseURL,
|
||||
headers = {},
|
||||
serverless,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName: modelOptions.model,
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
opts.headers = resolveHeaders(headers);
|
||||
this.langchainProxy = extractBaseURL(baseURL);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
const groupName = modelGroupMap[modelOptions.model].group;
|
||||
this.options.addParams = azureConfig.groupMap[groupName].addParams;
|
||||
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||
// Note: `forcePrompt` not re-assigned as only chat models are vision models
|
||||
|
||||
this.azure = !serverless && azureOptions;
|
||||
this.azureEndpoint =
|
||||
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
|
||||
if (serverless === true) {
|
||||
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
|
||||
? { 'api-version': azureOptions.azureOpenAIApiVersion }
|
||||
: undefined;
|
||||
this.options.headers['api-key'] = this.apiKey;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.options.defaultQuery) {
|
||||
opts.defaultQuery = this.options.defaultQuery;
|
||||
}
|
||||
|
||||
if (this.options.headers) {
|
||||
opts.headers = { ...opts.headers, ...this.options.headers };
|
||||
}
|
||||
|
||||
if (isAzure) {
|
||||
// Azure does not accept `model` in the body, so we need to remove it.
|
||||
delete modelOptions.model;
|
||||
|
||||
baseURL = this.langchainProxy
|
||||
? constructAzureURL({
|
||||
baseURL: this.langchainProxy,
|
||||
azureOptions: this.azure,
|
||||
})
|
||||
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
|
||||
|
||||
if (this.options.forcePrompt) {
|
||||
baseURL += '/completions';
|
||||
} else {
|
||||
baseURL += '/chat/completions';
|
||||
}
|
||||
|
||||
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
|
||||
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
|
||||
} else if (this.apiKey) {
|
||||
opts.headers.Authorization = `Bearer ${this.apiKey}`;
|
||||
}
|
||||
|
||||
if (process.env.OPENAI_ORGANIZATION) {
|
||||
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
if (this.useOpenRouter) {
|
||||
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
|
||||
opts.headers['X-Title'] = 'LibreChat';
|
||||
}
|
||||
|
||||
/* hacky fixes for Mistral AI API:
|
||||
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
|
||||
- If there is only one message and it's a system message, change the role to user
|
||||
*/
|
||||
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
|
||||
const { messages } = modelOptions;
|
||||
|
||||
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
|
||||
|
||||
if (systemMessageIndex > 0) {
|
||||
const [systemMessage] = messages.splice(systemMessageIndex, 1);
|
||||
messages.unshift(systemMessage);
|
||||
}
|
||||
|
||||
modelOptions.messages = messages;
|
||||
|
||||
if (messages.length === 1 && messages[0].role === 'system') {
|
||||
modelOptions.messages[0].role = 'user';
|
||||
}
|
||||
}
|
||||
|
||||
if (this.options.addParams && typeof this.options.addParams === 'object') {
|
||||
modelOptions = {
|
||||
...modelOptions,
|
||||
...this.options.addParams,
|
||||
};
|
||||
logger.debug('[ChatGPTClient] chatCompletion: added params', {
|
||||
addParams: this.options.addParams,
|
||||
modelOptions,
|
||||
});
|
||||
}
|
||||
|
||||
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
|
||||
this.options.dropParams.forEach((param) => {
|
||||
delete modelOptions[param];
|
||||
});
|
||||
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
|
||||
dropParams: this.options.dropParams,
|
||||
modelOptions,
|
||||
});
|
||||
}
|
||||
|
||||
if (baseURL.startsWith(CohereConstants.API_URL)) {
|
||||
const payload = createCoherePayload({ modelOptions });
|
||||
return await this.cohereChatCompletion({ payload, onTokenProgress });
|
||||
}
|
||||
|
||||
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
|
||||
baseURL = baseURL.split('v1')[0] + 'v1/completions';
|
||||
} else if (
|
||||
baseURL.includes('v1') &&
|
||||
!baseURL.includes('/chat/completions') &&
|
||||
this.isChatCompletion
|
||||
) {
|
||||
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
|
||||
}
|
||||
|
||||
const BASE_URL = new URL(baseURL);
|
||||
if (opts.defaultQuery) {
|
||||
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
|
||||
BASE_URL.searchParams.append(key, value);
|
||||
});
|
||||
delete opts.defaultQuery;
|
||||
}
|
||||
|
||||
const completionsURL = BASE_URL.toString();
|
||||
opts.body = JSON.stringify(modelOptions);
|
||||
|
||||
if (modelOptions.stream) {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
let done = false;
|
||||
await fetchEventSource(completionsURL, {
|
||||
...opts,
|
||||
signal: abortController.signal,
|
||||
async onopen(response) {
|
||||
if (response.status === 200) {
|
||||
return;
|
||||
}
|
||||
if (debug) {
|
||||
console.debug(response);
|
||||
}
|
||||
let error;
|
||||
try {
|
||||
const body = await response.text();
|
||||
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
|
||||
error.status = response.status;
|
||||
error.json = JSON.parse(body);
|
||||
} catch {
|
||||
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
|
||||
}
|
||||
throw error;
|
||||
},
|
||||
onclose() {
|
||||
if (debug) {
|
||||
console.debug('Server closed the connection unexpectedly, returning...');
|
||||
}
|
||||
// workaround for private API not sending [DONE] event
|
||||
if (!done) {
|
||||
onProgress('[DONE]');
|
||||
resolve();
|
||||
}
|
||||
},
|
||||
onerror(err) {
|
||||
if (debug) {
|
||||
console.debug(err);
|
||||
}
|
||||
// rethrow to stop the operation
|
||||
throw err;
|
||||
},
|
||||
onmessage(message) {
|
||||
if (debug) {
|
||||
console.debug(message);
|
||||
}
|
||||
if (!message.data || message.event === 'ping') {
|
||||
return;
|
||||
}
|
||||
if (message.data === '[DONE]') {
|
||||
onProgress('[DONE]');
|
||||
resolve();
|
||||
done = true;
|
||||
return;
|
||||
}
|
||||
onProgress(JSON.parse(message.data));
|
||||
},
|
||||
});
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
});
|
||||
}
|
||||
const response = await fetch(completionsURL, {
|
||||
...opts,
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
const body = await response.text();
|
||||
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
|
||||
error.status = response.status;
|
||||
try {
|
||||
error.json = JSON.parse(body);
|
||||
} catch {
|
||||
error.body = body;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/** @type {cohereChatCompletion} */
|
||||
async cohereChatCompletion({ payload, onTokenProgress }) {
|
||||
const cohere = new CohereClient({
|
||||
token: this.apiKey,
|
||||
environment: this.completionsUrl,
|
||||
});
|
||||
|
||||
if (!payload.stream) {
|
||||
const chatResponse = await cohere.chat(payload);
|
||||
return chatResponse.text;
|
||||
}
|
||||
|
||||
const chatStream = await cohere.chatStream(payload);
|
||||
let reply = '';
|
||||
for await (const message of chatStream) {
|
||||
if (!message) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.eventType === 'text-generation' && message.text) {
|
||||
onTokenProgress(message.text);
|
||||
reply += message.text;
|
||||
}
|
||||
/*
|
||||
Cohere API Chinese Unicode character replacement hotfix.
|
||||
Should be un-commented when the following issue is resolved:
|
||||
https://github.com/cohere-ai/cohere-typescript/issues/151
|
||||
|
||||
else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply = message.response.text;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
return reply;
|
||||
}
|
||||
|
||||
async generateTitle(userMessage, botMessage) {
|
||||
const instructionsPayload = {
|
||||
role: 'system',
|
||||
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
|
||||
|
||||
||>Message:
|
||||
${userMessage.message}
|
||||
||>Response:
|
||||
${botMessage.message}
|
||||
|
||||
||>Title:`,
|
||||
};
|
||||
|
||||
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
|
||||
titleGenClientOptions.modelOptions = {
|
||||
model: 'gpt-3.5-turbo',
|
||||
temperature: 0,
|
||||
presence_penalty: 0,
|
||||
frequency_penalty: 0,
|
||||
};
|
||||
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
|
||||
const result = await titleGenClient.getCompletion([instructionsPayload], null);
|
||||
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
|
||||
return result.choices[0].message.content
|
||||
.replace(/[^a-zA-Z0-9' ]/g, '')
|
||||
.replace(/\s+/g, ' ')
|
||||
.trim();
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
|
||||
this.setOptions(opts.clientOptions);
|
||||
}
|
||||
|
||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
|
||||
|
||||
let conversation =
|
||||
typeof opts.conversation === 'object'
|
||||
? opts.conversation
|
||||
: await this.conversationsCache.get(conversationId);
|
||||
|
||||
let isNewConversation = false;
|
||||
if (!conversation) {
|
||||
conversation = {
|
||||
messages: [],
|
||||
createdAt: Date.now(),
|
||||
};
|
||||
isNewConversation = true;
|
||||
}
|
||||
|
||||
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
|
||||
|
||||
const userMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
parentMessageId,
|
||||
role: 'User',
|
||||
message,
|
||||
};
|
||||
conversation.messages.push(userMessage);
|
||||
|
||||
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
|
||||
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
|
||||
const { prompt: payload, context } = await this.buildPrompt(
|
||||
conversation.messages,
|
||||
userMessage.id,
|
||||
{
|
||||
isChatGptModel: this.isChatGptModel,
|
||||
promptPrefix: opts.promptPrefix,
|
||||
},
|
||||
);
|
||||
|
||||
if (this.options.keepNecessaryMessagesOnly) {
|
||||
conversation.messages = context;
|
||||
}
|
||||
|
||||
let reply = '';
|
||||
let result = null;
|
||||
if (typeof opts.onProgress === 'function') {
|
||||
await this.getCompletion(
|
||||
payload,
|
||||
(progressMessage) => {
|
||||
if (progressMessage === '[DONE]') {
|
||||
return;
|
||||
}
|
||||
const token = this.isChatGptModel
|
||||
? progressMessage.choices[0].delta.content
|
||||
: progressMessage.choices[0].text;
|
||||
// first event's delta content is always undefined
|
||||
if (!token) {
|
||||
return;
|
||||
}
|
||||
if (this.options.debug) {
|
||||
console.debug(token);
|
||||
}
|
||||
if (token === this.endToken) {
|
||||
return;
|
||||
}
|
||||
opts.onProgress(token);
|
||||
reply += token;
|
||||
},
|
||||
opts.abortController || new AbortController(),
|
||||
);
|
||||
} else {
|
||||
result = await this.getCompletion(
|
||||
payload,
|
||||
null,
|
||||
opts.abortController || new AbortController(),
|
||||
);
|
||||
if (this.options.debug) {
|
||||
console.debug(JSON.stringify(result));
|
||||
}
|
||||
if (this.isChatGptModel) {
|
||||
reply = result.choices[0].message.content;
|
||||
} else {
|
||||
reply = result.choices[0].text.replace(this.endToken, '');
|
||||
}
|
||||
}
|
||||
|
||||
// avoids some rendering issues when using the CLI app
|
||||
if (this.options.debug) {
|
||||
console.debug();
|
||||
}
|
||||
|
||||
reply = reply.trim();
|
||||
|
||||
const replyMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
parentMessageId: userMessage.id,
|
||||
role: 'ChatGPT',
|
||||
message: reply,
|
||||
};
|
||||
conversation.messages.push(replyMessage);
|
||||
|
||||
const returnData = {
|
||||
response: replyMessage.message,
|
||||
conversationId,
|
||||
parentMessageId: replyMessage.parentMessageId,
|
||||
messageId: replyMessage.id,
|
||||
details: result || {},
|
||||
};
|
||||
|
||||
if (shouldGenerateTitle) {
|
||||
conversation.title = await this.generateTitle(userMessage, replyMessage);
|
||||
returnData.title = conversation.title;
|
||||
}
|
||||
|
||||
await this.conversationsCache.set(conversationId, conversation);
|
||||
|
||||
if (this.options.returnConversation) {
|
||||
returnData.conversation = conversation;
|
||||
}
|
||||
|
||||
return returnData;
|
||||
}
|
||||
|
||||
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
|
||||
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
|
||||
|
||||
// Handle attachments and create augmentedPrompt
|
||||
if (this.options.attachments) {
|
||||
const attachments = await this.options.attachments;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
if (this.message_file_map) {
|
||||
this.message_file_map[lastMessage.messageId] = attachments;
|
||||
} else {
|
||||
this.message_file_map = {
|
||||
[lastMessage.messageId]: attachments,
|
||||
};
|
||||
}
|
||||
|
||||
const files = await this.addImageURLs(lastMessage, attachments);
|
||||
this.options.attachments = files;
|
||||
|
||||
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
|
||||
}
|
||||
|
||||
if (this.message_file_map) {
|
||||
this.contextHandlers = createContextHandlers(
|
||||
this.options.req,
|
||||
messages[messages.length - 1].text,
|
||||
);
|
||||
}
|
||||
|
||||
// Calculate image token cost and process embedded files
|
||||
messages.forEach((message, i) => {
|
||||
if (this.message_file_map && this.message_file_map[message.messageId]) {
|
||||
const attachments = this.message_file_map[message.messageId];
|
||||
for (const file of attachments) {
|
||||
if (file.embedded) {
|
||||
this.contextHandlers?.processFile(file);
|
||||
continue;
|
||||
}
|
||||
|
||||
messages[i].tokenCount =
|
||||
(messages[i].tokenCount || 0) +
|
||||
this.calculateImageTokenCost({
|
||||
width: file.width,
|
||||
height: file.height,
|
||||
detail: this.options.imageDetail ?? ImageDetail.auto,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (this.contextHandlers) {
|
||||
this.augmentedPrompt = await this.contextHandlers.createContext();
|
||||
promptPrefix = this.augmentedPrompt + promptPrefix;
|
||||
}
|
||||
|
||||
if (promptPrefix) {
|
||||
// If the prompt prefix doesn't end with the end token, add it.
|
||||
if (!promptPrefix.endsWith(`${this.endToken}`)) {
|
||||
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
|
||||
}
|
||||
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
|
||||
}
|
||||
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
|
||||
|
||||
const instructionsPayload = {
|
||||
role: 'system',
|
||||
content: promptPrefix,
|
||||
};
|
||||
|
||||
const messagePayload = {
|
||||
role: 'system',
|
||||
content: promptSuffix,
|
||||
};
|
||||
|
||||
let currentTokenCount;
|
||||
if (isChatGptModel) {
|
||||
currentTokenCount =
|
||||
this.getTokenCountForMessage(instructionsPayload) +
|
||||
this.getTokenCountForMessage(messagePayload);
|
||||
} else {
|
||||
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
|
||||
}
|
||||
let promptBody = '';
|
||||
const maxTokenCount = this.maxPromptTokens;
|
||||
|
||||
const context = [];
|
||||
|
||||
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
||||
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
||||
const buildPromptBody = async () => {
|
||||
if (currentTokenCount < maxTokenCount && messages.length > 0) {
|
||||
const message = messages.pop();
|
||||
const roleLabel =
|
||||
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
|
||||
? this.userLabel
|
||||
: this.chatGptLabel;
|
||||
const messageString = `${this.startToken}${roleLabel}:\n${
|
||||
message?.text ?? message?.message
|
||||
}${this.endToken}\n`;
|
||||
let newPromptBody;
|
||||
if (promptBody || isChatGptModel) {
|
||||
newPromptBody = `${messageString}${promptBody}`;
|
||||
} else {
|
||||
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
|
||||
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
|
||||
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
|
||||
// like "what's the last thing I wrote?".
|
||||
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
|
||||
}
|
||||
|
||||
context.unshift(message);
|
||||
|
||||
const tokenCountForMessage = this.getTokenCount(messageString);
|
||||
const newTokenCount = currentTokenCount + tokenCountForMessage;
|
||||
if (newTokenCount > maxTokenCount) {
|
||||
if (promptBody) {
|
||||
// This message would put us over the token limit, so don't add it.
|
||||
return false;
|
||||
}
|
||||
// This is the first message, so we can't add it. Just throw an error.
|
||||
throw new Error(
|
||||
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
|
||||
);
|
||||
}
|
||||
promptBody = newPromptBody;
|
||||
currentTokenCount = newTokenCount;
|
||||
// wait for next tick to avoid blocking the event loop
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
return buildPromptBody();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
await buildPromptBody();
|
||||
|
||||
const prompt = `${promptBody}${promptSuffix}`;
|
||||
if (isChatGptModel) {
|
||||
messagePayload.content = prompt;
|
||||
// Add 3 tokens for Assistant Label priming after all messages have been counted.
|
||||
currentTokenCount += 3;
|
||||
}
|
||||
|
||||
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
||||
this.modelOptions.max_tokens = Math.min(
|
||||
this.maxContextTokens - currentTokenCount,
|
||||
this.maxResponseTokens,
|
||||
);
|
||||
|
||||
if (isChatGptModel) {
|
||||
return { prompt: [instructionsPayload, messagePayload], context };
|
||||
}
|
||||
return { prompt, context, promptTokens: currentTokenCount };
|
||||
}
|
||||
|
||||
getTokenCount(text) {
|
||||
return this.gptEncoder.encode(text, 'all').length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Algorithm adapted from "6. Counting tokens for chat API calls" of
|
||||
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
*
|
||||
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
|
||||
*
|
||||
* @param {Object} message
|
||||
*/
|
||||
getTokenCountForMessage(message) {
|
||||
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
|
||||
let tokensPerMessage = 3;
|
||||
let tokensPerName = 1;
|
||||
|
||||
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
|
||||
tokensPerMessage = 4;
|
||||
tokensPerName = -1;
|
||||
}
|
||||
|
||||
let numTokens = tokensPerMessage;
|
||||
for (let [key, value] of Object.entries(message)) {
|
||||
numTokens += this.getTokenCount(value);
|
||||
if (key === 'name') {
|
||||
numTokens += tokensPerName;
|
||||
}
|
||||
}
|
||||
|
||||
return numTokens;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = ChatGPTClient;
|
||||
@@ -1,7 +1,7 @@
|
||||
const { google } = require('googleapis');
|
||||
const { Tokenizer } = require('@librechat/api');
|
||||
const { concat } = require('@langchain/core/utils/stream');
|
||||
const { ChatVertexAI } = require('@langchain/google-vertexai');
|
||||
const { Tokenizer, getSafetySettings } = require('@librechat/api');
|
||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
||||
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
||||
@@ -12,13 +12,13 @@ const {
|
||||
endpointSettings,
|
||||
parseTextParts,
|
||||
EModelEndpoint,
|
||||
googleSettings,
|
||||
ContentTypes,
|
||||
VisionModes,
|
||||
ErrorTypes,
|
||||
Constants,
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
@@ -166,16 +166,6 @@ class GoogleClient extends BaseClient {
|
||||
);
|
||||
}
|
||||
|
||||
// Add thinking configuration
|
||||
this.modelOptions.thinkingConfig = {
|
||||
thinkingBudget:
|
||||
(this.modelOptions.thinking ?? googleSettings.thinking.default)
|
||||
? this.modelOptions.thinkingBudget
|
||||
: 0,
|
||||
};
|
||||
delete this.modelOptions.thinking;
|
||||
delete this.modelOptions.thinkingBudget;
|
||||
|
||||
this.sender =
|
||||
this.options.sender ??
|
||||
getResponseSender({
|
||||
|
||||
@@ -5,7 +5,6 @@ const {
|
||||
isEnabled,
|
||||
Tokenizer,
|
||||
createFetch,
|
||||
resolveHeaders,
|
||||
constructAzureURL,
|
||||
genAzureChatCompletion,
|
||||
createStreamEventHandlers,
|
||||
@@ -16,6 +15,7 @@ const {
|
||||
ContentTypes,
|
||||
parseTextParts,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
KnownEndpoints,
|
||||
openAISettings,
|
||||
ImageDetailCost,
|
||||
@@ -37,6 +37,7 @@ const { addSpaceIfNeeded, sleep } = require('~/server/utils');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const { createLLM, RunManager } = require('./llm');
|
||||
const ChatGPTClient = require('./ChatGPTClient');
|
||||
const { summaryBuffer } = require('./memory');
|
||||
const { runTitleChain } = require('./chains');
|
||||
const { tokenSplit } = require('./document');
|
||||
@@ -46,6 +47,12 @@ const { logger } = require('~/config');
|
||||
class OpenAIClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
super(apiKey, options);
|
||||
this.ChatGPTClient = new ChatGPTClient();
|
||||
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
|
||||
/** @type {getCompletion} */
|
||||
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
|
||||
/** @type {cohereChatCompletion} */
|
||||
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
|
||||
this.contextStrategy = options.contextStrategy
|
||||
? options.contextStrategy.toLowerCase()
|
||||
: 'discard';
|
||||
@@ -372,12 +379,23 @@ class OpenAIClient extends BaseClient {
|
||||
return files;
|
||||
}
|
||||
|
||||
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
|
||||
async buildMessages(
|
||||
messages,
|
||||
parentMessageId,
|
||||
{ isChatCompletion = false, promptPrefix = null },
|
||||
opts,
|
||||
) {
|
||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
summary: this.shouldSummarize,
|
||||
});
|
||||
if (!isChatCompletion) {
|
||||
return await this.buildPrompt(orderedMessages, {
|
||||
isChatGptModel: isChatCompletion,
|
||||
promptPrefix,
|
||||
});
|
||||
}
|
||||
|
||||
let payload;
|
||||
let instructions;
|
||||
@@ -653,10 +671,8 @@ class OpenAIClient extends BaseClient {
|
||||
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
|
||||
configOptions.baseOptions = {
|
||||
headers: resolveHeaders({
|
||||
headers: {
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
},
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
}),
|
||||
};
|
||||
}
|
||||
@@ -751,7 +767,7 @@ class OpenAIClient extends BaseClient {
|
||||
groupMap,
|
||||
});
|
||||
|
||||
this.options.headers = resolveHeaders({ headers });
|
||||
this.options.headers = resolveHeaders(headers);
|
||||
this.options.reverseProxyUrl = baseURL ?? null;
|
||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
@@ -1183,7 +1199,7 @@ ${convo}
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
opts.defaultHeaders = resolveHeaders({ headers });
|
||||
opts.defaultHeaders = resolveHeaders(headers);
|
||||
this.langchainProxy = extractBaseURL(baseURL);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
@@ -1224,9 +1240,7 @@ ${convo}
|
||||
}
|
||||
|
||||
if (this.isOmni === true && modelOptions.max_tokens != null) {
|
||||
const paramName =
|
||||
modelOptions.useResponsesApi === true ? 'max_output_tokens' : 'max_completion_tokens';
|
||||
modelOptions[paramName] = modelOptions.max_tokens;
|
||||
modelOptions.max_completion_tokens = modelOptions.max_tokens;
|
||||
delete modelOptions.max_tokens;
|
||||
}
|
||||
if (this.isOmni === true && modelOptions.temperature != null) {
|
||||
|
||||
542
api/app/clients/PluginsClient.js
Normal file
542
api/app/clients/PluginsClient.js
Normal file
@@ -0,0 +1,542 @@
|
||||
const OpenAIClient = require('./OpenAIClient');
|
||||
const { CallbackManager } = require('@langchain/core/callbacks/manager');
|
||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||
const { processFileURL } = require('~/server/services/Files/process');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { formatLangChainMessages } = require('./prompts');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { loadTools } = require('./tools/util');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class PluginsClient extends OpenAIClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
super(apiKey, options);
|
||||
this.sender = options.sender ?? 'Assistant';
|
||||
this.tools = [];
|
||||
this.actions = [];
|
||||
this.setOptions(options);
|
||||
this.openAIApiKey = this.apiKey;
|
||||
this.executor = null;
|
||||
}
|
||||
|
||||
setOptions(options) {
|
||||
this.agentOptions = { ...options.agentOptions };
|
||||
this.functionsAgent = this.agentOptions?.agent === 'functions';
|
||||
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
|
||||
|
||||
super.setOptions(options);
|
||||
|
||||
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
||||
}
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
artifacts: this.options.artifacts,
|
||||
chatGptLabel: this.options.chatGptLabel,
|
||||
modelLabel: this.options.modelLabel,
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
tools: this.options.tools,
|
||||
...this.modelOptions,
|
||||
agentOptions: this.agentOptions,
|
||||
iconURL: this.options.iconURL,
|
||||
greeting: this.options.greeting,
|
||||
spec: this.options.spec,
|
||||
};
|
||||
}
|
||||
|
||||
saveLatestAction(action) {
|
||||
this.actions.push(action);
|
||||
}
|
||||
|
||||
getFunctionModelName(input) {
|
||||
if (/-(?!0314)\d{4}/.test(input)) {
|
||||
return input;
|
||||
} else if (input.includes('gpt-3.5-turbo')) {
|
||||
return 'gpt-3.5-turbo';
|
||||
} else if (input.includes('gpt-4')) {
|
||||
return 'gpt-4';
|
||||
} else {
|
||||
return 'gpt-3.5-turbo';
|
||||
}
|
||||
}
|
||||
|
||||
getBuildMessagesOptions(opts) {
|
||||
return {
|
||||
isChatCompletion: true,
|
||||
promptPrefix: opts.promptPrefix,
|
||||
abortController: opts.abortController,
|
||||
};
|
||||
}
|
||||
|
||||
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
|
||||
const modelOptions = {
|
||||
modelName: this.agentOptions.model,
|
||||
temperature: this.agentOptions.temperature,
|
||||
};
|
||||
|
||||
const model = this.initializeLLM({
|
||||
...modelOptions,
|
||||
context: 'plugins',
|
||||
initialMessageCount: this.currentMessages.length + 1,
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
|
||||
);
|
||||
|
||||
// Map Messages to Langchain format
|
||||
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
|
||||
userName: this.options?.name,
|
||||
});
|
||||
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
|
||||
|
||||
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
|
||||
const memory = new BufferMemory({
|
||||
llm: model,
|
||||
chatHistory: new ChatMessageHistory(pastMessages),
|
||||
});
|
||||
|
||||
const { loadedTools } = await loadTools({
|
||||
user,
|
||||
model,
|
||||
tools: this.options.tools,
|
||||
functions: this.functionsAgent,
|
||||
options: {
|
||||
memory,
|
||||
signal: this.abortController.signal,
|
||||
openAIApiKey: this.openAIApiKey,
|
||||
conversationId: this.conversationId,
|
||||
fileStrategy: this.options.req.app.locals.fileStrategy,
|
||||
processFileURL,
|
||||
message,
|
||||
},
|
||||
useSpecs: true,
|
||||
});
|
||||
|
||||
if (loadedTools.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.tools = loadedTools;
|
||||
|
||||
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
|
||||
logger.debug(
|
||||
'[PluginsClient] Loaded Tools',
|
||||
this.tools.map((tool) => tool.name),
|
||||
);
|
||||
|
||||
const handleAction = (action, runId, callback = null) => {
|
||||
this.saveLatestAction(action);
|
||||
|
||||
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
|
||||
|
||||
if (typeof callback === 'function') {
|
||||
callback(action, runId);
|
||||
}
|
||||
};
|
||||
|
||||
// initialize agent
|
||||
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
|
||||
|
||||
let customInstructions = (this.options.promptPrefix ?? '').trim();
|
||||
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
|
||||
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
|
||||
}
|
||||
|
||||
this.executor = await initializer({
|
||||
model,
|
||||
signal,
|
||||
pastMessages,
|
||||
tools: this.tools,
|
||||
customInstructions,
|
||||
verbose: this.options.debug,
|
||||
returnIntermediateSteps: true,
|
||||
customName: this.options.chatGptLabel,
|
||||
currentDateString: this.currentDateString,
|
||||
callbackManager: CallbackManager.fromHandlers({
|
||||
async handleAgentAction(action, runId) {
|
||||
handleAction(action, runId, onAgentAction);
|
||||
},
|
||||
async handleChainEnd(action) {
|
||||
if (typeof onChainEnd === 'function') {
|
||||
onChainEnd(action);
|
||||
}
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
logger.debug('[PluginsClient] Loaded agent.');
|
||||
}
|
||||
|
||||
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
|
||||
let errorMessage = '';
|
||||
const maxAttempts = 1;
|
||||
|
||||
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
|
||||
const errorInput = buildErrorInput({
|
||||
message,
|
||||
errorMessage,
|
||||
actions: this.actions,
|
||||
functionsAgent: this.functionsAgent,
|
||||
});
|
||||
const input = attempts > 1 ? errorInput : message;
|
||||
|
||||
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
|
||||
|
||||
if (errorMessage.length > 0) {
|
||||
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
|
||||
}
|
||||
|
||||
try {
|
||||
this.result = await this.executor.call({ input, signal }, [
|
||||
{
|
||||
async handleToolStart(...args) {
|
||||
await onToolStart(...args);
|
||||
},
|
||||
async handleToolEnd(...args) {
|
||||
await onToolEnd(...args);
|
||||
},
|
||||
async handleLLMEnd(output) {
|
||||
const { generations } = output;
|
||||
const { text } = generations[0][0];
|
||||
if (text && typeof stream === 'function') {
|
||||
await stream(text);
|
||||
}
|
||||
},
|
||||
},
|
||||
]);
|
||||
break; // Exit the loop if the function call is successful
|
||||
} catch (err) {
|
||||
logger.error('[PluginsClient] executorCall error:', err);
|
||||
if (attempts === maxAttempts) {
|
||||
const { run } = this.runManager.getRunByConversationId(this.conversationId);
|
||||
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
|
||||
this.result.output = run && run.error ? run.error : defaultOutput;
|
||||
this.result.errorMessage = run && run.error ? run.error : err.message;
|
||||
this.result.intermediateSteps = this.actions;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} responseMessage
|
||||
* @param {Partial<TMessage>} saveOptions
|
||||
* @param {string} user
|
||||
* @returns
|
||||
*/
|
||||
async handleResponseMessage(responseMessage, saveOptions, user) {
|
||||
const { output, errorMessage, ...result } = this.result;
|
||||
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
|
||||
output,
|
||||
errorMessage,
|
||||
...result,
|
||||
});
|
||||
const { error } = responseMessage;
|
||||
if (!error) {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
|
||||
}
|
||||
|
||||
// Record usage only when completion is skipped as it is already recorded in the agent phase.
|
||||
if (!this.agentOptions.skipCompletion && !error) {
|
||||
await this.recordTokenUsage(responseMessage);
|
||||
}
|
||||
|
||||
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
delete responseMessage.tokenCount;
|
||||
return { ...responseMessage, ...result, databasePromise };
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
/** @type {Promise<TMessage>} */
|
||||
let userMessagePromise;
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
} else {
|
||||
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
}
|
||||
|
||||
// If a message is edited, no tools can be used.
|
||||
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
||||
if (completionMode) {
|
||||
this.setOptions(opts);
|
||||
return super.sendMessage(message, opts);
|
||||
}
|
||||
|
||||
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
||||
const {
|
||||
user,
|
||||
conversationId,
|
||||
responseMessageId,
|
||||
saveOptions,
|
||||
userMessage,
|
||||
onAgentAction,
|
||||
onChainEnd,
|
||||
onToolStart,
|
||||
onToolEnd,
|
||||
} = await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
this.currentMessages.push(userMessage);
|
||||
|
||||
let {
|
||||
prompt: payload,
|
||||
tokenCountMap,
|
||||
promptTokens,
|
||||
} = await this.buildMessages(
|
||||
this.currentMessages,
|
||||
userMessage.messageId,
|
||||
this.getBuildMessagesOptions({
|
||||
promptPrefix: null,
|
||||
abortController: this.abortController,
|
||||
}),
|
||||
);
|
||||
|
||||
if (tokenCountMap) {
|
||||
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
|
||||
if (tokenCountMap[userMessage.messageId]) {
|
||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
||||
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
|
||||
}
|
||||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
this.result = {};
|
||||
if (payload) {
|
||||
this.currentMessages = payload;
|
||||
}
|
||||
|
||||
if (!this.skipSaveUserMessage) {
|
||||
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const balance = this.options.req?.app?.locals?.balance;
|
||||
if (balance?.enabled) {
|
||||
await checkBalance({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
txData: {
|
||||
user: this.user,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
debug: this.options.debug,
|
||||
model: this.modelOptions.model,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const responseMessage = {
|
||||
endpoint: EModelEndpoint.gptPlugins,
|
||||
iconURL: this.options.iconURL,
|
||||
messageId: responseMessageId,
|
||||
conversationId,
|
||||
parentMessageId: userMessage.messageId,
|
||||
isCreatedByUser: false,
|
||||
model: this.modelOptions.model,
|
||||
sender: this.sender,
|
||||
promptTokens,
|
||||
};
|
||||
|
||||
await this.initialize({
|
||||
user,
|
||||
message,
|
||||
onAgentAction,
|
||||
onChainEnd,
|
||||
signal: this.abortController.signal,
|
||||
onProgress: opts.onProgress,
|
||||
});
|
||||
|
||||
// const stream = async (text) => {
|
||||
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
|
||||
// };
|
||||
await this.executorCall(message, {
|
||||
signal: this.abortController.signal,
|
||||
// stream,
|
||||
onToolStart,
|
||||
onToolEnd,
|
||||
});
|
||||
|
||||
// If message was aborted mid-generation
|
||||
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
|
||||
responseMessage.text = 'Cancelled.';
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
// If error occurred during generation (likely token_balance)
|
||||
if (this.result?.errorMessage?.length > 0) {
|
||||
responseMessage.error = true;
|
||||
responseMessage.text = this.result.output;
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
|
||||
const partialText = opts.getPartialText();
|
||||
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
|
||||
responseMessage.text =
|
||||
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
|
||||
addImages(this.result.intermediateSteps, responseMessage);
|
||||
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
if (this.agentOptions.skipCompletion && this.result.output) {
|
||||
responseMessage.text = this.result.output;
|
||||
addImages(this.result.intermediateSteps, responseMessage);
|
||||
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
|
||||
|
||||
const promptPrefix = buildPromptPrefix({
|
||||
result: this.result,
|
||||
message,
|
||||
functionsAgent: this.functionsAgent,
|
||||
});
|
||||
|
||||
logger.debug('[PluginsClient]', { promptPrefix });
|
||||
|
||||
payload = await this.buildCompletionPrompt({
|
||||
messages: this.currentMessages,
|
||||
promptPrefix,
|
||||
});
|
||||
|
||||
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
|
||||
responseMessage.text = await this.sendCompletion(payload, opts);
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
|
||||
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
|
||||
|
||||
const orderedMessages = messages;
|
||||
let promptPrefix = _promptPrefix.trim();
|
||||
// If the prompt prefix doesn't end with the end token, add it.
|
||||
if (!promptPrefix.endsWith(`${this.endToken}`)) {
|
||||
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
|
||||
}
|
||||
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
|
||||
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
|
||||
|
||||
const instructionsPayload = {
|
||||
role: 'system',
|
||||
content: promptPrefix,
|
||||
};
|
||||
|
||||
const messagePayload = {
|
||||
role: 'system',
|
||||
content: promptSuffix,
|
||||
};
|
||||
|
||||
if (this.isGpt3) {
|
||||
instructionsPayload.role = 'user';
|
||||
messagePayload.role = 'user';
|
||||
instructionsPayload.content += `\n${promptSuffix}`;
|
||||
}
|
||||
|
||||
// testing if this works with browser endpoint
|
||||
if (!this.isGpt3 && this.options.reverseProxyUrl) {
|
||||
instructionsPayload.role = 'user';
|
||||
}
|
||||
|
||||
let currentTokenCount =
|
||||
this.getTokenCountForMessage(instructionsPayload) +
|
||||
this.getTokenCountForMessage(messagePayload);
|
||||
|
||||
let promptBody = '';
|
||||
const maxTokenCount = this.maxPromptTokens;
|
||||
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
||||
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
||||
const buildPromptBody = async () => {
|
||||
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
|
||||
const message = orderedMessages.pop();
|
||||
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
|
||||
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
|
||||
let messageString = `${this.startToken}${roleLabel}:\n${
|
||||
message.text ?? message.content ?? ''
|
||||
}${this.endToken}\n`;
|
||||
let newPromptBody = `${messageString}${promptBody}`;
|
||||
|
||||
const tokenCountForMessage = this.getTokenCount(messageString);
|
||||
const newTokenCount = currentTokenCount + tokenCountForMessage;
|
||||
if (newTokenCount > maxTokenCount) {
|
||||
if (promptBody) {
|
||||
// This message would put us over the token limit, so don't add it.
|
||||
return false;
|
||||
}
|
||||
// This is the first message, so we can't add it. Just throw an error.
|
||||
throw new Error(
|
||||
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
|
||||
);
|
||||
}
|
||||
promptBody = newPromptBody;
|
||||
currentTokenCount = newTokenCount;
|
||||
// wait for next tick to avoid blocking the event loop
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
return buildPromptBody();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
await buildPromptBody();
|
||||
const prompt = promptBody;
|
||||
messagePayload.content = prompt;
|
||||
// Add 2 tokens for metadata after all messages have been counted.
|
||||
currentTokenCount += 2;
|
||||
|
||||
if (this.isGpt3 && messagePayload.content.length > 0) {
|
||||
const context = 'Chat History:\n';
|
||||
messagePayload.content = `${context}${prompt}`;
|
||||
currentTokenCount += this.getTokenCount(context);
|
||||
}
|
||||
|
||||
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
|
||||
this.modelOptions.max_tokens = Math.min(
|
||||
this.maxContextTokens - currentTokenCount,
|
||||
this.maxResponseTokens,
|
||||
);
|
||||
|
||||
if (this.isGpt3) {
|
||||
messagePayload.content += promptSuffix;
|
||||
return [instructionsPayload, messagePayload];
|
||||
}
|
||||
|
||||
const result = [messagePayload, instructionsPayload];
|
||||
|
||||
if (this.functionsAgent && !this.isGpt3) {
|
||||
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
|
||||
}
|
||||
|
||||
return result.filter((message) => message.content.length > 0);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = PluginsClient;
|
||||
@@ -1,11 +1,15 @@
|
||||
const ChatGPTClient = require('./ChatGPTClient');
|
||||
const OpenAIClient = require('./OpenAIClient');
|
||||
const PluginsClient = require('./PluginsClient');
|
||||
const GoogleClient = require('./GoogleClient');
|
||||
const TextStream = require('./TextStream');
|
||||
const AnthropicClient = require('./AnthropicClient');
|
||||
const toolUtils = require('./tools/util');
|
||||
|
||||
module.exports = {
|
||||
ChatGPTClient,
|
||||
OpenAIClient,
|
||||
PluginsClient,
|
||||
GoogleClient,
|
||||
TextStream,
|
||||
AnthropicClient,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const axios = require('axios');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const footer = `Use the context as your learned knowledge to better answer the user.
|
||||
|
||||
@@ -19,7 +18,7 @@ function createContextHandlers(req, userMessageContent) {
|
||||
const queryPromises = [];
|
||||
const processedFiles = [];
|
||||
const processedIds = new Set();
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
|
||||
|
||||
const query = async (file) => {
|
||||
@@ -97,35 +96,35 @@ function createContextHandlers(req, userMessageContent) {
|
||||
resolvedQueries.length === 0
|
||||
? '\n\tThe semantic search did not return any results.'
|
||||
: resolvedQueries
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
let contextItems = queryResult.data;
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
let contextItems = queryResult.data;
|
||||
|
||||
const generateContext = (currentContext) =>
|
||||
`
|
||||
const generateContext = (currentContext) =>
|
||||
`
|
||||
<file>
|
||||
<filename>${file.filename}</filename>
|
||||
<context>${currentContext}
|
||||
</context>
|
||||
</file>`;
|
||||
|
||||
if (useFullContext) {
|
||||
return generateContext(`\n${contextItems}`);
|
||||
}
|
||||
if (useFullContext) {
|
||||
return generateContext(`\n${contextItems}`);
|
||||
}
|
||||
|
||||
contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
<contextItem>
|
||||
<![CDATA[${pageContent?.trim()}]]>
|
||||
</contextItem>`;
|
||||
})
|
||||
.join('');
|
||||
})
|
||||
.join('');
|
||||
|
||||
return generateContext(contextItems);
|
||||
})
|
||||
.join('');
|
||||
return generateContext(contextItems);
|
||||
})
|
||||
.join('');
|
||||
|
||||
if (useFullContext) {
|
||||
const prompt = `${header}
|
||||
|
||||
@@ -237,9 +237,41 @@ const formatAgentMessages = (payload) => {
|
||||
return messages;
|
||||
};
|
||||
|
||||
/**
|
||||
* Formats an array of messages for LangChain, making sure all content fields are strings
|
||||
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
|
||||
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
|
||||
*/
|
||||
const formatContentStrings = (payload) => {
|
||||
const messages = [];
|
||||
|
||||
for (const message of payload) {
|
||||
if (typeof message.content === 'string') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!Array.isArray(message.content)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reduce text types to a single string, ignore all other types
|
||||
const content = message.content.reduce((acc, curr) => {
|
||||
if (curr.type === ContentTypes.TEXT) {
|
||||
return `${acc}${curr[ContentTypes.TEXT]}\n`;
|
||||
}
|
||||
return acc;
|
||||
}, '');
|
||||
|
||||
message.content = content.trim();
|
||||
}
|
||||
|
||||
return messages;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
formatMessage,
|
||||
formatFromLangChain,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
formatLangChainMessages,
|
||||
};
|
||||
|
||||
@@ -422,46 +422,6 @@ describe('BaseClient', () => {
|
||||
expect(response).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
|
||||
const mockCrypto = require('crypto');
|
||||
const newUUID = 'new-uuid-1234';
|
||||
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
|
||||
|
||||
const opts = {
|
||||
isRegenerate: true,
|
||||
responseMessageId: 'existing-message-id_',
|
||||
};
|
||||
|
||||
await TestClient.setMessageOptions(opts);
|
||||
|
||||
expect(TestClient.responseMessageId).toBe(newUUID);
|
||||
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
|
||||
|
||||
mockCrypto.randomUUID.mockRestore();
|
||||
});
|
||||
|
||||
test('should not replace responseMessageId when isRegenerate is false', async () => {
|
||||
const opts = {
|
||||
isRegenerate: false,
|
||||
responseMessageId: 'existing-message-id_',
|
||||
};
|
||||
|
||||
await TestClient.setMessageOptions(opts);
|
||||
|
||||
expect(TestClient.responseMessageId).toBe('existing-message-id_');
|
||||
});
|
||||
|
||||
test('should not replace responseMessageId when it does not end with underscore', async () => {
|
||||
const opts = {
|
||||
isRegenerate: true,
|
||||
responseMessageId: 'existing-message-id',
|
||||
};
|
||||
|
||||
await TestClient.setMessageOptions(opts);
|
||||
|
||||
expect(TestClient.responseMessageId).toBe('existing-message-id');
|
||||
});
|
||||
|
||||
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
|
||||
const userMessage = 'Second message in the conversation';
|
||||
const opts = {
|
||||
|
||||
@@ -531,6 +531,44 @@ describe('OpenAIClient', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage/getCompletion/chatCompletion', () => {
|
||||
afterEach(() => {
|
||||
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
||||
});
|
||||
|
||||
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
|
||||
const model = 'text-davinci-003';
|
||||
const onProgress = jest.fn().mockImplementation(() => ({}));
|
||||
|
||||
const testClient = new OpenAIClient('test-api-key', {
|
||||
...defaultOptions,
|
||||
modelOptions: { model },
|
||||
});
|
||||
|
||||
const getCompletion = jest.spyOn(testClient, 'getCompletion');
|
||||
await testClient.sendMessage('Hi mom!', { onProgress });
|
||||
|
||||
expect(getCompletion).toHaveBeenCalled();
|
||||
expect(getCompletion.mock.calls.length).toBe(1);
|
||||
|
||||
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
|
||||
|
||||
expect(fetchEventSource).toHaveBeenCalled();
|
||||
expect(fetchEventSource.mock.calls.length).toBe(1);
|
||||
|
||||
// Check if the first argument (url) is correct
|
||||
const firstCallArgs = fetchEventSource.mock.calls[0];
|
||||
|
||||
const expectedURL = 'https://api.openai.com/v1/completions';
|
||||
expect(firstCallArgs[0]).toBe(expectedURL);
|
||||
|
||||
const requestBody = JSON.parse(firstCallArgs[1].body);
|
||||
expect(requestBody).toHaveProperty('model');
|
||||
expect(requestBody.model).toBe(model);
|
||||
});
|
||||
});
|
||||
|
||||
describe('checkVisionRequest functionality', () => {
|
||||
let client;
|
||||
const attachments = [{ type: 'image/png' }];
|
||||
|
||||
314
api/app/clients/specs/PluginsClient.test.js
Normal file
314
api/app/clients/specs/PluginsClient.test.js
Normal file
@@ -0,0 +1,314 @@
|
||||
const crypto = require('crypto');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
|
||||
const PluginsClient = require('../PluginsClient');
|
||||
|
||||
jest.mock('~/db/connect');
|
||||
jest.mock('~/models/Conversation', () => {
|
||||
return function () {
|
||||
return {
|
||||
save: jest.fn(),
|
||||
deleteConvos: jest.fn(),
|
||||
};
|
||||
};
|
||||
});
|
||||
|
||||
const defaultAzureOptions = {
|
||||
azureOpenAIApiInstanceName: 'your-instance-name',
|
||||
azureOpenAIApiDeploymentName: 'your-deployment-name',
|
||||
azureOpenAIApiVersion: '2020-07-01-preview',
|
||||
};
|
||||
|
||||
describe('PluginsClient', () => {
|
||||
let TestAgent;
|
||||
let options = {
|
||||
tools: [],
|
||||
modelOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
temperature: 0,
|
||||
max_tokens: 2,
|
||||
},
|
||||
agentOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
},
|
||||
};
|
||||
let parentMessageId;
|
||||
let conversationId;
|
||||
const fakeMessages = [];
|
||||
const userMessage = 'Hello, ChatGPT!';
|
||||
const apiKey = 'fake-api-key';
|
||||
|
||||
beforeEach(() => {
|
||||
TestAgent = new PluginsClient(apiKey, options);
|
||||
TestAgent.loadHistory = jest
|
||||
.fn()
|
||||
.mockImplementation((conversationId, parentMessageId = null) => {
|
||||
if (!conversationId) {
|
||||
TestAgent.currentMessages = [];
|
||||
return Promise.resolve([]);
|
||||
}
|
||||
|
||||
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
|
||||
messages: fakeMessages,
|
||||
parentMessageId,
|
||||
});
|
||||
|
||||
const chatMessages = orderedMessages.map((msg) =>
|
||||
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
|
||||
? new HumanMessage(msg.text)
|
||||
: new AIMessage(msg.text),
|
||||
);
|
||||
|
||||
TestAgent.currentMessages = orderedMessages;
|
||||
return Promise.resolve(chatMessages);
|
||||
});
|
||||
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
|
||||
if (opts && typeof opts === 'object') {
|
||||
TestAgent.setOptions(opts);
|
||||
}
|
||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
|
||||
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
|
||||
this.pastMessages = await TestAgent.loadHistory(
|
||||
conversationId,
|
||||
TestAgent.options?.parentMessageId,
|
||||
);
|
||||
|
||||
const userMessage = {
|
||||
text: message,
|
||||
sender: 'ChatGPT',
|
||||
isCreatedByUser: true,
|
||||
messageId: userMessageId,
|
||||
parentMessageId,
|
||||
conversationId,
|
||||
};
|
||||
|
||||
const response = {
|
||||
sender: 'ChatGPT',
|
||||
text: 'Hello, User!',
|
||||
isCreatedByUser: false,
|
||||
messageId: crypto.randomUUID(),
|
||||
parentMessageId: userMessage.messageId,
|
||||
conversationId,
|
||||
};
|
||||
|
||||
fakeMessages.push(userMessage);
|
||||
fakeMessages.push(response);
|
||||
return response;
|
||||
});
|
||||
});
|
||||
|
||||
test('initializes PluginsClient without crashing', () => {
|
||||
expect(TestAgent).toBeInstanceOf(PluginsClient);
|
||||
});
|
||||
|
||||
test('check setOptions function', () => {
|
||||
expect(TestAgent.agentIsGpt3).toBe(true);
|
||||
});
|
||||
|
||||
describe('sendMessage', () => {
|
||||
test('sendMessage should return a response message', async () => {
|
||||
const expectedResult = expect.objectContaining({
|
||||
sender: 'ChatGPT',
|
||||
text: expect.any(String),
|
||||
isCreatedByUser: false,
|
||||
messageId: expect.any(String),
|
||||
parentMessageId: expect.any(String),
|
||||
conversationId: expect.any(String),
|
||||
});
|
||||
|
||||
const response = await TestAgent.sendMessage(userMessage);
|
||||
parentMessageId = response.messageId;
|
||||
conversationId = response.conversationId;
|
||||
expect(response).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
|
||||
const userMessage = 'Second message in the conversation';
|
||||
const opts = {
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
};
|
||||
|
||||
const expectedResult = expect.objectContaining({
|
||||
sender: 'ChatGPT',
|
||||
text: expect.any(String),
|
||||
isCreatedByUser: false,
|
||||
messageId: expect.any(String),
|
||||
parentMessageId: expect.any(String),
|
||||
conversationId: opts.conversationId,
|
||||
});
|
||||
|
||||
const response = await TestAgent.sendMessage(userMessage, opts);
|
||||
parentMessageId = response.messageId;
|
||||
expect(response.conversationId).toEqual(conversationId);
|
||||
expect(response).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
test('should return chat history', async () => {
|
||||
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
|
||||
expect(TestAgent.currentMessages).toHaveLength(4);
|
||||
expect(chatMessages[0].text).toEqual(userMessage);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFunctionModelName', () => {
|
||||
let client;
|
||||
|
||||
beforeEach(() => {
|
||||
client = new PluginsClient('dummy_api_key');
|
||||
});
|
||||
|
||||
test('should return the input when it includes a dash followed by four digits', () => {
|
||||
expect(client.getFunctionModelName('-1234')).toBe('-1234');
|
||||
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
|
||||
});
|
||||
|
||||
test('should return the input for all function-capable models (`0613` models and above)', () => {
|
||||
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
|
||||
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
|
||||
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
|
||||
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
|
||||
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
||||
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
|
||||
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
|
||||
});
|
||||
|
||||
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
|
||||
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
|
||||
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
|
||||
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
|
||||
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
|
||||
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
|
||||
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
|
||||
test('should return "gpt-4" when the input includes "gpt-4"', () => {
|
||||
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
|
||||
});
|
||||
|
||||
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
|
||||
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
|
||||
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Azure OpenAI tests specific to Plugins', () => {
|
||||
// TODO: add more tests for Azure OpenAI integration with Plugins
|
||||
// let client;
|
||||
// beforeEach(() => {
|
||||
// client = new PluginsClient('dummy_api_key');
|
||||
// });
|
||||
|
||||
test('should not call getFunctionModelName when azure options are set', () => {
|
||||
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
|
||||
const model = 'gpt-4-turbo';
|
||||
|
||||
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
|
||||
const testClient = new PluginsClient('dummy_api_key', {
|
||||
agentOptions: {
|
||||
model,
|
||||
agent: 'functions',
|
||||
},
|
||||
azure: defaultAzureOptions,
|
||||
});
|
||||
|
||||
expect(spy).not.toHaveBeenCalled();
|
||||
expect(testClient.agentOptions.model).toBe(model);
|
||||
|
||||
spy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage with filtered tools', () => {
|
||||
let TestAgent;
|
||||
const apiKey = 'fake-api-key';
|
||||
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
|
||||
|
||||
beforeEach(() => {
|
||||
TestAgent = new PluginsClient(apiKey, {
|
||||
tools: mockTools,
|
||||
modelOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
temperature: 0,
|
||||
max_tokens: 2,
|
||||
},
|
||||
agentOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
},
|
||||
});
|
||||
|
||||
TestAgent.options.req = {
|
||||
app: {
|
||||
locals: {},
|
||||
},
|
||||
};
|
||||
|
||||
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
|
||||
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = TestAgent.options.tools.filter((plugin) =>
|
||||
includedTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
} else {
|
||||
const tools = TestAgent.options.tools.filter(
|
||||
(plugin) => !filteredTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
}
|
||||
|
||||
return {
|
||||
text: 'Mocked response',
|
||||
tools: TestAgent.options.tools,
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
test('should filter out tools when filteredTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should only include specified tools when includedTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should prioritize includedTools over filteredTools', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool1' }),
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should not modify tools when no filters are provided', async () => {
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(4);
|
||||
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,8 +3,8 @@ const path = require('path');
|
||||
const OpenAI = require('openai');
|
||||
const fetch = require('node-fetch');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { FileContext, ContentTypes } = require('librechat-data-provider');
|
||||
const { getImageBasename } = require('~/server/services/Files/images');
|
||||
const extractBaseURL = require('~/utils/extractBaseURL');
|
||||
@@ -46,10 +46,7 @@ class DALLE3 extends Tool {
|
||||
}
|
||||
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
config.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
config.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
/** @type {OpenAI} */
|
||||
@@ -166,8 +163,7 @@ Error Message: ${error.message}`);
|
||||
if (this.isAgent) {
|
||||
let fetchOptions = {};
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
fetchOptions.dispatcher = proxyAgent;
|
||||
fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
const imageResponse = await fetch(theImageUrl, fetchOptions);
|
||||
const arrayBuffer = await imageResponse.arrayBuffer();
|
||||
|
||||
@@ -3,10 +3,10 @@ const axios = require('axios');
|
||||
const { v4 } = require('uuid');
|
||||
const OpenAI = require('openai');
|
||||
const FormData = require('form-data');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logAxiosError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
@@ -107,12 +107,6 @@ const getImageEditPromptDescription = () => {
|
||||
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
|
||||
};
|
||||
|
||||
function createAbortHandler() {
|
||||
return function () {
|
||||
logger.debug('[ImageGenOAI] Image generation aborted');
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates OpenAI Image tools (generation and editing)
|
||||
* @param {Object} fields - Configuration fields
|
||||
@@ -189,10 +183,7 @@ function createOpenAIImageTools(fields = {}) {
|
||||
}
|
||||
const clientConfig = { ...closureConfig };
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
clientConfig.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
/** @type {OpenAI} */
|
||||
@@ -210,18 +201,10 @@ function createOpenAIImageTools(fields = {}) {
|
||||
}
|
||||
|
||||
let resp;
|
||||
/** @type {AbortSignal} */
|
||||
let derivedSignal = null;
|
||||
/** @type {() => void} */
|
||||
let abortHandler = null;
|
||||
|
||||
try {
|
||||
if (runnableConfig?.signal) {
|
||||
derivedSignal = AbortSignal.any([runnableConfig.signal]);
|
||||
abortHandler = createAbortHandler();
|
||||
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
|
||||
const derivedSignal = runnableConfig?.signal
|
||||
? AbortSignal.any([runnableConfig.signal])
|
||||
: undefined;
|
||||
resp = await openai.images.generate(
|
||||
{
|
||||
model: 'gpt-image-1',
|
||||
@@ -245,10 +228,6 @@ function createOpenAIImageTools(fields = {}) {
|
||||
logAxiosError({ error, message });
|
||||
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
|
||||
Error Message: ${error.message}`);
|
||||
} finally {
|
||||
if (abortHandler && derivedSignal) {
|
||||
derivedSignal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
}
|
||||
|
||||
if (!resp) {
|
||||
@@ -338,10 +317,7 @@ Error Message: ${error.message}`);
|
||||
|
||||
const clientConfig = { ...closureConfig };
|
||||
if (process.env.PROXY) {
|
||||
const proxyAgent = new ProxyAgent(process.env.PROXY);
|
||||
clientConfig.fetchOptions = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const formData = new FormData();
|
||||
@@ -433,17 +409,10 @@ Error Message: ${error.message}`);
|
||||
headers['Authorization'] = `Bearer ${apiKey}`;
|
||||
}
|
||||
|
||||
/** @type {AbortSignal} */
|
||||
let derivedSignal = null;
|
||||
/** @type {() => void} */
|
||||
let abortHandler = null;
|
||||
|
||||
try {
|
||||
if (runnableConfig?.signal) {
|
||||
derivedSignal = AbortSignal.any([runnableConfig.signal]);
|
||||
abortHandler = createAbortHandler();
|
||||
derivedSignal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
const derivedSignal = runnableConfig?.signal
|
||||
? AbortSignal.any([runnableConfig.signal])
|
||||
: undefined;
|
||||
|
||||
/** @type {import('axios').AxiosRequestConfig} */
|
||||
const axiosConfig = {
|
||||
@@ -453,19 +422,6 @@ Error Message: ${error.message}`);
|
||||
baseURL,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
try {
|
||||
const url = new URL(process.env.PROXY);
|
||||
axiosConfig.proxy = {
|
||||
host: url.hostname.replace(/^\[|\]$/g, ''),
|
||||
port: url.port ? parseInt(url.port, 10) : undefined,
|
||||
protocol: url.protocol.replace(':', ''),
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error parsing proxy URL:', error);
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.IMAGE_GEN_OAI_AZURE_API_VERSION && process.env.IMAGE_GEN_OAI_BASEURL) {
|
||||
axiosConfig.params = {
|
||||
'api-version': process.env.IMAGE_GEN_OAI_AZURE_API_VERSION,
|
||||
@@ -511,10 +467,6 @@ Error Message: ${error.message}`);
|
||||
logAxiosError({ error, message });
|
||||
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
|
||||
Error Message: ${error.message || 'Unknown error'}`);
|
||||
} finally {
|
||||
if (abortHandler && derivedSignal) {
|
||||
derivedSignal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
const DALLE3 = require('../DALLE3');
|
||||
const { ProxyAgent } = require('undici');
|
||||
|
||||
const processFileURL = jest.fn();
|
||||
|
||||
jest.mock('~/server/services/Files/images', () => ({
|
||||
getImageBasename: jest.fn().mockImplementation((url) => {
|
||||
const parts = url.split('/');
|
||||
const lastPart = parts.pop();
|
||||
const imageExtensionRegex = /\.(jpg|jpeg|png|gif|bmp|tiff|svg)$/i;
|
||||
if (imageExtensionRegex.test(lastPart)) {
|
||||
return lastPart;
|
||||
}
|
||||
return '';
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('fs', () => {
|
||||
return {
|
||||
existsSync: jest.fn(),
|
||||
mkdirSync: jest.fn(),
|
||||
promises: {
|
||||
writeFile: jest.fn(),
|
||||
readFile: jest.fn(),
|
||||
unlink: jest.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('path', () => {
|
||||
return {
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn(),
|
||||
relative: jest.fn(),
|
||||
extname: jest.fn().mockImplementation((filename) => {
|
||||
return filename.slice(filename.lastIndexOf('.'));
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe('DALLE3 Proxy Configuration', () => {
|
||||
let originalEnv;
|
||||
|
||||
beforeAll(() => {
|
||||
originalEnv = { ...process.env };
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should configure ProxyAgent in fetchOptions.dispatcher when PROXY env is set', () => {
|
||||
// Set proxy environment variable
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
process.env.DALLE_API_KEY = 'test-api-key';
|
||||
|
||||
// Create instance
|
||||
const dalleWithProxy = new DALLE3({ processFileURL });
|
||||
|
||||
// Check that the openai client exists
|
||||
expect(dalleWithProxy.openai).toBeDefined();
|
||||
|
||||
// Check that _options exists and has fetchOptions with a dispatcher
|
||||
expect(dalleWithProxy.openai._options).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions.dispatcher).toBeDefined();
|
||||
expect(dalleWithProxy.openai._options.fetchOptions.dispatcher).toBeInstanceOf(ProxyAgent);
|
||||
});
|
||||
|
||||
it('should not configure ProxyAgent when PROXY env is not set', () => {
|
||||
// Ensure PROXY is not set
|
||||
delete process.env.PROXY;
|
||||
process.env.DALLE_API_KEY = 'test-api-key';
|
||||
|
||||
// Create instance
|
||||
const dalleWithoutProxy = new DALLE3({ processFileURL });
|
||||
|
||||
// Check that the openai client exists
|
||||
expect(dalleWithoutProxy.openai).toBeDefined();
|
||||
|
||||
// Check that _options exists but fetchOptions either doesn't exist or doesn't have a dispatcher
|
||||
expect(dalleWithoutProxy.openai._options).toBeDefined();
|
||||
|
||||
// fetchOptions should either not exist or not have a dispatcher
|
||||
if (dalleWithoutProxy.openai._options.fetchOptions) {
|
||||
expect(dalleWithoutProxy.openai._options.fetchOptions.dispatcher).toBeUndefined();
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,35 +1,26 @@
|
||||
const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, EToolResources } = require('librechat-data-provider');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Object} options
|
||||
* @param {ServerRequest} options.req
|
||||
* @param {Agent['tool_resources']} options.tool_resources
|
||||
* @param {string} [options.agentId] - The agent ID for file access control
|
||||
* @returns {Promise<{
|
||||
* files: Array<{ file_id: string; filename: string }>,
|
||||
* toolContext: string
|
||||
* }>}
|
||||
*/
|
||||
const primeFiles = async (options) => {
|
||||
const { tool_resources, req, agentId } = options;
|
||||
const { tool_resources } = options;
|
||||
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
|
||||
|
||||
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
|
||||
|
||||
@@ -68,7 +59,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
if (files.length === 0) {
|
||||
return 'No files to search. Instruct the user to add files for the search.';
|
||||
}
|
||||
const jwtToken = generateShortLivedToken(req.user.id);
|
||||
const jwtToken = req.headers.authorization.split(' ')[1];
|
||||
if (!jwtToken) {
|
||||
return 'There was an error authenticating the file search request.';
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
const { mcpToolPattern } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const {
|
||||
Tools,
|
||||
EToolResources,
|
||||
loadWebSearchAuth,
|
||||
replaceSpecialVars,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
@@ -230,7 +235,7 @@ const loadTools = async ({
|
||||
|
||||
/** @type {Record<string, string>} */
|
||||
const toolContextMap = {};
|
||||
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
||||
const appTools = (await getCachedTools({ includeGlobal: true })) ?? {};
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool === Tools.execute_code) {
|
||||
@@ -240,13 +245,7 @@ const loadTools = async ({
|
||||
authFields: [EnvVar.CODE_API_KEY],
|
||||
});
|
||||
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
|
||||
const { files, toolContext } = await primeCodeFiles(
|
||||
{
|
||||
...options,
|
||||
agentId: agent?.id,
|
||||
},
|
||||
codeApiKey,
|
||||
);
|
||||
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
@@ -261,10 +260,7 @@ const loadTools = async ({
|
||||
continue;
|
||||
} else if (tool === Tools.file_search) {
|
||||
requestedTools[tool] = async () => {
|
||||
const { files, toolContext } = await primeSearchFiles({
|
||||
...options,
|
||||
agentId: agent?.id,
|
||||
});
|
||||
const { files, toolContext } = await primeSearchFiles(options);
|
||||
if (toolContext) {
|
||||
toolContextMap[tool] = toolContext;
|
||||
}
|
||||
@@ -298,7 +294,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
});
|
||||
};
|
||||
continue;
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
} else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
|
||||
requestedTools[tool] = async () =>
|
||||
createMCPTool({
|
||||
req: options.req,
|
||||
|
||||
3
api/cache/banViolation.js
vendored
3
api/cache/banViolation.js
vendored
@@ -1,8 +1,7 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, math } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, math, removePorts } = require('~/server/utils');
|
||||
const { deleteAllUserSessions } = require('~/models');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const getLogStores = require('./getLogStores');
|
||||
|
||||
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
||||
|
||||
62
api/cache/cacheConfig.js
vendored
62
api/cache/cacheConfig.js
vendored
@@ -1,62 +0,0 @@
|
||||
const fs = require('fs');
|
||||
const { math, isEnabled } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
|
||||
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
|
||||
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
|
||||
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
|
||||
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
|
||||
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
|
||||
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
|
||||
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
|
||||
}
|
||||
|
||||
const USE_REDIS = isEnabled(process.env.USE_REDIS);
|
||||
if (USE_REDIS && !process.env.REDIS_URI) {
|
||||
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
|
||||
}
|
||||
|
||||
// Comma-separated list of cache namespaces that should be forced to use in-memory storage
|
||||
// even when Redis is enabled. This allows selective performance optimization for specific caches.
|
||||
const FORCED_IN_MEMORY_CACHE_NAMESPACES = process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES
|
||||
? process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES.split(',').map((key) => key.trim())
|
||||
: [];
|
||||
|
||||
// Validate against CacheKeys enum
|
||||
if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {
|
||||
const validKeys = Object.values(CacheKeys);
|
||||
const invalidKeys = FORCED_IN_MEMORY_CACHE_NAMESPACES.filter((key) => !validKeys.includes(key));
|
||||
|
||||
if (invalidKeys.length > 0) {
|
||||
throw new Error(
|
||||
`Invalid cache keys in FORCED_IN_MEMORY_CACHE_NAMESPACES: ${invalidKeys.join(', ')}. Valid keys: ${validKeys.join(', ')}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const cacheConfig = {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES,
|
||||
USE_REDIS,
|
||||
REDIS_URI: process.env.REDIS_URI,
|
||||
REDIS_USERNAME: process.env.REDIS_USERNAME,
|
||||
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
|
||||
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
|
||||
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
|
||||
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
|
||||
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),
|
||||
/** Max delay between reconnection attempts in ms */
|
||||
REDIS_RETRY_MAX_DELAY: math(process.env.REDIS_RETRY_MAX_DELAY, 3000),
|
||||
/** Max number of reconnection attempts (0 = infinite) */
|
||||
REDIS_RETRY_MAX_ATTEMPTS: math(process.env.REDIS_RETRY_MAX_ATTEMPTS, 10),
|
||||
/** Connection timeout in ms */
|
||||
REDIS_CONNECT_TIMEOUT: math(process.env.REDIS_CONNECT_TIMEOUT, 10000),
|
||||
/** Queue commands when disconnected */
|
||||
REDIS_ENABLE_OFFLINE_QUEUE: isEnabled(process.env.REDIS_ENABLE_OFFLINE_QUEUE ?? 'true'),
|
||||
|
||||
CI: isEnabled(process.env.CI),
|
||||
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
|
||||
|
||||
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
|
||||
};
|
||||
|
||||
module.exports = { cacheConfig };
|
||||
157
api/cache/cacheConfig.spec.js
vendored
157
api/cache/cacheConfig.spec.js
vendored
@@ -1,157 +0,0 @@
|
||||
const fs = require('fs');
|
||||
|
||||
describe('cacheConfig', () => {
|
||||
let originalEnv;
|
||||
let originalReadFileSync;
|
||||
|
||||
beforeEach(() => {
|
||||
originalEnv = { ...process.env };
|
||||
originalReadFileSync = fs.readFileSync;
|
||||
|
||||
// Clear all related env vars first
|
||||
delete process.env.REDIS_URI;
|
||||
delete process.env.REDIS_CA;
|
||||
delete process.env.REDIS_KEY_PREFIX_VAR;
|
||||
delete process.env.REDIS_KEY_PREFIX;
|
||||
delete process.env.USE_REDIS;
|
||||
delete process.env.REDIS_PING_INTERVAL;
|
||||
delete process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES;
|
||||
|
||||
// Clear require cache
|
||||
jest.resetModules();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
fs.readFileSync = originalReadFileSync;
|
||||
jest.resetModules();
|
||||
});
|
||||
|
||||
describe('REDIS_KEY_PREFIX validation and resolution', () => {
|
||||
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
|
||||
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
|
||||
});
|
||||
|
||||
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
|
||||
process.env.DEPLOYMENT_ID = 'test-deployment-123';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
|
||||
});
|
||||
|
||||
test('should use direct REDIS_KEY_PREFIX value', () => {
|
||||
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
|
||||
});
|
||||
|
||||
test('should default to empty string when no prefix is configured', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
|
||||
});
|
||||
|
||||
test('should handle empty variable reference', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
|
||||
process.env.EMPTY_VAR = '';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
|
||||
});
|
||||
|
||||
test('should handle undefined variable reference', () => {
|
||||
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('USE_REDIS and REDIS_URI validation', () => {
|
||||
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
|
||||
process.env.USE_REDIS = 'true';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
|
||||
});
|
||||
|
||||
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = 'redis://localhost:6379';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = '';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('REDIS_CA file reading', () => {
|
||||
test('should be null when REDIS_CA is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_CA).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('REDIS_PING_INTERVAL configuration', () => {
|
||||
test('should default to 0 when REDIS_PING_INTERVAL is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_PING_INTERVAL).toBe(0);
|
||||
});
|
||||
|
||||
test('should use provided REDIS_PING_INTERVAL value', () => {
|
||||
process.env.REDIS_PING_INTERVAL = '300';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.REDIS_PING_INTERVAL).toBe(300);
|
||||
});
|
||||
});
|
||||
|
||||
describe('FORCED_IN_MEMORY_CACHE_NAMESPACES validation', () => {
|
||||
test('should parse comma-separated cache keys correctly', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = ' ROLES, STATIC_CONFIG ,MESSAGES ';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([
|
||||
'ROLES',
|
||||
'STATIC_CONFIG',
|
||||
'MESSAGES',
|
||||
]);
|
||||
});
|
||||
|
||||
test('should throw error for invalid cache keys', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = 'INVALID_KEY,ROLES';
|
||||
|
||||
expect(() => {
|
||||
require('./cacheConfig');
|
||||
}).toThrow('Invalid cache keys in FORCED_IN_MEMORY_CACHE_NAMESPACES: INVALID_KEY');
|
||||
});
|
||||
|
||||
test('should handle empty string gracefully', () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = '';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
|
||||
});
|
||||
|
||||
test('should handle undefined env var gracefully', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
108
api/cache/cacheFactory.js
vendored
108
api/cache/cacheFactory.js
vendored
@@ -1,108 +0,0 @@
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
const { Keyv } = require('keyv');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { RedisStore: ConnectRedis } = require('connect-redis');
|
||||
const MemoryStore = require('memorystore')(require('express-session'));
|
||||
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
const { violationFile } = require('./keyvFiles');
|
||||
|
||||
/**
|
||||
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
|
||||
* @param {string} namespace - The cache namespace.
|
||||
* @param {number} [ttl] - Time to live for cache entries.
|
||||
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
|
||||
* @returns {Keyv} Cache instance.
|
||||
*/
|
||||
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
|
||||
if (
|
||||
cacheConfig.USE_REDIS &&
|
||||
!cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES?.includes(namespace)
|
||||
) {
|
||||
try {
|
||||
const keyvRedis = new KeyvRedis(keyvRedisClient);
|
||||
const cache = new Keyv(keyvRedis, { namespace, ttl });
|
||||
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
|
||||
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
|
||||
|
||||
cache.on('error', (err) => {
|
||||
logger.error(`Cache error in namespace ${namespace}:`, err);
|
||||
});
|
||||
|
||||
return cache;
|
||||
} catch (err) {
|
||||
logger.error(`Failed to create Redis cache for namespace ${namespace}:`, err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
|
||||
return new Keyv({ namespace, ttl });
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a cache instance for storing violation data.
|
||||
* Uses a file-based fallback store if Redis is not enabled.
|
||||
* @param {string} namespace - The cache namespace for violations.
|
||||
* @param {number} [ttl] - Time to live for cache entries.
|
||||
* @returns {Keyv} Cache instance for violations.
|
||||
*/
|
||||
const violationCache = (namespace, ttl = undefined) => {
|
||||
return standardCache(`violations:${namespace}`, ttl, violationFile);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a session cache instance using Redis or in-memory store.
|
||||
* @param {string} namespace - The session namespace.
|
||||
* @param {number} [ttl] - Time to live for session entries.
|
||||
* @returns {MemoryStore | ConnectRedis} Session store instance.
|
||||
*/
|
||||
const sessionCache = (namespace, ttl = undefined) => {
|
||||
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
|
||||
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
|
||||
const store = new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
|
||||
if (ioredisClient) {
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error(`Session store Redis error for namespace ${namespace}:`, err);
|
||||
});
|
||||
}
|
||||
return store;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a rate limiter cache using Redis.
|
||||
* @param {string} prefix - The key prefix for rate limiting.
|
||||
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
|
||||
*/
|
||||
const limiterCache = (prefix) => {
|
||||
if (!prefix) throw new Error('prefix is required');
|
||||
if (!cacheConfig.USE_REDIS) return undefined;
|
||||
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
|
||||
|
||||
try {
|
||||
if (!ioredisClient) {
|
||||
logger.warn(`Redis client not available for rate limiter with prefix ${prefix}`);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return new RedisStore({ sendCommand, prefix });
|
||||
} catch (err) {
|
||||
logger.error(`Failed to create Redis rate limiter for prefix ${prefix}:`, err);
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const sendCommand = (...args) => {
|
||||
if (!ioredisClient) {
|
||||
logger.warn('Redis client not available for command execution');
|
||||
return Promise.reject(new Error('Redis client not available'));
|
||||
}
|
||||
|
||||
return ioredisClient.call(...args).catch((err) => {
|
||||
logger.error('Redis command execution failed:', err);
|
||||
throw err;
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = { standardCache, sessionCache, violationCache, limiterCache };
|
||||
432
api/cache/cacheFactory.spec.js
vendored
432
api/cache/cacheFactory.spec.js
vendored
@@ -1,432 +0,0 @@
|
||||
const { Time } = require('librechat-data-provider');
|
||||
|
||||
// Mock dependencies first
|
||||
const mockKeyvRedis = {
|
||||
namespace: '',
|
||||
keyPrefixSeparator: '',
|
||||
};
|
||||
|
||||
const mockKeyv = jest.fn().mockReturnValue({
|
||||
mock: 'keyv',
|
||||
on: jest.fn(),
|
||||
});
|
||||
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
|
||||
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
|
||||
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
|
||||
|
||||
const mockIoredisClient = {
|
||||
call: jest.fn(),
|
||||
on: jest.fn(),
|
||||
};
|
||||
|
||||
const mockKeyvRedisClient = {};
|
||||
const mockViolationFile = {};
|
||||
|
||||
// Mock modules before requiring the main module
|
||||
jest.mock('@keyv/redis', () => ({
|
||||
default: jest.fn().mockImplementation(() => mockKeyvRedis),
|
||||
}));
|
||||
|
||||
jest.mock('keyv', () => ({
|
||||
Keyv: mockKeyv,
|
||||
}));
|
||||
|
||||
jest.mock('./cacheConfig', () => ({
|
||||
cacheConfig: {
|
||||
USE_REDIS: false,
|
||||
REDIS_KEY_PREFIX: 'test',
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES: [],
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('./redisClients', () => ({
|
||||
keyvRedisClient: mockKeyvRedisClient,
|
||||
ioredisClient: mockIoredisClient,
|
||||
GLOBAL_PREFIX_SEPARATOR: '::',
|
||||
}));
|
||||
|
||||
jest.mock('./keyvFiles', () => ({
|
||||
violationFile: mockViolationFile,
|
||||
}));
|
||||
|
||||
jest.mock('connect-redis', () => ({ RedisStore: mockConnectRedis }));
|
||||
|
||||
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
|
||||
|
||||
jest.mock('rate-limit-redis', () => ({
|
||||
RedisStore: mockRedisStore,
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
info: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocking
|
||||
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
|
||||
describe('cacheFactory', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Reset cache config mock
|
||||
cacheConfig.USE_REDIS = false;
|
||||
cacheConfig.REDIS_KEY_PREFIX = 'test';
|
||||
cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES = [];
|
||||
});
|
||||
|
||||
describe('redisCache', () => {
|
||||
it('should create Redis cache when USE_REDIS is true', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
|
||||
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
|
||||
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
|
||||
});
|
||||
|
||||
it('should create Redis cache with undefined ttl when not provided', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
|
||||
standardCache(namespace);
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
|
||||
});
|
||||
|
||||
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
const fallbackStore = { some: 'store' };
|
||||
|
||||
standardCache(namespace, ttl, fallbackStore);
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
|
||||
});
|
||||
|
||||
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
|
||||
});
|
||||
|
||||
it('should handle namespace and ttl as undefined', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
|
||||
standardCache();
|
||||
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
|
||||
});
|
||||
|
||||
it('should use fallback when namespace is in FORCED_IN_MEMORY_CACHE_NAMESPACES', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES = ['forced-memory'];
|
||||
const namespace = 'forced-memory';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(require('@keyv/redis').default).not.toHaveBeenCalled();
|
||||
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
|
||||
});
|
||||
|
||||
it('should use Redis when namespace is not in FORCED_IN_MEMORY_CACHE_NAMESPACES', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES = ['other-namespace'];
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
|
||||
standardCache(namespace, ttl);
|
||||
|
||||
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
|
||||
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
|
||||
});
|
||||
|
||||
it('should throw error when Redis cache creation fails', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'test-namespace';
|
||||
const ttl = 3600;
|
||||
const testError = new Error('Redis connection failed');
|
||||
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
KeyvRedis.mockImplementationOnce(() => {
|
||||
throw testError;
|
||||
});
|
||||
|
||||
expect(() => standardCache(namespace, ttl)).toThrow('Redis connection failed');
|
||||
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
`Failed to create Redis cache for namespace ${namespace}:`,
|
||||
testError,
|
||||
);
|
||||
|
||||
expect(mockKeyv).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('violationCache', () => {
|
||||
it('should create violation cache with prefixed namespace', () => {
|
||||
const namespace = 'test-violations';
|
||||
const ttl = 7200;
|
||||
|
||||
// We can't easily mock the internal redisCache call since it's in the same module
|
||||
// But we can test that the function executes without throwing
|
||||
expect(() => violationCache(namespace, ttl)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should create violation cache with undefined ttl', () => {
|
||||
const namespace = 'test-violations';
|
||||
|
||||
violationCache(namespace);
|
||||
|
||||
// The function should call redisCache with violations: prefixed namespace
|
||||
// Since we can't easily mock the internal redisCache call, we test the behavior
|
||||
expect(() => violationCache(namespace)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle undefined namespace', () => {
|
||||
expect(() => violationCache(undefined)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sessionCache', () => {
|
||||
it('should return MemoryStore when USE_REDIS is false', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
const result = sessionCache(namespace, ttl);
|
||||
|
||||
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
|
||||
expect(result).toBe(mockMemoryStore());
|
||||
});
|
||||
|
||||
it('should return ConnectRedis when USE_REDIS is true', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
const result = sessionCache(namespace, ttl);
|
||||
|
||||
expect(mockConnectRedis).toHaveBeenCalledWith({
|
||||
client: mockIoredisClient,
|
||||
ttl,
|
||||
prefix: `${namespace}:`,
|
||||
});
|
||||
expect(result).toBe(mockConnectRedis());
|
||||
});
|
||||
|
||||
it('should add colon to namespace if not present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
sessionCache(namespace);
|
||||
|
||||
expect(mockConnectRedis).toHaveBeenCalledWith({
|
||||
client: mockIoredisClient,
|
||||
ttl: undefined,
|
||||
prefix: 'sessions:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add colon to namespace if already present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions:';
|
||||
|
||||
sessionCache(namespace);
|
||||
|
||||
expect(mockConnectRedis).toHaveBeenCalledWith({
|
||||
client: mockIoredisClient,
|
||||
ttl: undefined,
|
||||
prefix: 'sessions:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle undefined ttl', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const namespace = 'sessions';
|
||||
|
||||
sessionCache(namespace);
|
||||
|
||||
expect(mockMemoryStore).toHaveBeenCalledWith({
|
||||
ttl: undefined,
|
||||
checkPeriod: Time.ONE_DAY,
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error when ConnectRedis constructor fails', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
const ttl = 86400;
|
||||
|
||||
// Mock ConnectRedis to throw an error during construction
|
||||
const redisError = new Error('Redis connection failed');
|
||||
mockConnectRedis.mockImplementationOnce(() => {
|
||||
throw redisError;
|
||||
});
|
||||
|
||||
// The error should propagate up, not be caught
|
||||
expect(() => sessionCache(namespace, ttl)).toThrow('Redis connection failed');
|
||||
|
||||
// Verify that MemoryStore was NOT used as fallback
|
||||
expect(mockMemoryStore).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should register error handler but let errors propagate to Express', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
// Create a mock session store with middleware methods
|
||||
const mockSessionStore = {
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
};
|
||||
mockConnectRedis.mockReturnValue(mockSessionStore);
|
||||
|
||||
const store = sessionCache(namespace);
|
||||
|
||||
// Verify error handler was registered
|
||||
expect(mockIoredisClient.on).toHaveBeenCalledWith('error', expect.any(Function));
|
||||
|
||||
// Get the error handler
|
||||
const errorHandler = mockIoredisClient.on.mock.calls.find((call) => call[0] === 'error')[1];
|
||||
|
||||
// Simulate an error from Redis during a session operation
|
||||
const redisError = new Error('Socket closed unexpectedly');
|
||||
|
||||
// The error handler should log but not swallow the error
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
errorHandler(redisError);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
`Session store Redis error for namespace ${namespace}::`,
|
||||
redisError,
|
||||
);
|
||||
|
||||
// Now simulate what happens when session middleware tries to use the store
|
||||
const callback = jest.fn();
|
||||
mockSessionStore.get.mockImplementation((sid, cb) => {
|
||||
cb(new Error('Redis connection lost'));
|
||||
});
|
||||
|
||||
// Call the store's get method (as Express session would)
|
||||
store.get('test-session-id', callback);
|
||||
|
||||
// The error should be passed to the callback, not swallowed
|
||||
expect(callback).toHaveBeenCalledWith(new Error('Redis connection lost'));
|
||||
});
|
||||
|
||||
it('should handle null ioredisClient gracefully', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const namespace = 'sessions';
|
||||
|
||||
// Temporarily set ioredisClient to null (simulating connection not established)
|
||||
const originalClient = require('./redisClients').ioredisClient;
|
||||
require('./redisClients').ioredisClient = null;
|
||||
|
||||
// ConnectRedis might accept null client but would fail on first use
|
||||
// The important thing is it doesn't throw uncaught exceptions during construction
|
||||
const store = sessionCache(namespace);
|
||||
expect(store).toBeDefined();
|
||||
|
||||
// Restore original client
|
||||
require('./redisClients').ioredisClient = originalClient;
|
||||
});
|
||||
});
|
||||
|
||||
describe('limiterCache', () => {
|
||||
it('should return undefined when USE_REDIS is false', () => {
|
||||
cacheConfig.USE_REDIS = false;
|
||||
const result = limiterCache('prefix');
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return RedisStore when USE_REDIS is true', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
const result = limiterCache('rate-limit');
|
||||
|
||||
expect(mockRedisStore).toHaveBeenCalledWith({
|
||||
sendCommand: expect.any(Function),
|
||||
prefix: `rate-limit:`,
|
||||
});
|
||||
expect(result).toBe(mockRedisStore());
|
||||
});
|
||||
|
||||
it('should add colon to prefix if not present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
limiterCache('rate-limit');
|
||||
|
||||
expect(mockRedisStore).toHaveBeenCalledWith({
|
||||
sendCommand: expect.any(Function),
|
||||
prefix: 'rate-limit:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add colon to prefix if already present', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
limiterCache('rate-limit:');
|
||||
|
||||
expect(mockRedisStore).toHaveBeenCalledWith({
|
||||
sendCommand: expect.any(Function),
|
||||
prefix: 'rate-limit:',
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass sendCommand function that calls ioredisClient.call', async () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
mockIoredisClient.call.mockResolvedValue('test-value');
|
||||
|
||||
limiterCache('rate-limit');
|
||||
|
||||
const sendCommandCall = mockRedisStore.mock.calls[0][0];
|
||||
const sendCommand = sendCommandCall.sendCommand;
|
||||
|
||||
// Test that sendCommand properly delegates to ioredisClient.call
|
||||
const args = ['GET', 'test-key'];
|
||||
const result = await sendCommand(...args);
|
||||
|
||||
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
|
||||
expect(result).toBe('test-value');
|
||||
});
|
||||
|
||||
it('should handle sendCommand errors properly', async () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Mock the call method to reject with an error
|
||||
const testError = new Error('Redis error');
|
||||
mockIoredisClient.call.mockRejectedValue(testError);
|
||||
|
||||
limiterCache('rate-limit');
|
||||
|
||||
const sendCommandCall = mockRedisStore.mock.calls[0][0];
|
||||
const sendCommand = sendCommandCall.sendCommand;
|
||||
|
||||
// Test that sendCommand properly handles errors
|
||||
const args = ['GET', 'test-key'];
|
||||
|
||||
await expect(sendCommand(...args)).rejects.toThrow('Redis error');
|
||||
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
|
||||
});
|
||||
|
||||
it('should handle undefined prefix', () => {
|
||||
cacheConfig.USE_REDIS = true;
|
||||
expect(() => limiterCache()).toThrow('prefix is required');
|
||||
});
|
||||
});
|
||||
});
|
||||
165
api/cache/getLogStores.js
vendored
165
api/cache/getLogStores.js
vendored
@@ -1,53 +1,113 @@
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
const { Keyv } = require('keyv');
|
||||
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||
const { logFile } = require('./keyvFiles');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const { isEnabled, math } = require('~/server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
const isRedisEnabled = isEnabled(USE_REDIS);
|
||||
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
|
||||
|
||||
const createViolationInstance = (namespace) => {
|
||||
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
|
||||
return new Keyv(config);
|
||||
};
|
||||
|
||||
// Serve cache from memory so no need to clear it on startup/exit
|
||||
const pending_req = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
|
||||
|
||||
const config = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const roles = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||
|
||||
const mcpTools = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MCP_TOOLS });
|
||||
|
||||
const audioRuns = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const messages = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
|
||||
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
|
||||
|
||||
const flows = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
|
||||
|
||||
const tokenConfig = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const genTitle = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
const s3ExpiryInterval = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
|
||||
|
||||
const abortKeys = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const openIdExchangedTokensCache = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.OPENID_EXCHANGED_TOKENS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const namespaces = {
|
||||
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }),
|
||||
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS),
|
||||
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT),
|
||||
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER),
|
||||
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT),
|
||||
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS),
|
||||
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
|
||||
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
|
||||
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
|
||||
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
|
||||
[ViolationTypes.BAN]: new Keyv({
|
||||
[CacheKeys.ROLES]: roles,
|
||||
[CacheKeys.MCP_TOOLS]: mcpTools,
|
||||
[CacheKeys.CONFIG_STORE]: config,
|
||||
[CacheKeys.PENDING_REQ]: pending_req,
|
||||
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
||||
[CacheKeys.ENCODED_DOMAINS]: new Keyv({
|
||||
store: keyvMongo,
|
||||
namespace: CacheKeys.BANS,
|
||||
ttl: cacheConfig.BAN_DURATION,
|
||||
namespace: CacheKeys.ENCODED_DOMAINS,
|
||||
ttl: 0,
|
||||
}),
|
||||
|
||||
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION),
|
||||
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
|
||||
|
||||
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
|
||||
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
|
||||
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
|
||||
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
|
||||
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
|
||||
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
|
||||
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
|
||||
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES),
|
||||
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES),
|
||||
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES),
|
||||
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
|
||||
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
|
||||
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
|
||||
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
|
||||
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
|
||||
CacheKeys.OPENID_EXCHANGED_TOKENS,
|
||||
Time.TEN_MINUTES,
|
||||
general: new Keyv({ store: logFile, namespace: 'violations' }),
|
||||
concurrent: createViolationInstance('concurrent'),
|
||||
non_browser: createViolationInstance('non_browser'),
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
|
||||
[ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
|
||||
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
|
||||
ViolationTypes.RESET_PASSWORD_LIMIT,
|
||||
),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
),
|
||||
logins: createViolationInstance('logins'),
|
||||
[CacheKeys.ABORT_KEYS]: abortKeys,
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
[CacheKeys.FLOWS]: flows,
|
||||
[CacheKeys.OPENID_EXCHANGED_TOKENS]: openIdExchangedTokensCache,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -56,10 +116,7 @@ const namespaces = {
|
||||
*/
|
||||
function getTTLStores() {
|
||||
return Object.values(namespaces).filter(
|
||||
(store) =>
|
||||
store instanceof Keyv &&
|
||||
parseInt(store.opts?.ttl ?? '0') > 0 &&
|
||||
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
|
||||
(store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -95,18 +152,18 @@ async function clearExpiredFromCache(cache) {
|
||||
if (data?.expires && data.expires <= expiryTime) {
|
||||
const deleted = await cache.opts.store.delete(key);
|
||||
if (!deleted) {
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
debugMemoryCache &&
|
||||
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
|
||||
continue;
|
||||
}
|
||||
cleared++;
|
||||
}
|
||||
} catch (error) {
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
debugMemoryCache &&
|
||||
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
|
||||
const deleted = await cache.opts.store.delete(key);
|
||||
if (!deleted) {
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
debugMemoryCache &&
|
||||
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
|
||||
continue;
|
||||
}
|
||||
@@ -115,7 +172,7 @@ async function clearExpiredFromCache(cache) {
|
||||
}
|
||||
|
||||
if (cleared > 0) {
|
||||
cacheConfig.DEBUG_MEMORY_CACHE &&
|
||||
debugMemoryCache &&
|
||||
console.log(
|
||||
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
|
||||
);
|
||||
@@ -156,7 +213,7 @@ async function clearAllExpiredFromCache() {
|
||||
}
|
||||
}
|
||||
|
||||
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
|
||||
if (!isRedisEnabled && !isEnabled(CI)) {
|
||||
/** @type {Set<NodeJS.Timeout>} */
|
||||
const cleanupIntervals = new Set();
|
||||
|
||||
@@ -167,7 +224,7 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
|
||||
|
||||
cleanupIntervals.add(cleanup);
|
||||
|
||||
if (cacheConfig.DEBUG_MEMORY_CACHE) {
|
||||
if (debugMemoryCache) {
|
||||
const monitor = setInterval(() => {
|
||||
const ttlStores = getTTLStores();
|
||||
const memory = process.memoryUsage();
|
||||
@@ -188,13 +245,13 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
|
||||
}
|
||||
|
||||
const dispose = () => {
|
||||
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...');
|
||||
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
|
||||
cleanupIntervals.forEach((interval) => clearInterval(interval));
|
||||
cleanupIntervals.clear();
|
||||
|
||||
// One final cleanup before exit
|
||||
clearAllExpiredFromCache().then(() => {
|
||||
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed');
|
||||
debugMemoryCache && console.log('[Cache] Final cleanup completed');
|
||||
process.exit(0);
|
||||
});
|
||||
};
|
||||
|
||||
92
api/cache/ioredisClient.js
vendored
Normal file
92
api/cache/ioredisClient.js
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
const fs = require('fs');
|
||||
const Redis = require('ioredis');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
|
||||
|
||||
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
|
||||
let ioredisClient;
|
||||
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
|
||||
|
||||
function mapURI(uri) {
|
||||
const regex =
|
||||
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
|
||||
const match = uri.match(regex);
|
||||
|
||||
if (match) {
|
||||
const { scheme, user, password, host, port } = match.groups;
|
||||
|
||||
return {
|
||||
scheme: scheme || 'none',
|
||||
user: user || null,
|
||||
password: password || null,
|
||||
host: host || null,
|
||||
port: port || null,
|
||||
};
|
||||
} else {
|
||||
const parts = uri.split(':');
|
||||
if (parts.length === 2) {
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: parts[0],
|
||||
port: parts[1],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: uri,
|
||||
port: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (REDIS_URI && isEnabled(USE_REDIS)) {
|
||||
let redisOptions = null;
|
||||
|
||||
if (REDIS_CA) {
|
||||
const ca = fs.readFileSync(REDIS_CA);
|
||||
redisOptions = { tls: { ca } };
|
||||
}
|
||||
|
||||
if (isEnabled(USE_REDIS_CLUSTER)) {
|
||||
const hosts = REDIS_URI.split(',').map((item) => {
|
||||
var value = mapURI(item);
|
||||
|
||||
return {
|
||||
host: value.host,
|
||||
port: value.port,
|
||||
};
|
||||
});
|
||||
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
|
||||
} else {
|
||||
ioredisClient = new Redis(REDIS_URI, redisOptions);
|
||||
}
|
||||
|
||||
ioredisClient.on('ready', () => {
|
||||
logger.info('IoRedis connection ready');
|
||||
});
|
||||
ioredisClient.on('reconnecting', () => {
|
||||
logger.info('IoRedis connection reconnecting');
|
||||
});
|
||||
ioredisClient.on('end', () => {
|
||||
logger.info('IoRedis connection ended');
|
||||
});
|
||||
ioredisClient.on('close', () => {
|
||||
logger.info('IoRedis connection closed');
|
||||
});
|
||||
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
|
||||
ioredisClient.setMaxListeners(redis_max_listeners);
|
||||
logger.info(
|
||||
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
|
||||
);
|
||||
} else {
|
||||
logger.info('[Optional] IoRedis not initialized for rate limiters.');
|
||||
}
|
||||
|
||||
module.exports = ioredisClient;
|
||||
109
api/cache/keyvRedis.js
vendored
Normal file
109
api/cache/keyvRedis.js
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
const fs = require('fs');
|
||||
const ioredis = require('ioredis');
|
||||
const KeyvRedis = require('@keyv/redis').default;
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
|
||||
process.env;
|
||||
|
||||
let keyvRedis;
|
||||
const redis_prefix = REDIS_KEY_PREFIX || '';
|
||||
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
|
||||
|
||||
function mapURI(uri) {
|
||||
const regex =
|
||||
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
|
||||
const match = uri.match(regex);
|
||||
|
||||
if (match) {
|
||||
const { scheme, user, password, host, port } = match.groups;
|
||||
|
||||
return {
|
||||
scheme: scheme || 'none',
|
||||
user: user || null,
|
||||
password: password || null,
|
||||
host: host || null,
|
||||
port: port || null,
|
||||
};
|
||||
} else {
|
||||
const parts = uri.split(':');
|
||||
if (parts.length === 2) {
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: parts[0],
|
||||
port: parts[1],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
scheme: 'none',
|
||||
user: null,
|
||||
password: null,
|
||||
host: uri,
|
||||
port: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (REDIS_URI && isEnabled(USE_REDIS)) {
|
||||
let redisOptions = null;
|
||||
/** @type {import('@keyv/redis').KeyvRedisOptions} */
|
||||
let keyvOpts = {
|
||||
useRedisSets: false,
|
||||
keyPrefix: redis_prefix,
|
||||
};
|
||||
|
||||
if (REDIS_CA) {
|
||||
const ca = fs.readFileSync(REDIS_CA);
|
||||
redisOptions = { tls: { ca } };
|
||||
}
|
||||
|
||||
if (isEnabled(USE_REDIS_CLUSTER)) {
|
||||
const hosts = REDIS_URI.split(',').map((item) => {
|
||||
var value = mapURI(item);
|
||||
|
||||
return {
|
||||
host: value.host,
|
||||
port: value.port,
|
||||
};
|
||||
});
|
||||
const cluster = new ioredis.Cluster(hosts, { redisOptions });
|
||||
keyvRedis = new KeyvRedis(cluster, keyvOpts);
|
||||
} else {
|
||||
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
|
||||
}
|
||||
|
||||
const pingInterval = setInterval(
|
||||
() => {
|
||||
logger.debug('KeyvRedis ping');
|
||||
keyvRedis.client.ping().catch((err) => logger.error('Redis keep-alive ping failed:', err));
|
||||
},
|
||||
5 * 60 * 1000,
|
||||
);
|
||||
|
||||
keyvRedis.on('ready', () => {
|
||||
logger.info('KeyvRedis connection ready');
|
||||
});
|
||||
keyvRedis.on('reconnecting', () => {
|
||||
logger.info('KeyvRedis connection reconnecting');
|
||||
});
|
||||
keyvRedis.on('end', () => {
|
||||
logger.info('KeyvRedis connection ended');
|
||||
});
|
||||
keyvRedis.on('close', () => {
|
||||
clearInterval(pingInterval);
|
||||
logger.info('KeyvRedis connection closed');
|
||||
});
|
||||
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
|
||||
keyvRedis.setMaxListeners(redis_max_listeners);
|
||||
logger.info(
|
||||
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
|
||||
);
|
||||
} else {
|
||||
logger.info('[Optional] Redis not initialized.');
|
||||
}
|
||||
|
||||
module.exports = keyvRedis;
|
||||
5
api/cache/logViolation.js
vendored
5
api/cache/logViolation.js
vendored
@@ -1,5 +1,4 @@
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const getLogStores = require('./getLogStores');
|
||||
const banViolation = require('./banViolation');
|
||||
|
||||
@@ -10,14 +9,14 @@ const banViolation = require('./banViolation');
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {string} type - The type of violation.
|
||||
* @param {Object} errorMessage - The error message to log.
|
||||
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
|
||||
* @param {number} [score=1] - The severity of the violation. Defaults to 1
|
||||
*/
|
||||
const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||
const userId = req.user?.id ?? req.user?._id;
|
||||
if (!userId) {
|
||||
return;
|
||||
}
|
||||
const logs = getLogStores(ViolationTypes.GENERAL);
|
||||
const logs = getLogStores('general');
|
||||
const violationLogs = getLogStores(type);
|
||||
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;
|
||||
|
||||
|
||||
204
api/cache/redisClients.js
vendored
204
api/cache/redisClients.js
vendored
@@ -1,204 +0,0 @@
|
||||
const IoRedis = require('ioredis');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createClient, createCluster } = require('@keyv/redis');
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
|
||||
const GLOBAL_PREFIX_SEPARATOR = '::';
|
||||
|
||||
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
|
||||
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
|
||||
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
|
||||
const ca = cacheConfig.REDIS_CA;
|
||||
|
||||
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
|
||||
let ioredisClient = null;
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
/** @type {import('ioredis').RedisOptions | import('ioredis').ClusterOptions} */
|
||||
const redisOptions = {
|
||||
username: username,
|
||||
password: password,
|
||||
tls: ca ? { ca } : undefined,
|
||||
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
|
||||
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
|
||||
retryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
reconnectOnError: (err) => {
|
||||
const targetError = 'READONLY';
|
||||
if (err.message.includes(targetError)) {
|
||||
logger.warn('ioredis reconnecting due to READONLY error');
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
connectTimeout: cacheConfig.REDIS_CONNECT_TIMEOUT,
|
||||
maxRetriesPerRequest: 3,
|
||||
};
|
||||
|
||||
ioredisClient =
|
||||
urls.length === 1
|
||||
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
|
||||
: new IoRedis.Cluster(cacheConfig.REDIS_URI, {
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
});
|
||||
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error('ioredis client error:', err);
|
||||
});
|
||||
|
||||
ioredisClient.on('connect', () => {
|
||||
logger.info('ioredis client connected');
|
||||
});
|
||||
|
||||
ioredisClient.on('ready', () => {
|
||||
logger.info('ioredis client ready');
|
||||
});
|
||||
|
||||
ioredisClient.on('reconnecting', (delay) => {
|
||||
logger.info(`ioredis client reconnecting in ${delay}ms`);
|
||||
});
|
||||
|
||||
ioredisClient.on('close', () => {
|
||||
logger.warn('ioredis client connection closed');
|
||||
});
|
||||
|
||||
/** Ping Interval to keep the Redis server connection alive (if enabled) */
|
||||
let pingInterval = null;
|
||||
const clearPingInterval = () => {
|
||||
if (pingInterval) {
|
||||
clearInterval(pingInterval);
|
||||
pingInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
if (cacheConfig.REDIS_PING_INTERVAL > 0) {
|
||||
pingInterval = setInterval(() => {
|
||||
if (ioredisClient && ioredisClient.status === 'ready') {
|
||||
ioredisClient.ping().catch((err) => {
|
||||
logger.error('ioredis ping failed:', err);
|
||||
});
|
||||
}
|
||||
}, cacheConfig.REDIS_PING_INTERVAL * 1000);
|
||||
ioredisClient.on('close', clearPingInterval);
|
||||
ioredisClient.on('end', clearPingInterval);
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
|
||||
let keyvRedisClient = null;
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
/**
|
||||
* ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
|
||||
* The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
|
||||
* @type {import('@keyv/redis').RedisClientOptions | import('@keyv/redis').RedisClusterOptions}
|
||||
*/
|
||||
const redisOptions = {
|
||||
username,
|
||||
password,
|
||||
socket: {
|
||||
tls: ca != null,
|
||||
ca,
|
||||
connectTimeout: cacheConfig.REDIS_CONNECT_TIMEOUT,
|
||||
reconnectStrategy: (retries) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
retries > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`@keyv/redis client giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return new Error('Max reconnection attempts reached');
|
||||
}
|
||||
const delay = Math.min(retries * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`@keyv/redis reconnecting... attempt ${retries}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
},
|
||||
disableOfflineQueue: !cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
};
|
||||
|
||||
keyvRedisClient =
|
||||
urls.length === 1
|
||||
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
|
||||
: createCluster({
|
||||
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
|
||||
defaults: redisOptions,
|
||||
});
|
||||
|
||||
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
|
||||
|
||||
keyvRedisClient.on('error', (err) => {
|
||||
logger.error('@keyv/redis client error:', err);
|
||||
});
|
||||
|
||||
keyvRedisClient.on('connect', () => {
|
||||
logger.info('@keyv/redis client connected');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('ready', () => {
|
||||
logger.info('@keyv/redis client ready');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('reconnecting', () => {
|
||||
logger.info('@keyv/redis client reconnecting...');
|
||||
});
|
||||
|
||||
keyvRedisClient.on('disconnect', () => {
|
||||
logger.warn('@keyv/redis client disconnected');
|
||||
});
|
||||
|
||||
keyvRedisClient.connect().catch((err) => {
|
||||
logger.error('@keyv/redis initial connection failed:', err);
|
||||
throw err;
|
||||
});
|
||||
|
||||
/** Ping Interval to keep the Redis server connection alive (if enabled) */
|
||||
let pingInterval = null;
|
||||
const clearPingInterval = () => {
|
||||
if (pingInterval) {
|
||||
clearInterval(pingInterval);
|
||||
pingInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
if (cacheConfig.REDIS_PING_INTERVAL > 0) {
|
||||
pingInterval = setInterval(() => {
|
||||
if (keyvRedisClient && keyvRedisClient.isReady) {
|
||||
keyvRedisClient.ping().catch((err) => {
|
||||
logger.error('@keyv/redis ping failed:', err);
|
||||
});
|
||||
}
|
||||
}, cacheConfig.REDIS_PING_INTERVAL * 1000);
|
||||
keyvRedisClient.on('disconnect', clearPingInterval);
|
||||
keyvRedisClient.on('end', clearPingInterval);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };
|
||||
@@ -61,7 +61,7 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
|
||||
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
|
||||
const { model, ...model_parameters } = _m;
|
||||
/** @type {Record<string, FunctionTool>} */
|
||||
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
/** @type {TEphemeralAgent | null} */
|
||||
const ephemeralAgent = req.body.ephemeralAgent;
|
||||
const mcpServers = new Set(ephemeralAgent?.mcp);
|
||||
@@ -70,9 +70,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
if (ephemeralAgent?.execute_code === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
@@ -90,7 +87,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
const result = {
|
||||
return {
|
||||
id: agent_id,
|
||||
instructions,
|
||||
provider: endpoint,
|
||||
@@ -98,11 +95,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
model,
|
||||
tools,
|
||||
};
|
||||
|
||||
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
|
||||
result.artifacts = ephemeralAgent.artifacts;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -316,10 +308,17 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
||||
if (shouldCreateVersion) {
|
||||
const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash);
|
||||
if (duplicateVersion && !forceVersion) {
|
||||
// No changes detected, return the current agent without creating a new version
|
||||
const agentObj = currentAgent.toObject();
|
||||
agentObj.version = versions.length;
|
||||
return agentObj;
|
||||
const error = new Error(
|
||||
'Duplicate version: This would create a version identical to an existing one',
|
||||
);
|
||||
error.statusCode = 409;
|
||||
error.details = {
|
||||
duplicateVersion,
|
||||
versionIndex: versions.findIndex(
|
||||
(v) => JSON.stringify(duplicateVersion) === JSON.stringify(v),
|
||||
),
|
||||
};
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -413,7 +413,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -670,7 +670,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -879,31 +879,45 @@ describe('models/Agent', () => {
|
||||
expect(emptyParamsAgent.model_parameters).toEqual({});
|
||||
});
|
||||
|
||||
test('should not create new version for duplicate updates', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const testCases = generateVersionTestCases();
|
||||
test('should detect duplicate versions and reject updates', async () => {
|
||||
const originalConsoleError = console.error;
|
||||
console.error = jest.fn();
|
||||
|
||||
for (const testCase of testCases) {
|
||||
const testAgentId = `agent_${uuidv4()}`;
|
||||
try {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const testCases = generateVersionTestCases();
|
||||
|
||||
await createAgent({
|
||||
id: testAgentId,
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
...testCase.initial,
|
||||
});
|
||||
for (const testCase of testCases) {
|
||||
const testAgentId = `agent_${uuidv4()}`;
|
||||
|
||||
const updatedAgent = await updateAgent({ id: testAgentId }, testCase.update);
|
||||
expect(updatedAgent.versions).toHaveLength(2); // No new version created
|
||||
await createAgent({
|
||||
id: testAgentId,
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
...testCase.initial,
|
||||
});
|
||||
|
||||
// Update with duplicate data should succeed but not create a new version
|
||||
const duplicateUpdate = await updateAgent({ id: testAgentId }, testCase.duplicate);
|
||||
await updateAgent({ id: testAgentId }, testCase.update);
|
||||
|
||||
expect(duplicateUpdate.versions).toHaveLength(2); // No new version created
|
||||
let error;
|
||||
try {
|
||||
await updateAgent({ id: testAgentId }, testCase.duplicate);
|
||||
} catch (e) {
|
||||
error = e;
|
||||
}
|
||||
|
||||
const agent = await getAgent({ id: testAgentId });
|
||||
expect(agent.versions).toHaveLength(2);
|
||||
expect(error).toBeDefined();
|
||||
expect(error.message).toContain('Duplicate version');
|
||||
expect(error.statusCode).toBe(409);
|
||||
expect(error.details).toBeDefined();
|
||||
expect(error.details.duplicateVersion).toBeDefined();
|
||||
|
||||
const agent = await getAgent({ id: testAgentId });
|
||||
expect(agent.versions).toHaveLength(2);
|
||||
}
|
||||
} finally {
|
||||
console.error = originalConsoleError;
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1079,13 +1093,20 @@ describe('models/Agent', () => {
|
||||
expect(secondUpdate.versions).toHaveLength(3);
|
||||
|
||||
// Update without forceVersion and no changes should not create a version
|
||||
const duplicateUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{ tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] },
|
||||
{ updatingUserId: authorId.toString(), forceVersion: false },
|
||||
);
|
||||
let error;
|
||||
try {
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{ tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] },
|
||||
{ updatingUserId: authorId.toString(), forceVersion: false },
|
||||
);
|
||||
} catch (e) {
|
||||
error = e;
|
||||
}
|
||||
|
||||
expect(duplicateUpdate.versions).toHaveLength(3); // No new version created
|
||||
expect(error).toBeDefined();
|
||||
expect(error.message).toContain('Duplicate version');
|
||||
expect(error.statusCode).toBe(409);
|
||||
});
|
||||
|
||||
test('should handle isDuplicateVersion with arrays containing null/undefined values', async () => {
|
||||
@@ -1311,7 +1332,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -1493,7 +1514,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -1777,7 +1798,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -2329,7 +2350,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -2379,18 +2400,11 @@ describe('models/Agent', () => {
|
||||
agent_ids: ['agent1', 'agent2'],
|
||||
});
|
||||
|
||||
const updatedAgent = await updateAgent(
|
||||
{ id: agentId },
|
||||
{ agent_ids: ['agent1', 'agent2', 'agent3'] },
|
||||
);
|
||||
expect(updatedAgent.versions).toHaveLength(2);
|
||||
await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] });
|
||||
|
||||
// Update with same agent_ids should succeed but not create a new version
|
||||
const duplicateUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{ agent_ids: ['agent1', 'agent2', 'agent3'] },
|
||||
);
|
||||
expect(duplicateUpdate.versions).toHaveLength(2); // No new version created
|
||||
await expect(
|
||||
updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }),
|
||||
).rejects.toThrow('Duplicate version');
|
||||
});
|
||||
|
||||
test('should handle agent_ids field alongside other fields', async () => {
|
||||
@@ -2529,10 +2543,9 @@ describe('models/Agent', () => {
|
||||
expect(updated.versions).toHaveLength(2);
|
||||
expect(updated.agent_ids).toEqual([]);
|
||||
|
||||
// Update with same empty agent_ids should succeed but not create a new version
|
||||
const duplicateUpdate = await updateAgent({ id: agentId }, { agent_ids: [] });
|
||||
expect(duplicateUpdate.versions).toHaveLength(2); // No new version created
|
||||
expect(duplicateUpdate.agent_ids).toEqual([]);
|
||||
await expect(updateAgent({ id: agentId }, { agent_ids: [] })).rejects.toThrow(
|
||||
'Duplicate version',
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle agent without agent_ids field', async () => {
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
@@ -100,15 +98,10 @@ module.exports = {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
if (req.body.isTemporary) {
|
||||
const expiredAt = new Date();
|
||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
||||
update.expiredAt = expiredAt;
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
@@ -1,572 +0,0 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
deleteNullOrEmptyConversations,
|
||||
searchConversation,
|
||||
getConvosByCursor,
|
||||
getConvosQueried,
|
||||
getConvoFiles,
|
||||
getConvoTitle,
|
||||
deleteConvos,
|
||||
saveConvo,
|
||||
getConvo,
|
||||
} = require('./Conversation');
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
jest.mock('./Message');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getMessages, deleteMessages } = require('./Message');
|
||||
|
||||
const { Conversation } = require('~/db/models');
|
||||
|
||||
describe('Conversation Operations', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockConversationData;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clear database
|
||||
await Conversation.deleteMany({});
|
||||
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Default mock implementations
|
||||
getMessages.mockResolvedValue([]);
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
body: {},
|
||||
};
|
||||
|
||||
mockConversationData = {
|
||||
conversationId: uuidv4(),
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
};
|
||||
});
|
||||
|
||||
describe('saveConvo', () => {
|
||||
it('should save a conversation for an authenticated user', async () => {
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBe('Test Conversation');
|
||||
expect(result.endpoint).toBe(EModelEndpoint.openAI);
|
||||
|
||||
// Verify the conversation was actually saved to the database
|
||||
const savedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
});
|
||||
expect(savedConvo).toBeTruthy();
|
||||
expect(savedConvo.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should query messages when saving a conversation', async () => {
|
||||
// Mock messages as ObjectIds
|
||||
const mongoose = require('mongoose');
|
||||
const mockMessages = [new mongoose.Types.ObjectId(), new mongoose.Types.ObjectId()];
|
||||
getMessages.mockResolvedValue(mockMessages);
|
||||
|
||||
await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
// Verify that getMessages was called with correct parameters
|
||||
expect(getMessages).toHaveBeenCalledWith(
|
||||
{ conversationId: mockConversationData.conversationId },
|
||||
'_id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle newConversationId when provided', async () => {
|
||||
const newConversationId = uuidv4();
|
||||
const result = await saveConvo(mockReq, {
|
||||
...mockConversationData,
|
||||
newConversationId,
|
||||
});
|
||||
|
||||
expect(result.conversationId).toBe(newConversationId);
|
||||
});
|
||||
|
||||
it('should handle unsetFields metadata', async () => {
|
||||
const metadata = {
|
||||
unsetFields: { someField: 1 },
|
||||
};
|
||||
|
||||
await saveConvo(mockReq, mockConversationData, metadata);
|
||||
|
||||
const savedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(savedConvo.someField).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary conversation handling', () => {
|
||||
it('should save a conversation with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is false', async () => {
|
||||
mockReq.body = { isTemporary: false };
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a conversation without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary in body
|
||||
mockReq.body = {};
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
// Should still save the conversation but with expiredAt as null
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should update expiredAt when saving existing temporary conversation', async () => {
|
||||
// First save a temporary conversation
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
const originalExpiredAt = firstSave.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same conversationId but different title
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockReq, updatedData);
|
||||
|
||||
// Should update title and create new expiredAt
|
||||
expect(secondSave.title).toBe('Updated Title');
|
||||
expect(secondSave.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave.expiredAt).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not set expiredAt when updating non-temporary conversation', async () => {
|
||||
// First save a non-temporary conversation
|
||||
mockReq.body = { isTemporary: false };
|
||||
const firstSave = await saveConvo(mockReq, mockConversationData);
|
||||
expect(firstSave.expiredAt).toBeNull();
|
||||
|
||||
// Update without isTemporary flag
|
||||
mockReq.body = {};
|
||||
const updatedData = { ...mockConversationData, title: 'Updated Title' };
|
||||
const secondSave = await saveConvo(mockReq, updatedData);
|
||||
|
||||
expect(secondSave.title).toBe('Updated Title');
|
||||
expect(secondSave.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosByCursor', async () => {
|
||||
// Create some test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Future expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours from now
|
||||
updatedAt: new Date(),
|
||||
});
|
||||
|
||||
// Mock Meili search
|
||||
Conversation.meiliSearch = jest.fn().mockResolvedValue({ hits: [] });
|
||||
|
||||
const result = await getConvosByCursor('user123');
|
||||
|
||||
// Should only return conversations with null or non-existent expiredAt
|
||||
expect(result.conversations).toHaveLength(1);
|
||||
expect(result.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
});
|
||||
|
||||
it('should filter out expired conversations in getConvosQueried', async () => {
|
||||
// Create test conversations
|
||||
const nonExpiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Non-expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
});
|
||||
|
||||
const expiredConvo = await Conversation.create({
|
||||
conversationId: uuidv4(),
|
||||
user: 'user123',
|
||||
title: 'Expired',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
});
|
||||
|
||||
const convoIds = [
|
||||
{ conversationId: nonExpiredConvo.conversationId },
|
||||
{ conversationId: expiredConvo.conversationId },
|
||||
];
|
||||
|
||||
const result = await getConvosQueried('user123', convoIds);
|
||||
|
||||
// Should only return the non-expired conversation
|
||||
expect(result.conversations).toHaveLength(1);
|
||||
expect(result.conversations[0].conversationId).toBe(nonExpiredConvo.conversationId);
|
||||
expect(result.convoMap[nonExpiredConvo.conversationId]).toBeDefined();
|
||||
expect(result.convoMap[expiredConvo.conversationId]).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchConversation', () => {
|
||||
it('should find a conversation by conversationId', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await searchConversation(mockConversationData.conversationId);
|
||||
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBeUndefined(); // Only returns conversationId and user
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await searchConversation('non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvo', () => {
|
||||
it('should retrieve a conversation for a user', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Conversation',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvo('user123', mockConversationData.conversationId);
|
||||
|
||||
expect(result.conversationId).toBe(mockConversationData.conversationId);
|
||||
expect(result.user).toBe('user123');
|
||||
expect(result.title).toBe('Test Conversation');
|
||||
});
|
||||
|
||||
it('should return null if conversation not found', async () => {
|
||||
const result = await getConvo('user123', 'non-existent-id');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoTitle', () => {
|
||||
it('should return the conversation title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'Test Title',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBe('Test Title');
|
||||
});
|
||||
|
||||
it('should return null if conversation has no title', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: null,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoTitle('user123', mockConversationData.conversationId);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return "New Chat" if conversation not found', async () => {
|
||||
const result = await getConvoTitle('user123', 'non-existent-id');
|
||||
expect(result).toBe('New Chat');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getConvoFiles', () => {
|
||||
it('should return conversation files', async () => {
|
||||
const files = ['file1', 'file2'];
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
files,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual(files);
|
||||
});
|
||||
|
||||
it('should return empty array if no files', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
const result = await getConvoFiles(mockConversationData.conversationId);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return empty array if conversation not found', async () => {
|
||||
const result = await getConvoFiles('non-existent-id');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteConvos', () => {
|
||||
it('should delete conversations and associated messages', async () => {
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user123',
|
||||
title: 'To Delete',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 5 });
|
||||
|
||||
const result = await deleteConvos('user123', {
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
expect(result.messages.deletedCount).toBe(5);
|
||||
expect(deleteMessages).toHaveBeenCalledWith({
|
||||
conversationId: { $in: [mockConversationData.conversationId] },
|
||||
});
|
||||
|
||||
// Verify conversation was deleted
|
||||
const deletedConvo = await Conversation.findOne({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
});
|
||||
expect(deletedConvo).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw error if no conversations found', async () => {
|
||||
await expect(deleteConvos('user123', { conversationId: 'non-existent' })).rejects.toThrow(
|
||||
'Conversation not found or already deleted.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteNullOrEmptyConversations', () => {
|
||||
it('should delete conversations with null, empty, or missing conversationIds', async () => {
|
||||
// Since conversationId is required by the schema, we can't create documents with null/missing IDs
|
||||
// This test should verify the function works when such documents exist (e.g., from data corruption)
|
||||
|
||||
// For this test, let's create a valid conversation and verify the function doesn't delete it
|
||||
await Conversation.create({
|
||||
conversationId: mockConversationData.conversationId,
|
||||
user: 'user4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
});
|
||||
|
||||
deleteMessages.mockResolvedValue({ deletedCount: 0 });
|
||||
|
||||
const result = await deleteNullOrEmptyConversations();
|
||||
|
||||
expect(result.conversations.deletedCount).toBe(0); // No invalid conversations to delete
|
||||
expect(result.messages.deletedCount).toBe(0);
|
||||
|
||||
// Verify valid conversation remains
|
||||
const remainingConvos = await Conversation.find({});
|
||||
expect(remainingConvos).toHaveLength(1);
|
||||
expect(remainingConvos[0].conversationId).toBe(mockConversationData.conversationId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle database errors in saveConvo', async () => {
|
||||
// Force a database error by disconnecting
|
||||
await mongoose.disconnect();
|
||||
|
||||
const result = await saveConvo(mockReq, mockConversationData);
|
||||
|
||||
expect(result).toEqual({ message: 'Error saving conversation' });
|
||||
|
||||
// Reconnect for other tests
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,5 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { getAgent } = require('./Agent');
|
||||
const { EToolResources } = require('librechat-data-provider');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
/**
|
||||
@@ -14,124 +12,17 @@ const findFileById = async (file_id, options = {}) => {
|
||||
return await File.findOne({ file_id, ...options }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a user has access to multiple files through a shared agent (batch operation)
|
||||
* @param {string} userId - The user ID to check access for
|
||||
* @param {string[]} fileIds - Array of file IDs to check
|
||||
* @param {string} agentId - The agent ID that might grant access
|
||||
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
|
||||
*/
|
||||
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
|
||||
const accessMap = new Map();
|
||||
|
||||
// Initialize all files as no access
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, false));
|
||||
|
||||
try {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
|
||||
if (!agent) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if user is the author - if so, grant access to all files
|
||||
if (agent.author.toString() === userId) {
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, true));
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is shared with the user via projects
|
||||
if (!agent.projectIds || agent.projectIds.length === 0) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is in global project
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
|
||||
) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Agent is globally shared - check if it's collaborative
|
||||
if (checkCollaborative && !agent.isCollaborative) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check which files are actually attached
|
||||
const attachedFileIds = new Set();
|
||||
if (agent.tool_resources) {
|
||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grant access only to files that are attached to this agent
|
||||
fileIds.forEach((fileId) => {
|
||||
if (attachedFileIds.has(fileId)) {
|
||||
accessMap.set(fileId, true);
|
||||
}
|
||||
});
|
||||
|
||||
return accessMap;
|
||||
} catch (error) {
|
||||
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
|
||||
return accessMap;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves files matching a given filter, sorted by the most recently updated.
|
||||
* @param {Object} filter - The filter criteria to apply.
|
||||
* @param {Object} [_sortOptions] - Optional sort parameters.
|
||||
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
|
||||
* Default excludes the 'text' field.
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {string} [options.userId] - User ID for access control
|
||||
* @param {string} [options.agentId] - Agent ID that might grant access to files
|
||||
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
|
||||
*/
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
|
||||
// If userId and agentId are provided, filter files based on access
|
||||
if (options.userId && options.agentId) {
|
||||
// Collect file IDs that need access check
|
||||
const filesToCheck = [];
|
||||
const ownedFiles = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (file.user && file.user.toString() === options.userId) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
filesToCheck.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
if (filesToCheck.length === 0) {
|
||||
return ownedFiles;
|
||||
}
|
||||
|
||||
// Batch check access for all non-owned files
|
||||
const fileIds = filesToCheck.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(
|
||||
options.userId,
|
||||
fileIds,
|
||||
options.agentId,
|
||||
false,
|
||||
);
|
||||
|
||||
// Filter files based on access
|
||||
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
|
||||
|
||||
return [...ownedFiles, ...accessibleFiles];
|
||||
}
|
||||
|
||||
return files;
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -141,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, option
|
||||
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
|
||||
*/
|
||||
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
|
||||
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
|
||||
if (!fileIds || !fileIds.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const filter = {
|
||||
file_id: { $in: fileIds },
|
||||
$or: [],
|
||||
};
|
||||
|
||||
if (toolResourceSet.has(EToolResources.ocr)) {
|
||||
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
|
||||
if (toolResourceSet.size) {
|
||||
filter.$or = [];
|
||||
}
|
||||
|
||||
if (toolResourceSet.has(EToolResources.file_search)) {
|
||||
filter.$or.push({ embedded: true });
|
||||
}
|
||||
@@ -285,5 +176,4 @@ module.exports = {
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
hasAccessToFilesViaAgent,
|
||||
};
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { fileSchema } = require('@librechat/data-schemas');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { projectSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { getFiles, createFile } = require('./File');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { createAgent } = require('./Agent');
|
||||
|
||||
let File;
|
||||
let Agent;
|
||||
let Project;
|
||||
|
||||
describe('File Access Control', () => {
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
File = mongoose.models.File || mongoose.model('File', fileSchema);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await Project.deleteMany({});
|
||||
});
|
||||
|
||||
describe('hasAccessToFilesViaAgent', () => {
|
||||
it('should efficiently check access for multiple files at once', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: `file-${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with only first two files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0], fileIds[1]],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
|
||||
// Check access for all files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
|
||||
// Should have access only to the first two files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(false);
|
||||
expect(accessMap.get(fileIds[3])).toBe(false);
|
||||
});
|
||||
|
||||
it('should grant access to all files when user is the agent author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0]], // Only one file attached
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Check access as the author
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
|
||||
|
||||
// Author should have access to all files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
expect(accessMap.get(fileIds[1])).toBe(true);
|
||||
expect(accessMap.get(fileIds[2])).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle non-existent agent gracefully', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
|
||||
|
||||
// Should have no access to any files
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when agent is not collaborative', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent with files but isCollaborative: false
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
|
||||
// Should have no access to any files when isCollaborative is false
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFiles with agent access control', () => {
|
||||
test('should return files owned by user and files accessible through agent', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const ownedFileId = `file_${uuidv4()}`;
|
||||
const sharedFileId = `file_${uuidv4()}`;
|
||||
const inaccessibleFileId = `file_${uuidv4()}`;
|
||||
|
||||
// Create/get global project using getProjectByName which will upsert
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
|
||||
|
||||
// Create agent with shared file
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Shared Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
projectIds: [globalProject._id],
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [sharedFileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
file_id: ownedFileId,
|
||||
user: userId,
|
||||
filename: 'owned.txt',
|
||||
filepath: '/uploads/owned.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: sharedFileId,
|
||||
user: authorId,
|
||||
filename: 'shared.txt',
|
||||
filepath: '/uploads/shared.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
embedded: true,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: inaccessibleFileId,
|
||||
user: authorId,
|
||||
filename: 'inaccessible.txt',
|
||||
filepath: '/uploads/inaccessible.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 300,
|
||||
});
|
||||
|
||||
// Get files with access control
|
||||
const files = await getFiles(
|
||||
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: userId.toString(), agentId },
|
||||
);
|
||||
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
|
||||
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
|
||||
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
|
||||
});
|
||||
|
||||
test('should return all files when no userId/agentId provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileId1 = `file_${uuidv4()}`;
|
||||
const fileId2 = `file_${uuidv4()}`;
|
||||
|
||||
await createFile({
|
||||
file_id: fileId1,
|
||||
user: userId,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
await createFile({
|
||||
file_id: fileId2,
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 200,
|
||||
});
|
||||
|
||||
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,5 @@
|
||||
const { z } = require('zod');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createTempChatExpirationDate } = require('@librechat/api');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
@@ -56,14 +54,9 @@ async function saveMessage(req, params, metadata) {
|
||||
};
|
||||
|
||||
if (req?.body?.isTemporary) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
update.expiredAt = createTempChatExpirationDate(customConfig);
|
||||
} catch (err) {
|
||||
logger.error('Error creating temporary chat expiration date:', err);
|
||||
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
|
||||
update.expiredAt = null;
|
||||
}
|
||||
const expiredAt = new Date();
|
||||
expiredAt.setDate(expiredAt.getDate() + 30);
|
||||
update.expiredAt = expiredAt;
|
||||
} else {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { messageSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
updateMessage,
|
||||
deleteMessages,
|
||||
bulkSaveMessages,
|
||||
updateMessageText,
|
||||
deleteMessagesSince,
|
||||
} = require('./Message');
|
||||
|
||||
jest.mock('~/server/services/Config/getCustomConfig');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IMessage>}
|
||||
*/
|
||||
@@ -121,21 +117,21 @@ describe('Message Operations', () => {
|
||||
const conversationId = uuidv4();
|
||||
|
||||
// Create multiple messages in the same conversation
|
||||
await saveMessage(mockReq, {
|
||||
const message1 = await saveMessage(mockReq, {
|
||||
messageId: 'msg1',
|
||||
conversationId,
|
||||
text: 'First message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockReq, {
|
||||
const message2 = await saveMessage(mockReq, {
|
||||
messageId: 'msg2',
|
||||
conversationId,
|
||||
text: 'Second message',
|
||||
user: 'user123',
|
||||
});
|
||||
|
||||
await saveMessage(mockReq, {
|
||||
const message3 = await saveMessage(mockReq, {
|
||||
messageId: 'msg3',
|
||||
conversationId,
|
||||
text: 'Third message',
|
||||
@@ -318,265 +314,4 @@ describe('Message Operations', () => {
|
||||
expect(messages[0].text).toBe('Victim message');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTemporary message handling', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should save a message with expiredAt when isTemporary is true', async () => {
|
||||
// Mock custom config with 24 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
const afterSave = new Date();
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
expect(result.expiredAt).toBeInstanceOf(Date);
|
||||
|
||||
// Verify expiredAt is approximately 24 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
new Date(afterSave.getTime() + 24 * 60 * 60 * 1000 + 1000).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is false', async () => {
|
||||
mockReq.body = { isTemporary: false };
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should save a message without expiredAt when isTemporary is not provided', async () => {
|
||||
// No isTemporary in body
|
||||
mockReq.body = {};
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use custom retention period from config', async () => {
|
||||
// Mock custom config with 48 hour retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 48,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 48 hours in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 48 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle minimum retention period (1 hour)', async () => {
|
||||
// Mock custom config with less than minimum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 0.5, // Half hour - should be clamped to 1 hour
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 1 hour in the future (minimum)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 1 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle maximum retention period (8760 hours)', async () => {
|
||||
// Mock custom config with more than maximum retention
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 10000, // Should be clamped to 8760 hours
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Verify expiredAt is approximately 8760 hours (1 year) in the future
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 8760 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle getCustomConfig errors gracefully', async () => {
|
||||
// Mock getCustomConfig to throw an error
|
||||
getCustomConfig.mockRejectedValue(new Error('Config service unavailable'));
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
// Should still save the message but with expiredAt as null
|
||||
expect(result.messageId).toBe('msg123');
|
||||
expect(result.expiredAt).toBeNull();
|
||||
});
|
||||
|
||||
it('should use default retention when config is not provided', async () => {
|
||||
// Mock getCustomConfig to return empty config
|
||||
getCustomConfig.mockResolvedValue({});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
|
||||
const beforeSave = new Date();
|
||||
const result = await saveMessage(mockReq, mockMessageData);
|
||||
|
||||
expect(result.expiredAt).toBeDefined();
|
||||
|
||||
// Default retention is 30 days (720 hours)
|
||||
const expectedExpirationTime = new Date(beforeSave.getTime() + 30 * 24 * 60 * 60 * 1000);
|
||||
const actualExpirationTime = new Date(result.expiredAt);
|
||||
|
||||
expect(actualExpirationTime.getTime()).toBeGreaterThanOrEqual(
|
||||
expectedExpirationTime.getTime() - 1000,
|
||||
);
|
||||
expect(actualExpirationTime.getTime()).toBeLessThanOrEqual(
|
||||
expectedExpirationTime.getTime() + 1000,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not update expiredAt on message update', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const savedMessage = await saveMessage(mockReq, mockMessageData);
|
||||
const originalExpiredAt = savedMessage.expiredAt;
|
||||
|
||||
// Now update the message without isTemporary flag
|
||||
mockReq.body = {};
|
||||
const updatedMessage = await updateMessage(mockReq, {
|
||||
messageId: 'msg123',
|
||||
text: 'Updated text',
|
||||
});
|
||||
|
||||
// expiredAt should not be in the returned updated message object
|
||||
expect(updatedMessage.expiredAt).toBeUndefined();
|
||||
|
||||
// Verify in database that expiredAt wasn't changed
|
||||
const dbMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
|
||||
expect(dbMessage.expiredAt).toEqual(originalExpiredAt);
|
||||
});
|
||||
|
||||
it('should preserve expiredAt when saving existing temporary message', async () => {
|
||||
// First save a temporary message
|
||||
getCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
temporaryChatRetention: 24,
|
||||
},
|
||||
});
|
||||
|
||||
mockReq.body = { isTemporary: true };
|
||||
const firstSave = await saveMessage(mockReq, mockMessageData);
|
||||
const originalExpiredAt = firstSave.expiredAt;
|
||||
|
||||
// Wait a bit to ensure time difference
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Save again with same messageId but different text
|
||||
const updatedData = { ...mockMessageData, text: 'Updated text' };
|
||||
const secondSave = await saveMessage(mockReq, updatedData);
|
||||
|
||||
// Should update text but create new expiredAt
|
||||
expect(secondSave.text).toBe('Updated text');
|
||||
expect(secondSave.expiredAt).toBeDefined();
|
||||
expect(new Date(secondSave.expiredAt).getTime()).toBeGreaterThan(
|
||||
new Date(originalExpiredAt).getTime(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle bulk operations with temporary messages', async () => {
|
||||
// This test verifies bulkSaveMessages doesn't interfere with expiredAt
|
||||
const messages = [
|
||||
{
|
||||
messageId: 'bulk1',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 1',
|
||||
user: 'user123',
|
||||
expiredAt: new Date(Date.now() + 24 * 60 * 60 * 1000),
|
||||
},
|
||||
{
|
||||
messageId: 'bulk2',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Bulk message 2',
|
||||
user: 'user123',
|
||||
expiredAt: null,
|
||||
},
|
||||
];
|
||||
|
||||
await bulkSaveMessages(messages);
|
||||
|
||||
const savedMessages = await Message.find({
|
||||
messageId: { $in: ['bulk1', 'bulk2'] },
|
||||
}).lean();
|
||||
|
||||
expect(savedMessages).toHaveLength(2);
|
||||
|
||||
const bulk1 = savedMessages.find((m) => m.messageId === 'bulk1');
|
||||
const bulk2 = savedMessages.find((m) => m.messageId === 'bulk2');
|
||||
|
||||
expect(bulk1.expiredAt).toBeDefined();
|
||||
expect(bulk2.expiredAt).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { matchModelName } = require('../utils/tokens');
|
||||
const { matchModelName } = require('../utils');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
@@ -87,9 +87,6 @@ const tokenValues = Object.assign(
|
||||
'gpt-4.1': { prompt: 2, completion: 8 },
|
||||
'gpt-4.5': { prompt: 75, completion: 150 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-5': { prompt: 1.25, completion: 10 },
|
||||
'gpt-5-mini': { prompt: 0.25, completion: 2 },
|
||||
'gpt-5-nano': { prompt: 0.05, completion: 0.4 },
|
||||
'gpt-4o': { prompt: 2.5, completion: 10 },
|
||||
'gpt-4o-2024-05-13': { prompt: 5, completion: 15 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
@@ -138,11 +135,10 @@ const tokenValues = Object.assign(
|
||||
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-2': { prompt: 2.0, completion: 10.0 },
|
||||
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
|
||||
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
|
||||
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
|
||||
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
|
||||
'grok-3': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-4': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-beta': { prompt: 5.0, completion: 15.0 },
|
||||
'mistral-large': { prompt: 2.0, completion: 6.0 },
|
||||
'pixtral-large': { prompt: 2.0, completion: 6.0 },
|
||||
@@ -150,9 +146,6 @@ const tokenValues = Object.assign(
|
||||
codestral: { prompt: 0.3, completion: 0.9 },
|
||||
'ministral-8b': { prompt: 0.1, completion: 0.1 },
|
||||
'ministral-3b': { prompt: 0.04, completion: 0.04 },
|
||||
// GPT-OSS models
|
||||
'gpt-oss-20b': { prompt: 0.05, completion: 0.2 },
|
||||
'gpt-oss-120b': { prompt: 0.15, completion: 0.6 },
|
||||
},
|
||||
bedrockValues,
|
||||
);
|
||||
@@ -220,12 +213,6 @@ const getValueKey = (model, endpoint) => {
|
||||
return 'gpt-4.1';
|
||||
} else if (modelName.includes('gpt-4o-2024-05-13')) {
|
||||
return 'gpt-4o-2024-05-13';
|
||||
} else if (modelName.includes('gpt-5-nano')) {
|
||||
return 'gpt-5-nano';
|
||||
} else if (modelName.includes('gpt-5-mini')) {
|
||||
return 'gpt-5-mini';
|
||||
} else if (modelName.includes('gpt-5')) {
|
||||
return 'gpt-5';
|
||||
} else if (modelName.includes('gpt-4o-mini')) {
|
||||
return 'gpt-4o-mini';
|
||||
} else if (modelName.includes('gpt-4o')) {
|
||||
|
||||
@@ -25,14 +25,8 @@ describe('getValueKey', () => {
|
||||
expect(getValueKey('gpt-4-some-other-info')).toBe('8k');
|
||||
});
|
||||
|
||||
it('should return "gpt-5" for model name containing "gpt-5"', () => {
|
||||
expect(getValueKey('gpt-5-some-other-info')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-2025-01-30-0130')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-turbo')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-0130')).toBe('gpt-5');
|
||||
it('should return undefined for model names that do not match any known patterns', () => {
|
||||
expect(getValueKey('gpt-5-some-other-info')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
||||
@@ -90,29 +84,6 @@ describe('getValueKey', () => {
|
||||
expect(getValueKey('gpt-4.1-nano-0125')).toBe('gpt-4.1-nano');
|
||||
});
|
||||
|
||||
it('should return "gpt-5" for model type of "gpt-5"', () => {
|
||||
expect(getValueKey('gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-2025-01-30-0130')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5')).toBe('gpt-5');
|
||||
expect(getValueKey('openai/gpt-5-2025-01-30')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-turbo')).toBe('gpt-5');
|
||||
expect(getValueKey('gpt-5-0130')).toBe('gpt-5');
|
||||
});
|
||||
|
||||
it('should return "gpt-5-mini" for model type of "gpt-5-mini"', () => {
|
||||
expect(getValueKey('gpt-5-mini-2025-01-30')).toBe('gpt-5-mini');
|
||||
expect(getValueKey('openai/gpt-5-mini')).toBe('gpt-5-mini');
|
||||
expect(getValueKey('gpt-5-mini-0130')).toBe('gpt-5-mini');
|
||||
expect(getValueKey('gpt-5-mini-2025-01-30-0130')).toBe('gpt-5-mini');
|
||||
});
|
||||
|
||||
it('should return "gpt-5-nano" for model type of "gpt-5-nano"', () => {
|
||||
expect(getValueKey('gpt-5-nano-2025-01-30')).toBe('gpt-5-nano');
|
||||
expect(getValueKey('openai/gpt-5-nano')).toBe('gpt-5-nano');
|
||||
expect(getValueKey('gpt-5-nano-0130')).toBe('gpt-5-nano');
|
||||
expect(getValueKey('gpt-5-nano-2025-01-30-0130')).toBe('gpt-5-nano');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o" for model type of "gpt-4o"', () => {
|
||||
expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o');
|
||||
@@ -236,48 +207,6 @@ describe('getMultiplier', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5', () => {
|
||||
const valueKey = getValueKey('gpt-5-2025-01-30');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-5'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5-preview', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5-mini', () => {
|
||||
const valueKey = getValueKey('gpt-5-mini-2025-01-30');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-5-mini'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-mini'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5-mini-preview', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5-mini'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5-mini', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-mini'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-5-nano', () => {
|
||||
const valueKey = getValueKey('gpt-5-nano-2025-01-30');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-5-nano'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-nano'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'gpt-5-nano-preview', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-5-nano'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'openai/gpt-5-nano', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-5-nano'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o', () => {
|
||||
const valueKey = getValueKey('gpt-4o-2024-08-06');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
@@ -378,22 +307,10 @@ describe('getMultiplier', () => {
|
||||
});
|
||||
|
||||
it('should return defaultRate if derived valueKey does not match any known patterns', () => {
|
||||
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-10-some-other-info' })).toBe(
|
||||
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-5-some-other-info' })).toBe(
|
||||
defaultRate,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return correct multipliers for GPT-OSS models', () => {
|
||||
const models = ['gpt-oss-20b', 'gpt-oss-120b'];
|
||||
models.forEach((key) => {
|
||||
const expectedPrompt = tokenValues[key].prompt;
|
||||
const expectedCompletion = tokenValues[key].completion;
|
||||
expect(getMultiplier({ valueKey: key, tokenType: 'prompt' })).toBe(expectedPrompt);
|
||||
expect(getMultiplier({ valueKey: key, tokenType: 'completion' })).toBe(expectedCompletion);
|
||||
expect(getMultiplier({ model: key, tokenType: 'prompt' })).toBe(expectedPrompt);
|
||||
expect(getMultiplier({ model: key, tokenType: 'completion' })).toBe(expectedCompletion);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('AWS Bedrock Model Tests', () => {
|
||||
@@ -719,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 model', () => {
|
||||
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-3'].prompt,
|
||||
@@ -754,15 +662,6 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
tokenValues['grok-3-mini-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "v0.8.0-rc1",
|
||||
"version": "v0.7.8",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -44,21 +44,20 @@
|
||||
"@googleapis/youtube": "^20.0.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/community": "^0.3.47",
|
||||
"@langchain/core": "^0.3.62",
|
||||
"@langchain/core": "^0.3.60",
|
||||
"@langchain/google-genai": "^0.2.13",
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.75",
|
||||
"@librechat/agents": "^2.4.41",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.17.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"compression": "^1.8.1",
|
||||
"connect-redis": "^8.1.0",
|
||||
"cohere-ai": "^7.9.1",
|
||||
"compression": "^1.7.4",
|
||||
"connect-redis": "^7.1.0",
|
||||
"cookie": "^0.7.2",
|
||||
"cookie-parser": "^1.4.7",
|
||||
"cors": "^2.8.5",
|
||||
@@ -68,11 +67,10 @@
|
||||
"express": "^4.21.2",
|
||||
"express-mongo-sanitize": "^2.2.0",
|
||||
"express-rate-limit": "^7.4.1",
|
||||
"express-session": "^1.18.2",
|
||||
"express-session": "^1.18.1",
|
||||
"express-static-gzip": "^2.2.0",
|
||||
"file-type": "^18.7.0",
|
||||
"firebase": "^11.0.2",
|
||||
"form-data": "^4.0.4",
|
||||
"googleapis": "^126.0.1",
|
||||
"handlebars": "^4.7.7",
|
||||
"https-proxy-agent": "^7.0.6",
|
||||
@@ -90,12 +88,12 @@
|
||||
"mime": "^3.0.0",
|
||||
"module-alias": "^2.2.3",
|
||||
"mongoose": "^8.12.1",
|
||||
"multer": "^2.0.2",
|
||||
"multer": "^2.0.1",
|
||||
"nanoid": "^3.3.7",
|
||||
"node-fetch": "^2.7.0",
|
||||
"nodemailer": "^6.9.15",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "^5.10.1",
|
||||
"openai": "^4.96.2",
|
||||
"openai-chat-tokens": "^0.2.8",
|
||||
"openid-client": "^6.5.0",
|
||||
"passport": "^0.6.0",
|
||||
@@ -120,7 +118,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"jest": "^29.7.0",
|
||||
"mongodb-memory-server": "^10.1.4",
|
||||
"mongodb-memory-server": "^10.1.3",
|
||||
"nodemon": "^3.0.3",
|
||||
"supertest": "^7.1.0"
|
||||
}
|
||||
|
||||
@@ -169,6 +169,9 @@ function disposeClient(client) {
|
||||
client.isGenerativeModel = null;
|
||||
}
|
||||
// Properties specific to OpenAIClient
|
||||
if (client.ChatGPTClient) {
|
||||
client.ChatGPTClient = null;
|
||||
}
|
||||
if (client.completionsUrl) {
|
||||
client.completionsUrl = null;
|
||||
}
|
||||
|
||||
282
api/server/controllers/AskController.js
Normal file
282
api/server/controllers/AskController.js
Normal file
@@ -0,0 +1,282 @@
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const {
|
||||
disposeClient,
|
||||
processReqData,
|
||||
clientRegistry,
|
||||
requestDataMap,
|
||||
} = require('~/server/cleanup');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
modelDisplayLabel,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let client = null;
|
||||
let abortKey = null;
|
||||
let cleanupHandlers = [];
|
||||
let clientRef = null;
|
||||
|
||||
logger.debug('[AskController]', {
|
||||
text,
|
||||
conversationId,
|
||||
...endpointOption,
|
||||
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
|
||||
});
|
||||
|
||||
let userMessage = null;
|
||||
let userMessagePromise = null;
|
||||
let promptTokens = null;
|
||||
let userMessageId = null;
|
||||
let responseMessageId = null;
|
||||
let getAbortData = null;
|
||||
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
model: endpointOption.modelOptions.model,
|
||||
modelDisplayLabel,
|
||||
});
|
||||
const initialConversationId = conversationId;
|
||||
const newConvo = !initialConversationId;
|
||||
const userId = req.user.id;
|
||||
|
||||
let reqDataContext = {
|
||||
userMessage,
|
||||
userMessagePromise,
|
||||
responseMessageId,
|
||||
promptTokens,
|
||||
conversationId,
|
||||
userMessageId,
|
||||
};
|
||||
|
||||
const updateReqData = (data = {}) => {
|
||||
reqDataContext = processReqData(data, reqDataContext);
|
||||
abortKey = reqDataContext.abortKey;
|
||||
userMessage = reqDataContext.userMessage;
|
||||
userMessagePromise = reqDataContext.userMessagePromise;
|
||||
responseMessageId = reqDataContext.responseMessageId;
|
||||
promptTokens = reqDataContext.promptTokens;
|
||||
conversationId = reqDataContext.conversationId;
|
||||
userMessageId = reqDataContext.userMessageId;
|
||||
};
|
||||
|
||||
let { onProgress: progressCallback, getPartialText } = createOnProgress();
|
||||
|
||||
const performCleanup = () => {
|
||||
logger.debug('[AskController] Performing cleanup');
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
for (const handler of cleanupHandlers) {
|
||||
try {
|
||||
if (typeof handler === 'function') {
|
||||
handler();
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (abortKey) {
|
||||
logger.debug('[AskController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
abortKey = null;
|
||||
}
|
||||
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
client = null;
|
||||
}
|
||||
|
||||
reqDataContext = null;
|
||||
userMessage = null;
|
||||
userMessagePromise = null;
|
||||
promptTokens = null;
|
||||
getAbortData = null;
|
||||
progressCallback = null;
|
||||
endpointOption = null;
|
||||
cleanupHandlers = null;
|
||||
addTitle = null;
|
||||
|
||||
if (requestDataMap.has(req)) {
|
||||
requestDataMap.delete(req);
|
||||
}
|
||||
logger.debug('[AskController] Cleanup completed');
|
||||
};
|
||||
|
||||
try {
|
||||
({ client } = await initializeClient({ req, res, endpointOption }));
|
||||
if (clientRegistry && client) {
|
||||
clientRegistry.register(client, { userId }, client);
|
||||
}
|
||||
|
||||
if (client) {
|
||||
requestDataMap.set(req, { client });
|
||||
}
|
||||
|
||||
clientRef = new WeakRef(client);
|
||||
|
||||
getAbortData = () => {
|
||||
const currentClient = clientRef?.deref();
|
||||
const currentText =
|
||||
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
|
||||
|
||||
return {
|
||||
sender,
|
||||
conversationId,
|
||||
messageId: reqDataContext.responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: currentText,
|
||||
userMessage: userMessage,
|
||||
userMessagePromise: userMessagePromise,
|
||||
promptTokens: reqDataContext.promptTokens,
|
||||
};
|
||||
};
|
||||
|
||||
const { onStart, abortController } = createAbortController(
|
||||
req,
|
||||
res,
|
||||
getAbortData,
|
||||
updateReqData,
|
||||
);
|
||||
|
||||
const closeHandler = () => {
|
||||
logger.debug('[AskController] Request closed');
|
||||
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
|
||||
return;
|
||||
}
|
||||
abortController.abort();
|
||||
logger.debug('[AskController] Request aborted on close');
|
||||
};
|
||||
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
res.removeListener('close', closeHandler);
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
});
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
parentMessageId,
|
||||
conversationId: reqDataContext.conversationId,
|
||||
overrideParentMessageId,
|
||||
getReqData: updateReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
},
|
||||
};
|
||||
|
||||
/** @type {TMessage} */
|
||||
let response = await client.sendMessage(text, messageOptions);
|
||||
response.endpoint = endpointOption.endpoint;
|
||||
|
||||
const databasePromise = response.databasePromise;
|
||||
delete response.databasePromise;
|
||||
|
||||
const { conversation: convoData = {} } = await databasePromise;
|
||||
const conversation = { ...convoData };
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
const latestUserMessage = reqDataContext.userMessage;
|
||||
|
||||
if (client?.options?.attachments && latestUserMessage) {
|
||||
latestUserMessage.files = client.options.attachments;
|
||||
if (endpointOption?.modelOptions?.model) {
|
||||
conversation.model = endpointOption.modelOptions.model;
|
||||
}
|
||||
delete latestUserMessage.image_urls;
|
||||
}
|
||||
|
||||
if (!abortController.signal.aborted) {
|
||||
const finalResponseMessage = { ...response };
|
||||
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: latestUserMessage,
|
||||
responseMessage: finalResponseMessage,
|
||||
});
|
||||
res.end();
|
||||
|
||||
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...finalResponseMessage, user: userId },
|
||||
{ context: 'api/server/controllers/AskController.js - response end' },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!client?.skipSaveUserMessage && latestUserMessage) {
|
||||
await saveMessage(req, latestUserMessage, {
|
||||
context: "api/server/controllers/AskController.js - don't skip saving user message",
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
client,
|
||||
})
|
||||
.then(() => {
|
||||
logger.debug('[AskController] Title generation started');
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[AskController] Error in title generation', err);
|
||||
})
|
||||
.finally(() => {
|
||||
logger.debug('[AskController] Title generation completed');
|
||||
performCleanup();
|
||||
});
|
||||
} else {
|
||||
performCleanup();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AskController] Error handling request', error);
|
||||
let partialText = '';
|
||||
try {
|
||||
const currentClient = clientRef?.deref();
|
||||
partialText =
|
||||
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
|
||||
} catch (getTextError) {
|
||||
logger.error('[AskController] Error calling getText() during error handling', getTextError);
|
||||
}
|
||||
|
||||
handleAbortError(res, req, error, {
|
||||
sender,
|
||||
partialText,
|
||||
conversationId: reqDataContext.conversationId,
|
||||
messageId: reqDataContext.responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
|
||||
userMessageId: reqDataContext.userMessageId,
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
|
||||
})
|
||||
.finally(() => {
|
||||
performCleanup();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = AskController;
|
||||
@@ -1,17 +1,17 @@
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const openIdClient = require('openid-client');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
requestPasswordReset,
|
||||
setOpenIDAuthTokens,
|
||||
registerUser,
|
||||
resetPassword,
|
||||
setAuthTokens,
|
||||
registerUser,
|
||||
requestPasswordReset,
|
||||
setOpenIDAuthTokens,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getResponseSender } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbortError,
|
||||
@@ -12,8 +10,9 @@ const {
|
||||
clientRegistry,
|
||||
requestDataMap,
|
||||
} = require('~/server/cleanup');
|
||||
const { createOnProgress } = require('~/server/utils');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const EditController = async (req, res, next, initializeClient) => {
|
||||
let {
|
||||
@@ -85,7 +84,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
}
|
||||
|
||||
if (abortKey) {
|
||||
logger.debug('[EditController] Cleaning up abort controller');
|
||||
logger.debug('[AskController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
abortKey = null;
|
||||
}
|
||||
@@ -199,7 +198,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
const finalUserMessage = reqDataContext.userMessage;
|
||||
const finalResponseMessage = { ...response };
|
||||
|
||||
sendEvent(res, {
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
|
||||
@@ -24,23 +24,17 @@ const handleValidationError = (err, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = (err, _req, res, _next) => {
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
module.exports = (err, req, res, next) => {
|
||||
try {
|
||||
if (err.name === 'ValidationError') {
|
||||
return handleValidationError(err, res);
|
||||
return (err = handleValidationError(err, res));
|
||||
}
|
||||
if (err.code && err.code == 11000) {
|
||||
return handleDuplicateKeyError(err, res);
|
||||
return (err = handleDuplicateKeyError(err, res));
|
||||
}
|
||||
// Special handling for errors like SyntaxError
|
||||
if (err.statusCode && err.body) {
|
||||
return res.status(err.statusCode).send(err.body);
|
||||
}
|
||||
|
||||
logger.error('ErrorController => error', err);
|
||||
return res.status(500).send('An unknown error occurred.');
|
||||
} catch (err) {
|
||||
logger.error('ErrorController => processing error', err);
|
||||
return res.status(500).send('Processing error in ErrorController.');
|
||||
logger.error('ErrorController => error', err);
|
||||
res.status(500).send('An unknown error occurred.');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,241 +0,0 @@
|
||||
const errorController = require('./ErrorController');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// Mock the logger
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ErrorController', () => {
|
||||
let mockReq, mockRes, mockNext;
|
||||
|
||||
beforeEach(() => {
|
||||
mockReq = {};
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
mockNext = jest.fn();
|
||||
logger.error.mockClear();
|
||||
});
|
||||
|
||||
describe('ValidationError handling', () => {
|
||||
it('should handle ValidationError with single error', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '["Email is required"]',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with multiple errors', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {
|
||||
email: { message: 'Email is required', path: 'email' },
|
||||
password: { message: 'Password is required', path: 'password' },
|
||||
},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '"Email is required Password is required"',
|
||||
fields: '["email","password"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
|
||||
});
|
||||
|
||||
it('should handle ValidationError with empty errors object', () => {
|
||||
const validationError = {
|
||||
name: 'ValidationError',
|
||||
errors: {},
|
||||
};
|
||||
|
||||
errorController(validationError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: '[]',
|
||||
fields: '[]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Duplicate key error handling', () => {
|
||||
it('should handle duplicate key error (code 11000)', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle duplicate key error with multiple fields', () => {
|
||||
const duplicateKeyError = {
|
||||
code: 11000,
|
||||
keyValue: { email: 'test@example.com', username: 'testuser' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email","username"] already exists.',
|
||||
fields: '["email","username"]',
|
||||
});
|
||||
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
|
||||
});
|
||||
|
||||
it('should handle error with code 11000 as string', () => {
|
||||
const duplicateKeyError = {
|
||||
code: '11000',
|
||||
keyValue: { email: 'test@example.com' },
|
||||
};
|
||||
|
||||
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(409);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({
|
||||
messages: 'An document with that ["email"] already exists.',
|
||||
fields: '["email"]',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('SyntaxError handling', () => {
|
||||
it('should handle errors with statusCode and body', () => {
|
||||
const syntaxError = {
|
||||
statusCode: 400,
|
||||
body: 'Invalid JSON syntax',
|
||||
};
|
||||
|
||||
errorController(syntaxError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
|
||||
});
|
||||
|
||||
it('should handle errors with different statusCode and body', () => {
|
||||
const customError = {
|
||||
statusCode: 422,
|
||||
body: { error: 'Unprocessable entity' },
|
||||
};
|
||||
|
||||
errorController(customError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(422);
|
||||
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
|
||||
});
|
||||
|
||||
it('should handle error with statusCode but no body', () => {
|
||||
const partialError = {
|
||||
statusCode: 400,
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
|
||||
it('should handle error with body but no statusCode', () => {
|
||||
const partialError = {
|
||||
body: 'Some error message',
|
||||
};
|
||||
|
||||
errorController(partialError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Unknown error handling', () => {
|
||||
it('should handle unknown errors', () => {
|
||||
const unknownError = new Error('Some unknown error');
|
||||
|
||||
errorController(unknownError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
|
||||
});
|
||||
|
||||
it('should handle errors with code other than 11000', () => {
|
||||
const mongoError = {
|
||||
code: 11100,
|
||||
message: 'Some MongoDB error',
|
||||
};
|
||||
|
||||
errorController(mongoError, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
|
||||
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
|
||||
});
|
||||
|
||||
it('should handle null/undefined errors', () => {
|
||||
errorController(null, mockReq, mockRes, mockNext);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'ErrorController => processing error',
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Catch block handling', () => {
|
||||
beforeEach(() => {
|
||||
// Restore logger mock to normal behavior for these tests
|
||||
logger.error.mockRestore();
|
||||
logger.error = jest.fn();
|
||||
});
|
||||
|
||||
it('should handle errors when logger.error throws', () => {
|
||||
// Create fresh mocks for this test
|
||||
const freshMockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
send: jest.fn(),
|
||||
};
|
||||
|
||||
// Mock logger to throw on the first call, succeed on the second
|
||||
logger.error
|
||||
.mockImplementationOnce(() => {
|
||||
throw new Error('Logger error');
|
||||
})
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const testError = new Error('Test error');
|
||||
|
||||
errorController(testError, mockReq, freshMockRes, mockNext);
|
||||
|
||||
expect(freshMockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
|
||||
expect(logger.error).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,15 +1,54 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
getToolkitKey,
|
||||
checkPluginAuth,
|
||||
filterUniquePlugins,
|
||||
convertMCPToolsToPlugins,
|
||||
} = require('@librechat/api');
|
||||
const { CacheKeys, AuthType } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
const { getToolkitKey } = require('~/server/services/ToolService');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { availableTools } = require('~/app/clients/tools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Filters out duplicate plugins from the list of plugins.
|
||||
*
|
||||
* @param {TPlugin[]} plugins The list of plugins to filter.
|
||||
* @returns {TPlugin[]} The list of plugins with duplicates removed.
|
||||
*/
|
||||
const filterUniquePlugins = (plugins) => {
|
||||
const seen = new Set();
|
||||
return plugins.filter((plugin) => {
|
||||
const duplicate = seen.has(plugin.pluginKey);
|
||||
seen.add(plugin.pluginKey);
|
||||
return !duplicate;
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines if a plugin is authenticated by checking if all required authentication fields have non-empty values.
|
||||
* Supports alternate authentication fields, allowing validation against multiple possible environment variables.
|
||||
*
|
||||
* @param {TPlugin} plugin The plugin object containing the authentication configuration.
|
||||
* @returns {boolean} True if the plugin is authenticated for all required fields, false otherwise.
|
||||
*/
|
||||
const checkPluginAuth = (plugin) => {
|
||||
if (!plugin.authConfig || plugin.authConfig.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return plugin.authConfig.every((authFieldObj) => {
|
||||
const authFieldOptions = authFieldObj.authField.split('||');
|
||||
let isFieldAuthenticated = false;
|
||||
|
||||
for (const fieldOption of authFieldOptions) {
|
||||
const envValue = process.env[fieldOption];
|
||||
if (envValue && envValue.trim() !== '' && envValue !== AuthType.USER_PROVIDED) {
|
||||
isFieldAuthenticated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return isFieldAuthenticated;
|
||||
});
|
||||
};
|
||||
|
||||
const getAvailablePluginsController = async (req, res) => {
|
||||
try {
|
||||
@@ -100,21 +139,15 @@ function createGetServerTools() {
|
||||
*/
|
||||
const getAvailableTools = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user?.id;
|
||||
const customConfig = await getCustomConfig();
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
const userPlugins = convertMCPToolsToPlugins({ functionTools: cachedUserTools, customConfig });
|
||||
|
||||
if (cachedToolsArray != null && userPlugins != null) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
res.status(200).json(dedupedTools);
|
||||
const cachedTools = await cache.get(CacheKeys.TOOLS);
|
||||
if (cachedTools) {
|
||||
res.status(200).json(cachedTools);
|
||||
return;
|
||||
}
|
||||
|
||||
// If not in cache, build from manifest
|
||||
let pluginManifest = availableTools;
|
||||
const customConfig = await getCustomConfig();
|
||||
if (customConfig?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
@@ -140,16 +173,14 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
|
||||
const toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
|
||||
const isToolkit =
|
||||
plugin.toolkit === true &&
|
||||
Object.keys(toolDefinitions).some(
|
||||
(key) => getToolkitKey({ toolkits, toolName: key }) === plugin.pluginKey,
|
||||
);
|
||||
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey);
|
||||
|
||||
if (!isToolDefined && !isToolkit) {
|
||||
continue;
|
||||
@@ -187,12 +218,10 @@ const getAvailableTools = async (req, res) => {
|
||||
|
||||
toolsOutput.push(toolToAdd);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...finalTools]);
|
||||
|
||||
res.status(200).json(dedupedTools);
|
||||
res.status(200).json(finalTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
res.status(500).json({ message: error.message });
|
||||
|
||||
@@ -1,520 +0,0 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
// Mock the dependencies
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomConfig: jest.fn(),
|
||||
getCachedTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ToolService', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(() => ({
|
||||
loadManifestTools: jest.fn().mockResolvedValue([]),
|
||||
})),
|
||||
getFlowStateManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/app/clients/tools', () => ({
|
||||
availableTools: [],
|
||||
toolkits: [],
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
checkPluginAuth: jest.fn(),
|
||||
filterUniquePlugins: jest.fn(),
|
||||
convertMCPToolsToPlugins: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import the actual module with the function we want to test
|
||||
const { getAvailableTools, getAvailablePluginsController } = require('./PluginController');
|
||||
const {
|
||||
filterUniquePlugins,
|
||||
checkPluginAuth,
|
||||
convertMCPToolsToPlugins,
|
||||
getToolkitKey,
|
||||
} = require('@librechat/api');
|
||||
|
||||
describe('PluginController', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = { user: { id: 'test-user-id' } };
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
});
|
||||
|
||||
describe('getAvailablePluginsController', () => {
|
||||
beforeEach(() => {
|
||||
mockReq.app = { locals: { filteredTools: [], includedTools: [] } };
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to remove duplicate plugins', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).toHaveBeenCalled();
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
// The response includes authenticated: true for each plugin when checkPluginAuth returns true
|
||||
expect(mockRes.json).toHaveBeenCalledWith([
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First', authenticated: true },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second', authenticated: true },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify plugin authentication', async () => {
|
||||
const mockPlugin = { name: 'Plugin1', pluginKey: 'key1', description: 'First' };
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValueOnce(true);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
});
|
||||
|
||||
it('should return cached plugins when available', async () => {
|
||||
const cachedPlugins = [
|
||||
{ name: 'CachedPlugin', pluginKey: 'cached', description: 'Cached plugin' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedPlugins);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).not.toHaveBeenCalled();
|
||||
expect(checkPluginAuth).not.toHaveBeenCalled();
|
||||
expect(mockRes.json).toHaveBeenCalledWith(cachedPlugins);
|
||||
});
|
||||
|
||||
it('should filter plugins based on includedTools', async () => {
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
mockReq.app.locals.includedTools = ['key1'];
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(1);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAvailableTools', () => {
|
||||
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}server1`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
},
|
||||
};
|
||||
const mockConvertedPlugins = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
},
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockConvertedPlugins);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: mockUserTools,
|
||||
customConfig: null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
|
||||
const mockUserPlugins = [
|
||||
{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' },
|
||||
];
|
||||
const mockManifestPlugins = [
|
||||
{ name: 'ManifestTool', pluginKey: 'manifest-tool', description: 'Manifest tool' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(mockManifestPlugins);
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockUserPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...mockUserPlugins, ...mockManifestPlugins]);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should be called to deduplicate the combined array
|
||||
expect(filterUniquePlugins).toHaveBeenLastCalledWith([
|
||||
...mockUserPlugins,
|
||||
...mockManifestPlugins,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify authentication status', async () => {
|
||||
const mockPlugin = { name: 'Tool1', pluginKey: 'tool1', description: 'Tool 1' };
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({ tool1: true });
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
});
|
||||
|
||||
it('should use getToolkitKey for toolkit validation', async () => {
|
||||
const mockToolkit = {
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
description: 'Toolkit 1',
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue('toolkit1');
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({
|
||||
toolkit1_function: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(getToolkitKey).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('plugin.icon behavior', () => {
|
||||
const callGetAvailableToolsWithMCPServer = async (mcpServers) => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue({ mcpServers });
|
||||
|
||||
const functionTools = {
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'test-tool', description: 'A test tool' },
|
||||
},
|
||||
};
|
||||
|
||||
const mockConvertedPlugin = {
|
||||
name: 'test-tool',
|
||||
pluginKey: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
icon: mcpServers['test-server']?.iconPath,
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
};
|
||||
|
||||
getCachedTools.mockResolvedValueOnce(functionTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockConvertedPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
return responseData.find((tool) => tool.name === 'test-tool');
|
||||
};
|
||||
|
||||
it('should set plugin.icon when iconPath is defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {
|
||||
iconPath: '/path/to/icon.png',
|
||||
},
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(mcpServers);
|
||||
expect(testTool.icon).toBe('/path/to/icon.png');
|
||||
});
|
||||
|
||||
it('should set plugin.icon to undefined when iconPath is not defined', async () => {
|
||||
const mcpServers = {
|
||||
'test-server': {},
|
||||
};
|
||||
const testTool = await callGetAvailableToolsWithMCPServer(mcpServers);
|
||||
expect(testTool.icon).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('helper function integration', () => {
|
||||
it('should properly handle MCP tools with custom user variables', async () => {
|
||||
const customConfig = {
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: { title: 'API Key', description: 'Your API key' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// We need to test the actual flow where MCP manager tools are included
|
||||
const mcpManagerTools = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
},
|
||||
];
|
||||
|
||||
// Mock the MCP manager to return tools
|
||||
const mockMCPManager = {
|
||||
loadManifestTools: jest.fn().mockResolvedValue(mcpManagerTools),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
|
||||
// First call returns user tools (empty in this case)
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock convertMCPToolsToPlugins to return empty array for user tools
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
|
||||
// Mock filterUniquePlugins to pass through
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
|
||||
// Mock checkPluginAuth
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
// Second call returns tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Find the MCP tool in the response
|
||||
const mcpTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
|
||||
// The actual implementation adds authConfig and sets authenticated to false when customUserVars exist
|
||||
expect(mcpTool).toBeDefined();
|
||||
expect(mcpTool.authConfig).toEqual([
|
||||
{ authField: 'API_KEY', label: 'API Key', description: 'Your API key' },
|
||||
]);
|
||||
expect(mcpTool.authenticated).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle error cases gracefully', async () => {
|
||||
mockCache.get.mockRejectedValue(new Error('Cache error'));
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ message: 'Cache error' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases with undefined/null values', () => {
|
||||
it('should handle undefined cache gracefully', async () => {
|
||||
getLogStores.mockReturnValue(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
});
|
||||
|
||||
it('should handle null cachedTools and cachedUserTools', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: null,
|
||||
customConfig: null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle when getCachedTools returns undefined', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(undefined);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
// Mock getCachedTools to return undefined for both calls
|
||||
getCachedTools.mockReset();
|
||||
getCachedTools.mockResolvedValueOnce(undefined).mockResolvedValueOnce(undefined);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: undefined,
|
||||
customConfig: null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
const userTools = {
|
||||
'user-tool': { function: { name: 'user-tool', description: 'User tool' } },
|
||||
};
|
||||
const userPlugins = [{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' }];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValue(userTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(userPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...userPlugins, ...cachedTools]);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([...userPlugins, ...cachedTools]);
|
||||
});
|
||||
|
||||
it('should handle empty toolDefinitions object', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// With empty tool definitions, no tools should be in the final output
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle MCP tools without customUserVars', async () => {
|
||||
const customConfig = {
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
// No customUserVars defined
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
},
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
const mockPlugin = {
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
};
|
||||
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
// The actual implementation doesn't set authConfig on tools without customUserVars
|
||||
expect(responseData[0].authConfig).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle req.app.locals with undefined filteredTools and includedTools', async () => {
|
||||
mockReq.app = { locals: {} };
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([]);
|
||||
});
|
||||
|
||||
it('should handle toolkit with undefined toolDefinitions keys', async () => {
|
||||
const mockToolkit = {
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
description: 'Toolkit 1',
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return null
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should handle null toolDefinitions gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,11 @@
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
FileSources,
|
||||
webSearchKeys,
|
||||
extractWebSearchEnvVars,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { webSearchKeys, extractWebSearchEnvVars, normalizeHttpError } = require('@librechat/api');
|
||||
const {
|
||||
getFiles,
|
||||
updateUser,
|
||||
@@ -14,7 +20,6 @@ const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/service
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
|
||||
const { Tools, Constants, FileSources } = require('librechat-data-provider');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { Transaction, Balance, User } = require('~/db/models');
|
||||
const { deleteToolCalls } = require('~/models/ToolCall');
|
||||
@@ -89,8 +94,8 @@ const updateUserPluginsController = async (req, res) => {
|
||||
|
||||
if (userPluginsService instanceof Error) {
|
||||
logger.error('[userPluginsService]', userPluginsService);
|
||||
const { status, message } = normalizeHttpError(userPluginsService);
|
||||
return res.status(status).send({ message });
|
||||
const { status, message } = userPluginsService;
|
||||
res.status(status).send({ message });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,7 +142,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
authService = await updateUserPluginAuth(user.id, keys[i], pluginKey, values[i]);
|
||||
if (authService instanceof Error) {
|
||||
logger.error('[authService]', authService);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
({ status, message } = authService);
|
||||
}
|
||||
}
|
||||
} else if (action === 'uninstall') {
|
||||
@@ -151,7 +156,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
`[authService] Error deleting all auth for MCP tool ${pluginKey}:`,
|
||||
authService,
|
||||
);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
({ status, message } = authService);
|
||||
}
|
||||
} else {
|
||||
// This handles:
|
||||
@@ -163,7 +168,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name
|
||||
if (authService instanceof Error) {
|
||||
logger.error('[authService] Error deleting specific auth key:', authService);
|
||||
({ status, message } = normalizeHttpError(authService));
|
||||
({ status, message } = authService);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,16 +180,14 @@ const updateUserPluginsController = async (req, res) => {
|
||||
try {
|
||||
const mcpManager = getMCPManager(user.id);
|
||||
if (mcpManager) {
|
||||
// Extract server name from pluginKey (format: "mcp_<serverName>")
|
||||
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
|
||||
logger.info(
|
||||
`[updateUserPluginsController] Disconnecting MCP server ${serverName} for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
`[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`,
|
||||
);
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
await mcpManager.disconnectUserConnections(user.id);
|
||||
}
|
||||
} catch (disconnectError) {
|
||||
logger.error(
|
||||
`[updateUserPluginsController] Error disconnecting MCP connection for user ${user.id} after plugin auth update:`,
|
||||
`[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`,
|
||||
disconnectError,
|
||||
);
|
||||
// Do not fail the request for this, but log it.
|
||||
@@ -193,8 +196,7 @@ const updateUserPluginsController = async (req, res) => {
|
||||
return res.status(status).send();
|
||||
}
|
||||
|
||||
const normalized = normalizeHttpError({ status, message });
|
||||
return res.status(normalized.status).send({ message: normalized.message });
|
||||
res.status(status).send({ message });
|
||||
} catch (err) {
|
||||
logger.error('[updateUserPluginsController]', err);
|
||||
return res.status(500).json({ message: 'Something went wrong.' });
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
const { duplicateAgent } = require('../v1');
|
||||
const { getAgent, createAgent } = require('~/models/Agent');
|
||||
const { getActions } = require('~/models/Action');
|
||||
const { nanoid } = require('nanoid');
|
||||
|
||||
jest.mock('~/models/Agent');
|
||||
jest.mock('~/models/Action');
|
||||
jest.mock('nanoid');
|
||||
|
||||
describe('duplicateAgent', () => {
|
||||
let req, res;
|
||||
|
||||
beforeEach(() => {
|
||||
req = {
|
||||
params: { id: 'agent_123' },
|
||||
user: { id: 'user_456' },
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should duplicate an agent successfully', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
author: 'user_789',
|
||||
versions: [{ name: 'Test Agent', version: 1 }],
|
||||
__v: 0,
|
||||
};
|
||||
|
||||
const mockNewAgent = {
|
||||
id: 'agent_new_123',
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
author: 'user_456',
|
||||
versions: [
|
||||
{
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
getAgent.mockResolvedValue(mockAgent);
|
||||
getActions.mockResolvedValue([]);
|
||||
nanoid.mockReturnValue('new_123');
|
||||
createAgent.mockResolvedValue(mockNewAgent);
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
|
||||
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'agent_new_123',
|
||||
author: 'user_456',
|
||||
name: expect.stringContaining('Test Agent ('),
|
||||
description: 'Test Description',
|
||||
instructions: 'Test Instructions',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['file_search'],
|
||||
actions: [],
|
||||
}),
|
||||
);
|
||||
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.not.objectContaining({
|
||||
versions: expect.anything(),
|
||||
__v: expect.anything(),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(201);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
agent: mockNewAgent,
|
||||
actions: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
description: 'Test Description',
|
||||
versions: [
|
||||
{
|
||||
name: 'Test Agent',
|
||||
versions: [{ name: 'Nested' }],
|
||||
__v: 1,
|
||||
},
|
||||
],
|
||||
__v: 2,
|
||||
};
|
||||
|
||||
const mockNewAgent = {
|
||||
id: 'agent_new_123',
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
versions: [
|
||||
{
|
||||
name: 'Test Agent (1/2/23, 12:34)',
|
||||
description: 'Test Description',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
getAgent.mockResolvedValue(mockAgent);
|
||||
getActions.mockResolvedValue([]);
|
||||
nanoid.mockReturnValue('new_123');
|
||||
createAgent.mockResolvedValue(mockNewAgent);
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(mockNewAgent.versions).toHaveLength(1);
|
||||
|
||||
const firstVersion = mockNewAgent.versions[0];
|
||||
expect(firstVersion).not.toHaveProperty('versions');
|
||||
expect(firstVersion).not.toHaveProperty('__v');
|
||||
|
||||
expect(mockNewAgent).not.toHaveProperty('__v');
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(201);
|
||||
});
|
||||
|
||||
it('should return 404 if agent not found', async () => {
|
||||
getAgent.mockResolvedValue(null);
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Agent not found',
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle tool_resources.ocr correctly', async () => {
|
||||
const mockAgent = {
|
||||
id: 'agent_123',
|
||||
name: 'Test Agent',
|
||||
tool_resources: {
|
||||
ocr: { enabled: true, config: 'test' },
|
||||
other: { should: 'not be copied' },
|
||||
},
|
||||
};
|
||||
|
||||
getAgent.mockResolvedValue(mockAgent);
|
||||
getActions.mockResolvedValue([]);
|
||||
nanoid.mockReturnValue('new_123');
|
||||
createAgent.mockResolvedValue({ id: 'agent_new_123' });
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(createAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tool_resources: {
|
||||
ocr: { enabled: true, config: 'test' },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
getAgent.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
await duplicateAgent(req, res);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
|
||||
});
|
||||
});
|
||||
@@ -1,23 +1,18 @@
|
||||
require('events').EventEmitter.defaultMaxListeners = 100;
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
sendEvent,
|
||||
createRun,
|
||||
Tokenizer,
|
||||
checkAccess,
|
||||
memoryInstructions,
|
||||
formatContentStrings,
|
||||
createMemoryProcessor,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Callback,
|
||||
Providers,
|
||||
GraphEvents,
|
||||
TitleMethod,
|
||||
formatMessage,
|
||||
formatAgentMessages,
|
||||
formatContentStrings,
|
||||
getTokenCountForMessage,
|
||||
createMetadataAggregator,
|
||||
} = require('@librechat/agents');
|
||||
@@ -27,41 +22,31 @@ const {
|
||||
VisionModes,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
KnownEndpoints,
|
||||
PermissionTypes,
|
||||
isAgentsEndpoint,
|
||||
AgentCapabilities,
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { DynamicStructuredTool } = require('@langchain/core/tools');
|
||||
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
|
||||
const {
|
||||
findPluginAuthsByKeys,
|
||||
getFormattedMemories,
|
||||
deleteMemory,
|
||||
setMemory,
|
||||
} = require('~/models');
|
||||
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
|
||||
getCustomEndpointConfig,
|
||||
createGetMCPAuthMap,
|
||||
checkCapability,
|
||||
} = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const { checkAccess } = require('~/server/middleware/roles/access');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { getMCPManager } = require('~/config');
|
||||
|
||||
const omitTitleOptions = new Set([
|
||||
'stream',
|
||||
'thinking',
|
||||
'streaming',
|
||||
'clientOptions',
|
||||
'thinkingConfig',
|
||||
'thinkingBudget',
|
||||
'includeThoughts',
|
||||
'maxOutputTokens',
|
||||
'additionalModelRequestFields',
|
||||
]);
|
||||
|
||||
/**
|
||||
* @param {ServerRequest} req
|
||||
* @param {Agent} agent
|
||||
@@ -71,15 +56,13 @@ const payloadParser = ({ req, agent, endpoint }) => {
|
||||
if (isAgentsEndpoint(endpoint)) {
|
||||
return { model: undefined };
|
||||
} else if (endpoint === EModelEndpoint.bedrock) {
|
||||
const parsedValues = bedrockInputSchema.parse(agent.model_parameters);
|
||||
if (parsedValues.thinking == null) {
|
||||
parsedValues.thinking = false;
|
||||
}
|
||||
return parsedValues;
|
||||
return bedrockInputSchema.parse(agent.model_parameters);
|
||||
}
|
||||
return req.body.endpointOption.model_parameters;
|
||||
};
|
||||
|
||||
const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]);
|
||||
|
||||
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
|
||||
|
||||
function createTokenCounter(encoding) {
|
||||
@@ -402,34 +385,6 @@ class AgentClient extends BaseClient {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a promise that resolves with the memory promise result or undefined after a timeout
|
||||
* @param {Promise<(TAttachment | null)[] | undefined>} memoryPromise - The memory promise to await
|
||||
* @param {number} timeoutMs - Timeout in milliseconds (default: 3000)
|
||||
* @returns {Promise<(TAttachment | null)[] | undefined>}
|
||||
*/
|
||||
async awaitMemoryWithTimeout(memoryPromise, timeoutMs = 3000) {
|
||||
if (!memoryPromise) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Memory processing timeout')), timeoutMs),
|
||||
);
|
||||
|
||||
const attachments = await Promise.race([memoryPromise, timeoutPromise]);
|
||||
return attachments;
|
||||
} catch (error) {
|
||||
if (error.message === 'Memory processing timeout') {
|
||||
logger.warn('[AgentClient] Memory processing timed out after 3 seconds');
|
||||
} else {
|
||||
logger.error('[AgentClient] Error processing memory:', error);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<string | undefined>}
|
||||
*/
|
||||
@@ -438,12 +393,7 @@ class AgentClient extends BaseClient {
|
||||
if (user.personalization?.memories === false) {
|
||||
return;
|
||||
}
|
||||
const hasAccess = await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
|
||||
|
||||
if (!hasAccess) {
|
||||
logger.debug(
|
||||
@@ -488,12 +438,6 @@ class AgentClient extends BaseClient {
|
||||
res: this.options.res,
|
||||
agent: prelimAgent,
|
||||
allowedProviders,
|
||||
endpointOption: {
|
||||
endpoint:
|
||||
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
|
||||
? EModelEndpoint.agents
|
||||
: memoryConfig.agent?.provider,
|
||||
},
|
||||
});
|
||||
|
||||
if (!agent) {
|
||||
@@ -540,39 +484,6 @@ class AgentClient extends BaseClient {
|
||||
return withoutKeys;
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters out image URLs from message content
|
||||
* @param {BaseMessage} message - The message to filter
|
||||
* @returns {BaseMessage} - A new message with image URLs removed
|
||||
*/
|
||||
filterImageUrls(message) {
|
||||
if (!message.content || typeof message.content === 'string') {
|
||||
return message;
|
||||
}
|
||||
|
||||
if (Array.isArray(message.content)) {
|
||||
const filteredContent = message.content.filter(
|
||||
(part) => part.type !== ContentTypes.IMAGE_URL,
|
||||
);
|
||||
|
||||
if (filteredContent.length === 1 && filteredContent[0].type === ContentTypes.TEXT) {
|
||||
const MessageClass = message.constructor;
|
||||
return new MessageClass({
|
||||
content: filteredContent[0].text,
|
||||
additional_kwargs: message.additional_kwargs,
|
||||
});
|
||||
}
|
||||
|
||||
const MessageClass = message.constructor;
|
||||
return new MessageClass({
|
||||
content: filteredContent,
|
||||
additional_kwargs: message.additional_kwargs,
|
||||
});
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {BaseMessage[]} messages
|
||||
* @returns {Promise<void | (TAttachment | null)[]>}
|
||||
@@ -600,11 +511,7 @@ class AgentClient extends BaseClient {
|
||||
messagesToProcess = [...messages.slice(-messageWindowSize)];
|
||||
}
|
||||
}
|
||||
|
||||
const filteredMessages = messagesToProcess.map((msg) => this.filterImageUrls(msg));
|
||||
const bufferString = getBufferString(filteredMessages);
|
||||
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
|
||||
return await this.processMemory([bufferMessage]);
|
||||
return await this.processMemory(messagesToProcess);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
@@ -770,18 +677,23 @@ class AgentClient extends BaseClient {
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
user: this.options.req.user,
|
||||
},
|
||||
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
|
||||
recursionLimit: agentsEConfig?.recursionLimit,
|
||||
signal: abortController.signal,
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
};
|
||||
|
||||
const getUserMCPAuthMap = await createGetMCPAuthMap();
|
||||
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
this.indexTokenCountMap,
|
||||
toolSet,
|
||||
);
|
||||
if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) {
|
||||
initialMessages = formatContentStrings(initialMessages);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -797,9 +709,6 @@ class AgentClient extends BaseClient {
|
||||
if (i > 0) {
|
||||
this.model = agent.model_parameters.model;
|
||||
}
|
||||
if (i > 0 && config.signal == null) {
|
||||
config.signal = abortController.signal;
|
||||
}
|
||||
if (agent.recursion_limit && typeof agent.recursion_limit === 'number') {
|
||||
config.recursionLimit = agent.recursion_limit;
|
||||
}
|
||||
@@ -848,9 +757,6 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
let messages = _messages;
|
||||
if (agent.useLegacyContent === true) {
|
||||
messages = formatContentStrings(messages);
|
||||
}
|
||||
if (
|
||||
agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes(
|
||||
'prompt-caching',
|
||||
@@ -899,11 +805,10 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
|
||||
try {
|
||||
if (await hasCustomUserVars()) {
|
||||
config.configurable.userMCPAuthMap = await getMCPAuthMap({
|
||||
if (getUserMCPAuthMap) {
|
||||
config.configurable.userMCPAuthMap = await getUserMCPAuthMap({
|
||||
tools: agent.tools,
|
||||
userId: this.options.req.user.id,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -1030,9 +935,11 @@ class AgentClient extends BaseClient {
|
||||
});
|
||||
|
||||
try {
|
||||
const attachments = await this.awaitMemoryWithTimeout(memoryPromise);
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
} catch (err) {
|
||||
@@ -1042,9 +949,11 @@ class AgentClient extends BaseClient {
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
const attachments = await this.awaitMemoryWithTimeout(memoryPromise);
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
if (memoryPromise) {
|
||||
const attachments = await memoryPromise;
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
}
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
|
||||
@@ -1074,41 +983,23 @@ class AgentClient extends BaseClient {
|
||||
throw new Error('Run not initialized');
|
||||
}
|
||||
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
|
||||
const { req, res, agent } = this.options;
|
||||
let endpoint = agent.endpoint;
|
||||
|
||||
const endpoint = this.options.agent.endpoint;
|
||||
const { req, res } = this.options;
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
let clientOptions = {
|
||||
maxTokens: 75,
|
||||
model: agent.model || agent.model_parameters.model,
|
||||
};
|
||||
|
||||
let titleProviderConfig = await getProviderConfig(endpoint);
|
||||
|
||||
/** @type {TEndpoint | undefined} */
|
||||
const endpointConfig =
|
||||
req.app.locals.all ?? req.app.locals[endpoint] ?? titleProviderConfig.customEndpointConfig;
|
||||
let endpointConfig = req.app.locals[endpoint];
|
||||
if (!endpointConfig) {
|
||||
logger.warn(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
|
||||
);
|
||||
}
|
||||
|
||||
if (endpointConfig?.titleEndpoint && endpointConfig.titleEndpoint !== endpoint) {
|
||||
try {
|
||||
titleProviderConfig = await getProviderConfig(endpointConfig.titleEndpoint);
|
||||
endpoint = endpointConfig.titleEndpoint;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[api/server/controllers/agents/client.js #titleConvo] Error getting title endpoint config for ${endpointConfig.titleEndpoint}, falling back to default`,
|
||||
error,
|
||||
endpointConfig = await getCustomEndpointConfig(endpoint);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
|
||||
err,
|
||||
);
|
||||
// Fall back to original provider config
|
||||
endpoint = agent.endpoint;
|
||||
titleProviderConfig = await getProviderConfig(endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
endpointConfig &&
|
||||
endpointConfig.titleModel &&
|
||||
@@ -1116,71 +1007,33 @@ class AgentClient extends BaseClient {
|
||||
) {
|
||||
clientOptions.model = endpointConfig.titleModel;
|
||||
}
|
||||
|
||||
const options = await titleProviderConfig.getOptions({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
overrideEndpoint: endpoint,
|
||||
overrideModel: clientOptions.model,
|
||||
endpointOption: { model_parameters: clientOptions },
|
||||
});
|
||||
|
||||
let provider = options.provider ?? titleProviderConfig.overrideProvider ?? agent.provider;
|
||||
if (
|
||||
endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName == null
|
||||
clientOptions.model &&
|
||||
this.options.agent.model_parameters.model !== clientOptions.model
|
||||
) {
|
||||
provider = Providers.OPENAI;
|
||||
} else if (
|
||||
endpoint === EModelEndpoint.azureOpenAI &&
|
||||
options.llmConfig?.azureOpenAIApiInstanceName != null &&
|
||||
provider !== Providers.AZURE
|
||||
) {
|
||||
provider = Providers.AZURE;
|
||||
clientOptions =
|
||||
(
|
||||
await initOpenAI({
|
||||
req,
|
||||
res,
|
||||
optionsOnly: true,
|
||||
overrideModel: clientOptions.model,
|
||||
overrideEndpoint: endpoint,
|
||||
endpointOption: {
|
||||
model_parameters: clientOptions,
|
||||
},
|
||||
})
|
||||
)?.llmConfig ?? clientOptions;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
clientOptions = { ...options.llmConfig };
|
||||
if (options.configOptions) {
|
||||
clientOptions.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
const shouldRemoveMaxTokens = /\b(o\d|gpt-[5-9])\b/i.test(clientOptions.model);
|
||||
if (shouldRemoveMaxTokens && clientOptions.maxTokens != null) {
|
||||
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
|
||||
delete clientOptions.maxTokens;
|
||||
} else if (!shouldRemoveMaxTokens && !clientOptions.maxTokens) {
|
||||
clientOptions.maxTokens = 75;
|
||||
}
|
||||
if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_completion_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_completion_tokens;
|
||||
} else if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_output_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_output_tokens;
|
||||
}
|
||||
|
||||
clientOptions = Object.assign(
|
||||
Object.fromEntries(
|
||||
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
|
||||
),
|
||||
);
|
||||
|
||||
if (
|
||||
provider === Providers.GOOGLE &&
|
||||
(endpointConfig?.titleMethod === TitleMethod.FUNCTIONS ||
|
||||
endpointConfig?.titleMethod === TitleMethod.STRUCTURED)
|
||||
) {
|
||||
clientOptions.json = true;
|
||||
}
|
||||
|
||||
try {
|
||||
const titleResult = await this.run.generateTitle({
|
||||
provider,
|
||||
clientOptions,
|
||||
inputText: text,
|
||||
contentParts: this.contentParts,
|
||||
titleMethod: endpointConfig?.titleMethod,
|
||||
titlePrompt: endpointConfig?.titlePrompt,
|
||||
titlePromptTemplate: endpointConfig?.titlePromptTemplate,
|
||||
clientOptions,
|
||||
chainOptions: {
|
||||
signal: abortController.signal,
|
||||
callbacks: [
|
||||
@@ -1195,10 +1048,8 @@ class AgentClient extends BaseClient {
|
||||
let input_tokens, output_tokens;
|
||||
|
||||
if (item.usage) {
|
||||
input_tokens =
|
||||
item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens;
|
||||
output_tokens =
|
||||
item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens;
|
||||
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
|
||||
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
|
||||
} else if (item.tokenUsage) {
|
||||
input_tokens = item.tokenUsage.promptTokens;
|
||||
output_tokens = item.tokenUsage.completionTokens;
|
||||
@@ -1228,52 +1079,8 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {number} params.promptTokens
|
||||
* @param {number} params.completionTokens
|
||||
* @param {OpenAIUsageMetadata} [params.usage]
|
||||
* @param {string} [params.model]
|
||||
* @param {string} [params.context='message']
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ model, promptTokens, completionTokens, usage, context = 'message' }) {
|
||||
try {
|
||||
await spendTokens(
|
||||
{
|
||||
model,
|
||||
context,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
if (
|
||||
usage &&
|
||||
typeof usage === 'object' &&
|
||||
'reasoning_tokens' in usage &&
|
||||
typeof usage.reasoning_tokens === 'number'
|
||||
) {
|
||||
await spendTokens(
|
||||
{
|
||||
model,
|
||||
context: 'reasoning',
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ completionTokens: usage.reasoning_tokens },
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #recordTokenUsage] Error recording token usage',
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
/** Silent method, as `recordCollectedUsage` is used instead */
|
||||
async recordTokenUsage() {}
|
||||
|
||||
getEncoding() {
|
||||
return 'o200k_base';
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,10 @@
|
||||
// errorHandler.js
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { recordUsage } = require('~/server/services/Threads');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sendResponse } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
@@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === 'azureAssistants'
|
||||
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
@@ -105,6 +105,8 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
// const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
// logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||
}
|
||||
@@ -113,6 +115,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
|
||||
let run;
|
||||
try {
|
||||
// run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
@@ -125,9 +128,18 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
|
||||
let finalEvent;
|
||||
try {
|
||||
// const errorContentPart = {
|
||||
// text: {
|
||||
// value:
|
||||
// error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||
// },
|
||||
// type: ContentTypes.ERROR,
|
||||
// };
|
||||
|
||||
finalEvent = {
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
// runMessages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error finalizing error process`, error);
|
||||
|
||||
106
api/server/controllers/agents/llm.js
Normal file
106
api/server/controllers/agents/llm.js
Normal file
@@ -0,0 +1,106 @@
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { resolveHeaders } = require('librechat-data-provider');
|
||||
const { createLLM } = require('~/app/clients/llm');
|
||||
|
||||
/**
|
||||
* Initializes and returns a Language Learning Model (LLM) instance.
|
||||
*
|
||||
* @param {Object} options - Configuration options for the LLM.
|
||||
* @param {string} options.model - The model identifier.
|
||||
* @param {string} options.modelName - The specific name of the model.
|
||||
* @param {number} options.temperature - The temperature setting for the model.
|
||||
* @param {number} options.presence_penalty - The presence penalty for the model.
|
||||
* @param {number} options.frequency_penalty - The frequency penalty for the model.
|
||||
* @param {number} options.max_tokens - The maximum number of tokens for the model output.
|
||||
* @param {boolean} options.streaming - Whether to use streaming for the model output.
|
||||
* @param {Object} options.context - The context for the conversation.
|
||||
* @param {number} options.tokenBuffer - The token buffer size.
|
||||
* @param {number} options.initialMessageCount - The initial message count.
|
||||
* @param {string} options.conversationId - The ID of the conversation.
|
||||
* @param {string} options.user - The user identifier.
|
||||
* @param {string} options.langchainProxy - The langchain proxy URL.
|
||||
* @param {boolean} options.useOpenRouter - Whether to use OpenRouter.
|
||||
* @param {Object} options.options - Additional options.
|
||||
* @param {Object} options.options.headers - Custom headers for the request.
|
||||
* @param {string} options.options.proxy - Proxy URL.
|
||||
* @param {Object} options.options.req - The request object.
|
||||
* @param {Object} options.options.res - The response object.
|
||||
* @param {boolean} options.options.debug - Whether to enable debug mode.
|
||||
* @param {string} options.apiKey - The API key for authentication.
|
||||
* @param {Object} options.azure - Azure-specific configuration.
|
||||
* @param {Object} options.abortController - The AbortController instance.
|
||||
* @returns {Object} The initialized LLM instance.
|
||||
*/
|
||||
function initializeLLM(options) {
|
||||
const {
|
||||
model,
|
||||
modelName,
|
||||
temperature,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
max_tokens,
|
||||
streaming,
|
||||
user,
|
||||
langchainProxy,
|
||||
useOpenRouter,
|
||||
options: { headers, proxy },
|
||||
apiKey,
|
||||
azure,
|
||||
} = options;
|
||||
|
||||
const modelOptions = {
|
||||
modelName: modelName || model,
|
||||
temperature,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
user,
|
||||
};
|
||||
|
||||
if (max_tokens) {
|
||||
modelOptions.max_tokens = max_tokens;
|
||||
}
|
||||
|
||||
const configOptions = {};
|
||||
|
||||
if (langchainProxy) {
|
||||
configOptions.basePath = langchainProxy;
|
||||
}
|
||||
|
||||
if (useOpenRouter) {
|
||||
configOptions.basePath = 'https://openrouter.ai/api/v1';
|
||||
configOptions.baseOptions = {
|
||||
headers: {
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
|
||||
configOptions.baseOptions = {
|
||||
headers: resolveHeaders({
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
if (proxy) {
|
||||
configOptions.httpAgent = new HttpsProxyAgent(proxy);
|
||||
configOptions.httpsAgent = new HttpsProxyAgent(proxy);
|
||||
}
|
||||
|
||||
const llm = createLLM({
|
||||
modelOptions,
|
||||
configOptions,
|
||||
openAIApiKey: apiKey,
|
||||
azure,
|
||||
streaming,
|
||||
});
|
||||
|
||||
return llm;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
initializeLLM,
|
||||
};
|
||||
@@ -1,5 +1,3 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbortError,
|
||||
@@ -7,19 +5,17 @@ const {
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
isRegenerate,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
isContinued = false,
|
||||
editedContent = null,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
responseMessageId: editedResponseMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let sender;
|
||||
@@ -71,7 +67,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
handler();
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('[AgentController] Error in cleanup handler', e);
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,7 +155,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
try {
|
||||
res.removeListener('close', closeHandler);
|
||||
} catch (e) {
|
||||
logger.error('[AgentController] Error removing close listener', e);
|
||||
// Ignore
|
||||
}
|
||||
});
|
||||
|
||||
@@ -167,15 +163,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
user: userId,
|
||||
onStart,
|
||||
getReqData,
|
||||
isContinued,
|
||||
isRegenerate,
|
||||
editedContent,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res,
|
||||
},
|
||||
@@ -215,7 +206,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
// Create a new response object with minimal copies
|
||||
const finalResponse = { ...response };
|
||||
|
||||
sendEvent(res, {
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
@@ -233,26 +224,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
);
|
||||
}
|
||||
}
|
||||
// Edge case: sendMessage completed but abort happened during sendCompletion
|
||||
// We need to ensure a final event is sent
|
||||
else if (!res.headersSent && !res.finished) {
|
||||
logger.debug(
|
||||
'[AgentController] Handling edge case: `sendMessage` completed but aborted during `sendCompletion`',
|
||||
);
|
||||
|
||||
const finalResponse = { ...response };
|
||||
finalResponse.error = true;
|
||||
|
||||
sendEvent(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: finalResponse,
|
||||
error: { message: 'Request was aborted during completion' },
|
||||
});
|
||||
res.end();
|
||||
}
|
||||
|
||||
// Save user message if needed
|
||||
if (!client.skipSaveUserMessage) {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
const { z } = require('zod');
|
||||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
@@ -10,7 +8,6 @@ const {
|
||||
SystemRoles,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getAgent,
|
||||
@@ -33,7 +30,6 @@ const { deleteFileByFilter } = require('~/models/File');
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
[Tools.file_search]: true,
|
||||
[Tools.web_search]: true,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -46,13 +42,9 @@ const systemTools = {
|
||||
*/
|
||||
const createAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const validatedData = agentCreateSchema.parse(req.body);
|
||||
const { tools = [], ...agentData } = removeNullishValues(validatedData);
|
||||
|
||||
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
|
||||
const { id: userId } = req.user;
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
agentData.author = userId;
|
||||
agentData.tools = [];
|
||||
|
||||
const availableTools = await getCachedTools({ includeGlobal: true });
|
||||
@@ -66,13 +58,19 @@ const createAgentHandler = async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
Object.assign(agentData, {
|
||||
author: userId,
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
const agent = await createAgent(agentData);
|
||||
res.status(201).json(agent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
logger.error('[/Agents] Error creating agent', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
@@ -156,16 +154,14 @@ const getAgentHandler = async (req, res) => {
|
||||
const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const validatedData = agentUpdateSchema.parse(req.body);
|
||||
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
|
||||
const { projectIds, removeProjectIds, ...updateData } = req.body;
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id });
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
@@ -194,9 +190,6 @@ const updateAgentHandler = async (req, res) => {
|
||||
});
|
||||
}
|
||||
|
||||
// Add version count to the response
|
||||
updatedAgent.version = updatedAgent.versions ? updatedAgent.versions.length : 0;
|
||||
|
||||
if (updatedAgent.author) {
|
||||
updatedAgent.author = updatedAgent.author.toString();
|
||||
}
|
||||
@@ -207,11 +200,6 @@ const updateAgentHandler = async (req, res) => {
|
||||
|
||||
return res.json(updatedAgent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error('[/Agents/:id] Validation error', error.errors);
|
||||
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
|
||||
}
|
||||
|
||||
logger.error('[/Agents/:id] Error updating Agent', error);
|
||||
|
||||
if (error.statusCode === 409) {
|
||||
@@ -254,8 +242,6 @@ const duplicateAgentHandler = async (req, res) => {
|
||||
createdAt: _createdAt,
|
||||
updatedAt: _updatedAt,
|
||||
tool_resources: _tool_resources = {},
|
||||
versions: _versions,
|
||||
__v: _v,
|
||||
...cloneData
|
||||
} = agent;
|
||||
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
|
||||
@@ -394,22 +380,6 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
return res.status(400).json({ message: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const existingAgent = await getAgent({ id: agent_id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
}
|
||||
|
||||
const buffer = await fs.readFile(req.file.path);
|
||||
|
||||
const fileStrategy = req.app.locals.fileStrategy;
|
||||
@@ -432,7 +402,14 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
source: fileStrategy,
|
||||
};
|
||||
|
||||
let _avatar = existingAgent.avatar;
|
||||
let _avatar;
|
||||
try {
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
_avatar = agent.avatar;
|
||||
} catch (error) {
|
||||
logger.error('[/:agent_id/avatar] Error fetching agent', error);
|
||||
_avatar = {};
|
||||
}
|
||||
|
||||
if (_avatar && _avatar.source) {
|
||||
const { deleteFile } = getStrategyFunctions(_avatar.source);
|
||||
@@ -454,7 +431,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
};
|
||||
|
||||
promises.push(
|
||||
await updateAgent({ id: agent_id }, data, {
|
||||
await updateAgent({ id: agent_id, author: req.user.id }, data, {
|
||||
updatingUserId: req.user.id,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -1,681 +0,0 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
|
||||
// Only mock the dependencies that are not database-related
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn().mockResolvedValue({
|
||||
web_search: true,
|
||||
execute_code: true,
|
||||
file_search: true,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Project', () => ({
|
||||
getProjectByName: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/images/avatar', () => ({
|
||||
resizeAvatar: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3Url: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
filterFile: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Action', () => ({
|
||||
updateAction: jest.fn(),
|
||||
getActions: jest.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/File', () => ({
|
||||
deleteFileByFilter: jest.fn(),
|
||||
}));
|
||||
|
||||
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
*/
|
||||
let Agent;
|
||||
|
||||
describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
let mongoServer;
|
||||
let mockReq;
|
||||
let mockRes;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
|
||||
// Reset all mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Setup mock request and response objects
|
||||
mockReq = {
|
||||
user: {
|
||||
id: new mongoose.Types.ObjectId().toString(),
|
||||
role: 'USER',
|
||||
},
|
||||
body: {},
|
||||
params: {},
|
||||
app: {
|
||||
locals: {
|
||||
fileStrategy: 'local',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockRes = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('createAgentHandler', () => {
|
||||
test('should create agent with allowed fields only', async () => {
|
||||
const validData = {
|
||||
name: 'Test Agent',
|
||||
description: 'A test agent',
|
||||
instructions: 'Be helpful',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
tools: ['web_search'],
|
||||
model_parameters: { temperature: 0.7 },
|
||||
tool_resources: {
|
||||
file_search: { file_ids: ['file1', 'file2'] },
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = validData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.name).toBe('Test Agent');
|
||||
expect(createdAgent.description).toBe('A test agent');
|
||||
expect(createdAgent.provider).toBe('openai');
|
||||
expect(createdAgent.model).toBe('gpt-4');
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
|
||||
expect(createdAgent.tools).toContain('web_search');
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb).toBeDefined();
|
||||
expect(agentInDb.name).toBe('Test Agent');
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
});
|
||||
|
||||
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
|
||||
const maliciousData = {
|
||||
// Required fields
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Malicious Agent',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
isCollaborative: true, // Should be stripped on creation
|
||||
versions: [], // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
id: 'custom_agent_id', // Should be overridden
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
mockReq.body = maliciousData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not set
|
||||
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
|
||||
expect(createdAgent.authorName).toBeUndefined();
|
||||
expect(createdAgent.isCollaborative).toBeFalsy();
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
|
||||
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
|
||||
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
|
||||
|
||||
// Verify timestamps are recent (not the malicious dates)
|
||||
const createdTime = new Date(createdAgent.createdAt).getTime();
|
||||
const now = Date.now();
|
||||
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
|
||||
expect(agentInDb.authorName).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should validate required fields', async () => {
|
||||
const invalidData = {
|
||||
name: 'Missing Required Fields',
|
||||
// Missing provider and model
|
||||
};
|
||||
|
||||
mockReq.body = invalidData;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify nothing was created in database
|
||||
const count = await Agent.countDocuments();
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
test('should handle tool_resources validation', async () => {
|
||||
const dataWithInvalidToolResources = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Tool Resources',
|
||||
tool_resources: {
|
||||
// Valid resources
|
||||
file_search: {
|
||||
file_ids: ['file1', 'file2'],
|
||||
vector_store_ids: ['vs1'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['file3'],
|
||||
},
|
||||
// Invalid resource (should be stripped by schema)
|
||||
invalid_resource: {
|
||||
file_ids: ['file4'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidToolResources;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.tool_resources).toBeDefined();
|
||||
expect(createdAgent.tool_resources.file_search).toBeDefined();
|
||||
expect(createdAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle avatar validation', async () => {
|
||||
const dataWithAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Avatar',
|
||||
avatar: {
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.avatar).toEqual({
|
||||
filepath: 'https://example.com/avatar.png',
|
||||
source: 's3',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle invalid avatar format', async () => {
|
||||
const dataWithInvalidAvatar = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Invalid Avatar',
|
||||
avatar: 'just-a-string', // Invalid format
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidAvatar;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateAgentHandler', () => {
|
||||
let existingAgentId;
|
||||
let existingAgentAuthorId;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create an existing agent for update tests
|
||||
existingAgentAuthorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
author: existingAgentAuthorId,
|
||||
description: 'Original description',
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Original Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-3.5-turbo',
|
||||
description: 'Original description',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
existingAgentId = agent.id;
|
||||
});
|
||||
|
||||
test('should update agent with allowed fields only', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Agent',
|
||||
description: 'Updated description',
|
||||
model: 'gpt-4',
|
||||
isCollaborative: true, // This IS allowed in updates
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(400);
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Updated Agent');
|
||||
expect(updatedAgent.description).toBe('Updated description');
|
||||
expect(updatedAgent.model).toBe('gpt-4');
|
||||
expect(updatedAgent.isCollaborative).toBe(true);
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Updated Agent');
|
||||
expect(agentInDb.isCollaborative).toBe(true);
|
||||
});
|
||||
|
||||
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated Name',
|
||||
|
||||
// Unauthorized fields that should be stripped
|
||||
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
|
||||
authorName: 'Hacker', // Should be stripped
|
||||
id: 'different_agent_id', // Should be stripped
|
||||
_id: new mongoose.Types.ObjectId(), // Should be stripped
|
||||
versions: [], // Should be stripped
|
||||
createdAt: new Date('2020-01-01'), // Should be stripped
|
||||
updatedAt: new Date('2020-01-01'), // Should be stripped
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unauthorized fields were not changed
|
||||
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
|
||||
expect(updatedAgent.authorName).toBeUndefined();
|
||||
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
|
||||
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
|
||||
expect(agentInDb.id).toBe(existingAgentId);
|
||||
});
|
||||
|
||||
test('should reject update from non-author when not collaborative', async () => {
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Unauthorized Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
|
||||
// Verify agent was not modified in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Original Agent');
|
||||
});
|
||||
|
||||
test('should allow update from non-author when collaborative', async () => {
|
||||
// First make the agent collaborative
|
||||
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
|
||||
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Collaborative Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Collaborative Update');
|
||||
// Author field should be removed for non-author
|
||||
expect(updatedAgent.author).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Collaborative Update');
|
||||
});
|
||||
|
||||
test('should allow admin to update any agent', async () => {
|
||||
const adminUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = adminUserId;
|
||||
mockReq.user.role = 'ADMIN'; // Set as admin
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Admin Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Admin Update');
|
||||
});
|
||||
|
||||
test('should handle projectIds updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
|
||||
const projectId1 = new mongoose.Types.ObjectId().toString();
|
||||
const projectId2 = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
mockReq.body = {
|
||||
projectIds: [projectId1, projectId2],
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent).toBeDefined();
|
||||
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
|
||||
});
|
||||
|
||||
test('should validate tool_resources in updates', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
tool_resources: {
|
||||
ocr: {
|
||||
file_ids: ['ocr1', 'ocr2'],
|
||||
},
|
||||
execute_code: {
|
||||
file_ids: ['img1'],
|
||||
},
|
||||
// Invalid tool resource
|
||||
invalid_tool: {
|
||||
file_ids: ['invalid'],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.tool_resources).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.ocr).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
|
||||
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return 404 for non-existent agent', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
|
||||
mockReq.body = {
|
||||
name: 'Update Non-existent',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(404);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
|
||||
});
|
||||
|
||||
test('should include version field in update response', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Updated with Version Check',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify version field is included and is a number
|
||||
expect(updatedAgent).toHaveProperty('version');
|
||||
expect(typeof updatedAgent.version).toBe('number');
|
||||
expect(updatedAgent.version).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(updatedAgent.version).toBe(agentInDb.versions.length);
|
||||
});
|
||||
|
||||
test('should handle validation errors properly', async () => {
|
||||
mockReq.user.id = existingAgentAuthorId.toString();
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
model_parameters: 'invalid-not-an-object', // Should be an object
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.any(Array),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mass Assignment Attack Scenarios', () => {
|
||||
test('should prevent setting system fields during creation', async () => {
|
||||
const systemFields = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'System Fields Test',
|
||||
|
||||
// System fields that should never be settable by users
|
||||
__v: 99,
|
||||
_id: new mongoose.Types.ObjectId(),
|
||||
versions: [
|
||||
{
|
||||
name: 'Fake Version',
|
||||
provider: 'fake',
|
||||
model: 'fake-model',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockReq.body = systemFields;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify system fields were not affected
|
||||
expect(createdAgent.__v).not.toBe(99);
|
||||
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
|
||||
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
|
||||
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.__v).not.toBe(99);
|
||||
});
|
||||
|
||||
test('should prevent privilege escalation through isCollaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Try to make it collaborative as a different user
|
||||
const attackerId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = attackerId;
|
||||
mockReq.params.id = agent.id;
|
||||
mockReq.body = {
|
||||
isCollaborative: true, // Trying to escalate privileges
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
// Should be rejected
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
|
||||
// Verify in database that it's still not collaborative
|
||||
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||
expect(agentInDb.isCollaborative).toBe(false);
|
||||
});
|
||||
|
||||
test('should prevent author hijacking', async () => {
|
||||
const originalAuthorId = new mongoose.Types.ObjectId();
|
||||
const attackerId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Admin creates an agent
|
||||
mockReq.user.id = originalAuthorId.toString();
|
||||
mockReq.user.role = 'ADMIN';
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Admin Agent',
|
||||
author: attackerId.toString(), // Trying to set different author
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Author should be the actual user, not the attempted value
|
||||
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
|
||||
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id });
|
||||
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
|
||||
});
|
||||
|
||||
test('should strip unknown fields to prevent future vulnerabilities', async () => {
|
||||
mockReq.body = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Future Proof Test',
|
||||
|
||||
// Unknown fields that might be added in future
|
||||
superAdminAccess: true,
|
||||
bypassAllChecks: true,
|
||||
internalFlag: 'secret',
|
||||
futureFeature: 'exploit',
|
||||
};
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
|
||||
// Verify unknown fields were stripped
|
||||
expect(createdAgent.superAdminAccess).toBeUndefined();
|
||||
expect(createdAgent.bypassAllChecks).toBeUndefined();
|
||||
expect(createdAgent.internalFlag).toBeUndefined();
|
||||
expect(createdAgent.futureFeature).toBeUndefined();
|
||||
|
||||
// Also check in database
|
||||
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
|
||||
expect(agentInDb.superAdminAccess).toBeUndefined();
|
||||
expect(agentInDb.bypassAllChecks).toBeUndefined();
|
||||
expect(agentInDb.internalFlag).toBeUndefined();
|
||||
expect(agentInDb.futureFeature).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,4 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -22,20 +19,20 @@ const {
|
||||
addThreadMetadata,
|
||||
saveAssistantMessage,
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
@@ -152,7 +149,7 @@ const chatV1 = async (req, res) => {
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(run_id, { thread_id });
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error cancelling run', error);
|
||||
@@ -162,7 +159,7 @@ const chatV1 = async (req, res) => {
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(run_id, { thread_id });
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
@@ -474,7 +471,7 @@ const chatV1 = async (req, res) => {
|
||||
await Promise.all(promises);
|
||||
|
||||
const sendInitialResponse = () => {
|
||||
sendEvent(res, {
|
||||
sendMessage(res, {
|
||||
sync: true,
|
||||
conversationId,
|
||||
// messages: previousMessages,
|
||||
@@ -590,7 +587,7 @@ const chatV1 = async (req, res) => {
|
||||
iconURL: endpointOption.iconURL,
|
||||
};
|
||||
|
||||
sendEvent(res, {
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: {
|
||||
@@ -623,7 +620,7 @@ const chatV1 = async (req, res) => {
|
||||
|
||||
if (!response.run.usage) {
|
||||
await sleep(3000);
|
||||
completedRun = await openai.beta.threads.runs.retrieve(response.run.id, { thread_id });
|
||||
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
|
||||
if (completedRun.usage) {
|
||||
await recordUsage({
|
||||
...completedRun.usage,
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -25,14 +22,15 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { sendMessage, sleep, countTokens } = require('~/server/utils');
|
||||
const { createRunBody } = require('~/server/services/createRunBody');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
@@ -311,7 +309,7 @@ const chatV2 = async (req, res) => {
|
||||
await Promise.all(promises);
|
||||
|
||||
const sendInitialResponse = () => {
|
||||
sendEvent(res, {
|
||||
sendMessage(res, {
|
||||
sync: true,
|
||||
conversationId,
|
||||
// messages: previousMessages,
|
||||
@@ -434,7 +432,7 @@ const chatV2 = async (req, res) => {
|
||||
iconURL: endpointOption.iconURL,
|
||||
};
|
||||
|
||||
sendEvent(res, {
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: {
|
||||
@@ -467,7 +465,7 @@ const chatV2 = async (req, res) => {
|
||||
|
||||
if (!response.run.usage) {
|
||||
await sleep(3000);
|
||||
completedRun = await openai.beta.threads.runs.retrieve(response.run.id, { thread_id });
|
||||
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
|
||||
if (completedRun.usage) {
|
||||
await recordUsage({
|
||||
...completedRun.usage,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// errorHandler.js
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||
const { sendResponse } = require('~/server/middleware/error');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { sendResponse } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
@@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === 'azureAssistants'
|
||||
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
@@ -108,7 +108,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(run_id, { thread_id });
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||
@@ -118,7 +118,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(run_id, { thread_id });
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
|
||||
@@ -173,16 +173,6 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Initializes the OpenAI client.
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {ServerRequest} params.req - The request object.
|
||||
* @param {ServerResponse} params.res - The response object.
|
||||
* @param {TEndpointOption} params.endpointOption - The endpoint options.
|
||||
* @param {boolean} params.initAppClient - Whether to initialize the app client.
|
||||
* @param {string} params.overrideEndpoint - The endpoint to override.
|
||||
* @returns {Promise<{ openai: OpenAIClient, openAIApiKey: string; client: import('~/app/clients/OpenAIClient') }>} - The initialized OpenAI client.
|
||||
*/
|
||||
async function getOpenAIClient({ req, res, endpointOption, initAppClient, overrideEndpoint }) {
|
||||
let endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
|
||||
@@ -197,7 +197,7 @@ const deleteAssistant = async (req, res) => {
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const deletionStatus = await openai.beta.assistants.delete(assistant_id);
|
||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||
if (deletionStatus?.deleted) {
|
||||
await deleteAssistantActions({ req, assistant_id });
|
||||
}
|
||||
@@ -365,7 +365,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
||||
try {
|
||||
await fs.unlink(req.file.path);
|
||||
logger.debug('[/:agent_id/avatar] Temp. image upload file deleted');
|
||||
} catch {
|
||||
} catch (error) {
|
||||
logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
const { nanoid } = require('nanoid');
|
||||
const { EnvVar } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { checkAccess, loadWebSearchAuth } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
AuthType,
|
||||
Permissions,
|
||||
ToolCallTypes,
|
||||
PermissionTypes,
|
||||
loadWebSearchAuth,
|
||||
} = require('librechat-data-provider');
|
||||
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { loadTools } = require('~/app/clients/tools/util');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { checkAccess } = require('~/server/middleware');
|
||||
const { getMessage } = require('~/models/Message');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const fieldsMap = {
|
||||
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
|
||||
@@ -79,7 +79,6 @@ const verifyToolAuth = async (req, res) => {
|
||||
throwError: false,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error loading auth values', error);
|
||||
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
|
||||
return;
|
||||
}
|
||||
@@ -133,12 +132,7 @@ const callTool = async (req, res) => {
|
||||
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
|
||||
let hasAccess = true;
|
||||
if (toolAccessPermType[toolId]) {
|
||||
hasAccess = await checkAccess({
|
||||
user: req.user,
|
||||
permissionType: toolAccessPermType[toolId],
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
|
||||
}
|
||||
if (!hasAccess) {
|
||||
logger.warn(
|
||||
|
||||
@@ -16,7 +16,7 @@ const { connectDb, indexSync } = require('~/db');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const errorController = require('./controllers/ErrorController');
|
||||
const initializeMCPs = require('./services/initializeMCPs');
|
||||
const initializeMCP = require('./services/initializeMCP');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const AppService = require('./services/AppService');
|
||||
const staticCache = require('./utils/staticCache');
|
||||
@@ -55,6 +55,7 @@ const startServer = async () => {
|
||||
|
||||
/* Middleware */
|
||||
app.use(noIndex);
|
||||
app.use(errorController);
|
||||
app.use(express.json({ limit: '3mb' }));
|
||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||
app.use(mongoSanitize());
|
||||
@@ -96,6 +97,7 @@ const startServer = async () => {
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/user', routes.user);
|
||||
app.use('/api/ask', routes.ask);
|
||||
app.use('/api/search', routes.search);
|
||||
app.use('/api/edit', routes.edit);
|
||||
app.use('/api/messages', routes.messages);
|
||||
@@ -116,13 +118,11 @@ const startServer = async () => {
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
app.use('/api/bedrock', routes.bedrock);
|
||||
app.use('/api/memories', routes.memories);
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
// Add the error controller one more time after all routes
|
||||
app.use(errorController);
|
||||
|
||||
app.use((req, res) => {
|
||||
res.set({
|
||||
'Cache-Control': process.env.INDEX_CACHE_CONTROL || 'no-cache, no-store, must-revalidate',
|
||||
@@ -146,7 +146,7 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCPs(app);
|
||||
initializeMCP(app);
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const request = require('supertest');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const mongoose = require('mongoose');
|
||||
@@ -58,30 +59,6 @@ describe('Server Configuration', () => {
|
||||
expect(response.headers['pragma']).toBe('no-cache');
|
||||
expect(response.headers['expires']).toBe('0');
|
||||
});
|
||||
|
||||
it('should return 500 for unknown errors via ErrorController', async () => {
|
||||
// Testing the error handling here on top of unit tests to ensure the middleware is correctly integrated
|
||||
|
||||
// Mock MongoDB operations to fail
|
||||
const originalFindOne = mongoose.models.User.findOne;
|
||||
const mockError = new Error('MongoDB operation failed');
|
||||
mongoose.models.User.findOne = jest.fn().mockImplementation(() => {
|
||||
throw mockError;
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await request(app).post('/api/auth/login').send({
|
||||
email: 'test@example.com',
|
||||
password: 'password123',
|
||||
});
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.text).toBe('An unknown error occurred.');
|
||||
} finally {
|
||||
// Restore original function
|
||||
mongoose.models.User.findOne = originalFindOne;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Polls the /health endpoint every 30ms for up to 10 seconds to wait for the server to start completely
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
|
||||
// abortMiddleware.js
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { abortRun } = require('./abortRun');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
@@ -101,7 +101,7 @@ async function abortMessage(req, res) {
|
||||
cleanupAbortController(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
return sendEvent(res, finalEvent);
|
||||
return sendMessage(res, finalEvent);
|
||||
}
|
||||
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
@@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
* @param {string} responseMessageId
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
getReqData({ abortKey });
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||
const { deleteMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const three_minutes = 1000 * 60 * 3;
|
||||
|
||||
@@ -34,7 +34,7 @@ async function abortRun(req, res) {
|
||||
const [thread_id, run_id] = runValues.split(':');
|
||||
|
||||
if (!run_id) {
|
||||
logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id });
|
||||
logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id });
|
||||
return res.status(204).send({ message: 'Run not found' });
|
||||
} else if (run_id === 'cancelled') {
|
||||
logger.warn('[abortRun] Run already cancelled', { thread_id });
|
||||
@@ -47,7 +47,7 @@ async function abortRun(req, res) {
|
||||
|
||||
try {
|
||||
await cache.set(cacheKey, 'cancelled', three_minutes);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(run_id, { thread_id });
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug('[abortRun] Cancelled run:', cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error('[abortRun] Error cancelling run', error);
|
||||
@@ -60,7 +60,7 @@ async function abortRun(req, res) {
|
||||
}
|
||||
|
||||
try {
|
||||
const run = await openai.beta.threads.runs.retrieve(run_id, { thread_id });
|
||||
const run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
@@ -93,7 +93,7 @@ async function abortRun(req, res) {
|
||||
};
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
return sendEvent(res, finalEvent);
|
||||
return sendMessage(res, finalEvent);
|
||||
}
|
||||
|
||||
res.json(finalEvent);
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
EndpointURLs,
|
||||
parseCompactConvo,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
parseCompactConvo,
|
||||
EndpointURLs,
|
||||
} = require('librechat-data-provider');
|
||||
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const assistants = require('~/server/services/Endpoints/assistants');
|
||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const anthropic = require('~/server/services/Endpoints/anthropic');
|
||||
const bedrock = require('~/server/services/Endpoints/bedrock');
|
||||
@@ -15,6 +15,7 @@ const openAI = require('~/server/services/Endpoints/openAI');
|
||||
const agents = require('~/server/services/Endpoints/agents');
|
||||
const custom = require('~/server/services/Endpoints/custom');
|
||||
const google = require('~/server/services/Endpoints/google');
|
||||
const { handleError } = require('~/server/utils');
|
||||
|
||||
const buildFunction = {
|
||||
[EModelEndpoint.openAI]: openAI.buildOptions,
|
||||
@@ -24,6 +25,7 @@ const buildFunction = {
|
||||
[EModelEndpoint.bedrock]: bedrock.buildOptions,
|
||||
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
|
||||
[EModelEndpoint.anthropic]: anthropic.buildOptions,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
|
||||
[EModelEndpoint.assistants]: assistants.buildOptions,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
|
||||
};
|
||||
@@ -34,9 +36,6 @@ async function buildEndpointOption(req, res, next) {
|
||||
try {
|
||||
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
|
||||
);
|
||||
return handleError(res, { text: 'Error parsing conversation' });
|
||||
}
|
||||
|
||||
@@ -58,6 +57,15 @@ async function buildEndpointOption(req, res, next) {
|
||||
return handleError(res, { text: 'Model spec mismatch' });
|
||||
}
|
||||
|
||||
if (
|
||||
currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins &&
|
||||
currentModelSpec.preset.tools
|
||||
) {
|
||||
return handleError(res, {
|
||||
text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`,
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
currentModelSpec.preset.spec = spec;
|
||||
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
|
||||
@@ -69,7 +77,6 @@ async function buildEndpointOption(req, res, next) {
|
||||
conversation: currentModelSpec.preset,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error parsing model spec for endpoint ${endpoint}`, error);
|
||||
return handleError(res, { text: 'Error parsing model spec' });
|
||||
}
|
||||
}
|
||||
@@ -77,23 +84,20 @@ async function buildEndpointOption(req, res, next) {
|
||||
try {
|
||||
const isAgents =
|
||||
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
|
||||
const builder = isAgents
|
||||
? (...args) => buildFunction[EModelEndpoint.agents](req, ...args)
|
||||
: buildFunction[endpointType ?? endpoint];
|
||||
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
|
||||
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
|
||||
|
||||
// TODO: use object params
|
||||
req.body.endpointOption = await builder(endpoint, parsedBody, endpointType);
|
||||
|
||||
// TODO: use `getModelsConfig` only when necessary
|
||||
const modelsConfig = await getModelsConfig(req);
|
||||
req.body.endpointOption.modelsConfig = modelsConfig;
|
||||
if (req.body.files && !isAgents) {
|
||||
req.body.endpointOption.attachments = processFiles(req.body.files);
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`,
|
||||
error,
|
||||
);
|
||||
return handleError(res, { text: 'Error building endpoint option' });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ const message = 'Your account has been temporarily banned due to violations of o
|
||||
* @function
|
||||
* @param {Object} req - Express Request object.
|
||||
* @param {Object} res - Express Response object.
|
||||
* @param {String} errorMessage - Error message to be displayed in case of /api/ask or /api/edit request.
|
||||
*
|
||||
* @returns {Promise<Object>} - Returns a Promise which when resolved sends a response status of 403 with a specific message if request is not of api/ask or api/edit types. If it is, calls `denyRequest()` function.
|
||||
*/
|
||||
@@ -134,7 +135,6 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
return await banResponse(req, res);
|
||||
} catch (error) {
|
||||
logger.error('Error in checkBan middleware:', error);
|
||||
return next(error);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { Time, CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { Time, CacheKeys } = require('librechat-data-provider');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { logViolation, getLogStores } = require('~/cache');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
@@ -37,7 +37,7 @@ const concurrentLimiter = async (req, res, next) => {
|
||||
|
||||
const userId = req.user?.id ?? req.user?._id ?? '';
|
||||
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
|
||||
const type = ViolationTypes.CONCURRENT;
|
||||
const type = 'concurrent';
|
||||
|
||||
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
|
||||
const pendingRequests = +((await cache.get(key)) ?? 0);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const crypto = require('crypto');
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
const { sendMessage, sendError } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
/**
|
||||
@@ -37,7 +36,7 @@ const denyRequest = async (req, res, errorMessage) => {
|
||||
isCreatedByUser: true,
|
||||
text,
|
||||
};
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
|
||||
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FORK_IP_MAX = parseInt(process.env.FORK_IP_MAX) || 30;
|
||||
const FORK_IP_WINDOW = parseInt(process.env.FORK_IP_WINDOW) || 1;
|
||||
const FORK_USER_MAX = parseInt(process.env.FORK_USER_MAX) || 7;
|
||||
const FORK_USER_WINDOW = parseInt(process.env.FORK_USER_WINDOW) || 1;
|
||||
const FORK_VIOLATION_SCORE = process.env.FORK_VIOLATION_SCORE;
|
||||
|
||||
const forkIpWindowMs = FORK_IP_WINDOW * 60 * 1000;
|
||||
const forkIpMax = FORK_IP_MAX;
|
||||
const forkIpWindowInMinutes = forkIpWindowMs / 60000;
|
||||
|
||||
const forkUserWindowMs = FORK_USER_WINDOW * 60 * 1000;
|
||||
const forkUserMax = FORK_USER_MAX;
|
||||
const forkUserWindowInMinutes = forkUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
forkIpWindowMs,
|
||||
forkIpMax,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowMs,
|
||||
forkUserMax,
|
||||
forkUserWindowInMinutes,
|
||||
forkViolationScore: FORK_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createForkHandler = (ip = true) => {
|
||||
const {
|
||||
forkIpMax,
|
||||
forkUserMax,
|
||||
forkViolationScore,
|
||||
forkIpWindowInMinutes,
|
||||
forkUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? forkIpMax : forkUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? forkIpWindowInMinutes : forkUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, forkViolationScore);
|
||||
res.status(429).json({ message: 'Too many conversation fork requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createForkLimiters = () => {
|
||||
const { forkIpWindowMs, forkIpMax, forkUserWindowMs, forkUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ipLimiterOptions = {
|
||||
windowMs: forkIpWindowMs,
|
||||
max: forkIpMax,
|
||||
handler: createForkHandler(),
|
||||
store: limiterCache('fork_ip_limiter'),
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: forkUserWindowMs,
|
||||
max: forkUserMax,
|
||||
handler: createForkHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
store: limiterCache('fork_user_limiter'),
|
||||
};
|
||||
|
||||
const forkIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const forkUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { forkIpLimiter, forkUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = { createForkLimiters };
|
||||
@@ -1,14 +1,16 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
|
||||
const IMPORT_IP_WINDOW = parseInt(process.env.IMPORT_IP_WINDOW) || 15;
|
||||
const IMPORT_USER_MAX = parseInt(process.env.IMPORT_USER_MAX) || 50;
|
||||
const IMPORT_USER_WINDOW = parseInt(process.env.IMPORT_USER_WINDOW) || 15;
|
||||
const IMPORT_VIOLATION_SCORE = process.env.IMPORT_VIOLATION_SCORE;
|
||||
|
||||
const importIpWindowMs = IMPORT_IP_WINDOW * 60 * 1000;
|
||||
const importIpMax = IMPORT_IP_MAX;
|
||||
@@ -25,18 +27,12 @@ const getEnvironmentVariables = () => {
|
||||
importUserWindowMs,
|
||||
importUserMax,
|
||||
importUserWindowInMinutes,
|
||||
importViolationScore: IMPORT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createImportHandler = (ip = true) => {
|
||||
const {
|
||||
importIpMax,
|
||||
importUserMax,
|
||||
importViolationScore,
|
||||
importIpWindowInMinutes,
|
||||
importUserWindowInMinutes,
|
||||
} = getEnvironmentVariables();
|
||||
const { importIpMax, importIpWindowInMinutes, importUserMax, importUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
|
||||
@@ -47,7 +43,7 @@ const createImportHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? importIpWindowInMinutes : importUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, importViolationScore);
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many conversation import requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -60,7 +56,6 @@ const createImportLimiters = () => {
|
||||
windowMs: importIpWindowMs,
|
||||
max: importIpMax,
|
||||
handler: createImportHandler(),
|
||||
store: limiterCache('import_ip_limiter'),
|
||||
};
|
||||
const userLimiterOptions = {
|
||||
windowMs: importUserWindowMs,
|
||||
@@ -69,9 +64,23 @@ const createImportLimiters = () => {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('import_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for import rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'import_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'import_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const importIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const importUserLimiter = rateLimit(userLimiterOptions);
|
||||
return { importIpLimiter, importUserLimiter };
|
||||
|
||||
@@ -4,7 +4,6 @@ const createSTTLimiters = require('./sttLimiters');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const importLimiters = require('./importLimiters');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const forkLimiters = require('./forkLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const toolCallLimiter = require('./toolCallLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
@@ -15,7 +14,6 @@ module.exports = {
|
||||
...uploadLimiters,
|
||||
...importLimiters,
|
||||
...messageLimiters,
|
||||
...forkLimiters,
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
toolCallLimiter,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = LOGIN_WINDOW * 60 * 1000;
|
||||
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
|
||||
const message = `Too many login attempts, please try again after ${windowInMinutes} minutes.`;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.LOGINS;
|
||||
const type = 'logins';
|
||||
const errorMessage = {
|
||||
type,
|
||||
max,
|
||||
@@ -27,9 +28,17 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('login_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for login rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'login_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const loginLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = loginLimiter;
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const denyRequest = require('~/server/middleware/denyRequest');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
MESSAGE_IP_MAX = 40,
|
||||
MESSAGE_IP_WINDOW = 1,
|
||||
MESSAGE_USER_MAX = 40,
|
||||
MESSAGE_USER_WINDOW = 1,
|
||||
MESSAGE_VIOLATION_SCORE: score,
|
||||
} = process.env;
|
||||
|
||||
const ipWindowMs = MESSAGE_IP_WINDOW * 60 * 1000;
|
||||
@@ -30,7 +31,7 @@ const userWindowInMinutes = userWindowMs / 60000;
|
||||
*/
|
||||
const createHandler = (ip = true) => {
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.MESSAGE_LIMIT;
|
||||
const type = 'message_limit';
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? ipMax : userMax,
|
||||
@@ -38,7 +39,7 @@ const createHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ipWindowInMinutes : userWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
};
|
||||
};
|
||||
@@ -50,7 +51,6 @@ const ipLimiterOptions = {
|
||||
windowMs: ipWindowMs,
|
||||
max: ipMax,
|
||||
handler: createHandler(),
|
||||
store: limiterCache('message_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
@@ -60,9 +60,23 @@ const userLimiterOptions = {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('message_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for message rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'message_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'message_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Message request rate limiter by IP
|
||||
*/
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = REGISTER_WINDOW * 60 * 1000;
|
||||
@@ -11,7 +12,7 @@ const windowInMinutes = windowMs / 60000;
|
||||
const message = `Too many accounts created, please try again after ${windowInMinutes} minutes`;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.REGISTRATIONS;
|
||||
const type = 'registrations';
|
||||
const errorMessage = {
|
||||
type,
|
||||
max,
|
||||
@@ -27,9 +28,17 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('register_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for register rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'register_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const registerLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = registerLimiter;
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
RESET_PASSWORD_WINDOW = 2,
|
||||
@@ -31,9 +33,17 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('reset_password_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for reset password rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'reset_password_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const resetPasswordLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = resetPasswordLimiter;
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
const STT_VIOLATION_SCORE = process.env.STT_VIOLATION_SCORE;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
sttViolationScore: STT_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes, sttViolationScore } =
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -42,7 +43,7 @@ const createSTTHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, sttViolationScore);
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -54,7 +55,6 @@ const createSTTLimiters = () => {
|
||||
windowMs: sttIpWindowMs,
|
||||
max: sttIpMax,
|
||||
handler: createSTTHandler(),
|
||||
store: limiterCache('stt_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
@@ -64,9 +64,23 @@ const createSTTLimiters = () => {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('stt_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for STT rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'stt_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'stt_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const sttIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const sttUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const { TOOL_CALL_VIOLATION_SCORE: score } = process.env;
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.TOOL_CALL_LIMIT;
|
||||
@@ -14,7 +15,7 @@ const handler = async (req, res) => {
|
||||
windowInMinutes: 1,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
|
||||
};
|
||||
|
||||
@@ -25,9 +26,17 @@ const limiterOptions = {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id;
|
||||
},
|
||||
store: limiterCache('tool_call_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for tool call rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'tool_call_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const toolCallLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = toolCallLimiter;
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
const TTS_VIOLATION_SCORE = process.env.TTS_VIOLATION_SCORE;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
@@ -25,12 +27,11 @@ const getEnvironmentVariables = () => {
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
ttsViolationScore: TTS_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes, ttsViolationScore } =
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -42,7 +43,7 @@ const createTTSHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, ttsViolationScore);
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -54,19 +55,32 @@ const createTTSLimiters = () => {
|
||||
windowMs: ttsIpWindowMs,
|
||||
max: ttsIpMax,
|
||||
handler: createTTSHandler(),
|
||||
store: limiterCache('tts_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
windowMs: ttsUserWindowMs,
|
||||
max: ttsUserMax,
|
||||
handler: createTTSHandler(false),
|
||||
store: limiterCache('tts_user_limiter'),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for TTS rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tts_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'tts_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const ttsIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const ttsUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
|
||||
const FILE_UPLOAD_IP_WINDOW = parseInt(process.env.FILE_UPLOAD_IP_WINDOW) || 15;
|
||||
const FILE_UPLOAD_USER_MAX = parseInt(process.env.FILE_UPLOAD_USER_MAX) || 50;
|
||||
const FILE_UPLOAD_USER_WINDOW = parseInt(process.env.FILE_UPLOAD_USER_WINDOW) || 15;
|
||||
const FILE_UPLOAD_VIOLATION_SCORE = process.env.FILE_UPLOAD_VIOLATION_SCORE;
|
||||
|
||||
const fileUploadIpWindowMs = FILE_UPLOAD_IP_WINDOW * 60 * 1000;
|
||||
const fileUploadIpMax = FILE_UPLOAD_IP_MAX;
|
||||
@@ -25,7 +27,6 @@ const getEnvironmentVariables = () => {
|
||||
fileUploadUserWindowMs,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore: FILE_UPLOAD_VIOLATION_SCORE,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -35,7 +36,6 @@ const createFileUploadHandler = (ip = true) => {
|
||||
fileUploadIpWindowInMinutes,
|
||||
fileUploadUserMax,
|
||||
fileUploadUserWindowInMinutes,
|
||||
fileUploadViolationScore,
|
||||
} = getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
@@ -47,7 +47,7 @@ const createFileUploadHandler = (ip = true) => {
|
||||
windowInMinutes: ip ? fileUploadIpWindowInMinutes : fileUploadUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, fileUploadViolationScore);
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many file upload requests. Try again later' });
|
||||
};
|
||||
};
|
||||
@@ -60,7 +60,6 @@ const createFileLimiters = () => {
|
||||
windowMs: fileUploadIpWindowMs,
|
||||
max: fileUploadIpMax,
|
||||
handler: createFileUploadHandler(),
|
||||
store: limiterCache('file_upload_ip_limiter'),
|
||||
};
|
||||
|
||||
const userLimiterOptions = {
|
||||
@@ -70,9 +69,23 @@ const createFileLimiters = () => {
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
store: limiterCache('file_upload_user_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for file upload rate limiters.');
|
||||
const sendCommand = (...args) => ioredisClient.call(...args);
|
||||
const ipStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'file_upload_ip_limiter:',
|
||||
});
|
||||
const userStore = new RedisStore({
|
||||
sendCommand,
|
||||
prefix: 'file_upload_user_limiter:',
|
||||
});
|
||||
ipLimiterOptions.store = ipStore;
|
||||
userLimiterOptions.store = userStore;
|
||||
}
|
||||
|
||||
const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
|
||||
const fileUploadUserLimiter = rateLimit(userLimiterOptions);
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { RedisStore } = require('rate-limit-redis');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { limiterCache } = require('~/cache/cacheFactory');
|
||||
const { removePorts, isEnabled } = require('~/server/utils');
|
||||
const ioredisClient = require('~/cache/ioredisClient');
|
||||
const { logViolation } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const {
|
||||
VERIFY_EMAIL_WINDOW = 2,
|
||||
@@ -31,9 +33,17 @@ const limiterOptions = {
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
store: limiterCache('verify_email_limiter'),
|
||||
};
|
||||
|
||||
if (isEnabled(process.env.USE_REDIS) && ioredisClient) {
|
||||
logger.debug('Using Redis for verify email rate limiter.');
|
||||
const store = new RedisStore({
|
||||
sendCommand: (...args) => ioredisClient.call(...args),
|
||||
prefix: 'verify_email_limiter:',
|
||||
});
|
||||
limiterOptions.store = store;
|
||||
}
|
||||
|
||||
const verifyEmailLimiter = rateLimit(limiterOptions);
|
||||
|
||||
module.exports = verifyEmailLimiter;
|
||||
|
||||
78
api/server/middleware/roles/access.js
Normal file
78
api/server/middleware/roles/access.js
Normal file
@@ -0,0 +1,78 @@
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Core function to check if a user has one or more required permissions
|
||||
*
|
||||
* @param {object} user - The user object
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
|
||||
* @param {object} [checkObject] - The object to check properties against
|
||||
* @returns {Promise<boolean>} Whether the user has the required permissions
|
||||
*/
|
||||
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
|
||||
if (!user) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role.permissions && role.permissions[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role.permissions[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && checkObject) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(checkObject, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
return hasAnyPermission;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
*
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
|
||||
*/
|
||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const hasAccess = await checkAccess(
|
||||
req.user,
|
||||
permissionType,
|
||||
permissions,
|
||||
bodyProps,
|
||||
req.body,
|
||||
);
|
||||
|
||||
if (hasAccess) {
|
||||
return next();
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
|
||||
);
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
checkAccess,
|
||||
generateCheckAccess,
|
||||
};
|
||||
@@ -1,5 +1,8 @@
|
||||
const checkAdmin = require('./admin');
|
||||
const { checkAccess, generateCheckAccess } = require('./access');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
checkAccess,
|
||||
generateCheckAccess,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const uap = require('ua-parser-js');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { handleError } = require('../utils');
|
||||
const { logViolation } = require('../../cache');
|
||||
|
||||
/**
|
||||
@@ -22,7 +21,7 @@ async function uaParser(req, res, next) {
|
||||
const ua = uap(req.headers['user-agent']);
|
||||
|
||||
if (!ua.browser.name) {
|
||||
const type = ViolationTypes.NON_BROWSER;
|
||||
const type = 'non_browser';
|
||||
await logViolation(req, res, type, { type }, score);
|
||||
return handleError(res, { message: 'Illegal request' });
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { handleError } = require('../utils');
|
||||
|
||||
function validateEndpoint(req, res, next) {
|
||||
const { endpoint: _endpoint, endpointType } = req.body;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { handleError } = require('@librechat/api');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const { handleError } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
/**
|
||||
* Validates the model of the request.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user