Compare commits

..

6 Commits

950 changed files with 21851 additions and 72536 deletions

View File

@@ -20,8 +20,8 @@ DOMAIN_CLIENT=http://localhost:3080
DOMAIN_SERVER=http://localhost:3080 DOMAIN_SERVER=http://localhost:3080
NO_INDEX=true NO_INDEX=true
# Use the address that is at most n number of hops away from the Express application. # Use the address that is at most n number of hops away from the Express application.
# req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left. # req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left.
# A value of 0 means that the first untrusted address would be req.socket.remoteAddress, i.e. there is no reverse proxy. # A value of 0 means that the first untrusted address would be req.socket.remoteAddress, i.e. there is no reverse proxy.
# Defaulted to 1. # Defaulted to 1.
TRUST_PROXY=1 TRUST_PROXY=1
@@ -58,7 +58,7 @@ DEBUG_CONSOLE=false
# Endpoints # # Endpoints #
#===================================================# #===================================================#
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic # ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
PROXY= PROXY=
@@ -88,7 +88,7 @@ PROXY=
#============# #============#
ANTHROPIC_API_KEY=user_provided ANTHROPIC_API_KEY=user_provided
# ANTHROPIC_MODELS=claude-opus-4-20250514,claude-sonnet-4-20250514,claude-3-7-sonnet-20250219,claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307 # ANTHROPIC_MODELS=claude-3-7-sonnet-latest,claude-3-7-sonnet-20250219,claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
# ANTHROPIC_REVERSE_PROXY= # ANTHROPIC_REVERSE_PROXY=
#============# #============#
@@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true # GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio) # Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite # GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# Vertex AI # Vertex AI
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001 # GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001 # GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
@@ -349,11 +349,6 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1 CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1 MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20 NON_BROWSER_VIOLATION_SCORE=20
TTS_VIOLATION_SCORE=0
STT_VIOLATION_SCORE=0
FORK_VIOLATION_SCORE=0
IMPORT_VIOLATION_SCORE=0
FILE_UPLOAD_VIOLATION_SCORE=0
LOGIN_MAX=7 LOGIN_MAX=7
LOGIN_WINDOW=5 LOGIN_WINDOW=5
@@ -448,47 +443,6 @@ OPENID_IMAGE_URL=
# Set to true to automatically redirect to the OpenID provider when a user visits the login page # Set to true to automatically redirect to the OpenID provider when a user visits the login page
# This will bypass the login form completely for users, only use this if OpenID is your only authentication method # This will bypass the login form completely for users, only use this if OpenID is your only authentication method
OPENID_AUTO_REDIRECT=false OPENID_AUTO_REDIRECT=false
# Set to true to use PKCE (Proof Key for Code Exchange) for OpenID authentication
OPENID_USE_PKCE=false
#Set to true to reuse openid tokens for authentication management instead of using the mongodb session and the custom refresh token.
OPENID_REUSE_TOKENS=
#By default, signing key verification results are cached in order to prevent excessive HTTP requests to the JWKS endpoint.
#If a signing key matching the kid is found, this will be cached and the next time this kid is requested the signing key will be served from the cache.
#Default is true.
OPENID_JWKS_URL_CACHE_ENABLED=
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
# Set to true to use the OpenID Connect end session endpoint for logout
OPENID_USE_END_SESSION_ENDPOINT=
# SAML
# Note: If OpenID is enabled, SAML authentication will be automatically disabled.
SAML_ENTRY_POINT=
SAML_ISSUER=
SAML_CERT=
SAML_CALLBACK_URL=/oauth/saml/callback
SAML_SESSION_SECRET=
# Attribute mappings (optional)
SAML_EMAIL_CLAIM=
SAML_USERNAME_CLAIM=
SAML_GIVEN_NAME_CLAIM=
SAML_FAMILY_NAME_CLAIM=
SAML_PICTURE_CLAIM=
SAML_NAME_CLAIM=
# Logint buttion settings (optional)
SAML_BUTTON_LABEL=
SAML_IMAGE_URL=
# Whether the SAML Response should be signed.
# - If "true", the entire `SAML Response` will be signed.
# - If "false" or unset, only the `SAML Assertion` will be signed (default behavior).
# SAML_USE_AUTHN_RESPONSE_SIGNED=
# LDAP # LDAP
LDAP_URL= LDAP_URL=
@@ -520,18 +474,6 @@ EMAIL_PASSWORD=
EMAIL_FROM_NAME= EMAIL_FROM_NAME=
EMAIL_FROM=noreply@librechat.ai EMAIL_FROM=noreply@librechat.ai
#========================#
# Mailgun API #
#========================#
# MAILGUN_API_KEY=your-mailgun-api-key
# MAILGUN_DOMAIN=mg.yourdomain.com
# EMAIL_FROM=noreply@yourdomain.com
# EMAIL_FROM_NAME="LibreChat"
# # Optional: For EU region
# MAILGUN_HOST=https://api.eu.mailgun.net
#========================# #========================#
# Firebase CDN # # Firebase CDN #
#========================# #========================#
@@ -580,10 +522,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
# If you have another service in front of your LibreChat doing compression, disable express based compression here # If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true # DISABLE_COMPRESSION=true
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
#===================================================# #===================================================#
# UI # # UI #
#===================================================# #===================================================#
@@ -601,31 +539,11 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS Options # # REDIS Options #
#===============# #===============#
# Enable Redis for caching and session storage # REDIS_URI=10.10.10.10:6379
# USE_REDIS=true # USE_REDIS=true
# Single Redis instance # USE_REDIS_CLUSTER=true
# REDIS_URI=redis://127.0.0.1:6379 # REDIS_CA=/path/to/ca.crt
# Redis cluster (multiple nodes)
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
# Redis with TLS/SSL encryption and CA certificate
# REDIS_URI=rediss://127.0.0.1:6380
# REDIS_CA=/path/to/ca-cert.pem
# Redis authentication (if required)
# REDIS_USERNAME=your_redis_username
# REDIS_PASSWORD=your_redis_password
# Redis key prefix configuration
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
# REDIS_KEY_PREFIX_VAR=K_REVISION
# Or use static prefix directly
# REDIS_KEY_PREFIX=librechat
# Redis connection limits
# REDIS_MAX_LISTENERS=40
#==================================================# #==================================================#
# Others # # Others #
@@ -645,9 +563,9 @@ HELP_AND_FAQ_URL=https://librechat.ai
# users always get the latest version. Customize # # users always get the latest version. Customize #
# only if you understand caching implications. # # only if you understand caching implications. #
# INDEX_CACHE_CONTROL=no-cache, no-store, must-revalidate # INDEX_HTML_CACHE_CONTROL=no-cache, no-store, must-revalidate
# INDEX_PRAGMA=no-cache # INDEX_HTML_PRAGMA=no-cache
# INDEX_EXPIRES=0 # INDEX_HTML_EXPIRES=0
# no-cache: Forces validation with server before using cached version # no-cache: Forces validation with server before using cached version
# no-store: Prevents storing the response entirely # no-store: Prevents storing the response entirely
@@ -657,33 +575,3 @@ HELP_AND_FAQ_URL=https://librechat.ai
# OpenWeather # # OpenWeather #
#=====================================================# #=====================================================#
OPENWEATHER_API_KEY= OPENWEATHER_API_KEY=
#====================================#
# LibreChat Code Interpreter API #
#====================================#
# https://code.librechat.ai
# LIBRECHAT_CODE_API_KEY=your-key
#======================#
# Web Search #
#======================#
# Note: All of the following variable names can be customized.
# Omit values to allow user to provide them.
# For more information on configuration values, see:
# https://librechat.ai/docs/features/web_search
# Search Provider (Required)
# SERPER_API_KEY=your_serper_api_key
# Scraper (Required)
# FIRECRAWL_API_KEY=your_firecrawl_api_key
# Optional: Custom Firecrawl API URL
# FIRECRAWL_API_URL=your_firecrawl_api_url
# Reranker (Required)
# JINA_API_KEY=your_jina_api_key
# or
# COHERE_API_KEY=your_cohere_api_key

View File

@@ -30,8 +30,8 @@ Project maintainers have the right and responsibility to remove, edit, or reject
2. Install typescript globally: `npm i -g typescript`. 2. Install typescript globally: `npm i -g typescript`.
3. Run `npm ci` to install dependencies. 3. Run `npm ci` to install dependencies.
4. Build the data provider: `npm run build:data-provider`. 4. Build the data provider: `npm run build:data-provider`.
5. Build data schemas: `npm run build:data-schemas`. 5. Build MCP: `npm run build:mcp`.
6. Build API methods: `npm run build:api`. 6. Build data schemas: `npm run build:data-schemas`.
7. Setup and run unit tests: 7. Setup and run unit tests:
- Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`. - Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`.
- Run backend unit tests: `npm run test:api`. - Run backend unit tests: `npm run test:api`.

View File

@@ -7,7 +7,6 @@ on:
- release/* - release/*
paths: paths:
- 'api/**' - 'api/**'
- 'packages/**'
jobs: jobs:
tests_Backend: tests_Backend:
name: Run Backend unit tests name: Run Backend unit tests
@@ -37,12 +36,12 @@ jobs:
- name: Install Data Provider Package - name: Install Data Provider Package
run: npm run build:data-provider run: npm run build:data-provider
- name: Install MCP Package
run: npm run build:mcp
- name: Install Data Schemas Package - name: Install Data Schemas Package
run: npm run build:data-schemas run: npm run build:data-schemas
- name: Install API Package
run: npm run build:api
- name: Create empty auth.json file - name: Create empty auth.json file
run: | run: |
mkdir -p api/data mkdir -p api/data
@@ -67,8 +66,5 @@ jobs:
- name: Run librechat-data-provider unit tests - name: Run librechat-data-provider unit tests
run: cd packages/data-provider && npm run test:ci run: cd packages/data-provider && npm run test:ci
- name: Run @librechat/data-schemas unit tests - name: Run librechat-mcp unit tests
run: cd packages/data-schemas && npm run test:ci run: cd packages/mcp && npm run test:ci
- name: Run @librechat/api unit tests
run: cd packages/api && npm run test:ci

View File

@@ -1,32 +0,0 @@
name: Publish `@librechat/client` to NPM
on:
workflow_dispatch:
inputs:
reason:
description: 'Reason for manual trigger'
required: false
default: 'Manual publish requested'
jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Use Node.js
uses: actions/setup-node@v4
with:
node-version: '18.x'
- name: Check if client package exists
run: |
if [ -d "packages/client" ]; then
echo "Client package directory found"
else
echo "Client package directory not found - workflow ready for future use"
exit 0
fi
- name: Placeholder for future publishing
run: echo "Client package publishing workflow is ready"

View File

@@ -2,7 +2,7 @@ name: Update Test Server
on: on:
workflow_run: workflow_run:
workflows: ["Docker Dev Branch Images Build"] workflows: ["Docker Dev Images Build"]
types: types:
- completed - completed
workflow_dispatch: workflow_dispatch:
@@ -12,8 +12,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: | if: |
github.repository == 'danny-avila/LibreChat' && github.repository == 'danny-avila/LibreChat' &&
(github.event_name == 'workflow_dispatch' || (github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success')
(github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'dev'))
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -30,17 +29,13 @@ jobs:
DO_USER: ${{ secrets.DO_USER }} DO_USER: ${{ secrets.DO_USER }}
run: | run: |
ssh -o StrictHostKeyChecking=no ${DO_USER}@${DO_HOST} << EOF ssh -o StrictHostKeyChecking=no ${DO_USER}@${DO_HOST} << EOF
sudo -i -u danny bash << 'EEOF' sudo -i -u danny bash << EEOF
cd ~/LibreChat && \ cd ~/LibreChat && \
git fetch origin main && \ git fetch origin main && \
sudo npm run stop:deployed && \ npm run update:deployed && \
sudo docker images --format "{{.Repository}}:{{.ID}}" | grep -E "lc-dev|librechat" | cut -d: -f2 | xargs -r sudo docker rmi -f || true && \
sudo npm run update:deployed && \
git checkout dev && \
git pull origin dev && \
git checkout do-deploy && \ git checkout do-deploy && \
git rebase dev && \ git rebase main && \
sudo npm run start:deployed && \ npm run start:deployed && \
echo "Update completed. Application should be running now." echo "Update completed. Application should be running now."
EEOF EEOF
EOF EOF

View File

@@ -1,72 +0,0 @@
name: Docker Dev Branch Images Build
on:
workflow_dispatch:
push:
branches:
- dev
paths:
- 'api/**'
- 'client/**'
- 'packages/**'
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- target: api-build
file: Dockerfile.multi
image_name: lc-dev-api
- target: node
file: Dockerfile
image_name: lc-dev
steps:
# Check out the repository
- name: Checkout
uses: actions/checkout@v4
# Set up QEMU
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
# Set up Docker Buildx
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Log in to GitHub Container Registry
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# Login to Docker Hub
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Prepare the environment
- name: Prepare environment
run: |
cp .env.example .env
# Build and push Docker images for each target
- name: Build and push Docker images
uses: docker/build-push-action@v5
with:
context: .
file: ${{ matrix.file }}
push: true
tags: |
ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.sha }}
ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest
${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.sha }}
${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest
platforms: linux/amd64,linux/arm64
target: ${{ matrix.target }}

View File

@@ -8,7 +8,7 @@ on:
- release/* - release/*
paths: paths:
- 'client/**' - 'client/**'
- 'packages/data-provider/**' - 'packages/**'
jobs: jobs:
tests_frontend_ubuntu: tests_frontend_ubuntu:

View File

@@ -26,15 +26,8 @@ jobs:
uses: azure/setup-helm@v4 uses: azure/setup-helm@v4
env: env:
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
- name: Build Subchart Deps
run: |
cd helm/librechat-rag-api
helm dependency build
- name: Run chart-releaser - name: Run chart-releaser
uses: helm/chart-releaser-action@v1.6.0 uses: helm/chart-releaser-action@v1.6.0
with:
charts_dir: helm
skip_existing: true
env: env:
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}" CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@@ -5,13 +5,12 @@ on:
paths: paths:
- "client/src/**" - "client/src/**"
- "api/**" - "api/**"
- "packages/data-provider/src/**"
jobs: jobs:
detect-unused-i18n-keys: detect-unused-i18n-keys:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
pull-requests: write pull-requests: write # Required for posting PR comments
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -23,7 +22,7 @@ jobs:
# Define paths # Define paths
I18N_FILE="client/src/locales/en/translation.json" I18N_FILE="client/src/locales/en/translation.json"
SOURCE_DIRS=("client/src" "api" "packages/data-provider/src") SOURCE_DIRS=("client/src" "api")
# Check if translation file exists # Check if translation file exists
if [[ ! -f "$I18N_FILE" ]]; then if [[ ! -f "$I18N_FILE" ]]; then

View File

@@ -98,8 +98,6 @@ jobs:
cd client cd client
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "") UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "") UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "")
# Filter out false positives
UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "")
echo "CLIENT_UNUSED<<EOF" >> $GITHUB_ENV echo "CLIENT_UNUSED<<EOF" >> $GITHUB_ENV
echo "$UNUSED" >> $GITHUB_ENV echo "$UNUSED" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV

22
.gitignore vendored
View File

@@ -52,10 +52,8 @@ bower_components/
*.d.ts *.d.ts
!vite-env.d.ts !vite-env.d.ts
# AI # Cline
.clineignore .clineignore
.cursor
.aider*
# Floobits # Floobits
.floo .floo
@@ -115,22 +113,4 @@ uploads/
# owner # owner
release/ release/
# Helm
helm/librechat/Chart.lock
helm/**/charts/
helm/**/.values.yaml
!/client/src/@types/i18next.d.ts !/client/src/@types/i18next.d.ts
# SAML Idp cert
*.cert
# AI Assistants
/.claude/
/.cursor/
/.copilot/
/.aider/
/.openai/
/.tabnine/
/.codeium

View File

@@ -5,47 +5,23 @@ All notable changes to this project will be documented in this file.
## [Unreleased] ## [Unreleased]
### ✨ New Features ### ✨ New Features
- ✨ feat: implement search parameter updates by **@mawburn** in [#7151](https://github.com/danny-avila/LibreChat/pull/7151) - ✨ feat: implement search parameter updates by **@mawburn** in [#7151](https://github.com/danny-avila/LibreChat/pull/7151)
- 🎏 feat: Add MCP support for Streamable HTTP Transport by **@benverhees** in [#7353](https://github.com/danny-avila/LibreChat/pull/7353) - 🎏 feat: Add MCP support for Streamable HTTP Transport by **@benverhees** in [#7353](https://github.com/danny-avila/LibreChat/pull/7353)
- 🔒 feat: Add Content Security Policy using Helmet middleware by **@rubentalstra** in [#7377](https://github.com/danny-avila/LibreChat/pull/7377)
- ✨ feat: Add Normalization for MCP Server Names by **@danny-avila** in [#7421](https://github.com/danny-avila/LibreChat/pull/7421)
- 📊 feat: Improve Helm Chart by **@hofq** in [#3638](https://github.com/danny-avila/LibreChat/pull/3638)
- 🦾 feat: Claude-4 Support by **@danny-avila** in [#7509](https://github.com/danny-avila/LibreChat/pull/7509)
- 🪨 feat: Bedrock Support for Claude-4 Reasoning by **@danny-avila** in [#7517](https://github.com/danny-avila/LibreChat/pull/7517)
### 🌍 Internationalization
- 🌍 i18n: Add `Danish` and `Czech` and `Catalan` localization support by **@rubentalstra** in [#7373](https://github.com/danny-avila/LibreChat/pull/7373)
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7375](https://github.com/danny-avila/LibreChat/pull/7375)
- 🌍 i18n: Update translation.json with latest translations by **@github-actions[bot]** in [#7468](https://github.com/danny-avila/LibreChat/pull/7468)
### 🔧 Fixes ### 🔧 Fixes
- 💬 fix: update aria-label for accessibility in ConvoLink component by **@berry-13** in [#7320](https://github.com/danny-avila/LibreChat/pull/7320) - 💬 fix: update aria-label for accessibility in ConvoLink component by **@berry-13** in [#7320](https://github.com/danny-avila/LibreChat/pull/7320)
- 🔑 fix: use `apiKey` instead of `openAIApiKey` in OpenAI-like Config by **@danny-avila** in [#7337](https://github.com/danny-avila/LibreChat/pull/7337) - 🔑 fix: use `apiKey` instead of `openAIApiKey` in OpenAI-like Config by **@danny-avila** in [#7337](https://github.com/danny-avila/LibreChat/pull/7337)
- 🔄 fix: update navigation logic in `useFocusChatEffect` to ensure correct search parameters are used by **@mawburn** in [#7340](https://github.com/danny-avila/LibreChat/pull/7340) - 🔄 fix: update navigation logic in `useFocusChatEffect` to ensure correct search parameters are used by **@mawburn** in [#7340](https://github.com/danny-avila/LibreChat/pull/7340)
- 🔄 fix: Improve MCP Connection Cleanup by **@danny-avila** in [#7400](https://github.com/danny-avila/LibreChat/pull/7400)
- 🛡️ fix: Preset and Validation Logic for URL Query Params by **@danny-avila** in [#7407](https://github.com/danny-avila/LibreChat/pull/7407)
- 🌘 fix: artifact of preview text is illegible in dark mode by **@nhtruong** in [#7405](https://github.com/danny-avila/LibreChat/pull/7405)
- 🛡️ fix: Temporarily Remove CSP until Configurable by **@danny-avila** in [#7419](https://github.com/danny-avila/LibreChat/pull/7419)
- 💽 fix: Exclude index page `/` from static cache settings by **@sbruel** in [#7382](https://github.com/danny-avila/LibreChat/pull/7382)
### ⚙️ Other Changes ### ⚙️ Other Changes
- 📜 docs: CHANGELOG for release v0.7.8 by **@github-actions[bot]** in [#7290](https://github.com/danny-avila/LibreChat/pull/7290) - 📜 docs: CHANGELOG for release v0.7.8 by **@github-actions[bot]** in [#7290](https://github.com/danny-avila/LibreChat/pull/7290)
- 📦 chore: Update API Package Dependencies by **@danny-avila** in [#7359](https://github.com/danny-avila/LibreChat/pull/7359) - 📦 chore: Update API Package Dependencies by **@danny-avila** in [#7359](https://github.com/danny-avila/LibreChat/pull/7359)
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7321](https://github.com/danny-avila/LibreChat/pull/7321)
- 📜 docs: Unreleased Changelog by **@github-actions[bot]** in [#7434](https://github.com/danny-avila/LibreChat/pull/7434)
- 🛡️ chore: `multer` v2.0.0 for CVE-2025-47935 and CVE-2025-47944 by **@danny-avila** in [#7454](https://github.com/danny-avila/LibreChat/pull/7454)
- 📂 refactor: Improve `FileAttachment` & File Form Deletion by **@danny-avila** in [#7471](https://github.com/danny-avila/LibreChat/pull/7471)
- 📊 chore: Remove Old Helm Chart by **@hofq** in [#7512](https://github.com/danny-avila/LibreChat/pull/7512)
- 🪖 chore: bump helm app version to v0.7.8 by **@austin-barrington** in [#7524](https://github.com/danny-avila/LibreChat/pull/7524)
@@ -91,6 +67,7 @@ Changes from v0.7.8-rc1 to v0.7.8.
--- ---
## [v0.7.8-rc1] - ## [v0.7.8-rc1] -
## [v0.7.8-rc1] -
Changes from v0.7.7 to v0.7.8-rc1. Changes from v0.7.7 to v0.7.8-rc1.

View File

@@ -1,4 +1,4 @@
# v0.7.9 # v0.7.8
# Base node image # Base node image
FROM node:20-alpine AS node FROM node:20-alpine AS node

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi # Dockerfile.multi
# v0.7.9 # v0.7.8
# Base for all builds # Base for all builds
FROM node:20-alpine AS base-min FROM node:20-alpine AS base-min
@@ -14,7 +14,7 @@ RUN npm config set fetch-retry-maxtimeout 600000 && \
npm config set fetch-retry-mintimeout 15000 npm config set fetch-retry-mintimeout 15000
COPY package*.json ./ COPY package*.json ./
COPY packages/data-provider/package*.json ./packages/data-provider/ COPY packages/data-provider/package*.json ./packages/data-provider/
COPY packages/api/package*.json ./packages/api/ COPY packages/mcp/package*.json ./packages/mcp/
COPY packages/data-schemas/package*.json ./packages/data-schemas/ COPY packages/data-schemas/package*.json ./packages/data-schemas/
COPY client/package*.json ./client/ COPY client/package*.json ./client/
COPY api/package*.json ./api/ COPY api/package*.json ./api/
@@ -24,27 +24,26 @@ FROM base-min AS base
WORKDIR /app WORKDIR /app
RUN npm ci RUN npm ci
# Build `data-provider` package # Build data-provider
FROM base AS data-provider-build FROM base AS data-provider-build
WORKDIR /app/packages/data-provider WORKDIR /app/packages/data-provider
COPY packages/data-provider ./ COPY packages/data-provider ./
RUN npm run build RUN npm run build
# Build `data-schemas` package # Build mcp package
FROM base AS mcp-build
WORKDIR /app/packages/mcp
COPY packages/mcp ./
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
RUN npm run build
# Build data-schemas
FROM base AS data-schemas-build FROM base AS data-schemas-build
WORKDIR /app/packages/data-schemas WORKDIR /app/packages/data-schemas
COPY packages/data-schemas ./ COPY packages/data-schemas ./
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
RUN npm run build RUN npm run build
# Build `api` package
FROM base AS api-package-build
WORKDIR /app/packages/api
COPY packages/api ./
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist
RUN npm run build
# Client build # Client build
FROM base AS client-build FROM base AS client-build
WORKDIR /app/client WORKDIR /app/client
@@ -64,8 +63,8 @@ RUN npm ci --omit=dev
COPY api ./api COPY api ./api
COPY config ./config COPY config ./config
COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist
COPY --from=mcp-build /app/packages/mcp/dist ./packages/mcp/dist
COPY --from=data-schemas-build /app/packages/data-schemas/dist ./packages/data-schemas/dist COPY --from=data-schemas-build /app/packages/data-schemas/dist ./packages/data-schemas/dist
COPY --from=api-package-build /app/packages/api/dist ./packages/api/dist
COPY --from=client-build /app/client/dist ./client/dist COPY --from=client-build /app/client/dist ./client/dist
WORKDIR /app/api WORKDIR /app/api
EXPOSE 3080 EXPOSE 3080

View File

@@ -52,7 +52,7 @@
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features - 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
- 🤖 **AI Model Selection**: - 🤖 **AI Model Selection**:
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure) - Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required - [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints): - Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai, - Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
@@ -66,14 +66,10 @@
- 🔦 **Agents & Tools Integration**: - 🔦 **Agents & Tools Integration**:
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**: - **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding - No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more - Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more - Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools - [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
- 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context
- Combines search providers, content scrapers, and result rerankers for optimal results
- **[Learn More →](https://www.librechat.ai/docs/features/web_search)**
- 🪄 **Generative UI with Code Artifacts**: - 🪄 **Generative UI with Code Artifacts**:
- [Code Artifacts](https://youtu.be/GfTj7O4gmd0?si=WJbdnemZpJzBrJo3) allow creation of React, HTML, and Mermaid diagrams directly in chat - [Code Artifacts](https://youtu.be/GfTj7O4gmd0?si=WJbdnemZpJzBrJo3) allow creation of React, HTML, and Mermaid diagrams directly in chat
@@ -149,8 +145,8 @@ Click on the thumbnail to open the video☝
**Other:** **Other:**
- **Website:** [librechat.ai](https://librechat.ai) - **Website:** [librechat.ai](https://librechat.ai)
- **Documentation:** [librechat.ai/docs](https://librechat.ai/docs) - **Documentation:** [docs.librechat.ai](https://docs.librechat.ai)
- **Blog:** [librechat.ai/blog](https://librechat.ai/blog) - **Blog:** [blog.librechat.ai](https://blog.librechat.ai)
--- ---

View File

@@ -10,7 +10,6 @@ const {
validateVisionModel, validateVisionModel,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { SplitStreamHandler: _Handler } = require('@librechat/agents'); const { SplitStreamHandler: _Handler } = require('@librechat/agents');
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
const { const {
truncateText, truncateText,
formatMessage, formatMessage,
@@ -27,6 +26,8 @@ const {
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { createFetch, createStreamEventHandlers } = require('./generators');
const Tokenizer = require('~/server/services/Tokenizer');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
const BaseClient = require('./BaseClient'); const BaseClient = require('./BaseClient');
const { logger } = require('~/config'); const { logger } = require('~/config');
@@ -69,10 +70,13 @@ class AnthropicClient extends BaseClient {
this.message_delta; this.message_delta;
/** Whether the model is part of the Claude 3 Family /** Whether the model is part of the Claude 3 Family
* @type {boolean} */ * @type {boolean} */
this.isClaudeLatest; this.isClaude3;
/** Whether to use Messages API or Completions API /** Whether to use Messages API or Completions API
* @type {boolean} */ * @type {boolean} */
this.useMessages; this.useMessages;
/** Whether or not the model is limited to the legacy amount of output tokens
* @type {boolean} */
this.isLegacyOutput;
/** Whether or not the model supports Prompt Caching /** Whether or not the model supports Prompt Caching
* @type {boolean} */ * @type {boolean} */
this.supportsCacheControl; this.supportsCacheControl;
@@ -112,25 +116,21 @@ class AnthropicClient extends BaseClient {
); );
const modelMatch = matchModelName(this.modelOptions.model, EModelEndpoint.anthropic); const modelMatch = matchModelName(this.modelOptions.model, EModelEndpoint.anthropic);
this.isClaudeLatest = this.isClaude3 = modelMatch.includes('claude-3');
/claude-[3-9]/.test(modelMatch) || /claude-(?:sonnet|opus|haiku)-[4-9]/.test(modelMatch); this.isLegacyOutput = !(
const isLegacyOutput = !( /claude-3[-.]5-sonnet/.test(modelMatch) || /claude-3[-.]7/.test(modelMatch)
/claude-3[-.]5-sonnet/.test(modelMatch) ||
/claude-3[-.]7/.test(modelMatch) ||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(modelMatch) ||
/claude-[4-9]/.test(modelMatch)
); );
this.supportsCacheControl = this.options.promptCache && checkPromptCacheSupport(modelMatch); this.supportsCacheControl = this.options.promptCache && checkPromptCacheSupport(modelMatch);
if ( if (
isLegacyOutput && this.isLegacyOutput &&
this.modelOptions.maxOutputTokens && this.modelOptions.maxOutputTokens &&
this.modelOptions.maxOutputTokens > legacy.maxOutputTokens.default this.modelOptions.maxOutputTokens > legacy.maxOutputTokens.default
) { ) {
this.modelOptions.maxOutputTokens = legacy.maxOutputTokens.default; this.modelOptions.maxOutputTokens = legacy.maxOutputTokens.default;
} }
this.useMessages = this.isClaudeLatest || !!this.options.attachments; this.useMessages = this.isClaude3 || !!this.options.attachments;
this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229'; this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229';
this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));
@@ -190,11 +190,10 @@ class AnthropicClient extends BaseClient {
reverseProxyUrl: this.options.reverseProxyUrl, reverseProxyUrl: this.options.reverseProxyUrl,
}), }),
apiKey: this.apiKey, apiKey: this.apiKey,
fetchOptions: {},
}; };
if (this.options.proxy) { if (this.options.proxy) {
options.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy); options.httpAgent = new HttpsProxyAgent(this.options.proxy);
} }
if (this.options.reverseProxyUrl) { if (this.options.reverseProxyUrl) {
@@ -655,10 +654,7 @@ class AnthropicClient extends BaseClient {
); );
}; };
if ( if (this.modelOptions.model.includes('claude-3')) {
/claude-[3-9]/.test(this.modelOptions.model) ||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(this.modelOptions.model)
) {
await buildMessagesPayload(); await buildMessagesPayload();
processTokens(); processTokens();
return { return {

View File

@@ -13,6 +13,7 @@ const {
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { checkBalance } = require('~/models/balanceMethods'); const { checkBalance } = require('~/models/balanceMethods');
const { truncateToolCallOutputs } = require('./prompts'); const { truncateToolCallOutputs } = require('./prompts');
const { addSpaceIfNeeded } = require('~/server/utils');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream'); const TextStream = require('./TextStream');
const { logger } = require('~/config'); const { logger } = require('~/config');
@@ -108,15 +109,12 @@ class BaseClient {
/** /**
* Abstract method to record token usage. Subclasses must implement this method. * Abstract method to record token usage. Subclasses must implement this method.
* If a correction to the token usage is needed, the method should return an object with the corrected token counts. * If a correction to the token usage is needed, the method should return an object with the corrected token counts.
* Should only be used if `recordCollectedUsage` was not used instead.
* @param {string} [model]
* @param {number} promptTokens * @param {number} promptTokens
* @param {number} completionTokens * @param {number} completionTokens
* @returns {Promise<void>} * @returns {Promise<void>}
*/ */
async recordTokenUsage({ model, promptTokens, completionTokens }) { async recordTokenUsage({ promptTokens, completionTokens }) {
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', { logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
model,
promptTokens, promptTokens,
completionTokens, completionTokens,
}); });
@@ -200,10 +198,6 @@ class BaseClient {
this.currentMessages[this.currentMessages.length - 1].messageId = head; this.currentMessages[this.currentMessages.length - 1].messageId = head;
} }
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
responseMessageId = crypto.randomUUID();
}
this.responseMessageId = responseMessageId; this.responseMessageId = responseMessageId;
return { return {
@@ -578,7 +572,7 @@ class BaseClient {
}); });
} }
const { editedContent } = opts; const { generation = '' } = opts;
// It's not necessary to push to currentMessages // It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages // depending on subclass implementation of handling messages
@@ -593,21 +587,11 @@ class BaseClient {
isCreatedByUser: false, isCreatedByUser: false,
model: this.modelOptions?.model ?? this.model, model: this.modelOptions?.model ?? this.model,
sender: this.sender, sender: this.sender,
text: generation,
}; };
this.currentMessages.push(userMessage, latestMessage); this.currentMessages.push(userMessage, latestMessage);
} else if (editedContent != null) { } else {
// Handle editedContent for content parts latestMessage.text = generation;
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
const { index, text, type } = editedContent;
if (index >= 0 && index < latestMessage.content.length) {
const contentPart = latestMessage.content[index];
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
contentPart[ContentTypes.THINK] = text;
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
contentPart[ContentTypes.TEXT] = text;
}
}
}
} }
this.continued = true; this.continued = true;
} else { } else {
@@ -688,32 +672,16 @@ class BaseClient {
}; };
if (typeof completion === 'string') { if (typeof completion === 'string') {
responseMessage.text = completion; responseMessage.text = addSpaceIfNeeded(generation) + completion;
} else if ( } else if (
Array.isArray(completion) && Array.isArray(completion) &&
(this.clientName === EModelEndpoint.agents || (this.clientName === EModelEndpoint.agents ||
isParamEndpoint(this.options.endpoint, this.options.endpointType)) isParamEndpoint(this.options.endpoint, this.options.endpointType))
) { ) {
responseMessage.text = ''; responseMessage.text = '';
responseMessage.content = completion;
if (!opts.editedContent || this.currentMessages.length === 0) {
responseMessage.content = completion;
} else {
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
if (!latestMessage?.content) {
responseMessage.content = completion;
} else {
const existingContent = [...latestMessage.content];
const { type: editedType } = opts.editedContent;
responseMessage.content = this.mergeEditedContent(
existingContent,
completion,
editedType,
);
}
}
} else if (Array.isArray(completion)) { } else if (Array.isArray(completion)) {
responseMessage.text = completion.join(''); responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
} }
if ( if (
@@ -744,13 +712,9 @@ class BaseClient {
} else { } else {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
completionTokens = responseMessage.tokenCount; completionTokens = responseMessage.tokenCount;
await this.recordTokenUsage({
usage,
promptTokens,
completionTokens,
model: responseMessage.model,
});
} }
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
} }
if (userMessagePromise) { if (userMessagePromise) {
@@ -828,8 +792,7 @@ class BaseClient {
userMessage.tokenCount = userMessageTokenCount; userMessage.tokenCount = userMessageTokenCount;
/* /*
Note: `AgentController` saves the user message if not saved here Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
*/ */
if (typeof opts?.getReqData === 'function') { if (typeof opts?.getReqData === 'function') {
opts.getReqData({ opts.getReqData({
@@ -838,8 +801,7 @@ class BaseClient {
} }
/* /*
Note: we update the user message to be sure it gets the calculated token count; Note: we update the user message to be sure it gets the calculated token count;
though `AgentController` saves the user message if not saved here though `AskController` saves the user message, EditController does not
(noted by `savedMessageIds`), EditController does not
*/ */
await userMessagePromise; await userMessagePromise;
await this.updateMessageInDatabase({ await this.updateMessageInDatabase({
@@ -1131,50 +1093,6 @@ class BaseClient {
return numTokens; return numTokens;
} }
/**
* Merges completion content with existing content when editing TEXT or THINK types
* @param {Array} existingContent - The existing content array
* @param {Array} newCompletion - The new completion content
* @param {string} editedType - The type of content being edited
* @returns {Array} The merged content array
*/
mergeEditedContent(existingContent, newCompletion, editedType) {
if (!newCompletion.length) {
return existingContent.concat(newCompletion);
}
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
return existingContent.concat(newCompletion);
}
const lastIndex = existingContent.length - 1;
const lastExisting = existingContent[lastIndex];
const firstNew = newCompletion[0];
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
return existingContent.concat(newCompletion);
}
const mergedContent = [...existingContent];
if (editedType === ContentTypes.TEXT) {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.TEXT]:
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
};
} else {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.THINK]:
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
(firstNew[ContentTypes.THINK] || ''),
};
}
// Add remaining completion items
return mergedContent.concat(newCompletion.slice(1));
}
async sendPayload(payload, opts = {}) { async sendPayload(payload, opts = {}) {
if (opts && typeof opts === 'object') { if (opts && typeof opts === 'object') {
this.setOptions(opts); this.setOptions(opts);

View File

@@ -0,0 +1,804 @@
const { Keyv } = require('keyv');
const crypto = require('crypto');
const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
ImageDetail,
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils');
const { createContextHandlers } = require('./prompts');
const { createCoherePayload } = require('./llm');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {};
class ChatGPTClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
this.conversationsCache = new Keyv(cacheOptions);
this.setOptions(options);
}
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = {
...this.options.modelOptions,
...options.modelOptions,
};
delete options.modelOptions;
// now we can merge options
this.options = {
...this.options,
...options,
};
} else {
this.options = options;
}
if (this.options.openaiApiKey) {
this.apiKey = this.options.openaiApiKey;
}
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || CHATGPT_MODEL,
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop,
};
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
const { isChatGptModel } = this;
this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
const { isUnofficialChatGptModel } = this;
// Davinci models have a max context length of 4097 tokens.
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
// I decided to reserve 1024 tokens for the response.
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
}
this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
if (isChatGptModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isUnofficialChatGptModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
});
} else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead.
this.startToken = '||>';
this.endToken = '';
try {
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch {
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
}
}
if (!this.modelOptions.stop) {
const stopTokens = [this.startToken];
if (this.endToken && this.endToken !== this.startToken) {
stopTokens.push(this.endToken);
}
stopTokens.push(`\n${this.userLabel}:`);
stopTokens.push('<|diff_marker|>');
// I chose not to do one for `chatGptLabel` because I've never seen it happen
this.modelOptions.stop = stopTokens;
}
if (this.options.reverseProxyUrl) {
this.completionsUrl = this.options.reverseProxyUrl;
} else if (isChatGptModel) {
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
} else {
this.completionsUrl = 'https://api.openai.com/v1/completions';
}
return this;
}
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
}
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
let modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = input;
} else {
modelOptions.prompt = input;
}
if (this.useOpenRouter && modelOptions.prompt) {
delete modelOptions.stop;
}
const { debug } = this.options;
let baseURL = this.completionsUrl;
if (debug) {
console.debug();
console.debug(baseURL);
console.debug(modelOptions);
console.debug();
}
const opts = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
};
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}
/** @type {TAzureConfig | undefined} */
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
const isAzure = this.azure || this.options.azure;
if (
(isAzure && this.isVisionModel && azureConfig) ||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName: modelOptions.model,
modelGroupMap,
groupMap,
});
opts.headers = resolveHeaders(headers);
this.langchainProxy = extractBaseURL(baseURL);
this.apiKey = azureOptions.azureOpenAIApiKey;
const groupName = modelGroupMap[modelOptions.model].group;
this.options.addParams = azureConfig.groupMap[groupName].addParams;
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
// Note: `forcePrompt` not re-assigned as only chat models are vision models
this.azure = !serverless && azureOptions;
this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
}
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
}
if (this.options.headers) {
opts.headers = { ...opts.headers, ...this.options.headers };
}
if (isAzure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
baseURL = this.langchainProxy
? constructAzureURL({
baseURL: this.langchainProxy,
azureOptions: this.azure,
})
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
if (this.options.forcePrompt) {
baseURL += '/completions';
} else {
baseURL += '/chat/completions';
}
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
} else if (this.apiKey) {
opts.headers.Authorization = `Bearer ${this.apiKey}`;
}
if (process.env.OPENAI_ORGANIZATION) {
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
}
if (this.useOpenRouter) {
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
opts.headers['X-Title'] = 'LibreChat';
}
/* hacky fixes for Mistral AI API:
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
- If there is only one message and it's a system message, change the role to user
*/
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
const { messages } = modelOptions;
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
if (systemMessageIndex > 0) {
const [systemMessage] = messages.splice(systemMessageIndex, 1);
messages.unshift(systemMessage);
}
modelOptions.messages = messages;
if (messages.length === 1 && messages[0].role === 'system') {
modelOptions.messages[0].role = 'user';
}
}
if (this.options.addParams && typeof this.options.addParams === 'object') {
modelOptions = {
...modelOptions,
...this.options.addParams,
};
logger.debug('[ChatGPTClient] chatCompletion: added params', {
addParams: this.options.addParams,
modelOptions,
});
}
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => {
delete modelOptions[param];
});
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
dropParams: this.options.dropParams,
modelOptions,
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
baseURL.includes('v1') &&
!baseURL.includes('/chat/completions') &&
this.isChatCompletion
) {
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
}
const BASE_URL = new URL(baseURL);
if (opts.defaultQuery) {
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
BASE_URL.searchParams.append(key, value);
});
delete opts.defaultQuery;
}
const completionsURL = BASE_URL.toString();
opts.body = JSON.stringify(modelOptions);
if (modelOptions.stream) {
return new Promise(async (resolve, reject) => {
try {
let done = false;
await fetchEventSource(completionsURL, {
...opts,
signal: abortController.signal,
async onopen(response) {
if (response.status === 200) {
return;
}
if (debug) {
console.debug(response);
}
let error;
try {
const body = await response.text();
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
error.json = JSON.parse(body);
} catch {
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
}
throw error;
},
onclose() {
if (debug) {
console.debug('Server closed the connection unexpectedly, returning...');
}
// workaround for private API not sending [DONE] event
if (!done) {
onProgress('[DONE]');
resolve();
}
},
onerror(err) {
if (debug) {
console.debug(err);
}
// rethrow to stop the operation
throw err;
},
onmessage(message) {
if (debug) {
console.debug(message);
}
if (!message.data || message.event === 'ping') {
return;
}
if (message.data === '[DONE]') {
onProgress('[DONE]');
resolve();
done = true;
return;
}
onProgress(JSON.parse(message.data));
},
});
} catch (err) {
reject(err);
}
});
}
const response = await fetch(completionsURL, {
...opts,
signal: abortController.signal,
});
if (response.status !== 200) {
const body = await response.text();
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
try {
error.json = JSON.parse(body);
} catch {
error.body = body;
}
throw error;
}
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
reply += message.text;
}
/*
Cohere API Chinese Unicode character replacement hotfix.
Should be un-commented when the following issue is resolved:
https://github.com/cohere-ai/cohere-typescript/issues/151
else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
*/
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
||>Message:
${userMessage.message}
||>Response:
${botMessage.message}
||>Title:`,
};
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
titleGenClientOptions.modelOptions = {
model: 'gpt-3.5-turbo',
temperature: 0,
presence_penalty: 0,
frequency_penalty: 0,
};
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
const result = await titleGenClient.getCompletion([instructionsPayload], null);
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
return result.choices[0].message.content
.replace(/[^a-zA-Z0-9' ]/g, '')
.replace(/\s+/g, ' ')
.trim();
}
async sendMessage(message, opts = {}) {
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
this.setOptions(opts.clientOptions);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
let conversation =
typeof opts.conversation === 'object'
? opts.conversation
: await this.conversationsCache.get(conversationId);
let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
isNewConversation = true;
}
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const userMessage = {
id: crypto.randomUUID(),
parentMessageId,
role: 'User',
message,
};
conversation.messages.push(userMessage);
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
const { prompt: payload, context } = await this.buildPrompt(
conversation.messages,
userMessage.id,
{
isChatGptModel: this.isChatGptModel,
promptPrefix: opts.promptPrefix,
},
);
if (this.options.keepNecessaryMessagesOnly) {
conversation.messages = context;
}
let reply = '';
let result = null;
if (typeof opts.onProgress === 'function') {
await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
return;
}
const token = this.isChatGptModel
? progressMessage.choices[0].delta.content
: progressMessage.choices[0].text;
// first event's delta content is always undefined
if (!token) {
return;
}
if (this.options.debug) {
console.debug(token);
}
if (token === this.endToken) {
return;
}
opts.onProgress(token);
reply += token;
},
opts.abortController || new AbortController(),
);
} else {
result = await this.getCompletion(
payload,
null,
opts.abortController || new AbortController(),
);
if (this.options.debug) {
console.debug(JSON.stringify(result));
}
if (this.isChatGptModel) {
reply = result.choices[0].message.content;
} else {
reply = result.choices[0].text.replace(this.endToken, '');
}
}
// avoids some rendering issues when using the CLI app
if (this.options.debug) {
console.debug();
}
reply = reply.trim();
const replyMessage = {
id: crypto.randomUUID(),
parentMessageId: userMessage.id,
role: 'ChatGPT',
message: reply,
};
conversation.messages.push(replyMessage);
const returnData = {
response: replyMessage.message,
conversationId,
parentMessageId: replyMessage.parentMessageId,
messageId: replyMessage.id,
details: result || {},
};
if (shouldGenerateTitle) {
conversation.title = await this.generateTitle(userMessage, replyMessage);
returnData.title = conversation.title;
}
await this.conversationsCache.set(conversationId, conversation);
if (this.options.returnConversation) {
returnData.conversation = conversation;
}
return returnData;
}
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
// Handle attachments and create augmentedPrompt
if (this.options.attachments) {
const attachments = await this.options.attachments;
const lastMessage = messages[messages.length - 1];
if (this.message_file_map) {
this.message_file_map[lastMessage.messageId] = attachments;
} else {
this.message_file_map = {
[lastMessage.messageId]: attachments,
};
}
const files = await this.addImageURLs(lastMessage, attachments);
this.options.attachments = files;
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
messages[messages.length - 1].text,
);
}
// Calculate image token cost and process embedded files
messages.forEach((message, i) => {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
messages[i].tokenCount =
(messages[i].tokenCount || 0) +
this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
}
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
let currentTokenCount;
if (isChatGptModel) {
currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
} else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
}
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
const context = [];
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && messages.length > 0) {
const message = messages.pop();
const roleLabel =
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
? this.userLabel
: this.chatGptLabel;
const messageString = `${this.startToken}${roleLabel}:\n${
message?.text ?? message?.message
}${this.endToken}\n`;
let newPromptBody;
if (promptBody || isChatGptModel) {
newPromptBody = `${messageString}${promptBody}`;
} else {
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
// like "what's the last thing I wrote?".
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}
context.unshift(message);
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = `${promptBody}${promptSuffix}`;
if (isChatGptModel) {
messagePayload.content = prompt;
// Add 3 tokens for Assistant Label priming after all messages have been counted.
currentTokenCount += 3;
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (isChatGptModel) {
return { prompt: [instructionsPayload, messagePayload], context };
}
return { prompt, context, promptTokens: currentTokenCount };
}
getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length;
}
/**
* Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
*
* @param {Object} message
*/
getTokenCountForMessage(message) {
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let tokensPerMessage = 3;
let tokensPerName = 1;
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
tokensPerMessage = 4;
tokensPerName = -1;
}
let numTokens = tokensPerMessage;
for (let [key, value] of Object.entries(message)) {
numTokens += this.getTokenCount(value);
if (key === 'name') {
numTokens += tokensPerName;
}
}
return numTokens;
}
}
module.exports = ChatGPTClient;

View File

@@ -1,7 +1,6 @@
const { google } = require('googleapis'); const { google } = require('googleapis');
const { concat } = require('@langchain/core/utils/stream'); const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai'); const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai'); const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { HumanMessage, SystemMessage } = require('@langchain/core/messages'); const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
@@ -12,14 +11,15 @@ const {
endpointSettings, endpointSettings,
parseTextParts, parseTextParts,
EModelEndpoint, EModelEndpoint,
googleSettings,
ContentTypes, ContentTypes,
VisionModes, VisionModes,
ErrorTypes, ErrorTypes,
Constants, Constants,
AuthKeys, AuthKeys,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
const { encodeAndFormat } = require('~/server/services/Files/images'); const { encodeAndFormat } = require('~/server/services/Files/images');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens'); const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils'); const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils'); const { sleep } = require('~/server/utils');
@@ -34,8 +34,7 @@ const BaseClient = require('./BaseClient');
const loc = process.env.GOOGLE_LOC || 'us-central1'; const loc = process.env.GOOGLE_LOC || 'us-central1';
const publisher = 'google'; const publisher = 'google';
const endpointPrefix = const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`;
const settings = endpointSettings[EModelEndpoint.google]; const settings = endpointSettings[EModelEndpoint.google];
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
@@ -166,16 +165,6 @@ class GoogleClient extends BaseClient {
); );
} }
// Add thinking configuration
this.modelOptions.thinkingConfig = {
thinkingBudget:
(this.modelOptions.thinking ?? googleSettings.thinking.default)
? this.modelOptions.thinkingBudget
: 0,
};
delete this.modelOptions.thinking;
delete this.modelOptions.thinkingBudget;
this.sender = this.sender =
this.options.sender ?? this.options.sender ??
getResponseSender({ getResponseSender({
@@ -247,11 +236,11 @@ class GoogleClient extends BaseClient {
msg.content = ( msg.content = (
!Array.isArray(msg.content) !Array.isArray(msg.content)
? [ ? [
{ {
type: ContentTypes.TEXT, type: ContentTypes.TEXT,
[ContentTypes.TEXT]: msg.content, [ContentTypes.TEXT]: msg.content,
}, },
] ]
: msg.content : msg.content
).concat(message.image_urls); ).concat(message.image_urls);

View File

@@ -1,11 +1,10 @@
const { z } = require('zod'); const { z } = require('zod');
const axios = require('axios'); const axios = require('axios');
const { Ollama } = require('ollama'); const { Ollama } = require('ollama');
const { sleep } = require('@librechat/agents');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider'); const { Constants } = require('librechat-data-provider');
const { deriveBaseURL } = require('~/utils'); const { deriveBaseURL, logAxiosError } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const ollamaPayloadSchema = z.object({ const ollamaPayloadSchema = z.object({
mirostat: z.number().optional(), mirostat: z.number().optional(),
@@ -68,7 +67,7 @@ class OllamaClient {
return models; return models;
} catch (error) { } catch (error) {
const logMessage = const logMessage =
"Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive)."; 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).';
logAxiosError({ message: logMessage, error }); logAxiosError({ message: logMessage, error });
return []; return [];
} }

View File

@@ -1,21 +1,13 @@
const { OllamaClient } = require('./OllamaClient'); const { OllamaClient } = require('./OllamaClient');
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents'); const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents');
const {
isEnabled,
Tokenizer,
createFetch,
resolveHeaders,
constructAzureURL,
genAzureChatCompletion,
createStreamEventHandlers,
} = require('@librechat/api');
const { const {
Constants, Constants,
ImageDetail, ImageDetail,
ContentTypes, ContentTypes,
parseTextParts, parseTextParts,
EModelEndpoint, EModelEndpoint,
resolveHeaders,
KnownEndpoints, KnownEndpoints,
openAISettings, openAISettings,
ImageDetailCost, ImageDetailCost,
@@ -24,6 +16,13 @@ const {
validateVisionModel, validateVisionModel,
mapModelToAzureConfig, mapModelToAzureConfig,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const {
extractBaseURL,
constructAzureURL,
getModelMaxTokens,
genAzureChatCompletion,
getModelMaxOutputTokens,
} = require('~/utils');
const { const {
truncateText, truncateText,
formatMessage, formatMessage,
@@ -31,12 +30,14 @@ const {
titleInstruction, titleInstruction,
createContextHandlers, createContextHandlers,
} = require('./prompts'); } = require('./prompts');
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { addSpaceIfNeeded, sleep } = require('~/server/utils'); const { createFetch, createStreamEventHandlers } = require('./generators');
const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens'); const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util'); const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm'); const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory'); const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains'); const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document'); const { tokenSplit } = require('./document');
@@ -46,6 +47,12 @@ const { logger } = require('~/config');
class OpenAIClient extends BaseClient { class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) { constructor(apiKey, options = {}) {
super(apiKey, options); super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase() ? options.contextStrategy.toLowerCase()
: 'discard'; : 'discard';
@@ -372,12 +379,23 @@ class OpenAIClient extends BaseClient {
return files; return files;
} }
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) { async buildMessages(
messages,
parentMessageId,
{ isChatCompletion = false, promptPrefix = null },
opts,
) {
let orderedMessages = this.constructor.getMessagesForConversation({ let orderedMessages = this.constructor.getMessagesForConversation({
messages, messages,
parentMessageId, parentMessageId,
summary: this.shouldSummarize, summary: this.shouldSummarize,
}); });
if (!isChatCompletion) {
return await this.buildPrompt(orderedMessages, {
isChatGptModel: isChatCompletion,
promptPrefix,
});
}
let payload; let payload;
let instructions; let instructions;
@@ -1141,7 +1159,6 @@ ${convo}
logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions }); logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions });
const opts = { const opts = {
baseURL, baseURL,
fetchOptions: {},
}; };
if (this.useOpenRouter) { if (this.useOpenRouter) {
@@ -1160,7 +1177,7 @@ ${convo}
} }
if (this.options.proxy) { if (this.options.proxy) {
opts.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy); opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
} }
/** @type {TAzureConfig | undefined} */ /** @type {TAzureConfig | undefined} */
@@ -1378,7 +1395,7 @@ ${convo}
...modelOptions, ...modelOptions,
stream: true, stream: true,
}; };
const stream = await openai.chat.completions const stream = await openai.beta.chat.completions
.stream(params) .stream(params)
.on('abort', () => { .on('abort', () => {
/* Do nothing here */ /* Do nothing here */

View File

@@ -0,0 +1,542 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { processFileURL } = require('~/server/services/Files/process');
const { EModelEndpoint } = require('librechat-data-provider');
const { checkBalance } = require('~/models/balanceMethods');
const { formatLangChainMessages } = require('./prompts');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { logger } = require('~/config');
class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.sender = options.sender ?? 'Assistant';
this.tools = [];
this.actions = [];
this.setOptions(options);
this.openAIApiKey = this.apiKey;
this.executor = null;
}
setOptions(options) {
this.agentOptions = { ...options.agentOptions };
this.functionsAgent = this.agentOptions?.agent === 'functions';
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
super.setOptions(options);
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) {
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
}
}
getSaveOptions() {
return {
artifacts: this.options.artifacts,
chatGptLabel: this.options.chatGptLabel,
modelLabel: this.options.modelLabel,
promptPrefix: this.options.promptPrefix,
tools: this.options.tools,
...this.modelOptions,
agentOptions: this.agentOptions,
iconURL: this.options.iconURL,
greeting: this.options.greeting,
spec: this.options.spec,
};
}
saveLatestAction(action) {
this.actions.push(action);
}
getFunctionModelName(input) {
if (/-(?!0314)\d{4}/.test(input)) {
return input;
} else if (input.includes('gpt-3.5-turbo')) {
return 'gpt-3.5-turbo';
} else if (input.includes('gpt-4')) {
return 'gpt-4';
} else {
return 'gpt-3.5-turbo';
}
}
getBuildMessagesOptions(opts) {
return {
isChatCompletion: true,
promptPrefix: opts.promptPrefix,
abortController: opts.abortController,
};
}
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = {
modelName: this.agentOptions.model,
temperature: this.agentOptions.temperature,
};
const model = this.initializeLLM({
...modelOptions,
context: 'plugins',
initialMessageCount: this.currentMessages.length + 1,
});
logger.debug(
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
);
// Map Messages to Langchain format
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
userName: this.options?.name,
});
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
const memory = new BufferMemory({
llm: model,
chatHistory: new ChatMessageHistory(pastMessages),
});
const { loadedTools } = await loadTools({
user,
model,
tools: this.options.tools,
functions: this.functionsAgent,
options: {
memory,
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
fileStrategy: this.options.req.app.locals.fileStrategy,
processFileURL,
message,
},
useSpecs: true,
});
if (loadedTools.length === 0) {
return;
}
this.tools = loadedTools;
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
logger.debug(
'[PluginsClient] Loaded Tools',
this.tools.map((tool) => tool.name),
);
const handleAction = (action, runId, callback = null) => {
this.saveLatestAction(action);
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
if (typeof callback === 'function') {
callback(action, runId);
}
};
// initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
let customInstructions = (this.options.promptPrefix ?? '').trim();
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
}
this.executor = await initializer({
model,
signal,
pastMessages,
tools: this.tools,
customInstructions,
verbose: this.options.debug,
returnIntermediateSteps: true,
customName: this.options.chatGptLabel,
currentDateString: this.currentDateString,
callbackManager: CallbackManager.fromHandlers({
async handleAgentAction(action, runId) {
handleAction(action, runId, onAgentAction);
},
async handleChainEnd(action) {
if (typeof onChainEnd === 'function') {
onChainEnd(action);
}
},
}),
});
logger.debug('[PluginsClient] Loaded agent.');
}
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
let errorMessage = '';
const maxAttempts = 1;
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
const errorInput = buildErrorInput({
message,
errorMessage,
actions: this.actions,
functionsAgent: this.functionsAgent,
});
const input = attempts > 1 ? errorInput : message;
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
if (errorMessage.length > 0) {
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
}
try {
this.result = await this.executor.call({ input, signal }, [
{
async handleToolStart(...args) {
await onToolStart(...args);
},
async handleToolEnd(...args) {
await onToolEnd(...args);
},
async handleLLMEnd(output) {
const { generations } = output;
const { text } = generations[0][0];
if (text && typeof stream === 'function') {
await stream(text);
}
},
},
]);
break; // Exit the loop if the function call is successful
} catch (err) {
logger.error('[PluginsClient] executorCall error:', err);
if (attempts === maxAttempts) {
const { run } = this.runManager.getRunByConversationId(this.conversationId);
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
this.result.output = run && run.error ? run.error : defaultOutput;
this.result.errorMessage = run && run.error ? run.error : err.message;
this.result.intermediateSteps = this.actions;
break;
}
}
}
}
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
output,
errorMessage,
...result,
});
const { error } = responseMessage;
if (!error) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
}
// Record usage only when completion is skipped as it is already recorded in the agent phase.
if (!this.agentOptions.skipCompletion && !error) {
await this.recordTokenUsage(responseMessage);
}
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result, databasePromise };
}
async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);
}
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
} = await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
this.currentMessages.push(userMessage);
let {
prompt: payload,
tokenCountMap,
promptTokens,
} = await this.buildMessages(
this.currentMessages,
userMessage.messageId,
this.getBuildMessagesOptions({
promptPrefix: null,
abortController: this.abortController,
}),
);
if (tokenCountMap) {
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
}
this.handleTokenCountMap(tokenCountMap);
}
this.result = {};
if (payload) {
this.currentMessages = payload;
}
if (!this.skipSaveUserMessage) {
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise,
});
}
}
const balance = this.options.req?.app?.locals?.balance;
if (balance?.enabled) {
await checkBalance({
req: this.options.req,
res: this.options.res,
txData: {
user: this.user,
tokenType: 'prompt',
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: EModelEndpoint.openAI,
},
});
}
const responseMessage = {
endpoint: EModelEndpoint.gptPlugins,
iconURL: this.options.iconURL,
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
isCreatedByUser: false,
model: this.modelOptions.model,
sender: this.sender,
promptTokens,
};
await this.initialize({
user,
message,
onAgentAction,
onChainEnd,
signal: this.abortController.signal,
onProgress: opts.onProgress,
});
// const stream = async (text) => {
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
// };
await this.executorCall(message, {
signal: this.abortController.signal,
// stream,
onToolStart,
onToolEnd,
});
// If message was aborted mid-generation
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
responseMessage.text = 'Cancelled.';
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
// If error occurred during generation (likely token_balance)
if (this.result?.errorMessage?.length > 0) {
responseMessage.error = true;
responseMessage.text = this.result.output;
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
const partialText = opts.getPartialText();
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
responseMessage.text =
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output) {
responseMessage.text = this.result.output;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
const promptPrefix = buildPromptPrefix({
result: this.result,
message,
functionsAgent: this.functionsAgent,
});
logger.debug('[PluginsClient]', { promptPrefix });
payload = await this.buildCompletionPrompt({
messages: this.currentMessages,
promptPrefix,
});
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
responseMessage.text = await this.sendCompletion(payload, opts);
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
const orderedMessages = messages;
let promptPrefix = _promptPrefix.trim();
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
if (this.isGpt3) {
instructionsPayload.role = 'user';
messagePayload.role = 'user';
instructionsPayload.content += `\n${promptSuffix}`;
}
// testing if this works with browser endpoint
if (!this.isGpt3 && this.options.reverseProxyUrl) {
instructionsPayload.role = 'user';
}
let currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop();
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${
message.text ?? message.content ?? ''
}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setTimeout(resolve, 0));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = promptBody;
messagePayload.content = prompt;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
if (this.isGpt3 && messagePayload.content.length > 0) {
const context = 'Chat History:\n';
messagePayload.content = `${context}${prompt}`;
currentTokenCount += this.getTokenCount(context);
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (this.isGpt3) {
messagePayload.content += promptSuffix;
return [instructionsPayload, messagePayload];
}
const result = [messagePayload, instructionsPayload];
if (this.functionsAgent && !this.isGpt3) {
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
}
return result.filter((message) => message.content.length > 0);
}
}
module.exports = PluginsClient;

View File

@@ -0,0 +1,71 @@
const fetch = require('node-fetch');
const { GraphEvents } = require('@librechat/agents');
const { logger, sendEvent } = require('~/config');
const { sleep } = require('~/server/utils');
/**
* Makes a function to make HTTP request and logs the process.
* @param {Object} params
* @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint.
* @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) {
/**
* Makes an HTTP request and logs the process.
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
* @param {RequestInit} [init] - Optional init options for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
return async (_url, init) => {
let url = _url;
if (directEndpoint) {
url = reverseProxyUrl;
}
logger.debug(`Making request to ${url}`);
if (typeof Bun !== 'undefined') {
return await fetch(url, init);
}
return await fetch(url, init);
};
}
// Add this at the module level outside the class
/**
* Creates event handlers for stream events that don't capture client references
* @param {Object} res - The response object to send events to
* @returns {Object} Object containing handler functions
*/
function createStreamEventHandlers(res) {
return {
[GraphEvents.ON_RUN_STEP]: (event) => {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_MESSAGE_DELTA]: (event) => {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_REASONING_DELTA]: (event) => {
if (res) {
sendEvent(res, event);
}
},
};
}
function createHandleLLMNewToken(streamRate) {
return async () => {
if (streamRate) {
await sleep(streamRate);
}
};
}
module.exports = {
createFetch,
createHandleLLMNewToken,
createStreamEventHandlers,
};

View File

@@ -1,11 +1,15 @@
const ChatGPTClient = require('./ChatGPTClient');
const OpenAIClient = require('./OpenAIClient'); const OpenAIClient = require('./OpenAIClient');
const PluginsClient = require('./PluginsClient');
const GoogleClient = require('./GoogleClient'); const GoogleClient = require('./GoogleClient');
const TextStream = require('./TextStream'); const TextStream = require('./TextStream');
const AnthropicClient = require('./AnthropicClient'); const AnthropicClient = require('./AnthropicClient');
const toolUtils = require('./tools/util'); const toolUtils = require('./tools/util');
module.exports = { module.exports = {
ChatGPTClient,
OpenAIClient, OpenAIClient,
PluginsClient,
GoogleClient, GoogleClient,
TextStream, TextStream,
AnthropicClient, AnthropicClient,

View File

@@ -1,5 +1,6 @@
const { ChatOpenAI } = require('@langchain/openai'); const { ChatOpenAI } = require('@langchain/openai');
const { isEnabled, sanitizeModelName, constructAzureURL } = require('@librechat/api'); const { sanitizeModelName, constructAzureURL } = require('~/utils');
const { isEnabled } = require('~/server/utils');
/** /**
* Creates a new instance of a language model (LLM) for chat interactions. * Creates a new instance of a language model (LLM) for chat interactions.

View File

@@ -1,7 +1,6 @@
const axios = require('axios'); const axios = require('axios');
const { isEnabled } = require('@librechat/api'); const { isEnabled } = require('~/server/utils');
const { logger } = require('@librechat/data-schemas'); const { logger } = require('~/config');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const footer = `Use the context as your learned knowledge to better answer the user. const footer = `Use the context as your learned knowledge to better answer the user.
@@ -19,7 +18,7 @@ function createContextHandlers(req, userMessageContent) {
const queryPromises = []; const queryPromises = [];
const processedFiles = []; const processedFiles = [];
const processedIds = new Set(); const processedIds = new Set();
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT); const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
const query = async (file) => { const query = async (file) => {
@@ -97,35 +96,35 @@ function createContextHandlers(req, userMessageContent) {
resolvedQueries.length === 0 resolvedQueries.length === 0
? '\n\tThe semantic search did not return any results.' ? '\n\tThe semantic search did not return any results.'
: resolvedQueries : resolvedQueries
.map((queryResult, index) => { .map((queryResult, index) => {
const file = processedFiles[index]; const file = processedFiles[index];
let contextItems = queryResult.data; let contextItems = queryResult.data;
const generateContext = (currentContext) => const generateContext = (currentContext) =>
` `
<file> <file>
<filename>${file.filename}</filename> <filename>${file.filename}</filename>
<context>${currentContext} <context>${currentContext}
</context> </context>
</file>`; </file>`;
if (useFullContext) { if (useFullContext) {
return generateContext(`\n${contextItems}`); return generateContext(`\n${contextItems}`);
} }
contextItems = queryResult.data contextItems = queryResult.data
.map((item) => { .map((item) => {
const pageContent = item[0].page_content; const pageContent = item[0].page_content;
return ` return `
<contextItem> <contextItem>
<![CDATA[${pageContent?.trim()}]]> <![CDATA[${pageContent?.trim()}]]>
</contextItem>`; </contextItem>`;
}) })
.join(''); .join('');
return generateContext(contextItems); return generateContext(contextItems);
}) })
.join(''); .join('');
if (useFullContext) { if (useFullContext) {
const prompt = `${header} const prompt = `${header}

View File

@@ -237,9 +237,41 @@ const formatAgentMessages = (payload) => {
return messages; return messages;
}; };
/**
* Formats an array of messages for LangChain, making sure all content fields are strings
* @param {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} payload - The array of messages to format.
* @returns {Array<(HumanMessage|AIMessage|SystemMessage|ToolMessage)>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
*/
const formatContentStrings = (payload) => {
const messages = [];
for (const message of payload) {
if (typeof message.content === 'string') {
continue;
}
if (!Array.isArray(message.content)) {
continue;
}
// Reduce text types to a single string, ignore all other types
const content = message.content.reduce((acc, curr) => {
if (curr.type === ContentTypes.TEXT) {
return `${acc}${curr[ContentTypes.TEXT]}\n`;
}
return acc;
}, '');
message.content = content.trim();
}
return messages;
};
module.exports = { module.exports = {
formatMessage, formatMessage,
formatFromLangChain, formatFromLangChain,
formatAgentMessages, formatAgentMessages,
formatContentStrings,
formatLangChainMessages, formatLangChainMessages,
}; };

View File

@@ -15,7 +15,7 @@ describe('AnthropicClient', () => {
{ {
role: 'user', role: 'user',
isCreatedByUser: true, isCreatedByUser: true,
text: "What's up", text: 'What\'s up',
messageId: '3', messageId: '3',
parentMessageId: '2', parentMessageId: '2',
}, },
@@ -170,7 +170,7 @@ describe('AnthropicClient', () => {
client.options.modelLabel = 'Claude-2'; client.options.modelLabel = 'Claude-2';
const result = await client.buildMessages(messages, parentMessageId); const result = await client.buildMessages(messages, parentMessageId);
const { prompt } = result; const { prompt } = result;
expect(prompt).toContain("Human's name: John"); expect(prompt).toContain('Human\'s name: John');
expect(prompt).toContain('You are Claude-2'); expect(prompt).toContain('You are Claude-2');
}); });
}); });
@@ -244,64 +244,6 @@ describe('AnthropicClient', () => {
); );
}); });
describe('Claude 4 model headers', () => {
it('should add "prompt-caching" beta header for claude-sonnet-4 model', () => {
const client = new AnthropicClient('test-api-key');
const modelOptions = {
model: 'claude-sonnet-4-20250514',
};
client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeDefined();
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31',
);
});
it('should add "prompt-caching" beta header for claude-opus-4 model', () => {
const client = new AnthropicClient('test-api-key');
const modelOptions = {
model: 'claude-opus-4-20250514',
};
client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeDefined();
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31',
);
});
it('should add "prompt-caching" beta header for claude-4-sonnet model', () => {
const client = new AnthropicClient('test-api-key');
const modelOptions = {
model: 'claude-4-sonnet-20250514',
};
client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeDefined();
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31',
);
});
it('should add "prompt-caching" beta header for claude-4-opus model', () => {
const client = new AnthropicClient('test-api-key');
const modelOptions = {
model: 'claude-4-opus-20250514',
};
client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeDefined();
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
'prompt-caching-2024-07-31',
);
});
});
it('should not add beta header for claude-3-5-sonnet-latest model', () => { it('should not add beta header for claude-3-5-sonnet-latest model', () => {
const client = new AnthropicClient('test-api-key'); const client = new AnthropicClient('test-api-key');
const modelOptions = { const modelOptions = {
@@ -309,7 +251,7 @@ describe('AnthropicClient', () => {
}; };
client.setOptions({ modelOptions, promptCache: true }); client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions); const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeUndefined(); expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
}); });
it('should not add beta header for other models', () => { it('should not add beta header for other models', () => {
@@ -320,7 +262,7 @@ describe('AnthropicClient', () => {
}, },
}); });
const anthropicClient = client.getClient(); const anthropicClient = client.getClient();
expect(anthropicClient._options.defaultHeaders).toBeUndefined(); expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
}); });
}); });
@@ -514,34 +456,6 @@ describe('AnthropicClient', () => {
expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue);
}); });
it('should not cap maxOutputTokens for Claude 4 Sonnet models', () => {
const client = new AnthropicClient('test-api-key');
const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 10; // 40,960 tokens
client.setOptions({
modelOptions: {
model: 'claude-sonnet-4-20250514',
maxOutputTokens: highTokenValue,
},
});
expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue);
});
it('should not cap maxOutputTokens for Claude 4 Opus models', () => {
const client = new AnthropicClient('test-api-key');
const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 6; // 24,576 tokens (under 32K limit)
client.setOptions({
modelOptions: {
model: 'claude-opus-4-20250514',
maxOutputTokens: highTokenValue,
},
});
expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue);
});
it('should cap maxOutputTokens for Claude 3.5 Haiku models', () => { it('should cap maxOutputTokens for Claude 3.5 Haiku models', () => {
const client = new AnthropicClient('test-api-key'); const client = new AnthropicClient('test-api-key');
const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2;
@@ -815,223 +729,4 @@ describe('AnthropicClient', () => {
expect(capturedOptions).toHaveProperty('topK', 10); expect(capturedOptions).toHaveProperty('topK', 10);
expect(capturedOptions).toHaveProperty('topP', 0.9); expect(capturedOptions).toHaveProperty('topP', 0.9);
}); });
describe('isClaudeLatest', () => {
it('should set isClaudeLatest to true for claude-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-3-sonnet-20240229',
},
});
expect(client.isClaudeLatest).toBe(true);
});
it('should set isClaudeLatest to true for claude-3.5 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-3.5-sonnet-20240229',
},
});
expect(client.isClaudeLatest).toBe(true);
});
it('should set isClaudeLatest to true for claude-sonnet-4 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-sonnet-4-20240229',
},
});
expect(client.isClaudeLatest).toBe(true);
});
it('should set isClaudeLatest to true for claude-opus-4 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-opus-4-20240229',
},
});
expect(client.isClaudeLatest).toBe(true);
});
it('should set isClaudeLatest to true for claude-3.5-haiku models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-3.5-haiku-20240229',
},
});
expect(client.isClaudeLatest).toBe(true);
});
it('should set isClaudeLatest to false for claude-2 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-2',
},
});
expect(client.isClaudeLatest).toBe(false);
});
it('should set isClaudeLatest to false for claude-instant models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-instant',
},
});
expect(client.isClaudeLatest).toBe(false);
});
it('should set isClaudeLatest to false for claude-sonnet-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-sonnet-3-20240229',
},
});
expect(client.isClaudeLatest).toBe(false);
});
it('should set isClaudeLatest to false for claude-opus-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-opus-3-20240229',
},
});
expect(client.isClaudeLatest).toBe(false);
});
it('should set isClaudeLatest to false for claude-haiku-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-haiku-3-20240229',
},
});
expect(client.isClaudeLatest).toBe(false);
});
});
describe('configureReasoning', () => {
it('should enable thinking for claude-opus-4 and claude-sonnet-4 models', async () => {
const client = new AnthropicClient('test-api-key');
// Create a mock async generator function
async function* mockAsyncGenerator() {
yield { type: 'message_start', message: { usage: {} } };
yield { delta: { text: 'Test response' } };
yield { type: 'message_delta', usage: {} };
}
// Mock createResponse to return the async generator
jest.spyOn(client, 'createResponse').mockImplementation(() => {
return mockAsyncGenerator();
});
// Test claude-opus-4
client.setOptions({
modelOptions: {
model: 'claude-opus-4-20250514',
},
thinking: true,
thinkingBudget: 2000,
});
let capturedOptions = null;
jest.spyOn(client, 'getClient').mockImplementation((options) => {
capturedOptions = options;
return {};
});
const payload = [{ role: 'user', content: 'Test message' }];
await client.sendCompletion(payload, {});
expect(capturedOptions).toHaveProperty('thinking');
expect(capturedOptions.thinking).toEqual({
type: 'enabled',
budget_tokens: 2000,
});
// Test claude-sonnet-4
client.setOptions({
modelOptions: {
model: 'claude-sonnet-4-20250514',
},
thinking: true,
thinkingBudget: 2000,
});
await client.sendCompletion(payload, {});
expect(capturedOptions).toHaveProperty('thinking');
expect(capturedOptions.thinking).toEqual({
type: 'enabled',
budget_tokens: 2000,
});
});
});
});
describe('Claude Model Tests', () => {
it('should handle Claude 3 and 4 series models correctly', () => {
const client = new AnthropicClient('test-key');
// Claude 3 series models
const claude3Models = [
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
'claude-3-5-sonnet-20240620',
'claude-3-5-haiku-20240620',
'claude-3.5-sonnet-20240620',
'claude-3.5-haiku-20240620',
'claude-3.7-sonnet-20240620',
'claude-3.7-haiku-20240620',
'anthropic/claude-3-opus-20240229',
'claude-3-opus-20240229/anthropic',
];
// Claude 4 series models
const claude4Models = [
'claude-sonnet-4-20250514',
'claude-opus-4-20250514',
'claude-4-sonnet-20250514',
'claude-4-opus-20250514',
'anthropic/claude-sonnet-4-20250514',
'claude-sonnet-4-20250514/anthropic',
];
// Test Claude 3 series
claude3Models.forEach((model) => {
client.setOptions({ modelOptions: { model } });
expect(
/claude-[3-9]/.test(client.modelOptions.model) ||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(client.modelOptions.model),
).toBe(true);
});
// Test Claude 4 series
claude4Models.forEach((model) => {
client.setOptions({ modelOptions: { model } });
expect(
/claude-[3-9]/.test(client.modelOptions.model) ||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(client.modelOptions.model),
).toBe(true);
});
// Test non-Claude 3/4 models
const nonClaudeModels = ['claude-2', 'claude-instant', 'gpt-4', 'gpt-3.5-turbo'];
nonClaudeModels.forEach((model) => {
client.setOptions({ modelOptions: { model } });
expect(
/claude-[3-9]/.test(client.modelOptions.model) ||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(client.modelOptions.model),
).toBe(false);
});
});
}); });

View File

@@ -1,7 +1,7 @@
const { Constants } = require('librechat-data-provider'); const { Constants } = require('librechat-data-provider');
const { initializeFakeClient } = require('./FakeClient'); const { initializeFakeClient } = require('./FakeClient');
jest.mock('~/db/connect'); jest.mock('~/lib/db/connectDb');
jest.mock('~/models', () => ({ jest.mock('~/models', () => ({
User: jest.fn(), User: jest.fn(),
Key: jest.fn(), Key: jest.fn(),
@@ -33,9 +33,7 @@ jest.mock('~/models', () => ({
const { getConvo, saveConvo } = require('~/models'); const { getConvo, saveConvo } = require('~/models');
jest.mock('@librechat/agents', () => { jest.mock('@librechat/agents', () => {
const { Providers } = jest.requireActual('@librechat/agents');
return { return {
Providers,
ChatOpenAI: jest.fn().mockImplementation(() => { ChatOpenAI: jest.fn().mockImplementation(() => {
return {}; return {};
}), }),
@@ -54,7 +52,7 @@ const messageHistory = [
{ {
role: 'user', role: 'user',
isCreatedByUser: true, isCreatedByUser: true,
text: "What's up", text: 'What\'s up',
messageId: '3', messageId: '3',
parentMessageId: '2', parentMessageId: '2',
}, },
@@ -422,46 +420,6 @@ describe('BaseClient', () => {
expect(response).toEqual(expectedResult); expect(response).toEqual(expectedResult);
}); });
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
const mockCrypto = require('crypto');
const newUUID = 'new-uuid-1234';
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe(newUUID);
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
mockCrypto.randomUUID.mockRestore();
});
test('should not replace responseMessageId when isRegenerate is false', async () => {
const opts = {
isRegenerate: false,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id_');
});
test('should not replace responseMessageId when it does not end with underscore', async () => {
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id');
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => { test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation'; const userMessage = 'Second message in the conversation';
const opts = { const opts = {
@@ -498,7 +456,7 @@ describe('BaseClient', () => {
const chatMessages2 = await TestClient.loadHistory(conversationId, '3'); const chatMessages2 = await TestClient.loadHistory(conversationId, '3');
expect(TestClient.currentMessages).toHaveLength(3); expect(TestClient.currentMessages).toHaveLength(3);
expect(chatMessages2[chatMessages2.length - 1].text).toEqual("What's up"); expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up');
}); });
/* Most of the new sendMessage logic revolving around edited/continued AI messages /* Most of the new sendMessage logic revolving around edited/continued AI messages

View File

@@ -5,7 +5,7 @@ const getLogStores = require('~/cache/getLogStores');
const OpenAIClient = require('../OpenAIClient'); const OpenAIClient = require('../OpenAIClient');
jest.mock('meilisearch'); jest.mock('meilisearch');
jest.mock('~/db/connect'); jest.mock('~/lib/db/connectDb');
jest.mock('~/models', () => ({ jest.mock('~/models', () => ({
User: jest.fn(), User: jest.fn(),
Key: jest.fn(), Key: jest.fn(),
@@ -462,17 +462,17 @@ describe('OpenAIClient', () => {
role: 'system', role: 'system',
name: 'example_user', name: 'example_user',
content: content:
"Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.", 'Let\'s circle back when we have more bandwidth to touch base on opportunities for increased leverage.',
}, },
{ {
role: 'system', role: 'system',
name: 'example_assistant', name: 'example_assistant',
content: "Let's talk later when we're less busy about how to do better.", content: 'Let\'s talk later when we\'re less busy about how to do better.',
}, },
{ {
role: 'user', role: 'user',
content: content:
"This late pivot means we don't have time to boil the ocean for the client deliverable.", 'This late pivot means we don\'t have time to boil the ocean for the client deliverable.',
}, },
]; ];
@@ -531,6 +531,44 @@ describe('OpenAIClient', () => {
}); });
}); });
describe('sendMessage/getCompletion/chatCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
const model = 'text-davinci-003';
const onProgress = jest.fn().mockImplementation(() => ({}));
const testClient = new OpenAIClient('test-api-key', {
...defaultOptions,
modelOptions: { model },
});
const getCompletion = jest.spyOn(testClient, 'getCompletion');
await testClient.sendMessage('Hi mom!', { onProgress });
expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);
// Check if the first argument (url) is correct
const firstCallArgs = fetchEventSource.mock.calls[0];
const expectedURL = 'https://api.openai.com/v1/completions';
expect(firstCallArgs[0]).toBe(expectedURL);
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).toHaveProperty('model');
expect(requestBody.model).toBe(model);
});
});
describe('checkVisionRequest functionality', () => { describe('checkVisionRequest functionality', () => {
let client; let client;
const attachments = [{ type: 'image/png' }]; const attachments = [{ type: 'image/png' }];

View File

@@ -0,0 +1,314 @@
const crypto = require('crypto');
const { Constants } = require('librechat-data-provider');
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
const PluginsClient = require('../PluginsClient');
jest.mock('~/lib/db/connectDb');
jest.mock('~/models/Conversation', () => {
return function () {
return {
save: jest.fn(),
deleteConvos: jest.fn(),
};
};
});
const defaultAzureOptions = {
azureOpenAIApiInstanceName: 'your-instance-name',
azureOpenAIApiDeploymentName: 'your-deployment-name',
azureOpenAIApiVersion: '2020-07-01-preview',
};
describe('PluginsClient', () => {
let TestAgent;
let options = {
tools: [],
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
};
let parentMessageId;
let conversationId;
const fakeMessages = [];
const userMessage = 'Hello, ChatGPT!';
const apiKey = 'fake-api-key';
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, options);
TestAgent.loadHistory = jest
.fn()
.mockImplementation((conversationId, parentMessageId = null) => {
if (!conversationId) {
TestAgent.currentMessages = [];
return Promise.resolve([]);
}
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
messages: fakeMessages,
parentMessageId,
});
const chatMessages = orderedMessages.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanMessage(msg.text)
: new AIMessage(msg.text),
);
TestAgent.currentMessages = orderedMessages;
return Promise.resolve(chatMessages);
});
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
if (opts && typeof opts === 'object') {
TestAgent.setOptions(opts);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
this.pastMessages = await TestAgent.loadHistory(
conversationId,
TestAgent.options?.parentMessageId,
);
const userMessage = {
text: message,
sender: 'ChatGPT',
isCreatedByUser: true,
messageId: userMessageId,
parentMessageId,
conversationId,
};
const response = {
sender: 'ChatGPT',
text: 'Hello, User!',
isCreatedByUser: false,
messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId,
conversationId,
};
fakeMessages.push(userMessage);
fakeMessages.push(response);
return response;
});
});
test('initializes PluginsClient without crashing', () => {
expect(TestAgent).toBeInstanceOf(PluginsClient);
});
test('check setOptions function', () => {
expect(TestAgent.agentIsGpt3).toBe(true);
});
describe('sendMessage', () => {
test('sendMessage should return a response message', async () => {
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: expect.any(String),
});
const response = await TestAgent.sendMessage(userMessage);
parentMessageId = response.messageId;
conversationId = response.conversationId;
expect(response).toEqual(expectedResult);
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {
conversationId,
parentMessageId,
};
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: opts.conversationId,
});
const response = await TestAgent.sendMessage(userMessage, opts);
parentMessageId = response.messageId;
expect(response.conversationId).toEqual(conversationId);
expect(response).toEqual(expectedResult);
});
test('should return chat history', async () => {
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
expect(TestAgent.currentMessages).toHaveLength(4);
expect(chatMessages[0].text).toEqual(userMessage);
});
});
describe('getFunctionModelName', () => {
let client;
beforeEach(() => {
client = new PluginsClient('dummy_api_key');
});
test('should return the input when it includes a dash followed by four digits', () => {
expect(client.getFunctionModelName('-1234')).toBe('-1234');
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
});
test('should return the input for all function-capable models (`0613` models and above)', () => {
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
});
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-4" when the input includes "gpt-4"', () => {
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
});
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
});
});
describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins
// let client;
// beforeEach(() => {
// client = new PluginsClient('dummy_api_key');
// });
test('should not call getFunctionModelName when azure options are set', () => {
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
const model = 'gpt-4-turbo';
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
const testClient = new PluginsClient('dummy_api_key', {
agentOptions: {
model,
agent: 'functions',
},
azure: defaultAzureOptions,
});
expect(spy).not.toHaveBeenCalled();
expect(testClient.agentOptions.model).toBe(model);
spy.mockRestore();
});
});
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
});

View File

@@ -0,0 +1,184 @@
require('dotenv').config();
const fs = require('fs');
const { z } = require('zod');
const path = require('path');
const yaml = require('js-yaml');
const { createOpenAPIChain } = require('langchain/chains');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { ChatPromptTemplate, HumanMessagePromptTemplate } = require('@langchain/core/prompts');
const { logger } = require('~/config');
function addLinePrefix(text, prefix = '// ') {
return text
.split('\n')
.map((line) => prefix + line)
.join('\n');
}
function createPrompt(name, functions) {
const prefix = `// The ${name} tool has the following functions. Determine the desired or most optimal function for the user's query:`;
const functionDescriptions = functions
.map((func) => `// - ${func.name}: ${func.description}`)
.join('\n');
return `${prefix}\n${functionDescriptions}
// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function.
// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`;
}
const AuthBearer = z
.object({
type: z.string().includes('service_http'),
authorization_type: z.string().includes('bearer'),
verification_tokens: z.object({
openai: z.string(),
}),
})
.catch(() => false);
const AuthDefinition = z
.object({
type: z.string(),
authorization_type: z.string(),
verification_tokens: z.object({
openai: z.string(),
}),
})
.catch(() => false);
async function readSpecFile(filePath) {
try {
const fileContents = await fs.promises.readFile(filePath, 'utf8');
if (path.extname(filePath) === '.json') {
return JSON.parse(fileContents);
}
return yaml.load(fileContents);
} catch (e) {
logger.error('[readSpecFile] error', e);
return false;
}
}
async function getSpec(url) {
const RegularUrl = z
.string()
.url()
.catch(() => false);
if (RegularUrl.parse(url) && path.extname(url) === '.json') {
const response = await fetch(url);
return await response.json();
}
const ValidSpecPath = z
.string()
.url()
.catch(async () => {
const spec = path.join(__dirname, '..', '.well-known', 'openapi', url);
if (!fs.existsSync(spec)) {
return false;
}
return await readSpecFile(spec);
});
return ValidSpecPath.parse(url);
}
async function createOpenAPIPlugin({ data, llm, user, message, memory, signal }) {
let spec;
try {
spec = await getSpec(data.api.url);
} catch (error) {
logger.error('[createOpenAPIPlugin] getSpec error', error);
return null;
}
if (!spec) {
logger.warn('[createOpenAPIPlugin] No spec found');
return null;
}
const headers = {};
const { auth, name_for_model, description_for_model, description_for_human } = data;
if (auth && AuthDefinition.parse(auth)) {
logger.debug('[createOpenAPIPlugin] auth detected', auth);
const { openai } = auth.verification_tokens;
if (AuthBearer.parse(auth)) {
headers.authorization = `Bearer ${openai}`;
logger.debug('[createOpenAPIPlugin] added auth bearer', headers);
}
}
const chainOptions = { llm };
if (data.headers && data.headers['librechat_user_id']) {
logger.debug('[createOpenAPIPlugin] id detected', headers);
headers[data.headers['librechat_user_id']] = user;
}
if (Object.keys(headers).length > 0) {
logger.debug('[createOpenAPIPlugin] headers detected', headers);
chainOptions.headers = headers;
}
if (data.params) {
logger.debug('[createOpenAPIPlugin] params detected', data.params);
chainOptions.params = data.params;
}
let history = '';
if (memory) {
logger.debug('[createOpenAPIPlugin] openAPI chain: memory detected', memory);
const { history: chat_history } = await memory.loadMemoryVariables({});
history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : '';
}
chainOptions.prompt = ChatPromptTemplate.fromMessages([
HumanMessagePromptTemplate.fromTemplate(
`# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix(
description_for_model,
)}${history}`,
),
]);
const chain = await createOpenAPIChain(spec, chainOptions);
const { functions } = chain.chains[0].lc_kwargs.llmKwargs;
return new DynamicStructuredTool({
name: name_for_model,
description_for_model: `${addLinePrefix(description_for_human)}${createPrompt(
name_for_model,
functions,
)}`,
description: `${description_for_human}`,
schema: z.object({
func: z
.string()
.describe(
`The function to invoke. The functions available are: ${functions
.map((func) => func.name)
.join(', ')}`,
),
intent: z
.string()
.describe('Describe your intent with the function and your expected result'),
}),
func: async ({ func = '', intent = '' }) => {
const filteredFunctions = functions.filter((f) => f.name === func);
chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions;
const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`;
const result = await chain.call({
query,
signal,
});
return result.response;
},
});
}
module.exports = {
getSpec,
readSpecFile,
createOpenAPIPlugin,
};

View File

@@ -0,0 +1,72 @@
const fs = require('fs');
const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin');
global.fetch = jest.fn().mockImplementationOnce(() => {
return new Promise((resolve) => {
resolve({
ok: true,
json: () => Promise.resolve({ key: 'value' }),
});
});
});
jest.mock('fs', () => ({
promises: {
readFile: jest.fn(),
},
existsSync: jest.fn(),
}));
describe('readSpecFile', () => {
it('reads JSON file correctly', async () => {
fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' }));
const result = await readSpecFile('test.json');
expect(result).toEqual({ test: 'value' });
});
it('reads YAML file correctly', async () => {
fs.promises.readFile.mockResolvedValue('test: value');
const result = await readSpecFile('test.yaml');
expect(result).toEqual({ test: 'value' });
});
it('handles error correctly', async () => {
fs.promises.readFile.mockRejectedValue(new Error('test error'));
const result = await readSpecFile('test.json');
expect(result).toBe(false);
});
});
describe('getSpec', () => {
it('fetches spec from url correctly', async () => {
const parsedJson = await getSpec('https://www.instacart.com/.well-known/ai-plugin.json');
const isObject = typeof parsedJson === 'object';
expect(isObject).toEqual(true);
});
it('reads spec from file correctly', async () => {
fs.existsSync.mockReturnValue(true);
fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' }));
const result = await getSpec('test.json');
expect(result).toEqual({ test: 'value' });
});
it('returns false when file does not exist', async () => {
fs.existsSync.mockReturnValue(false);
const result = await getSpec('test.json');
expect(result).toBe(false);
});
});
describe('createOpenAPIPlugin', () => {
it('returns null when getSpec throws an error', async () => {
const result = await createOpenAPIPlugin({ data: { api: { url: 'invalid' } } });
expect(result).toBe(null);
});
it('returns null when no spec is found', async () => {
const result = await createOpenAPIPlugin({});
expect(result).toBe(null);
});
// Add more tests here for different scenarios
});

View File

@@ -8,10 +8,10 @@ const { HttpsProxyAgent } = require('https-proxy-agent');
const { FileContext, ContentTypes } = require('librechat-data-provider'); const { FileContext, ContentTypes } = require('librechat-data-provider');
const { getImageBasename } = require('~/server/services/Files/images'); const { getImageBasename } = require('~/server/services/Files/images');
const extractBaseURL = require('~/utils/extractBaseURL'); const extractBaseURL = require('~/utils/extractBaseURL');
const logger = require('~/config/winston'); const { logger } = require('~/config');
const displayMessage = const displayMessage =
"DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.';
class DALLE3 extends Tool { class DALLE3 extends Tool {
constructor(fields = {}) { constructor(fields = {}) {
super(); super();

View File

@@ -4,13 +4,12 @@ const { v4 } = require('uuid');
const OpenAI = require('openai'); const OpenAI = require('openai');
const FormData = require('form-data'); const FormData = require('form-data');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { HttpsProxyAgent } = require('https-proxy-agent'); const { HttpsProxyAgent } = require('https-proxy-agent');
const { ContentTypes, EImageOutputType } = require('librechat-data-provider'); const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { extractBaseURL } = require('~/utils'); const { logAxiosError, extractBaseURL } = require('~/utils');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/** Default descriptions for image generation tool */ /** Default descriptions for image generation tool */
const DEFAULT_IMAGE_GEN_DESCRIPTION = ` const DEFAULT_IMAGE_GEN_DESCRIPTION = `
@@ -31,7 +30,7 @@ const DEFAULT_IMAGE_EDIT_DESCRIPTION =
When to use \`image_edit_oai\`: When to use \`image_edit_oai\`:
- The user wants to modify, extend, or remix one **or more** uploaded images, either: - The user wants to modify, extend, or remix one **or more** uploaded images, either:
- Previously generated, or in the current request (both to be included in the \`image_ids\` array). - Previously generated, or in the current request (both to be included in the \`image_ids\` array).
- Always when the user refers to uploaded images for editing, enhancement, remixing, style transfer, or combining elements. - Always when the user refers to uploaded images for editing, enhancement, remixing, style transfer, or combining elements.
- Any current or existing images are to be used as visual guides. - Any current or existing images are to be used as visual guides.
- If there are any files in the current request, they are more likely than not expected as references for image edit requests. - If there are any files in the current request, they are more likely than not expected as references for image edit requests.
@@ -65,7 +64,7 @@ const DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION = `Describe the changes, enhancement
Always base this prompt on the most recently uploaded reference images.`; Always base this prompt on the most recently uploaded reference images.`;
const displayMessage = const displayMessage =
"The tool displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; 'The tool displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.';
/** /**
* Replaces unwanted characters from the input string * Replaces unwanted characters from the input string
@@ -107,12 +106,6 @@ const getImageEditPromptDescription = () => {
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION; return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
}; };
function createAbortHandler() {
return function () {
logger.debug('[ImageGenOAI] Image generation aborted');
};
}
/** /**
* Creates OpenAI Image tools (generation and editing) * Creates OpenAI Image tools (generation and editing)
* @param {Object} fields - Configuration fields * @param {Object} fields - Configuration fields
@@ -207,18 +200,10 @@ function createOpenAIImageTools(fields = {}) {
} }
let resp; let resp;
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try { try {
if (runnableConfig?.signal) { const derivedSignal = runnableConfig?.signal
derivedSignal = AbortSignal.any([runnableConfig.signal]); ? AbortSignal.any([runnableConfig.signal])
abortHandler = createAbortHandler(); : undefined;
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
resp = await openai.images.generate( resp = await openai.images.generate(
{ {
model: 'gpt-image-1', model: 'gpt-image-1',
@@ -242,10 +227,6 @@ function createOpenAIImageTools(fields = {}) {
logAxiosError({ error, message }); logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable: return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
Error Message: ${error.message}`); Error Message: ${error.message}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
} }
if (!resp) { if (!resp) {
@@ -427,17 +408,10 @@ Error Message: ${error.message}`);
headers['Authorization'] = `Bearer ${apiKey}`; headers['Authorization'] = `Bearer ${apiKey}`;
} }
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try { try {
if (runnableConfig?.signal) { const derivedSignal = runnableConfig?.signal
derivedSignal = AbortSignal.any([runnableConfig.signal]); ? AbortSignal.any([runnableConfig.signal])
abortHandler = createAbortHandler(); : undefined;
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
/** @type {import('axios').AxiosRequestConfig} */ /** @type {import('axios').AxiosRequestConfig} */
const axiosConfig = { const axiosConfig = {
@@ -492,10 +466,6 @@ Error Message: ${error.message}`);
logAxiosError({ error, message }); logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable: return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
Error Message: ${error.message || 'Unknown error'}`); Error Message: ${error.message || 'Unknown error'}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
} }
}, },
{ {

View File

@@ -1,29 +1,10 @@
const OpenAI = require('openai'); const OpenAI = require('openai');
const DALLE3 = require('../DALLE3'); const DALLE3 = require('../DALLE3');
const logger = require('~/config/winston');
const { logger } = require('~/config');
jest.mock('openai'); jest.mock('openai');
jest.mock('@librechat/data-schemas', () => {
return {
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
};
});
jest.mock('tiktoken', () => {
return {
encoding_for_model: jest.fn().mockReturnValue({
encode: jest.fn(),
decode: jest.fn(),
}),
};
});
const processFileURL = jest.fn(); const processFileURL = jest.fn();
jest.mock('~/server/services/Files/images', () => ({ jest.mock('~/server/services/Files/images', () => ({
@@ -56,11 +37,6 @@ jest.mock('fs', () => {
return { return {
existsSync: jest.fn(), existsSync: jest.fn(),
mkdirSync: jest.fn(), mkdirSync: jest.fn(),
promises: {
writeFile: jest.fn(),
readFile: jest.fn(),
unlink: jest.fn(),
},
}; };
}); });

View File

@@ -1,35 +1,26 @@
const { z } = require('zod'); const { z } = require('zod');
const axios = require('axios'); const axios = require('axios');
const { tool } = require('@langchain/core/tools'); const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Tools, EToolResources } = require('librechat-data-provider'); const { Tools, EToolResources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/** /**
* *
* @param {Object} options * @param {Object} options
* @param {ServerRequest} options.req * @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources * @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @returns {Promise<{ * @returns {Promise<{
* files: Array<{ file_id: string; filename: string }>, * files: Array<{ file_id: string; filename: string }>,
* toolContext: string * toolContext: string
* }>} * }>}
*/ */
const primeFiles = async (options) => { const primeFiles = async (options) => {
const { tool_resources, req, agentId } = options; const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? []; const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids); const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? []; const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
const dbFiles = ( const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`; let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
@@ -68,7 +59,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
if (files.length === 0) { if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.'; return 'No files to search. Instruct the user to add files for the search.';
} }
const jwtToken = generateShortLivedToken(req.user.id); const jwtToken = req.headers.authorization.split(' ')[1];
if (!jwtToken) { if (!jwtToken) {
return 'There was an error authenticating the file search request.'; return 'There was an error authenticating the file search request.';
} }
@@ -144,7 +135,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
query: z query: z
.string() .string()
.describe( .describe(
"A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.", 'A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you\'re looking for. The query will be used for semantic similarity matching against the file contents.',
), ),
}), }),
}, },

View File

@@ -1,9 +1,8 @@
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator'); const { Calculator } = require('@langchain/community/tools/calculator');
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api'); const { createCodeExecutionTool, EnvVar } = require('@librechat/agents');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); const { Tools, Constants, EToolResources } = require('librechat-data-provider');
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider'); const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { const {
availableTools, availableTools,
manifestToolMap, manifestToolMap,
@@ -23,10 +22,11 @@ const {
} = require('../'); } = require('../');
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getCachedTools } = require('~/server/services/Config');
const { createMCPTool } = require('~/server/services/MCP'); const { createMCPTool } = require('~/server/services/MCP');
const { logger } = require('~/config');
const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
/** /**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
@@ -87,7 +87,7 @@ const validateTools = async (user, tools = []) => {
return Array.from(validToolsSet.values()); return Array.from(validToolsSet.values());
} catch (err) { } catch (err) {
logger.error('[validateTools] There was a problem validating tools', err); logger.error('[validateTools] There was a problem validating tools', err);
throw new Error(err); throw new Error('There was a problem validating tools');
} }
}; };
@@ -138,6 +138,7 @@ const loadTools = async ({
agent, agent,
model, model,
endpoint, endpoint,
useSpecs,
tools = [], tools = [],
options = {}, options = {},
functions = true, functions = true,
@@ -230,7 +231,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */ /** @type {Record<string, string>} */
const toolContextMap = {}; const toolContextMap = {};
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {}; const appTools = options.req?.app?.locals?.availableTools ?? {};
for (const tool of tools) { for (const tool of tools) {
if (tool === Tools.execute_code) { if (tool === Tools.execute_code) {
@@ -240,13 +241,7 @@ const loadTools = async ({
authFields: [EnvVar.CODE_API_KEY], authFields: [EnvVar.CODE_API_KEY],
}); });
const codeApiKey = authValues[EnvVar.CODE_API_KEY]; const codeApiKey = authValues[EnvVar.CODE_API_KEY];
const { files, toolContext } = await primeCodeFiles( const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
{
...options,
agentId: agent?.id,
},
codeApiKey,
);
if (toolContext) { if (toolContext) {
toolContextMap[tool] = toolContext; toolContextMap[tool] = toolContext;
} }
@@ -261,48 +256,17 @@ const loadTools = async ({
continue; continue;
} else if (tool === Tools.file_search) { } else if (tool === Tools.file_search) {
requestedTools[tool] = async () => { requestedTools[tool] = async () => {
const { files, toolContext } = await primeSearchFiles({ const { files, toolContext } = await primeSearchFiles(options);
...options,
agentId: agent?.id,
});
if (toolContext) { if (toolContext) {
toolContextMap[tool] = toolContext; toolContextMap[tool] = toolContext;
} }
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id }); return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
}; };
continue; continue;
} else if (tool === Tools.web_search) { } else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
const webSearchConfig = options?.req?.app?.locals?.webSearch;
const result = await loadWebSearchAuth({
userId: user,
loadAuthValues,
webSearchConfig,
});
const { onSearchResults, onGetHighlights } = options?.[Tools.web_search] ?? {};
requestedTools[tool] = async () => {
toolContextMap[tool] = `# \`${tool}\`:
Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
1. **Execute immediately without preface** when using \`${tool}\`.
2. **After the search, begin with a brief summary** that directly addresses the query without headers or explaining your process.
3. **Structure your response clearly** using Markdown formatting (Level 2 headers for sections, lists for multiple points, tables for comparisons).
4. **Cite sources properly** according to the citation anchor format, utilizing group anchors when appropriate.
5. **Tailor your approach to the query type** (academic, news, coding, etc.) while maintaining an expert, journalistic, unbiased tone.
6. **Provide comprehensive information** with specific details, examples, and as much relevant context as possible from search results.
7. **Avoid moralizing language.**
`.trim();
return createSearchTool({
...result.authResult,
onSearchResults,
onGetHighlights,
logger,
});
};
continue;
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () => requestedTools[tool] = async () =>
createMCPTool({ createMCPTool({
req: options.req, req: options.req,
res: options.res,
toolKey: tool, toolKey: tool,
model: agent?.model ?? model, model: agent?.model ?? model,
provider: agent?.provider ?? endpoint, provider: agent?.provider ?? endpoint,

View File

@@ -1,5 +1,8 @@
const mongoose = require('mongoose'); const mockUser = {
const { MongoMemoryServer } = require('mongodb-memory-server'); _id: 'fakeId',
save: jest.fn(),
findByIdAndDelete: jest.fn(),
};
const mockPluginService = { const mockPluginService = {
updateUserPluginAuth: jest.fn(), updateUserPluginAuth: jest.fn(),
@@ -7,18 +10,23 @@ const mockPluginService = {
getUserPluginAuthValue: jest.fn(), getUserPluginAuthValue: jest.fn(),
}; };
jest.mock('~/models/User', () => {
return function () {
return mockUser;
};
});
jest.mock('~/server/services/PluginService', () => mockPluginService); jest.mock('~/server/services/PluginService', () => mockPluginService);
const { BaseLLM } = require('@langchain/openai'); const { BaseLLM } = require('@langchain/openai');
const { Calculator } = require('@langchain/community/tools/calculator'); const { Calculator } = require('@langchain/community/tools/calculator');
const { User } = require('~/db/models'); const User = require('~/models/User');
const PluginService = require('~/server/services/PluginService'); const PluginService = require('~/server/services/PluginService');
const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools'); const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools');
const { StructuredSD, availableTools, DALLE3 } = require('../'); const { StructuredSD, availableTools, DALLE3 } = require('../');
describe('Tool Handlers', () => { describe('Tool Handlers', () => {
let mongoServer;
let fakeUser; let fakeUser;
const pluginKey = 'dalle'; const pluginKey = 'dalle';
const pluginKey2 = 'wolfram'; const pluginKey2 = 'wolfram';
@@ -29,9 +37,7 @@ describe('Tool Handlers', () => {
const authConfigs = mainPlugin.authConfig; const authConfigs = mainPlugin.authConfig;
beforeAll(async () => { beforeAll(async () => {
mongoServer = await MongoMemoryServer.create(); mockUser.save.mockResolvedValue(undefined);
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
const userAuthValues = {}; const userAuthValues = {};
mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => { mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => {
@@ -72,36 +78,9 @@ describe('Tool Handlers', () => {
}); });
afterAll(async () => { afterAll(async () => {
await mongoose.disconnect(); await mockUser.findByIdAndDelete(fakeUser._id);
await mongoServer.stop();
});
beforeEach(async () => {
// Clear mocks but not the database since we need the user to persist
jest.clearAllMocks();
// Reset the mock implementations
const userAuthValues = {};
mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => {
return userAuthValues[`${userId}-${authField}`];
});
mockPluginService.updateUserPluginAuth.mockImplementation(
(userId, authField, _pluginKey, credential) => {
const fields = authField.split('||');
fields.forEach((field) => {
userAuthValues[`${userId}-${field}`] = credential;
});
},
);
// Re-add the auth configs for the user
for (const authConfig of authConfigs) { for (const authConfig of authConfigs) {
await PluginService.updateUserPluginAuth( await PluginService.deleteUserPluginAuth(fakeUser._id, authConfig.authField);
fakeUser._id,
authConfig.authField,
pluginKey,
mockCredential,
);
} }
}); });
@@ -239,6 +218,7 @@ describe('Tool Handlers', () => {
try { try {
await loadTool2(); await loadTool2();
} catch (error) { } catch (error) {
// eslint-disable-next-line jest/no-conditional-expect
expect(error).toBeDefined(); expect(error).toBeDefined();
} }
}); });

View File

@@ -1,9 +1,8 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, math } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, math, removePorts } = require('~/server/utils');
const { deleteAllUserSessions } = require('~/models'); const { deleteAllUserSessions } = require('~/models');
const { removePorts } = require('~/server/utils');
const getLogStores = require('./getLogStores'); const getLogStores = require('./getLogStores');
const { logger } = require('~/config');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
const interval = math(BAN_INTERVAL, 20); const interval = math(BAN_INTERVAL, 20);
@@ -33,6 +32,7 @@ const banViolation = async (req, res, errorMessage) => {
if (!isEnabled(BAN_VIOLATIONS)) { if (!isEnabled(BAN_VIOLATIONS)) {
return; return;
} }
if (!errorMessage) { if (!errorMessage) {
return; return;
} }
@@ -51,6 +51,7 @@ const banViolation = async (req, res, errorMessage) => {
const banLogs = getLogStores(ViolationTypes.BAN); const banLogs = getLogStores(ViolationTypes.BAN);
const duration = errorMessage.duration || banLogs.opts.ttl; const duration = errorMessage.duration || banLogs.opts.ttl;
if (duration <= 0) { if (duration <= 0) {
return; return;
} }

View File

@@ -1,28 +1,48 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const banViolation = require('./banViolation'); const banViolation = require('./banViolation');
// Mock deleteAllUserSessions since we're testing ban logic, not session deletion jest.mock('keyv');
jest.mock('~/models', () => ({ jest.mock('../models/Session');
...jest.requireActual('~/models'), // Mocking the getLogStores function
deleteAllUserSessions: jest.fn().mockResolvedValue(true), jest.mock('./getLogStores', () => {
})); return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const { CacheKeys } = require('librechat-data-provider');
const math = require('../server/utils/math');
const mockGet = jest.fn();
const mockSet = jest.fn();
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = mockGet;
set = mockSet;
}
return new KeyvMongo('', {
namespace: CacheKeys.BANS,
ttl: math(process.env.BAN_DURATION, 7200000),
});
});
});
describe('banViolation', () => { describe('banViolation', () => {
let mongoServer;
let req, res, errorMessage; let req, res, errorMessage;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(() => { beforeEach(() => {
req = { req = {
ip: '127.0.0.1', ip: '127.0.0.1',
@@ -35,7 +55,7 @@ describe('banViolation', () => {
}; };
errorMessage = { errorMessage = {
type: 'someViolation', type: 'someViolation',
user_id: new mongoose.Types.ObjectId().toString(), // Use valid ObjectId user_id: '12345',
prev_count: 0, prev_count: 0,
violation_count: 0, violation_count: 0,
}; };

View File

@@ -1,33 +0,0 @@
const fs = require('fs');
const { math, isEnabled } = require('@librechat/api');
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
}
const USE_REDIS = isEnabled(process.env.USE_REDIS);
if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
const cacheConfig = {
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
CI: isEnabled(process.env.CI),
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
};
module.exports = { cacheConfig };

View File

@@ -1,108 +0,0 @@
const fs = require('fs');
describe('cacheConfig', () => {
let originalEnv;
let originalReadFileSync;
beforeEach(() => {
originalEnv = { ...process.env };
originalReadFileSync = fs.readFileSync;
// Clear all related env vars first
delete process.env.REDIS_URI;
delete process.env.REDIS_CA;
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
// Clear require cache
jest.resetModules();
});
afterEach(() => {
process.env = originalEnv;
fs.readFileSync = originalReadFileSync;
jest.resetModules();
});
describe('REDIS_KEY_PREFIX validation and resolution', () => {
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
expect(() => {
require('./cacheConfig');
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
});
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.DEPLOYMENT_ID = 'test-deployment-123';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
});
test('should use direct REDIS_KEY_PREFIX value', () => {
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
});
test('should default to empty string when no prefix is configured', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle empty variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
process.env.EMPTY_VAR = '';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle undefined variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
});
describe('USE_REDIS and REDIS_URI validation', () => {
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
process.env.USE_REDIS = 'true';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
expect(() => {
require('./cacheConfig');
}).not.toThrow();
});
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = '';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_CA).toBeNull();
});
});
});

View File

@@ -1,66 +0,0 @@
const KeyvRedis = require('@keyv/redis').default;
const { Keyv } = require('keyv');
const { cacheConfig } = require('./cacheConfig');
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
const { Time } = require('librechat-data-provider');
const { RedisStore: ConnectRedis } = require('connect-redis');
const MemoryStore = require('memorystore')(require('express-session'));
const { violationFile } = require('./keyvFiles');
const { RedisStore } = require('rate-limit-redis');
/**
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
* @param {string} namespace - The cache namespace.
* @param {number} [ttl] - Time to live for cache entries.
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
* @returns {Keyv} Cache instance.
*/
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
if (cacheConfig.USE_REDIS) {
const keyvRedis = new KeyvRedis(keyvRedisClient);
const cache = new Keyv(keyvRedis, { namespace, ttl });
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
return cache;
}
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
return new Keyv({ namespace, ttl });
};
/**
* Creates a cache instance for storing violation data.
* Uses a file-based fallback store if Redis is not enabled.
* @param {string} namespace - The cache namespace for violations.
* @param {number} [ttl] - Time to live for cache entries.
* @returns {Keyv} Cache instance for violations.
*/
const violationCache = (namespace, ttl = undefined) => {
return standardCache(`violations:${namespace}`, ttl, violationFile);
};
/**
* Creates a session cache instance using Redis or in-memory store.
* @param {string} namespace - The session namespace.
* @param {number} [ttl] - Time to live for session entries.
* @returns {MemoryStore | ConnectRedis} Session store instance.
*/
const sessionCache = (namespace, ttl = undefined) => {
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
return new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
};
/**
* Creates a rate limiter cache using Redis.
* @param {string} prefix - The key prefix for rate limiting.
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
*/
const limiterCache = (prefix) => {
if (!prefix) throw new Error('prefix is required');
if (!cacheConfig.USE_REDIS) return undefined;
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
return new RedisStore({ sendCommand, prefix });
};
const sendCommand = (...args) => ioredisClient?.call(...args);
module.exports = { standardCache, sessionCache, violationCache, limiterCache };

View File

@@ -1,270 +0,0 @@
const { Time } = require('librechat-data-provider');
// Mock dependencies first
const mockKeyvRedis = {
namespace: '',
keyPrefixSeparator: '',
};
const mockKeyv = jest.fn().mockReturnValue({ mock: 'keyv' });
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
const mockIoredisClient = {
call: jest.fn(),
};
const mockKeyvRedisClient = {};
const mockViolationFile = {};
// Mock modules before requiring the main module
jest.mock('@keyv/redis', () => ({
default: jest.fn().mockImplementation(() => mockKeyvRedis),
}));
jest.mock('keyv', () => ({
Keyv: mockKeyv,
}));
jest.mock('./cacheConfig', () => ({
cacheConfig: {
USE_REDIS: false,
REDIS_KEY_PREFIX: 'test',
},
}));
jest.mock('./redisClients', () => ({
keyvRedisClient: mockKeyvRedisClient,
ioredisClient: mockIoredisClient,
GLOBAL_PREFIX_SEPARATOR: '::',
}));
jest.mock('./keyvFiles', () => ({
violationFile: mockViolationFile,
}));
jest.mock('connect-redis', () => ({ RedisStore: mockConnectRedis }));
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
jest.mock('rate-limit-redis', () => ({
RedisStore: mockRedisStore,
}));
// Import after mocking
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
const { cacheConfig } = require('./cacheConfig');
describe('cacheFactory', () => {
beforeEach(() => {
jest.clearAllMocks();
// Reset cache config mock
cacheConfig.USE_REDIS = false;
cacheConfig.REDIS_KEY_PREFIX = 'test';
});
describe('redisCache', () => {
it('should create Redis cache when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
});
it('should create Redis cache with undefined ttl when not provided', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
standardCache(namespace);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
});
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
const fallbackStore = { some: 'store' };
standardCache(namespace, ttl, fallbackStore);
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
});
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
});
it('should handle namespace and ttl as undefined', () => {
cacheConfig.USE_REDIS = false;
standardCache();
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
});
});
describe('violationCache', () => {
it('should create violation cache with prefixed namespace', () => {
const namespace = 'test-violations';
const ttl = 7200;
// We can't easily mock the internal redisCache call since it's in the same module
// But we can test that the function executes without throwing
expect(() => violationCache(namespace, ttl)).not.toThrow();
});
it('should create violation cache with undefined ttl', () => {
const namespace = 'test-violations';
violationCache(namespace);
// The function should call redisCache with violations: prefixed namespace
// Since we can't easily mock the internal redisCache call, we test the behavior
expect(() => violationCache(namespace)).not.toThrow();
});
it('should handle undefined namespace', () => {
expect(() => violationCache(undefined)).not.toThrow();
});
});
describe('sessionCache', () => {
it('should return MemoryStore when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
expect(result).toBe(mockMemoryStore());
});
it('should return ConnectRedis when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl,
prefix: `${namespace}:`,
});
expect(result).toBe(mockConnectRedis());
});
it('should add colon to namespace if not present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should not add colon to namespace if already present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions:';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should handle undefined ttl', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockMemoryStore).toHaveBeenCalledWith({
ttl: undefined,
checkPeriod: Time.ONE_DAY,
});
});
});
describe('limiterCache', () => {
it('should return undefined when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const result = limiterCache('prefix');
expect(result).toBeUndefined();
});
it('should return RedisStore when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const result = limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: `rate-limit:`,
});
expect(result).toBe(mockRedisStore());
});
it('should add colon to prefix if not present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should not add colon to prefix if already present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit:');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should pass sendCommand function that calls ioredisClient.call', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
const sendCommandCall = mockRedisStore.mock.calls[0][0];
const sendCommand = sendCommandCall.sendCommand;
// Test that sendCommand properly delegates to ioredisClient.call
const args = ['GET', 'test-key'];
sendCommand(...args);
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
});
it('should handle undefined prefix', () => {
cacheConfig.USE_REDIS = true;
expect(() => limiterCache()).toThrow('prefix is required');
});
});
});

View File

@@ -1,52 +1,103 @@
const { cacheConfig } = require('./cacheConfig');
const { Keyv } = require('keyv'); const { Keyv } = require('keyv');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile } = require('./keyvFiles'); const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo'); const keyvMongo = require('./keyvMongo');
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const isRedisEnabled = isEnabled(USE_REDIS);
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
const createViolationInstance = (namespace) => {
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};
// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
const config = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const roles = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const audioRuns = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const messages = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
const flows = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
const tokenConfig = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const s3ExpiryInterval = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
const abortKeys = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
const namespaces = { const namespaces = {
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }), [CacheKeys.ROLES]: roles,
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS), [CacheKeys.CONFIG_STORE]: config,
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT), [CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER), [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT), [CacheKeys.ENCODED_DOMAINS]: new Keyv({
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS),
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
[ViolationTypes.BAN]: new Keyv({
store: keyvMongo, store: keyvMongo,
namespace: CacheKeys.BANS, namespace: CacheKeys.ENCODED_DOMAINS,
ttl: cacheConfig.BAN_DURATION, ttl: 0,
}), }),
general: new Keyv({ store: logFile, namespace: 'violations' }),
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION), concurrent: createViolationInstance('concurrent'),
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION), non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES), token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS), registrations: createViolationInstance('registrations'),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE), [ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ), [ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }), [ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES), [ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES), [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES), [ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES), [ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES), ViolationTypes.RESET_PASSWORD_LIMIT,
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
CacheKeys.OPENID_EXCHANGED_TOKENS,
Time.TEN_MINUTES,
), ),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
[CacheKeys.FLOWS]: flows,
}; };
/** /**
@@ -55,10 +106,7 @@ const namespaces = {
*/ */
function getTTLStores() { function getTTLStores() {
return Object.values(namespaces).filter( return Object.values(namespaces).filter(
(store) => (store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
store instanceof Keyv &&
parseInt(store.opts?.ttl ?? '0') > 0 &&
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
); );
} }
@@ -94,18 +142,18 @@ async function clearExpiredFromCache(cache) {
if (data?.expires && data.expires <= expiryTime) { if (data?.expires && data.expires <= expiryTime) {
const deleted = await cache.opts.store.delete(key); const deleted = await cache.opts.store.delete(key);
if (!deleted) { if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`); console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue; continue;
} }
cleared++; cleared++;
} }
} catch (error) { } catch (error) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error); console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
const deleted = await cache.opts.store.delete(key); const deleted = await cache.opts.store.delete(key);
if (!deleted) { if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`); console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue; continue;
} }
@@ -114,7 +162,7 @@ async function clearExpiredFromCache(cache) {
} }
if (cleared > 0) { if (cleared > 0) {
cacheConfig.DEBUG_MEMORY_CACHE && debugMemoryCache &&
console.log( console.log(
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`, `[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
); );
@@ -155,7 +203,7 @@ async function clearAllExpiredFromCache() {
} }
} }
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) { if (!isRedisEnabled && !isEnabled(CI)) {
/** @type {Set<NodeJS.Timeout>} */ /** @type {Set<NodeJS.Timeout>} */
const cleanupIntervals = new Set(); const cleanupIntervals = new Set();
@@ -166,7 +214,7 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
cleanupIntervals.add(cleanup); cleanupIntervals.add(cleanup);
if (cacheConfig.DEBUG_MEMORY_CACHE) { if (debugMemoryCache) {
const monitor = setInterval(() => { const monitor = setInterval(() => {
const ttlStores = getTTLStores(); const ttlStores = getTTLStores();
const memory = process.memoryUsage(); const memory = process.memoryUsage();
@@ -187,13 +235,13 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
} }
const dispose = () => { const dispose = () => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...'); debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
cleanupIntervals.forEach((interval) => clearInterval(interval)); cleanupIntervals.forEach((interval) => clearInterval(interval));
cleanupIntervals.clear(); cleanupIntervals.clear();
// One final cleanup before exit // One final cleanup before exit
clearAllExpiredFromCache().then(() => { clearAllExpiredFromCache().then(() => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed'); debugMemoryCache && console.log('[Cache] Final cleanup completed');
process.exit(0); process.exit(0);
}); });
}; };

92
api/cache/ioredisClient.js vendored Normal file
View File

@@ -0,0 +1,92 @@
const fs = require('fs');
const Redis = require('ioredis');
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
let ioredisClient;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
} else {
ioredisClient = new Redis(REDIS_URI, redisOptions);
}
ioredisClient.on('ready', () => {
logger.info('IoRedis connection ready');
});
ioredisClient.on('reconnecting', () => {
logger.info('IoRedis connection reconnecting');
});
ioredisClient.on('end', () => {
logger.info('IoRedis connection ended');
});
ioredisClient.on('close', () => {
logger.info('IoRedis connection closed');
});
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
ioredisClient.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
);
} else {
logger.info('[Optional] IoRedis not initialized for rate limiters.');
}
module.exports = ioredisClient;

106
api/cache/keyvRedis.js vendored Normal file
View File

@@ -0,0 +1,106 @@
const fs = require('fs');
const ioredis = require('ioredis');
const KeyvRedis = require('@keyv/redis').default;
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
process.env;
let keyvRedis;
const redis_prefix = REDIS_KEY_PREFIX || '';
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
/** @type {import('@keyv/redis').KeyvRedisOptions} */
let keyvOpts = {
useRedisSets: false,
keyPrefix: redis_prefix,
};
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
const cluster = new ioredis.Cluster(hosts, { redisOptions });
keyvRedis = new KeyvRedis(cluster, keyvOpts);
} else {
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
}
const pingInterval = setInterval(() => {
logger.debug('KeyvRedis ping');
keyvRedis.client.ping().catch(err => logger.error('Redis keep-alive ping failed:', err));
}, 5 * 60 * 1000);
keyvRedis.on('ready', () => {
logger.info('KeyvRedis connection ready');
});
keyvRedis.on('reconnecting', () => {
logger.info('KeyvRedis connection reconnecting');
});
keyvRedis.on('end', () => {
logger.info('KeyvRedis connection ended');
});
keyvRedis.on('close', () => {
clearInterval(pingInterval);
logger.info('KeyvRedis connection closed');
});
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
);
} else {
logger.info('[Optional] Redis not initialized.');
}
module.exports = keyvRedis;

View File

@@ -1,5 +1,4 @@
const { isEnabled } = require('~/server/utils'); const { isEnabled } = require('~/server/utils');
const { ViolationTypes } = require('librechat-data-provider');
const getLogStores = require('./getLogStores'); const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation'); const banViolation = require('./banViolation');
@@ -10,14 +9,14 @@ const banViolation = require('./banViolation');
* @param {Object} res - Express response object. * @param {Object} res - Express response object.
* @param {string} type - The type of violation. * @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log. * @param {Object} errorMessage - The error message to log.
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1 * @param {number} [score=1] - The severity of the violation. Defaults to 1
*/ */
const logViolation = async (req, res, type, errorMessage, score = 1) => { const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id; const userId = req.user?.id ?? req.user?._id;
if (!userId) { if (!userId) {
return; return;
} }
const logs = getLogStores(ViolationTypes.GENERAL); const logs = getLogStores('general');
const violationLogs = getLogStores(type); const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId; const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;

View File

@@ -1,57 +0,0 @@
const IoRedis = require('ioredis');
const { cacheConfig } = require('./cacheConfig');
const { createClient, createCluster } = require('@keyv/redis');
const GLOBAL_PREFIX_SEPARATOR = '::';
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
const ca = cacheConfig.REDIS_CA;
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
let ioredisClient = null;
if (cacheConfig.USE_REDIS) {
const redisOptions = {
username: username,
password: password,
tls: ca ? { ca } : undefined,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
};
ioredisClient =
urls.length === 1
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
: new IoRedis.Cluster(cacheConfig.REDIS_URI, { redisOptions });
// Pinging the Redis server every 5 minutes to keep the connection alive
const pingInterval = setInterval(() => ioredisClient.ping(), 5 * 60 * 1000);
ioredisClient.on('close', () => clearInterval(pingInterval));
ioredisClient.on('end', () => clearInterval(pingInterval));
}
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
let keyvRedisClient = null;
if (cacheConfig.USE_REDIS) {
// ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
// The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
const redisOptions = { username, password, socket: { tls: ca != null, ca } };
keyvRedisClient =
urls.length === 1
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
: createCluster({
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
defaults: redisOptions,
});
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
// Pinging the Redis server every 5 minutes to keep the connection alive
const keyvPingInterval = setInterval(() => keyvRedisClient.ping(), 5 * 60 * 1000);
keyvRedisClient.on('disconnect', () => clearInterval(keyvPingInterval));
keyvRedisClient.on('end', () => clearInterval(keyvPingInterval));
}
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };

View File

@@ -1,6 +1,7 @@
const axios = require('axios');
const { EventSource } = require('eventsource'); const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider'); const { Time, CacheKeys } = require('librechat-data-provider');
const { MCPManager, FlowStateManager } = require('@librechat/api'); const { MCPManager, FlowStateManager } = require('librechat-mcp');
const logger = require('./winston'); const logger = require('./winston');
global.EventSource = EventSource; global.EventSource = EventSource;
@@ -15,7 +16,7 @@ let flowManager = null;
*/ */
function getMCPManager(userId) { function getMCPManager(userId) {
if (!mcpManager) { if (!mcpManager) {
mcpManager = MCPManager.getInstance(); mcpManager = MCPManager.getInstance(logger);
} else { } else {
mcpManager.checkIdleConnections(userId); mcpManager.checkIdleConnections(userId);
} }
@@ -30,13 +31,66 @@ function getFlowStateManager(flowsCache) {
if (!flowManager) { if (!flowManager) {
flowManager = new FlowStateManager(flowsCache, { flowManager = new FlowStateManager(flowsCache, {
ttl: Time.ONE_MINUTE * 3, ttl: Time.ONE_MINUTE * 3,
logger,
}); });
} }
return flowManager; return flowManager;
} }
/**
* Sends message data in Server Sent Events format.
* @param {ServerResponse} res - The server response.
* @param {{ data: string | Record<string, unknown>, event?: string }} event - The message event.
* @param {string} event.event - The type of event.
* @param {string} event.data - The message to be sent.
*/
const sendEvent = (res, event) => {
if (typeof event.data === 'string' && event.data.length === 0) {
return;
}
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
};
/**
* Creates and configures an Axios instance with optional proxy settings.
*
* @typedef {import('axios').AxiosInstance} AxiosInstance
* @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig
*
* @returns {AxiosInstance} A configured Axios instance
* @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL
*/
function createAxiosInstance() {
const instance = axios.create();
if (process.env.proxy) {
try {
const url = new URL(process.env.proxy);
/** @type {AxiosProxyConfig} */
const proxyConfig = {
host: url.hostname.replace(/^\[|\]$/g, ''),
protocol: url.protocol.replace(':', ''),
};
if (url.port) {
proxyConfig.port = parseInt(url.port, 10);
}
instance.defaults.proxy = proxyConfig;
} catch (error) {
console.error('Error parsing proxy URL:', error);
throw new Error(`Invalid proxy URL: ${process.env.proxy}`);
}
}
return instance;
}
module.exports = { module.exports = {
logger, logger,
sendEvent,
getMCPManager, getMCPManager,
createAxiosInstance,
getFlowStateManager, getFlowStateManager,
}; };

View File

@@ -1,6 +1,7 @@
import axios from 'axios'; const axios = require('axios');
import { createAxiosInstance } from './axios'; const { createAxiosInstance } = require('./index');
// Mock axios
jest.mock('axios', () => ({ jest.mock('axios', () => ({
interceptors: { interceptors: {
request: { use: jest.fn(), eject: jest.fn() }, request: { use: jest.fn(), eject: jest.fn() },
@@ -19,13 +20,7 @@ jest.mock('axios', () => ({
post: jest.fn().mockResolvedValue({ data: {} }), post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }), put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }), delete: jest.fn().mockResolvedValue({ data: {} }),
reset: jest.fn().mockImplementation(function (this: { reset: jest.fn().mockImplementation(function () {
get: jest.Mock;
post: jest.Mock;
put: jest.Mock;
delete: jest.Mock;
create: jest.Mock;
}) {
this.get.mockClear(); this.get.mockClear();
this.post.mockClear(); this.post.mockClear();
this.put.mockClear(); this.put.mockClear();

View File

@@ -1,8 +0,0 @@
const mongoose = require('mongoose');
const { createModels } = require('@librechat/data-schemas');
const { connectDb } = require('./connect');
const indexSync = require('./indexSync');
createModels(mongoose);
module.exports = { connectDb, indexSync };

View File

@@ -1,174 +0,0 @@
const mongoose = require('mongoose');
const { MeiliSearch } = require('meilisearch');
const { logger } = require('@librechat/data-schemas');
const { FlowStateManager } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const Conversation = mongoose.models.Conversation;
const Message = mongoose.models.Message;
const searchEnabled = isEnabled(process.env.SEARCH);
const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC);
let currentTimeout = null;
class MeiliSearchClient {
static instance = null;
static getInstance() {
if (!MeiliSearchClient.instance) {
if (!process.env.MEILI_HOST || !process.env.MEILI_MASTER_KEY) {
throw new Error('Meilisearch configuration is missing.');
}
MeiliSearchClient.instance = new MeiliSearch({
host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY,
});
}
return MeiliSearchClient.instance;
}
}
/**
* Performs the actual sync operations for messages and conversations
*/
async function performSync() {
const client = MeiliSearchClient.getInstance();
const { status } = await client.health();
if (status !== 'available') {
throw new Error('Meilisearch not available');
}
if (indexingDisabled === true) {
logger.info('[indexSync] Indexing is disabled, skipping...');
return { messagesSync: false, convosSync: false };
}
let messagesSync = false;
let convosSync = false;
// Check if we need to sync messages
const messageProgress = await Message.getSyncProgress();
if (!messageProgress.isComplete) {
logger.info(
`[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`,
);
// Check if we should do a full sync or incremental
const messageCount = await Message.countDocuments();
const messagesIndexed = messageProgress.totalProcessed;
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
if (messageCount - messagesIndexed > syncThreshold) {
logger.info('[indexSync] Starting full message sync due to large difference');
await Message.syncWithMeili();
messagesSync = true;
} else if (messageCount !== messagesIndexed) {
logger.warn('[indexSync] Messages out of sync, performing incremental sync');
await Message.syncWithMeili();
messagesSync = true;
}
} else {
logger.info(
`[indexSync] Messages are fully synced: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments}`,
);
}
// Check if we need to sync conversations
const convoProgress = await Conversation.getSyncProgress();
if (!convoProgress.isComplete) {
logger.info(
`[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`,
);
const convoCount = await Conversation.countDocuments();
const convosIndexed = convoProgress.totalProcessed;
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
if (convoCount - convosIndexed > syncThreshold) {
logger.info('[indexSync] Starting full conversation sync due to large difference');
await Conversation.syncWithMeili();
convosSync = true;
} else if (convoCount !== convosIndexed) {
logger.warn('[indexSync] Convos out of sync, performing incremental sync');
await Conversation.syncWithMeili();
convosSync = true;
}
} else {
logger.info(
`[indexSync] Conversations are fully synced: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments}`,
);
}
return { messagesSync, convosSync };
}
/**
* Main index sync function that uses FlowStateManager to prevent concurrent execution
*/
async function indexSync() {
if (!searchEnabled) {
return;
}
logger.info('[indexSync] Starting index synchronization check...');
try {
// Get or create FlowStateManager instance
const flowsCache = getLogStores(CacheKeys.FLOWS);
if (!flowsCache) {
logger.warn('[indexSync] Flows cache not available, falling back to direct sync');
return await performSync();
}
const flowManager = new FlowStateManager(flowsCache, {
ttl: 60000 * 10, // 10 minutes TTL for sync operations
});
// Use a unique flow ID for the sync operation
const flowId = 'meili-index-sync';
const flowType = 'MEILI_SYNC';
// This will only execute the handler if no other instance is running the sync
const result = await flowManager.createFlowWithHandler(flowId, flowType, performSync);
if (result.messagesSync || result.convosSync) {
logger.info('[indexSync] Sync completed successfully');
} else {
logger.debug('[indexSync] No sync was needed');
}
return result;
} catch (err) {
if (err.message.includes('flow already exists')) {
logger.info('[indexSync] Sync already running on another instance');
return;
}
if (err.message.includes('not found')) {
logger.debug('[indexSync] Creating indices...');
currentTimeout = setTimeout(async () => {
try {
await Message.syncWithMeili();
await Conversation.syncWithMeili();
} catch (err) {
logger.error('[indexSync] Trouble creating indices, try restarting the server.', err);
}
}, 750);
} else if (err.message.includes('Meilisearch not configured')) {
logger.info('[indexSync] Meilisearch not configured, search will be disabled.');
} else {
logger.error('[indexSync] error', err);
}
}
}
process.on('exit', () => {
logger.debug('[indexSync] Clearing sync timeouts before exiting...');
clearTimeout(currentTimeout);
});
module.exports = indexSync;

View File

@@ -1,5 +0,0 @@
const mongoose = require('mongoose');
const { createModels } = require('@librechat/data-schemas');
const models = createModels(mongoose);
module.exports = { ...models };

View File

@@ -11,8 +11,5 @@ module.exports = {
moduleNameMapper: { moduleNameMapper: {
'~/(.*)': '<rootDir>/$1', '~/(.*)': '<rootDir>/$1',
'~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json', '~/data/auth.json': '<rootDir>/__mocks__/auth.mock.json',
'^openid-client/passport$': '<rootDir>/test/__mocks__/openid-client-passport.js', // Mock for the passport strategy part
'^openid-client$': '<rootDir>/test/__mocks__/openid-client.js',
}, },
transformIgnorePatterns: ['/node_modules/(?!(openid-client|oauth4webapi|jose)/).*/'],
}; };

View File

@@ -39,10 +39,7 @@ async function connectDb() {
}); });
} }
cached.conn = await cached.promise; cached.conn = await cached.promise;
return cached.conn; return cached.conn;
} }
module.exports = { module.exports = connectDb;
connectDb,
};

4
api/lib/db/index.js Normal file
View File

@@ -0,0 +1,4 @@
const connectDb = require('./connectDb');
const indexSync = require('./indexSync');
module.exports = { connectDb, indexSync };

89
api/lib/db/indexSync.js Normal file
View File

@@ -0,0 +1,89 @@
const { MeiliSearch } = require('meilisearch');
const { Conversation } = require('~/models/Conversation');
const { Message } = require('~/models/Message');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const searchEnabled = isEnabled(process.env.SEARCH);
const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC);
let currentTimeout = null;
class MeiliSearchClient {
static instance = null;
static getInstance() {
if (!MeiliSearchClient.instance) {
if (!process.env.MEILI_HOST || !process.env.MEILI_MASTER_KEY) {
throw new Error('Meilisearch configuration is missing.');
}
MeiliSearchClient.instance = new MeiliSearch({
host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY,
});
}
return MeiliSearchClient.instance;
}
}
async function indexSync() {
if (!searchEnabled) {
return;
}
try {
const client = MeiliSearchClient.getInstance();
const { status } = await client.health();
if (status !== 'available') {
throw new Error('Meilisearch not available');
}
if (indexingDisabled === true) {
logger.info('[indexSync] Indexing is disabled, skipping...');
return;
}
const messageCount = await Message.countDocuments();
const convoCount = await Conversation.countDocuments();
const messages = await client.index('messages').getStats();
const convos = await client.index('convos').getStats();
const messagesIndexed = messages.numberOfDocuments;
const convosIndexed = convos.numberOfDocuments;
logger.debug(`[indexSync] There are ${messageCount} messages and ${messagesIndexed} indexed`);
logger.debug(`[indexSync] There are ${convoCount} convos and ${convosIndexed} indexed`);
if (messageCount !== messagesIndexed) {
logger.debug('[indexSync] Messages out of sync, indexing');
Message.syncWithMeili();
}
if (convoCount !== convosIndexed) {
logger.debug('[indexSync] Convos out of sync, indexing');
Conversation.syncWithMeili();
}
} catch (err) {
if (err.message.includes('not found')) {
logger.debug('[indexSync] Creating indices...');
currentTimeout = setTimeout(async () => {
try {
await Message.syncWithMeili();
await Conversation.syncWithMeili();
} catch (err) {
logger.error('[indexSync] Trouble creating indices, try restarting the server.', err);
}
}, 750);
} else if (err.message.includes('Meilisearch not configured')) {
logger.info('[indexSync] Meilisearch not configured, search will be disabled.');
} else {
logger.error('[indexSync] error', err);
}
}
}
process.on('exit', () => {
logger.debug('[indexSync] Clearing sync timeouts before exiting...');
clearTimeout(currentTimeout);
});
module.exports = indexSync;

View File

@@ -1,4 +1,7 @@
const { Action } = require('~/db/models'); const mongoose = require('mongoose');
const { actionSchema } = require('@librechat/data-schemas');
const Action = mongoose.model('action', actionSchema);
/** /**
* Update an action with new data without overwriting existing properties, * Update an action with new data without overwriting existing properties,

View File

@@ -1,7 +1,6 @@
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const crypto = require('node:crypto'); const { agentSchema } = require('@librechat/data-schemas');
const { logger } = require('@librechat/data-schemas'); const { SystemRoles, Tools } = require('librechat-data-provider');
const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } = const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
require('librechat-data-provider').Constants; require('librechat-data-provider').Constants;
const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys; const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys;
@@ -11,10 +10,9 @@ const {
removeAgentIdsFromProject, removeAgentIdsFromProject,
removeAgentFromAllProjects, removeAgentFromAllProjects,
} = require('./Project'); } = require('./Project');
const { getCachedTools } = require('~/server/services/Config');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { getActions } = require('./Action');
const { Agent } = require('~/db/models'); const Agent = mongoose.model('agent', agentSchema);
/** /**
* Create an agent with the provided data. * Create an agent with the provided data.
@@ -23,19 +21,7 @@ const { Agent } = require('~/db/models');
* @throws {Error} If the agent creation fails. * @throws {Error} If the agent creation fails.
*/ */
const createAgent = async (agentData) => { const createAgent = async (agentData) => {
const { author, ...versionData } = agentData; return (await Agent.create(agentData)).toObject();
const timestamp = new Date();
const initialAgentData = {
...agentData,
versions: [
{
...versionData,
createdAt: timestamp,
updatedAt: timestamp,
},
],
};
return (await Agent.create(initialAgentData)).toObject();
}; };
/** /**
@@ -56,26 +42,18 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
* @param {string} params.agent_id * @param {string} params.agent_id
* @param {string} params.endpoint * @param {string} params.endpoint
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters] * @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found. * @returns {Agent|null} The agent document as a plain object, or null if not found.
*/ */
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => { const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => {
const { model, ...model_parameters } = _m; const { model, ...model_parameters } = _m;
/** @type {Record<string, FunctionTool>} */ /** @type {Record<string, FunctionTool>} */
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true }); const availableTools = req.app.locals.availableTools;
/** @type {TEphemeralAgent | null} */ const mcpServers = new Set(req.body.ephemeralAgent?.mcp);
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
/** @type {string[]} */ /** @type {string[]} */
const tools = []; const tools = [];
if (ephemeralAgent?.execute_code === true) { if (req.body.ephemeralAgent?.execute_code === true) {
tools.push(Tools.execute_code); tools.push(Tools.execute_code);
} }
if (ephemeralAgent?.file_search === true) {
tools.push(Tools.file_search);
}
if (ephemeralAgent?.web_search === true) {
tools.push(Tools.web_search);
}
if (mcpServers.size > 0) { if (mcpServers.size > 0) {
for (const toolName of Object.keys(availableTools)) { for (const toolName of Object.keys(availableTools)) {
@@ -90,7 +68,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
} }
const instructions = req.body.promptPrefix; const instructions = req.body.promptPrefix;
const result = { return {
id: agent_id, id: agent_id,
instructions, instructions,
provider: endpoint, provider: endpoint,
@@ -98,11 +76,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
model, model,
tools, tools,
}; };
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
result.artifacts = ephemeralAgent.artifacts;
}
return result;
}; };
/** /**
@@ -120,7 +93,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
return null; return null;
} }
if (agent_id === EPHEMERAL_AGENT_ID) { if (agent_id === EPHEMERAL_AGENT_ID) {
return await loadEphemeralAgent({ req, agent_id, endpoint, model_parameters }); return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters });
} }
const agent = await getAgent({ const agent = await getAgent({
id: agent_id, id: agent_id,
@@ -130,8 +103,6 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
return null; return null;
} }
agent.version = agent.versions ? agent.versions.length : 0;
if (agent.author.toString() === req.user.id) { if (agent.author.toString() === req.user.id) {
return agent; return agent;
} }
@@ -156,205 +127,19 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
} }
}; };
/**
* Check if a version already exists in the versions array, excluding timestamp and author fields
* @param {Object} updateData - The update data to compare
* @param {Object} currentData - The current agent data
* @param {Array} versions - The existing versions array
* @param {string} [actionsHash] - Hash of current action metadata
* @returns {Object|null} - The matching version if found, null otherwise
*/
const isDuplicateVersion = (updateData, currentData, versions, actionsHash = null) => {
if (!versions || versions.length === 0) {
return null;
}
const excludeFields = [
'_id',
'id',
'createdAt',
'updatedAt',
'author',
'updatedBy',
'created_at',
'updated_at',
'__v',
'versions',
'actionsHash', // Exclude actionsHash from direct comparison
];
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
if (Object.keys(directUpdates).length === 0 && !actionsHash) {
return null;
}
const wouldBeVersion = { ...currentData, ...directUpdates };
const lastVersion = versions[versions.length - 1];
if (actionsHash && lastVersion.actionsHash !== actionsHash) {
return null;
}
const allFields = new Set([...Object.keys(wouldBeVersion), ...Object.keys(lastVersion)]);
const importantFields = Array.from(allFields).filter((field) => !excludeFields.includes(field));
let isMatch = true;
for (const field of importantFields) {
if (!wouldBeVersion[field] && !lastVersion[field]) {
continue;
}
if (Array.isArray(wouldBeVersion[field]) && Array.isArray(lastVersion[field])) {
if (wouldBeVersion[field].length !== lastVersion[field].length) {
isMatch = false;
break;
}
// Special handling for projectIds (MongoDB ObjectIds)
if (field === 'projectIds') {
const wouldBeIds = wouldBeVersion[field].map((id) => id.toString()).sort();
const versionIds = lastVersion[field].map((id) => id.toString()).sort();
if (!wouldBeIds.every((id, i) => id === versionIds[i])) {
isMatch = false;
break;
}
}
// Handle arrays of objects like tool_kwargs
else if (typeof wouldBeVersion[field][0] === 'object' && wouldBeVersion[field][0] !== null) {
const sortedWouldBe = [...wouldBeVersion[field]].map((item) => JSON.stringify(item)).sort();
const sortedVersion = [...lastVersion[field]].map((item) => JSON.stringify(item)).sort();
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
isMatch = false;
break;
}
} else {
const sortedWouldBe = [...wouldBeVersion[field]].sort();
const sortedVersion = [...lastVersion[field]].sort();
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
isMatch = false;
break;
}
}
} else if (field === 'model_parameters') {
const wouldBeParams = wouldBeVersion[field] || {};
const lastVersionParams = lastVersion[field] || {};
if (JSON.stringify(wouldBeParams) !== JSON.stringify(lastVersionParams)) {
isMatch = false;
break;
}
} else if (wouldBeVersion[field] !== lastVersion[field]) {
isMatch = false;
break;
}
}
return isMatch ? lastVersion : null;
};
/** /**
* Update an agent with new data without overwriting existing * Update an agent with new data without overwriting existing
* properties, or create a new agent if it doesn't exist. * properties, or create a new agent if it doesn't exist.
* When an agent is updated, a copy of the current state will be saved to the versions array.
* *
* @param {Object} searchParameter - The search parameters to find the agent to update. * @param {Object} searchParameter - The search parameters to find the agent to update.
* @param {string} searchParameter.id - The ID of the agent to update. * @param {string} searchParameter.id - The ID of the agent to update.
* @param {string} [searchParameter.author] - The user ID of the agent's author. * @param {string} [searchParameter.author] - The user ID of the agent's author.
* @param {Object} updateData - An object containing the properties to update. * @param {Object} updateData - An object containing the properties to update.
* @param {Object} [options] - Optional configuration object.
* @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates).
* @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed.
* @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing).
* @returns {Promise<Agent>} The updated or newly created agent document as a plain object. * @returns {Promise<Agent>} The updated or newly created agent document as a plain object.
* @throws {Error} If the update would create a duplicate version
*/ */
const updateAgent = async (searchParameter, updateData, options = {}) => { const updateAgent = async (searchParameter, updateData) => {
const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options; const options = { new: true, upsert: false };
const mongoOptions = { new: true, upsert: false }; return Agent.findOneAndUpdate(searchParameter, updateData, options).lean();
const currentAgent = await Agent.findOne(searchParameter);
if (currentAgent) {
const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject();
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
let actionsHash = null;
// Generate actions hash if agent has actions
if (currentAgent.actions && currentAgent.actions.length > 0) {
// Extract action IDs from the format "domain_action_id"
const actionIds = currentAgent.actions
.map((action) => {
const parts = action.split(actionDelimiter);
return parts[1]; // Get just the action ID part
})
.filter(Boolean);
if (actionIds.length > 0) {
try {
const actions = await getActions(
{
action_id: { $in: actionIds },
},
true,
); // Include sensitive data for hash
actionsHash = await generateActionMetadataHash(currentAgent.actions, actions);
} catch (error) {
logger.error('Error fetching actions for hash generation:', error);
}
}
}
const shouldCreateVersion =
!skipVersioning &&
(forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet);
if (shouldCreateVersion) {
const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash);
if (duplicateVersion && !forceVersion) {
const error = new Error(
'Duplicate version: This would create a version identical to an existing one',
);
error.statusCode = 409;
error.details = {
duplicateVersion,
versionIndex: versions.findIndex(
(v) => JSON.stringify(duplicateVersion) === JSON.stringify(v),
),
};
throw error;
}
}
const versionEntry = {
...versionData,
...directUpdates,
updatedAt: new Date(),
};
// Include actions hash in version if available
if (actionsHash) {
versionEntry.actionsHash = actionsHash;
}
// Always store updatedBy field to track who made the change
if (updatingUserId) {
versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId);
}
if (shouldCreateVersion) {
updateData.$push = {
...($push || {}),
versions: versionEntry,
};
}
}
return Agent.findOneAndUpdate(searchParameter, updateData, mongoOptions).lean();
}; };
/** /**
@@ -366,7 +151,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
* @param {string} params.file_id * @param {string} params.file_id
* @returns {Promise<Agent>} The updated agent. * @returns {Promise<Agent>} The updated agent.
*/ */
const addAgentResourceFile = async ({ req, agent_id, tool_resource, file_id }) => { const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => {
const searchParameter = { id: agent_id }; const searchParameter = { id: agent_id };
let agent = await getAgent(searchParameter); let agent = await getAgent(searchParameter);
if (!agent) { if (!agent) {
@@ -392,9 +177,7 @@ const addAgentResourceFile = async ({ req, agent_id, tool_resource, file_id }) =
}, },
}; };
const updatedAgent = await updateAgent(searchParameter, updateData, { const updatedAgent = await updateAgent(searchParameter, updateData);
updatingUserId: req?.user?.id,
});
if (updatedAgent) { if (updatedAgent) {
return updatedAgent; return updatedAgent;
} else { } else {
@@ -486,6 +269,7 @@ const getListAgents = async (searchParameter) => {
delete globalQuery.author; delete globalQuery.author;
query = { $or: [globalQuery, query] }; query = { $or: [globalQuery, query] };
} }
const agents = ( const agents = (
await Agent.find(query, { await Agent.find(query, {
id: 1, id: 1,
@@ -557,10 +341,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds
delete updateQuery.author; delete updateQuery.author;
} }
const updatedAgent = await updateAgent(updateQuery, updateOps, { const updatedAgent = await updateAgent(updateQuery, updateOps);
updatingUserId: user.id,
skipVersioning: true,
});
if (updatedAgent) { if (updatedAgent) {
return updatedAgent; return updatedAgent;
} }
@@ -577,107 +358,15 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds
return await getAgent({ id: agentId }); return await getAgent({ id: agentId });
}; };
/**
* Reverts an agent to a specific version in its version history.
* @param {Object} searchParameter - The search parameters to find the agent to revert.
* @param {string} searchParameter.id - The ID of the agent to revert.
* @param {string} [searchParameter.author] - The user ID of the agent's author.
* @param {number} versionIndex - The index of the version to revert to in the versions array.
* @returns {Promise<MongoAgent>} The updated agent document after reverting.
* @throws {Error} If the agent is not found or the specified version does not exist.
*/
const revertAgentVersion = async (searchParameter, versionIndex) => {
const agent = await Agent.findOne(searchParameter);
if (!agent) {
throw new Error('Agent not found');
}
if (!agent.versions || !agent.versions[versionIndex]) {
throw new Error(`Version ${versionIndex} not found`);
}
const revertToVersion = agent.versions[versionIndex];
const updateData = {
...revertToVersion,
};
delete updateData._id;
delete updateData.id;
delete updateData.versions;
delete updateData.author;
delete updateData.updatedBy;
return Agent.findOneAndUpdate(searchParameter, updateData, { new: true }).lean();
};
/**
* Generates a hash of action metadata for version comparison
* @param {string[]} actionIds - Array of action IDs in format "domain_action_id"
* @param {Action[]} actions - Array of action documents
* @returns {Promise<string>} - SHA256 hash of the action metadata
*/
const generateActionMetadataHash = async (actionIds, actions) => {
if (!actionIds || actionIds.length === 0) {
return '';
}
// Create a map of action_id to metadata for quick lookup
const actionMap = new Map();
actions.forEach((action) => {
actionMap.set(action.action_id, action.metadata);
});
// Sort action IDs for consistent hashing
const sortedActionIds = [...actionIds].sort();
// Build a deterministic string representation of all action metadata
const metadataString = sortedActionIds
.map((actionFullId) => {
// Extract just the action_id part (after the delimiter)
const parts = actionFullId.split(actionDelimiter);
const actionId = parts[1];
const metadata = actionMap.get(actionId);
if (!metadata) {
return `${actionId}:null`;
}
// Sort metadata keys for deterministic output
const sortedKeys = Object.keys(metadata).sort();
const metadataStr = sortedKeys
.map((key) => `${key}:${JSON.stringify(metadata[key])}`)
.join(',');
return `${actionId}:{${metadataStr}}`;
})
.join(';');
// Use Web Crypto API to generate hash
const encoder = new TextEncoder();
const data = encoder.encode(metadataString);
const hashBuffer = await crypto.webcrypto.subtle.digest('SHA-256', data);
const hashArray = Array.from(new Uint8Array(hashBuffer));
const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('');
return hashHex;
};
/**
* Load a default agent based on the endpoint
* @param {string} endpoint
* @returns {Agent | null}
*/
module.exports = { module.exports = {
Agent,
getAgent, getAgent,
loadAgent, loadAgent,
createAgent, createAgent,
updateAgent, updateAgent,
deleteAgent, deleteAgent,
getListAgents, getListAgents,
revertAgentVersion,
updateAgentProjects, updateAgentProjects,
addAgentResourceFile, addAgentResourceFile,
removeAgentResourceFiles, removeAgentResourceFiles,
generateActionMetadataHash,
}; };

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,7 @@
const { Assistant } = require('~/db/models'); const mongoose = require('mongoose');
const { assistantSchema } = require('@librechat/data-schemas');
const Assistant = mongoose.model('assistant', assistantSchema);
/** /**
* Update an assistant with new data without overwriting existing properties, * Update an assistant with new data without overwriting existing properties,

4
api/models/Balance.js Normal file
View File

@@ -0,0 +1,4 @@
const mongoose = require('mongoose');
const { balanceSchema } = require('@librechat/data-schemas');
module.exports = mongoose.model('Balance', balanceSchema);

View File

@@ -1,5 +1,8 @@
const { logger } = require('@librechat/data-schemas'); const mongoose = require('mongoose');
const { Banner } = require('~/db/models'); const logger = require('~/config/winston');
const { bannerSchema } = require('@librechat/data-schemas');
const Banner = mongoose.model('Banner', bannerSchema);
/** /**
* Retrieves the current active banner. * Retrieves the current active banner.
@@ -25,4 +28,4 @@ const getBanner = async (user) => {
} }
}; };
module.exports = { getBanner }; module.exports = { Banner, getBanner };

86
api/models/Config.js Normal file
View File

@@ -0,0 +1,86 @@
const mongoose = require('mongoose');
const { logger } = require('~/config');
const major = [0, 0];
const minor = [0, 0];
const patch = [0, 5];
const configSchema = mongoose.Schema(
{
tag: {
type: String,
required: true,
validate: {
validator: function (tag) {
const [part1, part2, part3] = tag.replace('v', '').split('.').map(Number);
// Check if all parts are numbers
if (isNaN(part1) || isNaN(part2) || isNaN(part3)) {
return false;
}
// Check if all parts are within their respective ranges
if (part1 < major[0] || part1 > major[1]) {
return false;
}
if (part2 < minor[0] || part2 > minor[1]) {
return false;
}
if (part3 < patch[0] || part3 > patch[1]) {
return false;
}
return true;
},
message: 'Invalid tag value',
},
},
searchEnabled: {
type: Boolean,
default: false,
},
usersEnabled: {
type: Boolean,
default: false,
},
startupCounts: {
type: Number,
default: 0,
},
},
{ timestamps: true },
);
// Instance method
configSchema.methods.incrementCount = function () {
this.startupCounts += 1;
};
// Static methods
configSchema.statics.findByTag = async function (tag) {
return await this.findOne({ tag }).lean();
};
configSchema.statics.updateByTag = async function (tag, update) {
return await this.findOneAndUpdate({ tag }, update, { new: true });
};
const Config = mongoose.models.Config || mongoose.model('Config', configSchema);
module.exports = {
getConfigs: async (filter) => {
try {
return await Config.find(filter).lean();
} catch (error) {
logger.error('Error getting configs', error);
return { config: 'Error getting configs' };
}
},
deleteConfigs: async (filter) => {
try {
return await Config.deleteMany(filter);
} catch (error) {
logger.error('Error deleting configs', error);
return { config: 'Error deleting configs' };
}
},
};

View File

@@ -1,8 +1,6 @@
const { logger } = require('@librechat/data-schemas'); const Conversation = require('./schema/convoSchema');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getMessages, deleteMessages } = require('./Message'); const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models'); const logger = require('~/config/winston');
/** /**
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user. * Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
@@ -77,6 +75,7 @@ const getConvoFiles = async (conversationId) => {
}; };
module.exports = { module.exports = {
Conversation,
getConvoFiles, getConvoFiles,
searchConversation, searchConversation,
deleteNullOrEmptyConversations, deleteNullOrEmptyConversations,
@@ -100,15 +99,10 @@ module.exports = {
update.conversationId = newConversationId; update.conversationId = newConversationId;
} }
if (req?.body?.isTemporary) { if (req.body.isTemporary) {
try { const expiredAt = new Date();
const customConfig = await getCustomConfig(); expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = createTempChatExpirationDate(customConfig); update.expiredAt = expiredAt;
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
update.expiredAt = null;
}
} else { } else {
update.expiredAt = null; update.expiredAt = null;
} }
@@ -161,6 +155,7 @@ module.exports = {
{ cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {}, { cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {},
) => { ) => {
const filters = [{ user }]; const filters = [{ user }];
if (isArchived) { if (isArchived) {
filters.push({ isArchived: true }); filters.push({ isArchived: true });
} else { } else {
@@ -293,6 +288,7 @@ module.exports = {
deleteConvos: async (user, filter) => { deleteConvos: async (user, filter) => {
try { try {
const userFilter = { ...filter, user }; const userFilter = { ...filter, user };
const conversations = await Conversation.find(userFilter).select('conversationId'); const conversations = await Conversation.find(userFilter).select('conversationId');
const conversationIds = conversations.map((c) => c.conversationId); const conversationIds = conversations.map((c) => c.conversationId);

View File

@@ -1,5 +1,10 @@
const { logger } = require('@librechat/data-schemas'); const mongoose = require('mongoose');
const { ConversationTag, Conversation } = require('~/db/models'); const Conversation = require('./schema/convoSchema');
const logger = require('~/config/winston');
const { conversationTagSchema } = require('@librechat/data-schemas');
const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema);
/** /**
* Retrieves all conversation tags for a user. * Retrieves all conversation tags for a user.
@@ -135,13 +140,13 @@ const adjustPositions = async (user, oldPosition, newPosition) => {
const position = const position =
oldPosition < newPosition oldPosition < newPosition
? { ? {
$gt: Math.min(oldPosition, newPosition), $gt: Math.min(oldPosition, newPosition),
$lte: Math.max(oldPosition, newPosition), $lte: Math.max(oldPosition, newPosition),
} }
: { : {
$gte: Math.min(oldPosition, newPosition), $gte: Math.min(oldPosition, newPosition),
$lt: Math.max(oldPosition, newPosition), $lt: Math.max(oldPosition, newPosition),
}; };
await ConversationTag.updateMany( await ConversationTag.updateMany(
{ {

View File

@@ -1,8 +1,9 @@
const { logger } = require('@librechat/data-schemas'); const mongoose = require('mongoose');
const { EToolResources, FileContext, Constants } = require('librechat-data-provider'); const { EToolResources } = require('librechat-data-provider');
const { getProjectByName } = require('./Project'); const { fileSchema } = require('@librechat/data-schemas');
const { getAgent } = require('./Agent'); const { logger } = require('~/config');
const { File } = require('~/db/models');
const File = mongoose.model('File', fileSchema);
/** /**
* Finds a file by its file_id with additional query options. * Finds a file by its file_id with additional query options.
@@ -14,124 +15,17 @@ const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean(); return await File.findOne({ file_id, ...options }).lean();
}; };
/**
* Checks if a user has access to multiple files through a shared agent (batch operation)
* @param {string} userId - The user ID to check access for
* @param {string[]} fileIds - Array of file IDs to check
* @param {string} agentId - The agent ID that might grant access
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false));
try {
const agent = await getAgent({ id: agentId });
if (!agent) {
return accessMap;
}
// Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId) {
fileIds.forEach((fileId) => accessMap.set(fileId, true));
return accessMap;
}
// Check if agent is shared with the user via projects
if (!agent.projectIds || agent.projectIds.length === 0) {
return accessMap;
}
// Check if agent is in global project
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
) {
return accessMap;
}
// Agent is globally shared - check if it's collaborative
if (checkCollaborative && !agent.isCollaborative) {
return accessMap;
}
// Check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap;
} catch (error) {
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
return accessMap;
}
};
/** /**
* Retrieves files matching a given filter, sorted by the most recently updated. * Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply. * @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters. * @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results. * @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field. * Default excludes the 'text' field.
* @param {Object} [options] - Additional options
* @param {string} [options.userId] - User ID for access control
* @param {string} [options.agentId] - Agent ID that might grant access to files
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents. * @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
*/ */
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => { const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions }; const sortOptions = { updatedAt: -1, ..._sortOptions };
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean(); return await File.find(filter).select(selectFields).sort(sortOptions).lean();
// If userId and agentId are provided, filter files based on access
if (options.userId && options.agentId) {
// Collect file IDs that need access check
const filesToCheck = [];
const ownedFiles = [];
for (const file of files) {
if (file.user && file.user.toString() === options.userId) {
ownedFiles.push(file);
} else {
filesToCheck.push(file);
}
}
if (filesToCheck.length === 0) {
return ownedFiles;
}
// Batch check access for all non-owned files
const fileIds = filesToCheck.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(
options.userId,
fileIds,
options.agentId,
false,
);
// Filter files based on access
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
return [...ownedFiles, ...accessibleFiles];
}
return files;
}; };
/** /**
@@ -141,19 +35,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, option
* @returns {Promise<Array<MongoFile>>} Files that match the criteria * @returns {Promise<Array<MongoFile>>} Files that match the criteria
*/ */
const getToolFilesByIds = async (fileIds, toolResourceSet) => { const getToolFilesByIds = async (fileIds, toolResourceSet) => {
if (!fileIds || !fileIds.length || !toolResourceSet?.size) { if (!fileIds || !fileIds.length) {
return []; return [];
} }
try { try {
const filter = { const filter = {
file_id: { $in: fileIds }, file_id: { $in: fileIds },
$or: [],
}; };
if (toolResourceSet.has(EToolResources.ocr)) { if (toolResourceSet.size) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); filter.$or = [];
} }
if (toolResourceSet.has(EToolResources.file_search)) { if (toolResourceSet.has(EToolResources.file_search)) {
filter.$or.push({ embedded: true }); filter.$or.push({ embedded: true });
} }
@@ -275,6 +169,7 @@ async function batchUpdateFiles(updates) {
} }
module.exports = { module.exports = {
File,
findFileById, findFileById,
getFiles, getFiles,
getToolFilesByIds, getToolFilesByIds,
@@ -285,5 +180,4 @@ module.exports = {
deleteFiles, deleteFiles,
deleteFileByFilter, deleteFileByFilter,
batchUpdateFiles, batchUpdateFiles,
hasAccessToFilesViaAgent,
}; };

View File

@@ -1,264 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { fileSchema } = require('@librechat/data-schemas');
const { agentSchema } = require('@librechat/data-schemas');
const { projectSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { getFiles, createFile } = require('./File');
const { getProjectByName } = require('./Project');
const { createAgent } = require('./Agent');
let File;
let Agent;
let Project;
describe('File Access Control', () => {
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
File = mongoose.models.File || mongoose.model('File', fileSchema);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await File.deleteMany({});
await Agent.deleteMany({});
await Project.deleteMany({});
});
describe('hasAccessToFilesViaAgent', () => {
it('should efficiently check access for multiple files at once', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
// Create files
for (const fileId of fileIds) {
await createFile({
user: authorId,
file_id: fileId,
filename: `file-${fileId}.txt`,
filepath: `/uploads/${fileId}`,
});
}
// Create agent with only first two files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileIds[0], fileIds[1]],
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for all files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have access only to the first two files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false);
expect(accessMap.get(fileIds[3])).toBe(false);
});
it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create agent
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]], // Only one file attached
},
},
});
// Check access as the author
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should handle non-existent agent gracefully', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const fileIds = [uuidv4(), uuidv4()];
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
// Should have no access to any files
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when agent is not collaborative', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create agent with files but isCollaborative: false
await createAgent({
id: agentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have no access to any files when isCollaborative is false
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
});
describe('getFiles with agent access control', () => {
test('should return files owned by user and files accessible through agent', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
const ownedFileId = `file_${uuidv4()}`;
const sharedFileId = `file_${uuidv4()}`;
const inaccessibleFileId = `file_${uuidv4()}`;
// Create/get global project using getProjectByName which will upsert
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
// Create agent with shared file
await createAgent({
id: agentId,
name: 'Shared Agent',
provider: 'test',
model: 'test-model',
author: authorId,
projectIds: [globalProject._id],
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [sharedFileId],
},
},
});
// Create files
await createFile({
file_id: ownedFileId,
user: userId,
filename: 'owned.txt',
filepath: '/uploads/owned.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: sharedFileId,
user: authorId,
filename: 'shared.txt',
filepath: '/uploads/shared.txt',
type: 'text/plain',
bytes: 200,
embedded: true,
});
await createFile({
file_id: inaccessibleFileId,
user: authorId,
filename: 'inaccessible.txt',
filepath: '/uploads/inaccessible.txt',
type: 'text/plain',
bytes: 300,
});
// Get files with access control
const files = await getFiles(
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
null,
{ text: 0 },
{ userId: userId.toString(), agentId },
);
expect(files).toHaveLength(2);
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
});
test('should return all files when no userId/agentId provided', async () => {
const userId = new mongoose.Types.ObjectId();
const fileId1 = `file_${uuidv4()}`;
const fileId2 = `file_${uuidv4()}`;
await createFile({
file_id: fileId1,
user: userId,
filename: 'file1.txt',
filepath: '/uploads/file1.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: fileId2,
user: new mongoose.Types.ObjectId(),
filename: 'file2.txt',
filepath: '/uploads/file2.txt',
type: 'text/plain',
bytes: 200,
});
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
expect(files).toHaveLength(2);
});
});
});

4
api/models/Key.js Normal file
View File

@@ -0,0 +1,4 @@
const mongoose = require('mongoose');
const { keySchema } = require('@librechat/data-schemas');
module.exports = mongoose.model('Key', keySchema);

View File

@@ -1,8 +1,6 @@
const { z } = require('zod'); const { z } = require('zod');
const { logger } = require('@librechat/data-schemas'); const Message = require('./schema/messageSchema');
const { createTempChatExpirationDate } = require('@librechat/api'); const { logger } = require('~/config');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { Message } = require('~/db/models');
const idSchema = z.string().uuid(); const idSchema = z.string().uuid();
@@ -56,14 +54,9 @@ async function saveMessage(req, params, metadata) {
}; };
if (req?.body?.isTemporary) { if (req?.body?.isTemporary) {
try { const expiredAt = new Date();
const customConfig = await getCustomConfig(); expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = createTempChatExpirationDate(customConfig); update.expiredAt = expiredAt;
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.expiredAt = null;
}
} else { } else {
update.expiredAt = null; update.expiredAt = null;
} }
@@ -75,6 +68,7 @@ async function saveMessage(req, params, metadata) {
logger.info(`---\`saveMessage\` context: ${metadata?.context}`); logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.tokenCount = 0; update.tokenCount = 0;
} }
const message = await Message.findOneAndUpdate( const message = await Message.findOneAndUpdate(
{ messageId: params.messageId, user: req.user.id }, { messageId: params.messageId, user: req.user.id },
update, update,
@@ -146,6 +140,7 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) {
upsert: true, upsert: true,
}, },
})); }));
const result = await Message.bulkWrite(bulkOps); const result = await Message.bulkWrite(bulkOps);
return result; return result;
} catch (err) { } catch (err) {
@@ -260,7 +255,6 @@ async function updateMessage(req, message, metadata) {
text: updatedMessage.text, text: updatedMessage.text,
isCreatedByUser: updatedMessage.isCreatedByUser, isCreatedByUser: updatedMessage.isCreatedByUser,
tokenCount: updatedMessage.tokenCount, tokenCount: updatedMessage.tokenCount,
feedback: updatedMessage.feedback,
}; };
} catch (err) { } catch (err) {
logger.error('Error updating message:', err); logger.error('Error updating message:', err);
@@ -361,6 +355,7 @@ async function deleteMessages(filter) {
} }
module.exports = { module.exports = {
Message,
saveMessage, saveMessage,
bulkSaveMessages, bulkSaveMessages,
recordMessage, recordMessage,

View File

@@ -1,7 +1,32 @@
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { v4: uuidv4 } = require('uuid'); const { v4: uuidv4 } = require('uuid');
const { messageSchema } = require('@librechat/data-schemas');
jest.mock('mongoose');
const mockFindQuery = {
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockReturnThis(),
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
};
const mockSchema = {
findOneAndUpdate: jest.fn(),
updateOne: jest.fn(),
findOne: jest.fn(() => ({
lean: jest.fn(),
})),
find: jest.fn(() => mockFindQuery),
deleteMany: jest.fn(),
};
mongoose.model.mockReturnValue(mockSchema);
jest.mock('~/models/schema/messageSchema', () => mockSchema);
jest.mock('~/config/winston', () => ({
error: jest.fn(),
}));
const { const {
saveMessage, saveMessage,
@@ -10,102 +35,77 @@ const {
deleteMessages, deleteMessages,
updateMessageText, updateMessageText,
deleteMessagesSince, deleteMessagesSince,
} = require('./Message'); } = require('~/models/Message');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IMessage>}
*/
let Message;
describe('Message Operations', () => { describe('Message Operations', () => {
let mongoServer;
let mockReq; let mockReq;
let mockMessageData; let mockMessage;
beforeAll(async () => { beforeEach(() => {
mongoServer = await MongoMemoryServer.create(); jest.clearAllMocks();
const mongoUri = mongoServer.getUri();
Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
// Clear database
await Message.deleteMany({});
mockReq = { mockReq = {
user: { id: 'user123' }, user: { id: 'user123' },
}; };
mockMessageData = { mockMessage = {
messageId: 'msg123', messageId: 'msg123',
conversationId: uuidv4(), conversationId: uuidv4(),
text: 'Hello, world!', text: 'Hello, world!',
user: 'user123', user: 'user123',
}; };
mockSchema.findOneAndUpdate.mockResolvedValue({
toObject: () => mockMessage,
});
}); });
describe('saveMessage', () => { describe('saveMessage', () => {
it('should save a message for an authenticated user', async () => { it('should save a message for an authenticated user', async () => {
const result = await saveMessage(mockReq, mockMessageData); const result = await saveMessage(mockReq, mockMessage);
expect(result).toEqual(mockMessage);
expect(result.messageId).toBe('msg123'); expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
expect(result.user).toBe('user123'); { messageId: 'msg123', user: 'user123' },
expect(result.text).toBe('Hello, world!'); expect.objectContaining({ user: 'user123' }),
expect.any(Object),
// Verify the message was actually saved to the database );
const savedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
expect(savedMessage).toBeTruthy();
expect(savedMessage.text).toBe('Hello, world!');
}); });
it('should throw an error for unauthenticated user', async () => { it('should throw an error for unauthenticated user', async () => {
mockReq.user = null; mockReq.user = null;
await expect(saveMessage(mockReq, mockMessageData)).rejects.toThrow('User not authenticated'); await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
}); });
it('should handle invalid conversation ID gracefully', async () => { it('should throw an error for invalid conversation ID', async () => {
mockMessageData.conversationId = 'invalid-id'; mockMessage.conversationId = 'invalid-id';
const result = await saveMessage(mockReq, mockMessageData); await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined();
expect(result).toBeUndefined();
}); });
}); });
describe('updateMessageText', () => { describe('updateMessageText', () => {
it('should update message text for the authenticated user', async () => { it('should update message text for the authenticated user', async () => {
// First save a message
await saveMessage(mockReq, mockMessageData);
// Then update it
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' }); await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(mockSchema.updateOne).toHaveBeenCalledWith(
// Verify the update { messageId: 'msg123', user: 'user123' },
const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); { text: 'Updated text' },
expect(updatedMessage.text).toBe('Updated text'); );
}); });
}); });
describe('updateMessage', () => { describe('updateMessage', () => {
it('should update a message for the authenticated user', async () => { it('should update a message for the authenticated user', async () => {
// First save a message mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
await saveMessage(mockReq, mockMessageData);
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' }); const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(result).toEqual(
expect(result.messageId).toBe('msg123'); expect.objectContaining({
expect(result.text).toBe('Updated text'); messageId: 'msg123',
text: 'Hello, world!',
// Verify in database }),
const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); );
expect(updatedMessage.text).toBe('Updated text');
}); });
it('should throw an error if message is not found', async () => { it('should throw an error if message is not found', async () => {
mockSchema.findOneAndUpdate.mockResolvedValue(null);
await expect( await expect(
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }), updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
).rejects.toThrow('Message not found or user not authorized.'); ).rejects.toThrow('Message not found or user not authorized.');
@@ -114,45 +114,19 @@ describe('Message Operations', () => {
describe('deleteMessagesSince', () => { describe('deleteMessagesSince', () => {
it('should delete messages only for the authenticated user', async () => { it('should delete messages only for the authenticated user', async () => {
const conversationId = uuidv4(); mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
// Create multiple messages in the same conversation const result = await deleteMessagesSince(mockReq, {
const message1 = await saveMessage(mockReq, { messageId: 'msg123',
messageId: 'msg1', conversationId: 'convo123',
conversationId,
text: 'First message',
user: 'user123',
}); });
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
const message2 = await saveMessage(mockReq, { expect(mockSchema.find).not.toHaveBeenCalled();
messageId: 'msg2', expect(result).toBeUndefined();
conversationId,
text: 'Second message',
user: 'user123',
});
const message3 = await saveMessage(mockReq, {
messageId: 'msg3',
conversationId,
text: 'Third message',
user: 'user123',
});
// Delete messages since message2 (this should only delete messages created AFTER msg2)
await deleteMessagesSince(mockReq, {
messageId: 'msg2',
conversationId,
});
// Verify msg1 and msg2 remain, msg3 is deleted
const remainingMessages = await Message.find({ conversationId, user: 'user123' });
expect(remainingMessages).toHaveLength(2);
expect(remainingMessages.map((m) => m.messageId)).toContain('msg1');
expect(remainingMessages.map((m) => m.messageId)).toContain('msg2');
expect(remainingMessages.map((m) => m.messageId)).not.toContain('msg3');
}); });
it('should return undefined if no message is found', async () => { it('should return undefined if no message is found', async () => {
mockSchema.findOne().lean.mockResolvedValueOnce(null);
const result = await deleteMessagesSince(mockReq, { const result = await deleteMessagesSince(mockReq, {
messageId: 'nonexistent', messageId: 'nonexistent',
conversationId: 'convo123', conversationId: 'convo123',
@@ -163,71 +137,29 @@ describe('Message Operations', () => {
describe('getMessages', () => { describe('getMessages', () => {
it('should retrieve messages with the correct filter', async () => { it('should retrieve messages with the correct filter', async () => {
const conversationId = uuidv4(); const filter = { conversationId: 'convo123' };
await getMessages(filter);
// Save some messages expect(mockSchema.find).toHaveBeenCalledWith(filter);
await saveMessage(mockReq, { expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
messageId: 'msg1', expect(mockFindQuery.lean).toHaveBeenCalled();
conversationId,
text: 'First message',
user: 'user123',
});
await saveMessage(mockReq, {
messageId: 'msg2',
conversationId,
text: 'Second message',
user: 'user123',
});
const messages = await getMessages({ conversationId });
expect(messages).toHaveLength(2);
expect(messages[0].text).toBe('First message');
expect(messages[1].text).toBe('Second message');
}); });
}); });
describe('deleteMessages', () => { describe('deleteMessages', () => {
it('should delete messages with the correct filter', async () => { it('should delete messages with the correct filter', async () => {
// Save some messages for different users
await saveMessage(mockReq, mockMessageData);
await saveMessage(
{ user: { id: 'user456' } },
{
messageId: 'msg456',
conversationId: uuidv4(),
text: 'Other user message',
user: 'user456',
},
);
await deleteMessages({ user: 'user123' }); await deleteMessages({ user: 'user123' });
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
// Verify only user123's messages were deleted
const user123Messages = await Message.find({ user: 'user123' });
const user456Messages = await Message.find({ user: 'user456' });
expect(user123Messages).toHaveLength(0);
expect(user456Messages).toHaveLength(1);
}); });
}); });
describe('Conversation Hijacking Prevention', () => { describe('Conversation Hijacking Prevention', () => {
it("should not allow editing a message in another user's conversation", async () => { it('should not allow editing a message in another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } }; const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4(); const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123'; const victimMessageId = 'victim-msg-123';
// First, save a message as the victim (but we'll try to edit as attacker) mockSchema.findOneAndUpdate.mockResolvedValue(null);
const victimReq = { user: { id: 'victim123' } };
await saveMessage(victimReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
text: 'Victim message',
user: 'victim123',
});
// Attacker tries to edit the victim's message
await expect( await expect(
updateMessage(attackerReq, { updateMessage(attackerReq, {
messageId: victimMessageId, messageId: victimMessageId,
@@ -236,82 +168,71 @@ describe('Message Operations', () => {
}), }),
).rejects.toThrow('Message not found or user not authorized.'); ).rejects.toThrow('Message not found or user not authorized.');
// Verify the original message is unchanged expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
const originalMessage = await Message.findOne({ { messageId: victimMessageId, user: 'attacker123' },
messageId: victimMessageId, expect.anything(),
user: 'victim123', expect.anything(),
}); );
expect(originalMessage.text).toBe('Victim message');
}); });
it("should not allow deleting messages from another user's conversation", async () => { it('should not allow deleting messages from another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } }; const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4(); const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123'; const victimMessageId = 'victim-msg-123';
// Save a message as the victim mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
const victimReq = { user: { id: 'victim123' } };
await saveMessage(victimReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
text: 'Victim message',
user: 'victim123',
});
// Attacker tries to delete from victim's conversation
const result = await deleteMessagesSince(attackerReq, { const result = await deleteMessagesSince(attackerReq, {
messageId: victimMessageId, messageId: victimMessageId,
conversationId: victimConversationId, conversationId: victimConversationId,
}); });
expect(result).toBeUndefined(); expect(result).toBeUndefined();
expect(mockSchema.findOne).toHaveBeenCalledWith({
// Verify the victim's message still exists
const victimMessage = await Message.findOne({
messageId: victimMessageId, messageId: victimMessageId,
user: 'victim123',
});
expect(victimMessage).toBeTruthy();
expect(victimMessage.text).toBe('Victim message');
});
it("should not allow inserting a new message into another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4();
// Attacker tries to save a message - this should succeed but with attacker's user ID
const result = await saveMessage(attackerReq, {
conversationId: victimConversationId,
text: 'Inserted malicious message',
messageId: 'new-msg-123',
user: 'attacker123', user: 'attacker123',
}); });
});
expect(result).toBeTruthy(); it('should not allow inserting a new message into another user\'s conversation', async () => {
expect(result.user).toBe('attacker123'); const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4(); // Use a valid UUID
// Verify the message was saved with the attacker's user ID, not as an anonymous message await expect(
const savedMessage = await Message.findOne({ messageId: 'new-msg-123' }); saveMessage(attackerReq, {
expect(savedMessage.user).toBe('attacker123'); conversationId: victimConversationId,
expect(savedMessage.conversationId).toBe(victimConversationId); text: 'Inserted malicious message',
messageId: 'new-msg-123',
}),
).resolves.not.toThrow(); // It should not throw an error
// Check that the message was saved with the attacker's user ID
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'new-msg-123', user: 'attacker123' },
expect.objectContaining({
user: 'attacker123',
conversationId: victimConversationId,
}),
expect.anything(),
);
}); });
it('should allow retrieving messages from any conversation', async () => { it('should allow retrieving messages from any conversation', async () => {
const victimConversationId = uuidv4(); const victimConversationId = 'victim-convo-123';
// Save a message in the victim's conversation await getMessages({ conversationId: victimConversationId });
const victimReq = { user: { id: 'victim123' } };
await saveMessage(victimReq, { expect(mockSchema.find).toHaveBeenCalledWith({
messageId: 'victim-msg',
conversationId: victimConversationId, conversationId: victimConversationId,
text: 'Victim message',
user: 'victim123',
}); });
// Anyone should be able to retrieve messages by conversation ID mockSchema.find.mockReturnValueOnce({
const messages = await getMessages({ conversationId: victimConversationId }); select: jest.fn().mockReturnThis(),
expect(messages).toHaveLength(1); sort: jest.fn().mockReturnThis(),
expect(messages[0].text).toBe('Victim message'); lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
});
const result = await getMessages({ conversationId: victimConversationId });
expect(result).toEqual([{ text: 'Test message' }]);
}); });
}); });
}); });

View File

@@ -1,5 +1,5 @@
const { logger } = require('@librechat/data-schemas'); const Preset = require('./schema/presetSchema');
const { Preset } = require('~/db/models'); const { logger } = require('~/config');
const getPreset = async (user, presetId) => { const getPreset = async (user, presetId) => {
try { try {
@@ -11,6 +11,7 @@ const getPreset = async (user, presetId) => {
}; };
module.exports = { module.exports = {
Preset,
getPreset, getPreset,
getPresets: async (user, filter) => { getPresets: async (user, filter) => {
try { try {

View File

@@ -1,5 +1,8 @@
const { model } = require('mongoose');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { Project } = require('~/db/models'); const { projectSchema } = require('@librechat/data-schemas');
const Project = model('Project', projectSchema);
/** /**
* Retrieve a project by ID and convert the found project document to a plain object. * Retrieve a project by ID and convert the found project document to a plain object.

View File

@@ -1,5 +1,5 @@
const mongoose = require('mongoose');
const { ObjectId } = require('mongodb'); const { ObjectId } = require('mongodb');
const { logger } = require('@librechat/data-schemas');
const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider');
const { const {
getProjectByName, getProjectByName,
@@ -7,8 +7,12 @@ const {
removeGroupIdsFromProject, removeGroupIdsFromProject,
removeGroupFromAllProjects, removeGroupFromAllProjects,
} = require('./Project'); } = require('./Project');
const { PromptGroup, Prompt } = require('~/db/models'); const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas');
const { escapeRegExp } = require('~/server/utils'); const { escapeRegExp } = require('~/server/utils');
const { logger } = require('~/config');
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
const Prompt = mongoose.model('Prompt', promptSchema);
/** /**
* Create a pipeline for the aggregation to get prompt groups * Create a pipeline for the aggregation to get prompt groups

View File

@@ -1,3 +1,4 @@
const mongoose = require('mongoose');
const { const {
CacheKeys, CacheKeys,
SystemRoles, SystemRoles,
@@ -6,9 +7,11 @@ const {
permissionsSchema, permissionsSchema,
removeNullishValues, removeNullishValues,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { Role } = require('~/db/models'); const { roleSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const Role = mongoose.model('Role', roleSchema);
/** /**
* Retrieve a role by name and convert the found role document to a plain object. * Retrieve a role by name and convert the found role document to a plain object.
@@ -170,6 +173,35 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
} }
} }
/**
* Initialize default roles in the system.
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
* Updates existing roles with new permission types if they're missing.
*
* @returns {Promise<void>}
*/
const initializeRoles = async function () {
for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) {
let role = await Role.findOne({ name: roleName });
const defaultPerms = roleDefaults[roleName].permissions;
if (!role) {
// Create new role if it doesn't exist.
role = new Role(roleDefaults[roleName]);
} else {
// Ensure role.permissions is defined.
role.permissions = role.permissions || {};
// For each permission type in defaults, add it if missing.
for (const permType of Object.keys(defaultPerms)) {
if (role.permissions[permType] == null) {
role.permissions[permType] = defaultPerms[permType];
}
}
}
await role.save();
}
};
/** /**
* Migrates roles from old schema to new schema structure. * Migrates roles from old schema to new schema structure.
* This can be called directly to fix existing roles. * This can be called directly to fix existing roles.
@@ -250,8 +282,10 @@ const migrateRoleSchema = async function (roleName) {
}; };
module.exports = { module.exports = {
Role,
getRoleByName, getRoleByName,
initializeRoles,
updateRoleByName, updateRoleByName,
migrateRoleSchema,
updateAccessPermissions, updateAccessPermissions,
migrateRoleSchema,
}; };

View File

@@ -6,10 +6,8 @@ const {
roleDefaults, roleDefaults,
PermissionTypes, PermissionTypes,
} = require('librechat-data-provider'); } = require('librechat-data-provider');
const { getRoleByName, updateAccessPermissions } = require('~/models/Role'); const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
const getLogStores = require('~/cache/getLogStores'); const getLogStores = require('~/cache/getLogStores');
const { initializeRoles } = require('~/models');
const { Role } = require('~/db/models');
// Mock the cache // Mock the cache
jest.mock('~/cache/getLogStores', () => jest.mock('~/cache/getLogStores', () =>

275
api/models/Session.js Normal file
View File

@@ -0,0 +1,275 @@
const mongoose = require('mongoose');
const signPayload = require('~/server/services/signPayload');
const { hashToken } = require('~/server/utils/crypto');
const { sessionSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const Session = mongoose.model('Session', sessionSchema);
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
/**
* Error class for Session-related errors
*/
class SessionError extends Error {
constructor(message, code = 'SESSION_ERROR') {
super(message);
this.name = 'SessionError';
this.code = code;
}
}
/**
* Creates a new session for a user
* @param {string} userId - The ID of the user
* @param {Object} options - Additional options for session creation
* @param {Date} options.expiration - Custom expiration date
* @returns {Promise<{session: Session, refreshToken: string}>}
* @throws {SessionError}
*/
const createSession = async (userId, options = {}) => {
if (!userId) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
try {
const session = new Session({
user: userId,
expiration: options.expiration || new Date(Date.now() + expires),
});
const refreshToken = await generateRefreshToken(session);
return { session, refreshToken };
} catch (error) {
logger.error('[createSession] Error creating session:', error);
throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED');
}
};
/**
* Finds a session by various parameters
* @param {Object} params - Search parameters
* @param {string} [params.refreshToken] - The refresh token to search by
* @param {string} [params.userId] - The user ID to search by
* @param {string} [params.sessionId] - The session ID to search by
* @param {Object} [options] - Additional options
* @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents
* @returns {Promise<Session|null>}
* @throws {SessionError}
*/
const findSession = async (params, options = { lean: true }) => {
try {
const query = {};
if (!params.refreshToken && !params.userId && !params.sessionId) {
throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS');
}
if (params.refreshToken) {
const tokenHash = await hashToken(params.refreshToken);
query.refreshTokenHash = tokenHash;
}
if (params.userId) {
query.user = params.userId;
}
if (params.sessionId) {
const sessionId = params.sessionId.sessionId || params.sessionId;
if (!mongoose.Types.ObjectId.isValid(sessionId)) {
throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID');
}
query._id = sessionId;
}
// Add expiration check to only return valid sessions
query.expiration = { $gt: new Date() };
const sessionQuery = Session.findOne(query);
if (options.lean) {
return await sessionQuery.lean();
}
return await sessionQuery.exec();
} catch (error) {
logger.error('[findSession] Error finding session:', error);
throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED');
}
};
/**
* Updates session expiration
* @param {Session|string} session - The session or session ID to update
* @param {Date} [newExpiration] - Optional new expiration date
* @returns {Promise<Session>}
* @throws {SessionError}
*/
const updateExpiration = async (session, newExpiration) => {
try {
const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session;
if (!sessionDoc) {
throw new SessionError('Session not found', 'SESSION_NOT_FOUND');
}
sessionDoc.expiration = newExpiration || new Date(Date.now() + expires);
return await sessionDoc.save();
} catch (error) {
logger.error('[updateExpiration] Error updating session:', error);
throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED');
}
};
/**
* Deletes a session by refresh token or session ID
* @param {Object} params - Delete parameters
* @param {string} [params.refreshToken] - The refresh token of the session to delete
* @param {string} [params.sessionId] - The ID of the session to delete
* @returns {Promise<Object>}
* @throws {SessionError}
*/
const deleteSession = async (params) => {
try {
if (!params.refreshToken && !params.sessionId) {
throw new SessionError(
'Either refreshToken or sessionId is required',
'INVALID_DELETE_PARAMS',
);
}
const query = {};
if (params.refreshToken) {
query.refreshTokenHash = await hashToken(params.refreshToken);
}
if (params.sessionId) {
query._id = params.sessionId;
}
const result = await Session.deleteOne(query);
if (result.deletedCount === 0) {
logger.warn('[deleteSession] No session found to delete');
}
return result;
} catch (error) {
logger.error('[deleteSession] Error deleting session:', error);
throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED');
}
};
/**
* Deletes all sessions for a user
* @param {string} userId - The ID of the user
* @param {Object} [options] - Additional options
* @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session
* @param {string} [options.currentSessionId] - The ID of the current session to exclude
* @returns {Promise<Object>}
* @throws {SessionError}
*/
const deleteAllUserSessions = async (userId, options = {}) => {
try {
if (!userId) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
// Extract userId if it's passed as an object
const userIdString = userId.userId || userId;
if (!mongoose.Types.ObjectId.isValid(userIdString)) {
throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT');
}
const query = { user: userIdString };
if (options.excludeCurrentSession && options.currentSessionId) {
query._id = { $ne: options.currentSessionId };
}
const result = await Session.deleteMany(query);
if (result.deletedCount > 0) {
logger.debug(
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`,
);
}
return result;
} catch (error) {
logger.error('[deleteAllUserSessions] Error deleting user sessions:', error);
throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED');
}
};
/**
* Generates a refresh token for a session
* @param {Session} session - The session to generate a token for
* @returns {Promise<string>}
* @throws {SessionError}
*/
const generateRefreshToken = async (session) => {
if (!session || !session.user) {
throw new SessionError('Invalid session object', 'INVALID_SESSION');
}
try {
const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires;
if (!session.expiration) {
session.expiration = new Date(expiresIn);
}
const refreshToken = await signPayload({
payload: {
id: session.user,
sessionId: session._id,
},
secret: process.env.JWT_REFRESH_SECRET,
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
});
session.refreshTokenHash = await hashToken(refreshToken);
await session.save();
return refreshToken;
} catch (error) {
logger.error('[generateRefreshToken] Error generating refresh token:', error);
throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED');
}
};
/**
* Counts active sessions for a user
* @param {string} userId - The ID of the user
* @returns {Promise<number>}
* @throws {SessionError}
*/
const countActiveSessions = async (userId) => {
try {
if (!userId) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
return await Session.countDocuments({
user: userId,
expiration: { $gt: new Date() },
});
} catch (error) {
logger.error('[countActiveSessions] Error counting active sessions:', error);
throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED');
}
};
module.exports = {
createSession,
findSession,
updateExpiration,
deleteSession,
deleteAllUserSessions,
generateRefreshToken,
countActiveSessions,
SessionError,
};

351
api/models/Share.js Normal file
View File

@@ -0,0 +1,351 @@
const mongoose = require('mongoose');
const { nanoid } = require('nanoid');
const { Constants } = require('librechat-data-provider');
const { Conversation } = require('~/models/Conversation');
const { shareSchema } = require('@librechat/data-schemas');
const SharedLink = mongoose.model('SharedLink', shareSchema);
const { getMessages } = require('./Message');
const logger = require('~/config/winston');
class ShareServiceError extends Error {
constructor(message, code) {
super(message);
this.name = 'ShareServiceError';
this.code = code;
}
}
const memoizedAnonymizeId = (prefix) => {
const memo = new Map();
return (id) => {
if (!memo.has(id)) {
memo.set(id, `${prefix}_${nanoid()}`);
}
return memo.get(id);
};
};
const anonymizeConvoId = memoizedAnonymizeId('convo');
const anonymizeAssistantId = memoizedAnonymizeId('a');
const anonymizeMessageId = (id) =>
id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id);
function anonymizeConvo(conversation) {
if (!conversation) {
return null;
}
const newConvo = { ...conversation };
if (newConvo.assistant_id) {
newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id);
}
return newConvo;
}
function anonymizeMessages(messages, newConvoId) {
if (!Array.isArray(messages)) {
return [];
}
const idMap = new Map();
return messages.map((message) => {
const newMessageId = anonymizeMessageId(message.messageId);
idMap.set(message.messageId, newMessageId);
const anonymizedAttachments = message.attachments?.map((attachment) => {
return {
...attachment,
messageId: newMessageId,
conversationId: newConvoId,
};
});
return {
...message,
messageId: newMessageId,
parentMessageId:
idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId),
conversationId: newConvoId,
model: message.model?.startsWith('asst_')
? anonymizeAssistantId(message.model)
: message.model,
attachments: anonymizedAttachments,
};
});
}
async function getSharedMessages(shareId) {
try {
const share = await SharedLink.findOne({ shareId, isPublic: true })
.populate({
path: 'messages',
select: '-_id -__v -user',
})
.select('-_id -__v -user')
.lean();
if (!share?.conversationId || !share.isPublic) {
return null;
}
const newConvoId = anonymizeConvoId(share.conversationId);
const result = {
...share,
conversationId: newConvoId,
messages: anonymizeMessages(share.messages, newConvoId),
};
return result;
} catch (error) {
logger.error('[getShare] Error getting share link', {
error: error.message,
shareId,
});
throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR');
}
}
async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortDirection, search) {
try {
const query = { user, isPublic };
if (pageParam) {
if (sortDirection === 'desc') {
query[sortBy] = { $lt: pageParam };
} else {
query[sortBy] = { $gt: pageParam };
}
}
if (search && search.trim()) {
try {
const searchResults = await Conversation.meiliSearch(search);
if (!searchResults?.hits?.length) {
return {
links: [],
nextCursor: undefined,
hasNextPage: false,
};
}
const conversationIds = searchResults.hits.map((hit) => hit.conversationId);
query['conversationId'] = { $in: conversationIds };
} catch (searchError) {
logger.error('[getSharedLinks] Meilisearch error', {
error: searchError.message,
user,
});
return {
links: [],
nextCursor: undefined,
hasNextPage: false,
};
}
}
const sort = {};
sort[sortBy] = sortDirection === 'desc' ? -1 : 1;
if (Array.isArray(query.conversationId)) {
query.conversationId = { $in: query.conversationId };
}
const sharedLinks = await SharedLink.find(query)
.sort(sort)
.limit(pageSize + 1)
.select('-__v -user')
.lean();
const hasNextPage = sharedLinks.length > pageSize;
const links = sharedLinks.slice(0, pageSize);
const nextCursor = hasNextPage ? links[links.length - 1][sortBy] : undefined;
return {
links: links.map((link) => ({
shareId: link.shareId,
title: link?.title || 'Untitled',
isPublic: link.isPublic,
createdAt: link.createdAt,
conversationId: link.conversationId,
})),
nextCursor,
hasNextPage,
};
} catch (error) {
logger.error('[getSharedLinks] Error getting shares', {
error: error.message,
user,
});
throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR');
}
}
async function deleteAllSharedLinks(user) {
try {
const result = await SharedLink.deleteMany({ user });
return {
message: 'All shared links deleted successfully',
deletedCount: result.deletedCount,
};
} catch (error) {
logger.error('[deleteAllSharedLinks] Error deleting shared links', {
error: error.message,
user,
});
throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR');
}
}
async function createSharedLink(user, conversationId) {
if (!user || !conversationId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const [existingShare, conversationMessages] = await Promise.all([
SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(),
getMessages({ conversationId }),
]);
if (existingShare && existingShare.isPublic) {
throw new ShareServiceError('Share already exists', 'SHARE_EXISTS');
} else if (existingShare) {
await SharedLink.deleteOne({ conversationId });
}
const conversation = await Conversation.findOne({ conversationId }).lean();
const title = conversation?.title || 'Untitled';
const shareId = nanoid();
await SharedLink.create({
shareId,
conversationId,
messages: conversationMessages,
title,
user,
});
return { shareId, conversationId };
} catch (error) {
logger.error('[createSharedLink] Error creating shared link', {
error: error.message,
user,
conversationId,
});
throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR');
}
}
async function getSharedLink(user, conversationId) {
if (!user || !conversationId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const share = await SharedLink.findOne({ conversationId, user, isPublic: true })
.select('shareId -_id')
.lean();
if (!share) {
return { shareId: null, success: false };
}
return { shareId: share.shareId, success: true };
} catch (error) {
logger.error('[getSharedLink] Error getting shared link', {
error: error.message,
user,
conversationId,
});
throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR');
}
}
async function updateSharedLink(user, shareId) {
if (!user || !shareId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean();
if (!share) {
throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND');
}
const [updatedMessages] = await Promise.all([
getMessages({ conversationId: share.conversationId }),
]);
const newShareId = nanoid();
const update = {
messages: updatedMessages,
user,
shareId: newShareId,
};
const updatedShare = await SharedLink.findOneAndUpdate({ shareId, user }, update, {
new: true,
upsert: false,
runValidators: true,
}).lean();
if (!updatedShare) {
throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR');
}
anonymizeConvo(updatedShare);
return { shareId: newShareId, conversationId: updatedShare.conversationId };
} catch (error) {
logger.error('[updateSharedLink] Error updating shared link', {
error: error.message,
user,
shareId,
});
throw new ShareServiceError(
error.code === 'SHARE_UPDATE_ERROR' ? error.message : 'Error updating shared link',
error.code || 'SHARE_UPDATE_ERROR',
);
}
}
async function deleteSharedLink(user, shareId) {
if (!user || !shareId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const result = await SharedLink.findOneAndDelete({ shareId, user }).lean();
if (!result) {
return null;
}
return {
success: true,
shareId,
message: 'Share deleted successfully',
};
} catch (error) {
logger.error('[deleteSharedLink] Error deleting shared link', {
error: error.message,
user,
shareId,
});
throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR');
}
}
module.exports = {
SharedLink,
getSharedLink,
getSharedLinks,
createSharedLink,
updateSharedLink,
deleteSharedLink,
getSharedMessages,
deleteAllSharedLinks,
};

199
api/models/Token.js Normal file
View File

@@ -0,0 +1,199 @@
const mongoose = require('mongoose');
const { encryptV2 } = require('~/server/utils/crypto');
const { tokenSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
/**
* Token model.
* @type {mongoose.Model}
*/
const Token = mongoose.model('Token', tokenSchema);
/**
* Fixes the indexes for the Token collection from legacy TTL indexes to the new expiresAt index.
*/
async function fixIndexes() {
try {
if (
process.env.NODE_ENV === 'CI' ||
process.env.NODE_ENV === 'development' ||
process.env.NODE_ENV === 'test'
) {
return;
}
const indexes = await Token.collection.indexes();
logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2));
const unwantedTTLIndexes = indexes.filter(
(index) => index.key.createdAt === 1 && index.expireAfterSeconds !== undefined,
);
if (unwantedTTLIndexes.length === 0) {
logger.debug('No unwanted Token indexes found.');
return;
}
for (const index of unwantedTTLIndexes) {
logger.debug(`Dropping unwanted Token index: ${index.name}`);
await Token.collection.dropIndex(index.name);
logger.debug(`Dropped Token index: ${index.name}`);
}
logger.debug('Token index cleanup completed successfully.');
} catch (error) {
logger.error('An error occurred while fixing Token indexes:', error);
}
}
fixIndexes();
/**
* Creates a new Token instance.
* @param {Object} tokenData - The data for the new Token.
* @param {mongoose.Types.ObjectId} tokenData.userId - The user's ID. It is required.
* @param {String} tokenData.email - The user's email.
* @param {String} tokenData.token - The token. It is required.
* @param {Number} tokenData.expiresIn - The number of seconds until the token expires.
* @returns {Promise<mongoose.Document>} The new Token instance.
* @throws Will throw an error if token creation fails.
*/
async function createToken(tokenData) {
try {
const currentTime = new Date();
const expiresAt = new Date(currentTime.getTime() + tokenData.expiresIn * 1000);
const newTokenData = {
...tokenData,
createdAt: currentTime,
expiresAt,
};
return await Token.create(newTokenData);
} catch (error) {
logger.debug('An error occurred while creating token:', error);
throw error;
}
}
/**
* Finds a Token document that matches the provided query.
* @param {Object} query - The query to match against.
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
* @param {String} query.token - The token value.
* @param {String} [query.email] - The email of the user.
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
* @returns {Promise<Object|null>} The matched Token document, or null if not found.
* @throws Will throw an error if the find operation fails.
*/
async function findToken(query) {
try {
const conditions = [];
if (query.userId) {
conditions.push({ userId: query.userId });
}
if (query.token) {
conditions.push({ token: query.token });
}
if (query.email) {
conditions.push({ email: query.email });
}
if (query.identifier) {
conditions.push({ identifier: query.identifier });
}
const token = await Token.findOne({
$and: conditions,
}).lean();
return token;
} catch (error) {
logger.debug('An error occurred while finding token:', error);
throw error;
}
}
/**
* Updates a Token document that matches the provided query.
* @param {Object} query - The query to match against.
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
* @param {String} query.token - The token value.
* @param {String} [query.email] - The email of the user.
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
* @param {Object} updateData - The data to update the Token with.
* @returns {Promise<mongoose.Document|null>} The updated Token document, or null if not found.
* @throws Will throw an error if the update operation fails.
*/
async function updateToken(query, updateData) {
try {
return await Token.findOneAndUpdate(query, updateData, { new: true });
} catch (error) {
logger.debug('An error occurred while updating token:', error);
throw error;
}
}
/**
* Deletes all Token documents that match the provided token, user ID, or email.
* @param {Object} query - The query to match against.
* @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user.
* @param {String} query.token - The token value.
* @param {String} [query.email] - The email of the user.
* @param {String} [query.identifier] - Unique, alternative identifier for the token.
* @returns {Promise<Object>} The result of the delete operation.
* @throws Will throw an error if the delete operation fails.
*/
async function deleteTokens(query) {
try {
return await Token.deleteMany({
$or: [
{ userId: query.userId },
{ token: query.token },
{ email: query.email },
{ identifier: query.identifier },
],
});
} catch (error) {
logger.debug('An error occurred while deleting tokens:', error);
throw error;
}
}
/**
* Handles the OAuth token by creating or updating the token.
* @param {object} fields
* @param {string} fields.userId - The user's ID.
* @param {string} fields.token - The full token to store.
* @param {string} fields.identifier - Unique, alternative identifier for the token.
* @param {number} fields.expiresIn - The number of seconds until the token expires.
* @param {object} fields.metadata - Additional metadata to store with the token.
* @param {string} [fields.type="oauth"] - The type of token. Default is 'oauth'.
*/
async function handleOAuthToken({
token,
userId,
identifier,
expiresIn,
metadata,
type = 'oauth',
}) {
const encrypedToken = await encryptV2(token);
const tokenData = {
type,
userId,
metadata,
identifier,
token: encrypedToken,
expiresIn: parseInt(expiresIn, 10) || 3600,
};
const existingToken = await findToken({ userId, identifier });
if (existingToken) {
return await updateToken({ identifier }, tokenData);
} else {
return await createToken(tokenData);
}
}
module.exports = {
findToken,
createToken,
updateToken,
deleteTokens,
handleOAuthToken,
};

View File

@@ -1,4 +1,6 @@
const { ToolCall } = require('~/db/models'); const mongoose = require('mongoose');
const { toolCallSchema } = require('@librechat/data-schemas');
const ToolCall = mongoose.model('ToolCall', toolCallSchema);
/** /**
* Create a new tool call * Create a new tool call

View File

@@ -1,7 +1,9 @@
const { logger } = require('@librechat/data-schemas'); const mongoose = require('mongoose');
const { transactionSchema } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config'); const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx'); const { getMultiplier, getCacheMultiplier } = require('./tx');
const { Transaction, Balance } = require('~/db/models'); const { logger } = require('~/config');
const Balance = require('./Balance');
const cancelRate = 1.15; const cancelRate = 1.15;
@@ -138,19 +140,19 @@ const updateBalance = async ({ user, incrementValue, setValues }) => {
}; };
/** Method to calculate and set the tokenValue for a transaction */ /** Method to calculate and set the tokenValue for a transaction */
function calculateTokenValue(txn) { transactionSchema.methods.calculateTokenValue = function () {
if (!txn.valueKey || !txn.tokenType) { if (!this.valueKey || !this.tokenType) {
txn.tokenValue = txn.rawAmount; this.tokenValue = this.rawAmount;
} }
const { valueKey, tokenType, model, endpointTokenConfig } = txn; const { valueKey, tokenType, model, endpointTokenConfig } = this;
const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig })); const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig }));
txn.rate = multiplier; this.rate = multiplier;
txn.tokenValue = txn.rawAmount * multiplier; this.tokenValue = this.rawAmount * multiplier;
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
txn.rate *= cancelRate; this.rate *= cancelRate;
} }
} };
/** /**
* New static method to create an auto-refill transaction that does NOT trigger a balance update. * New static method to create an auto-refill transaction that does NOT trigger a balance update.
@@ -161,13 +163,13 @@ function calculateTokenValue(txn) {
* @param {number} txData.rawAmount - The raw amount of tokens. * @param {number} txData.rawAmount - The raw amount of tokens.
* @returns {Promise<object>} - The created transaction. * @returns {Promise<object>} - The created transaction.
*/ */
async function createAutoRefillTransaction(txData) { transactionSchema.statics.createAutoRefillTransaction = async function (txData) {
if (txData.rawAmount != null && isNaN(txData.rawAmount)) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return; return;
} }
const transaction = new Transaction(txData); const transaction = new this(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig; transaction.endpointTokenConfig = txData.endpointTokenConfig;
calculateTokenValue(transaction); transaction.calculateTokenValue();
await transaction.save(); await transaction.save();
const balanceResponse = await updateBalance({ const balanceResponse = await updateBalance({
@@ -183,20 +185,21 @@ async function createAutoRefillTransaction(txData) {
logger.debug('[Balance.check] Auto-refill performed', result); logger.debug('[Balance.check] Auto-refill performed', result);
result.transaction = transaction; result.transaction = transaction;
return result; return result;
} };
/** /**
* Static method to create a transaction and update the balance * Static method to create a transaction and update the balance
* @param {txData} txData - Transaction data. * @param {txData} txData - Transaction data.
*/ */
async function createTransaction(txData) { transactionSchema.statics.create = async function (txData) {
const Transaction = this;
if (txData.rawAmount != null && isNaN(txData.rawAmount)) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return; return;
} }
const transaction = new Transaction(txData); const transaction = new Transaction(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig; transaction.endpointTokenConfig = txData.endpointTokenConfig;
calculateTokenValue(transaction); transaction.calculateTokenValue();
await transaction.save(); await transaction.save();
@@ -206,6 +209,7 @@ async function createTransaction(txData) {
} }
let incrementValue = transaction.tokenValue; let incrementValue = transaction.tokenValue;
const balanceResponse = await updateBalance({ const balanceResponse = await updateBalance({
user: transaction.user, user: transaction.user,
incrementValue, incrementValue,
@@ -217,19 +221,21 @@ async function createTransaction(txData) {
balance: balanceResponse.tokenCredits, balance: balanceResponse.tokenCredits,
[transaction.tokenType]: incrementValue, [transaction.tokenType]: incrementValue,
}; };
} };
/** /**
* Static method to create a structured transaction and update the balance * Static method to create a structured transaction and update the balance
* @param {txData} txData - Transaction data. * @param {txData} txData - Transaction data.
*/ */
async function createStructuredTransaction(txData) { transactionSchema.statics.createStructured = async function (txData) {
const Transaction = this;
const transaction = new Transaction({ const transaction = new Transaction({
...txData, ...txData,
endpointTokenConfig: txData.endpointTokenConfig, endpointTokenConfig: txData.endpointTokenConfig,
}); });
calculateStructuredTokenValue(transaction); transaction.calculateStructuredTokenValue();
await transaction.save(); await transaction.save();
@@ -251,69 +257,71 @@ async function createStructuredTransaction(txData) {
balance: balanceResponse.tokenCredits, balance: balanceResponse.tokenCredits,
[transaction.tokenType]: incrementValue, [transaction.tokenType]: incrementValue,
}; };
} };
/** Method to calculate token value for structured tokens */ /** Method to calculate token value for structured tokens */
function calculateStructuredTokenValue(txn) { transactionSchema.methods.calculateStructuredTokenValue = function () {
if (!txn.tokenType) { if (!this.tokenType) {
txn.tokenValue = txn.rawAmount; this.tokenValue = this.rawAmount;
return; return;
} }
const { model, endpointTokenConfig } = txn; const { model, endpointTokenConfig } = this;
if (txn.tokenType === 'prompt') { if (this.tokenType === 'prompt') {
const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig }); const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig });
const writeMultiplier = const writeMultiplier =
getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier; getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier;
const readMultiplier = const readMultiplier =
getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier; getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier;
txn.rateDetail = { this.rateDetail = {
input: inputMultiplier, input: inputMultiplier,
write: writeMultiplier, write: writeMultiplier,
read: readMultiplier, read: readMultiplier,
}; };
const totalPromptTokens = const totalPromptTokens =
Math.abs(txn.inputTokens || 0) + Math.abs(this.inputTokens || 0) +
Math.abs(txn.writeTokens || 0) + Math.abs(this.writeTokens || 0) +
Math.abs(txn.readTokens || 0); Math.abs(this.readTokens || 0);
if (totalPromptTokens > 0) { if (totalPromptTokens > 0) {
txn.rate = this.rate =
(Math.abs(inputMultiplier * (txn.inputTokens || 0)) + (Math.abs(inputMultiplier * (this.inputTokens || 0)) +
Math.abs(writeMultiplier * (txn.writeTokens || 0)) + Math.abs(writeMultiplier * (this.writeTokens || 0)) +
Math.abs(readMultiplier * (txn.readTokens || 0))) / Math.abs(readMultiplier * (this.readTokens || 0))) /
totalPromptTokens; totalPromptTokens;
} else { } else {
txn.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens this.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens
} }
txn.tokenValue = -( this.tokenValue = -(
Math.abs(txn.inputTokens || 0) * inputMultiplier + Math.abs(this.inputTokens || 0) * inputMultiplier +
Math.abs(txn.writeTokens || 0) * writeMultiplier + Math.abs(this.writeTokens || 0) * writeMultiplier +
Math.abs(txn.readTokens || 0) * readMultiplier Math.abs(this.readTokens || 0) * readMultiplier
); );
txn.rawAmount = -totalPromptTokens; this.rawAmount = -totalPromptTokens;
} else if (txn.tokenType === 'completion') { } else if (this.tokenType === 'completion') {
const multiplier = getMultiplier({ tokenType: txn.tokenType, model, endpointTokenConfig }); const multiplier = getMultiplier({ tokenType: this.tokenType, model, endpointTokenConfig });
txn.rate = Math.abs(multiplier); this.rate = Math.abs(multiplier);
txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier; this.tokenValue = -Math.abs(this.rawAmount) * multiplier;
txn.rawAmount = -Math.abs(txn.rawAmount); this.rawAmount = -Math.abs(this.rawAmount);
} }
if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
txn.rate *= cancelRate; this.rate *= cancelRate;
if (txn.rateDetail) { if (this.rateDetail) {
txn.rateDetail = Object.fromEntries( this.rateDetail = Object.fromEntries(
Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]), Object.entries(this.rateDetail).map(([k, v]) => [k, v * cancelRate]),
); );
} }
} }
} };
const Transaction = mongoose.model('Transaction', transactionSchema);
/** /**
* Queries and retrieves transactions based on a given filter. * Queries and retrieves transactions based on a given filter.
@@ -332,9 +340,4 @@ async function getTransactions(filter) {
} }
} }
module.exports = { module.exports = { Transaction, getTransactions };
getTransactions,
createTransaction,
createAutoRefillTransaction,
createStructuredTransaction,
};

View File

@@ -3,13 +3,14 @@ const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getBalanceConfig } = require('~/server/services/Config'); const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx'); const { getMultiplier, getCacheMultiplier } = require('./tx');
const { createTransaction } = require('./Transaction'); const { Transaction } = require('./Transaction');
const { Balance } = require('~/db/models'); const Balance = require('./Balance');
// Mock the custom config module so we can control the balance flag. // Mock the custom config module so we can control the balance flag.
jest.mock('~/server/services/Config'); jest.mock('~/server/services/Config');
let mongoServer; let mongoServer;
beforeAll(async () => { beforeAll(async () => {
mongoServer = await MongoMemoryServer.create(); mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri(); const mongoUri = mongoServer.getUri();
@@ -367,7 +368,7 @@ describe('NaN Handling Tests', () => {
}; };
// Act // Act
const result = await createTransaction(txData); const result = await Transaction.create(txData);
// Assert: No transaction should be created and balance remains unchanged. // Assert: No transaction should be created and balance remains unchanged.
expect(result).toBeUndefined(); expect(result).toBeUndefined();

6
api/models/User.js Normal file
View File

@@ -0,0 +1,6 @@
const mongoose = require('mongoose');
const { userSchema } = require('@librechat/data-schemas');
const User = mongoose.model('User', userSchema);
module.exports = User;

View File

@@ -1,9 +1,9 @@
const { logger } = require('@librechat/data-schemas');
const { ViolationTypes } = require('librechat-data-provider'); const { ViolationTypes } = require('librechat-data-provider');
const { createAutoRefillTransaction } = require('./Transaction'); const { Transaction } = require('./Transaction');
const { logViolation } = require('~/cache'); const { logViolation } = require('~/cache');
const { getMultiplier } = require('./tx'); const { getMultiplier } = require('./tx');
const { Balance } = require('~/db/models'); const { logger } = require('~/config');
const Balance = require('./Balance');
function isInvalidDate(date) { function isInvalidDate(date) {
return isNaN(date); return isNaN(date);
@@ -60,7 +60,7 @@ const checkBalanceRecord = async function ({
) { ) {
try { try {
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */ /** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
const result = await createAutoRefillTransaction({ const result = await Transaction.createAutoRefillTransaction({
user: user, user: user,
tokenType: 'credits', tokenType: 'credits',
context: 'autoRefill', context: 'autoRefill',

View File

@@ -1,7 +1,6 @@
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server'); const { MongoMemoryServer } = require('mongodb-memory-server');
const { getMessages, bulkSaveMessages } = require('./Message'); const { Message, getMessages, bulkSaveMessages } = require('./Message');
const { Message } = require('~/db/models');
// Original version of buildTree function // Original version of buildTree function
function buildTree({ messages, fileMap }) { function buildTree({ messages, fileMap }) {
@@ -43,6 +42,7 @@ function buildTree({ messages, fileMap }) {
} }
let mongod; let mongod;
beforeAll(async () => { beforeAll(async () => {
mongod = await MongoMemoryServer.create(); mongod = await MongoMemoryServer.create();
const uri = mongod.getUri(); const uri = mongod.getUri();

View File

@@ -1,7 +1,13 @@
const mongoose = require('mongoose'); const {
const { createMethods } = require('@librechat/data-schemas'); comparePassword,
const methods = createMethods(mongoose); deleteUserById,
const { comparePassword } = require('./userMethods'); generateToken,
getUserById,
updateUser,
createUser,
countUsers,
findUser,
} = require('./userMethods');
const { const {
findFileById, findFileById,
createFile, createFile,
@@ -20,12 +26,32 @@ const {
deleteMessagesSince, deleteMessagesSince,
deleteMessages, deleteMessages,
} = require('./Message'); } = require('./Message');
const {
createSession,
findSession,
updateExpiration,
deleteSession,
deleteAllUserSessions,
generateRefreshToken,
countActiveSessions,
} = require('./Session');
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation'); const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
const { createToken, findToken, updateToken, deleteTokens } = require('./Token');
const Balance = require('./Balance');
const User = require('./User');
const Key = require('./Key');
module.exports = { module.exports = {
...methods,
comparePassword, comparePassword,
deleteUserById,
generateToken,
getUserById,
updateUser,
createUser,
countUsers,
findUser,
findFileById, findFileById,
createFile, createFile,
updateFile, updateFile,
@@ -51,4 +77,21 @@ module.exports = {
getPresets, getPresets,
savePreset, savePreset,
deletePresets, deletePresets,
createToken,
findToken,
updateToken,
deleteTokens,
createSession,
findSession,
updateExpiration,
deleteSession,
deleteAllUserSessions,
generateRefreshToken,
countActiveSessions,
User,
Key,
Balance,
}; };

View File

@@ -1,7 +1,7 @@
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const { getRandomValues } = require('@librechat/api'); const { getRandomValues, hashToken } = require('~/server/utils/crypto');
const { logger, hashToken } = require('@librechat/data-schemas'); const { createToken, findToken } = require('./Token');
const { createToken, findToken } = require('~/models'); const logger = require('~/config/winston');
/** /**
* @module inviteUser * @module inviteUser

View File

@@ -0,0 +1,475 @@
const _ = require('lodash');
const mongoose = require('mongoose');
const { MeiliSearch } = require('meilisearch');
const { parseTextParts, ContentTypes } = require('librechat-data-provider');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const logger = require('~/config/meiliLogger');
// Environment flags
/**
* Flag to indicate if search is enabled based on environment variables.
* @type {boolean}
*/
const searchEnabled = process.env.SEARCH && process.env.SEARCH.toLowerCase() === 'true';
/**
* Flag to indicate if MeiliSearch is enabled based on required environment variables.
* @type {boolean}
*/
const meiliEnabled = process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY && searchEnabled;
/**
* Validates the required options for configuring the mongoMeili plugin.
*
* @param {Object} options - The configuration options.
* @param {string} options.host - The MeiliSearch host.
* @param {string} options.apiKey - The MeiliSearch API key.
* @param {string} options.indexName - The name of the index.
* @throws {Error} Throws an error if any required option is missing.
*/
const validateOptions = function (options) {
const requiredKeys = ['host', 'apiKey', 'indexName'];
requiredKeys.forEach((key) => {
if (!options[key]) {
throw new Error(`Missing mongoMeili Option: ${key}`);
}
});
};
/**
* Factory function to create a MeiliMongooseModel class which extends a Mongoose model.
* This class contains static and instance methods to synchronize and manage the MeiliSearch index
* corresponding to the MongoDB collection.
*
* @param {Object} config - Configuration object.
* @param {Object} config.index - The MeiliSearch index object.
* @param {Array<string>} config.attributesToIndex - List of attributes to index.
* @returns {Function} A class definition that will be loaded into the Mongoose schema.
*/
const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
// The primary key is assumed to be the first attribute in the attributesToIndex array.
const primaryKey = attributesToIndex[0];
class MeiliMongooseModel {
/**
* Synchronizes the data between the MongoDB collection and the MeiliSearch index.
*
* The synchronization process involves:
* 1. Fetching all documents from the MongoDB collection and MeiliSearch index.
* 2. Comparing documents from both sources.
* 3. Deleting documents from MeiliSearch that no longer exist in MongoDB.
* 4. Adding documents to MeiliSearch that exist in MongoDB but not in the index.
* 5. Updating documents in MeiliSearch if key fields (such as `text` or `title`) differ.
* 6. Updating the `_meiliIndex` field in MongoDB to indicate the indexing status.
*
* Note: The function processes documents in batches because MeiliSearch's
* `index.getDocuments` requires an exact limit and `index.addDocuments` does not handle
* partial failures in a batch.
*
* @returns {Promise<void>} Resolves when the synchronization is complete.
*/
static async syncWithMeili() {
try {
let moreDocuments = true;
// Retrieve all MongoDB documents from the collection as plain JavaScript objects.
const mongoDocuments = await this.find().lean();
// Helper function to format a document by selecting only the attributes to index
// and omitting keys starting with '$'.
const format = (doc) =>
_.omitBy(_.pick(doc, attributesToIndex), (v, k) => k.startsWith('$'));
// Build a map of MongoDB documents for quick lookup based on the primary key.
const mongoMap = new Map(mongoDocuments.map((doc) => [doc[primaryKey], format(doc)]));
const indexMap = new Map();
let offset = 0;
const batchSize = 1000;
// Fetch documents from the MeiliSearch index in batches.
while (moreDocuments) {
const batch = await index.getDocuments({ limit: batchSize, offset });
if (batch.results.length === 0) {
moreDocuments = false;
}
for (const doc of batch.results) {
indexMap.set(doc[primaryKey], format(doc));
}
offset += batchSize;
}
logger.debug('[syncWithMeili]', { indexMap: indexMap.size, mongoMap: mongoMap.size });
const updateOps = [];
// Process documents present in the MeiliSearch index.
for (const [id, doc] of indexMap) {
const update = {};
update[primaryKey] = id;
if (mongoMap.has(id)) {
// If document exists in MongoDB, check for discrepancies in key fields.
if (
(doc.text && doc.text !== mongoMap.get(id).text) ||
(doc.title && doc.title !== mongoMap.get(id).title)
) {
logger.debug(
`[syncWithMeili] ${id} had document discrepancy in ${
doc.text ? 'text' : 'title'
} field`,
);
updateOps.push({
updateOne: { filter: update, update: { $set: { _meiliIndex: true } } },
});
await index.addDocuments([doc]);
}
} else {
// If the document does not exist in MongoDB, delete it from MeiliSearch.
await index.deleteDocument(id);
updateOps.push({
updateOne: { filter: update, update: { $set: { _meiliIndex: false } } },
});
}
}
// Process documents present in MongoDB.
for (const [id, doc] of mongoMap) {
const update = {};
update[primaryKey] = id;
// If the document is missing in the Meili index, add it.
if (!indexMap.has(id)) {
await index.addDocuments([doc]);
updateOps.push({
updateOne: { filter: update, update: { $set: { _meiliIndex: true } } },
});
} else if (doc._meiliIndex === false) {
// If the document exists but is marked as not indexed, update the flag.
updateOps.push({
updateOne: { filter: update, update: { $set: { _meiliIndex: true } } },
});
}
}
// Execute bulk update operations in MongoDB to update the _meiliIndex flags.
if (updateOps.length > 0) {
await this.collection.bulkWrite(updateOps);
logger.debug(
`[syncWithMeili] Finished indexing ${
primaryKey === 'messageId' ? 'messages' : 'conversations'
}`,
);
}
} catch (error) {
logger.error('[syncWithMeili] Error adding document to Meili', error);
}
}
/**
* Updates settings for the MeiliSearch index.
*
* @param {Object} settings - The settings to update on the MeiliSearch index.
* @returns {Promise<Object>} Promise resolving to the update result.
*/
static async setMeiliIndexSettings(settings) {
return await index.updateSettings(settings);
}
/**
* Searches the MeiliSearch index and optionally populates the results with data from MongoDB.
*
* @param {string} q - The search query.
* @param {Object} params - Additional search parameters for MeiliSearch.
* @param {boolean} populate - Whether to populate search hits with full MongoDB documents.
* @returns {Promise<Object>} The search results with populated hits if requested.
*/
static async meiliSearch(q, params, populate) {
const data = await index.search(q, params);
if (populate) {
// Build a query using the primary key values from the search hits.
const query = {};
query[primaryKey] = _.map(data.hits, (hit) => cleanUpPrimaryKeyValue(hit[primaryKey]));
// Build a projection object, including only keys that do not start with '$'.
const projection = Object.keys(this.schema.obj).reduce(
(results, key) => {
if (!key.startsWith('$')) {
results[key] = 1;
}
return results;
},
{ _id: 1, __v: 1 },
);
// Retrieve the full documents from MongoDB.
const hitsFromMongoose = await this.find(query, projection).lean();
// Merge the MongoDB documents with the search hits.
const populatedHits = data.hits.map(function (hit) {
const query = {};
query[primaryKey] = hit[primaryKey];
const originalHit = _.find(hitsFromMongoose, query);
return {
...(originalHit ?? {}),
...hit,
};
});
data.hits = populatedHits;
}
return data;
}
/**
* Preprocesses the current document for indexing.
*
* This method:
* - Picks only the defined attributes to index.
* - Omits any keys starting with '$'.
* - Replaces pipe characters ('|') in `conversationId` with '--'.
* - Extracts and concatenates text from an array of content items.
*
* @returns {Object} The preprocessed object ready for indexing.
*/
preprocessObjectForIndex() {
const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) =>
k.startsWith('$'),
);
if (object.conversationId && object.conversationId.includes('|')) {
object.conversationId = object.conversationId.replace(/\|/g, '--');
}
if (object.content && Array.isArray(object.content)) {
object.text = parseTextParts(object.content);
delete object.content;
}
return object;
}
/**
* Adds the current document to the MeiliSearch index.
*
* The method preprocesses the document, adds it to MeiliSearch, and then updates
* the MongoDB document's `_meiliIndex` flag to true.
*
* @returns {Promise<void>}
*/
async addObjectToMeili() {
const object = this.preprocessObjectForIndex();
try {
await index.addDocuments([object]);
} catch (error) {
// Error handling can be enhanced as needed.
logger.error('[addObjectToMeili] Error adding document to Meili', error);
}
await this.collection.updateMany({ _id: this._id }, { $set: { _meiliIndex: true } });
}
/**
* Updates the current document in the MeiliSearch index.
*
* @returns {Promise<void>}
*/
async updateObjectToMeili() {
const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) =>
k.startsWith('$'),
);
await index.updateDocuments([object]);
}
/**
* Deletes the current document from the MeiliSearch index.
*
* @returns {Promise<void>}
*/
async deleteObjectFromMeili() {
await index.deleteDocument(this._id);
}
/**
* Post-save hook to synchronize the document with MeiliSearch.
*
* If the document is already indexed (i.e. `_meiliIndex` is true), it updates it;
* otherwise, it adds the document to the index.
*/
postSaveHook() {
if (this._meiliIndex) {
this.updateObjectToMeili();
} else {
this.addObjectToMeili();
}
}
/**
* Post-update hook to update the document in MeiliSearch.
*
* This hook is triggered after a document update, ensuring that changes are
* propagated to the MeiliSearch index if the document is indexed.
*/
postUpdateHook() {
if (this._meiliIndex) {
this.updateObjectToMeili();
}
}
/**
* Post-remove hook to delete the document from MeiliSearch.
*
* This hook is triggered after a document is removed, ensuring that the document
* is also removed from the MeiliSearch index if it was previously indexed.
*/
postRemoveHook() {
if (this._meiliIndex) {
this.deleteObjectFromMeili();
}
}
}
return MeiliMongooseModel;
};
/**
* Mongoose plugin to synchronize MongoDB collections with a MeiliSearch index.
*
* This plugin:
* - Validates the provided options.
* - Adds a `_meiliIndex` field to the schema to track indexing status.
* - Sets up a MeiliSearch client and creates an index if it doesn't already exist.
* - Loads class methods for syncing, searching, and managing documents in MeiliSearch.
* - Registers Mongoose hooks (post-save, post-update, post-remove, etc.) to maintain index consistency.
*
* @param {mongoose.Schema} schema - The Mongoose schema to which the plugin is applied.
* @param {Object} options - Configuration options.
* @param {string} options.host - The MeiliSearch host.
* @param {string} options.apiKey - The MeiliSearch API key.
* @param {string} options.indexName - The name of the MeiliSearch index.
* @param {string} options.primaryKey - The primary key field for indexing.
*/
module.exports = function mongoMeili(schema, options) {
validateOptions(options);
// Add _meiliIndex field to the schema to track if a document has been indexed in MeiliSearch.
schema.add({
_meiliIndex: {
type: Boolean,
required: false,
select: false,
default: false,
},
});
const { host, apiKey, indexName, primaryKey } = options;
// Setup the MeiliSearch client.
const client = new MeiliSearch({ host, apiKey });
// Create the index asynchronously if it doesn't exist.
client.createIndex(indexName, { primaryKey });
// Setup the MeiliSearch index for this schema.
const index = client.index(indexName);
// Collect attributes from the schema that should be indexed.
const attributesToIndex = [
..._.reduce(
schema.obj,
function (results, value, key) {
return value.meiliIndex ? [...results, key] : results;
},
[],
),
];
// Load the class methods into the schema.
schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex }));
// Register Mongoose hooks to synchronize with MeiliSearch.
// Post-save: synchronize after a document is saved.
schema.post('save', function (doc) {
doc.postSaveHook();
});
// Post-update: synchronize after a document is updated.
schema.post('update', function (doc) {
doc.postUpdateHook();
});
// Post-remove: synchronize after a document is removed.
schema.post('remove', function (doc) {
doc.postRemoveHook();
});
// Pre-deleteMany hook: remove corresponding documents from MeiliSearch when multiple documents are deleted.
schema.pre('deleteMany', async function (next) {
if (!meiliEnabled) {
return next();
}
try {
// Check if the schema has a "messages" field to determine if it's a conversation schema.
if (Object.prototype.hasOwnProperty.call(schema.obj, 'messages')) {
const convoIndex = client.index('convos');
const deletedConvos = await mongoose.model('Conversation').find(this._conditions).lean();
const promises = deletedConvos.map((convo) =>
convoIndex.deleteDocument(convo.conversationId),
);
await Promise.all(promises);
}
// Check if the schema has a "messageId" field to determine if it's a message schema.
if (Object.prototype.hasOwnProperty.call(schema.obj, 'messageId')) {
const messageIndex = client.index('messages');
const deletedMessages = await mongoose.model('Message').find(this._conditions).lean();
const promises = deletedMessages.map((message) =>
messageIndex.deleteDocument(message.messageId),
);
await Promise.all(promises);
}
return next();
} catch (error) {
if (meiliEnabled) {
logger.error(
'[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion. Next startup may be slow due to syncing.',
error,
);
}
return next();
}
});
// Post-findOneAndUpdate hook: update MeiliSearch index after a document is updated via findOneAndUpdate.
schema.post('findOneAndUpdate', async function (doc) {
if (!meiliEnabled) {
return;
}
// If the document is unfinished, do not update the index.
if (doc.unfinished) {
return;
}
let meiliDoc;
// For conversation documents, try to fetch the document from the "convos" index.
if (doc.messages) {
try {
meiliDoc = await client.index('convos').getDocument(doc.conversationId);
} catch (error) {
logger.debug(
'[MeiliMongooseModel.findOneAndUpdate] Convo not found in MeiliSearch and will index ' +
doc.conversationId,
error,
);
}
}
// If the MeiliSearch document exists and the title is unchanged, do nothing.
if (meiliDoc && meiliDoc.title === doc.title) {
return;
}
// Otherwise, trigger a post-save hook to synchronize the document.
doc.postSaveHook();
});
};

View File

@@ -0,0 +1,18 @@
const mongoose = require('mongoose');
const mongoMeili = require('../plugins/mongoMeili');
const { convoSchema } = require('@librechat/data-schemas');
if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
convoSchema.plugin(mongoMeili, {
host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY,
/** Note: Will get created automatically if it doesn't exist already */
indexName: 'convos',
primaryKey: 'conversationId',
});
}
const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema);
module.exports = Conversation;

View File

@@ -0,0 +1,16 @@
const mongoose = require('mongoose');
const mongoMeili = require('~/models/plugins/mongoMeili');
const { messageSchema } = require('@librechat/data-schemas');
if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
messageSchema.plugin(mongoMeili, {
host: process.env.MEILI_HOST,
apiKey: process.env.MEILI_MASTER_KEY,
indexName: 'messages',
primaryKey: 'messageId',
});
}
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
module.exports = Message;

View File

@@ -0,0 +1,6 @@
const mongoose = require('mongoose');
const { pluginAuthSchema } = require('@librechat/data-schemas');
const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema);
module.exports = PluginAuth;

View File

@@ -0,0 +1,6 @@
const mongoose = require('mongoose');
const { presetSchema } = require('@librechat/data-schemas');
const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema);
module.exports = Preset;

View File

@@ -1,5 +1,6 @@
const { Transaction } = require('./Transaction');
const { logger } = require('~/config'); const { logger } = require('~/config');
const { createTransaction, createStructuredTransaction } = require('./Transaction');
/** /**
* Creates up to two transactions to record the spending of tokens. * Creates up to two transactions to record the spending of tokens.
* *
@@ -32,7 +33,7 @@ const spendTokens = async (txData, tokenUsage) => {
let prompt, completion; let prompt, completion;
try { try {
if (promptTokens !== undefined) { if (promptTokens !== undefined) {
prompt = await createTransaction({ prompt = await Transaction.create({
...txData, ...txData,
tokenType: 'prompt', tokenType: 'prompt',
rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0), rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0),
@@ -40,7 +41,7 @@ const spendTokens = async (txData, tokenUsage) => {
} }
if (completionTokens !== undefined) { if (completionTokens !== undefined) {
completion = await createTransaction({ completion = await Transaction.create({
...txData, ...txData,
tokenType: 'completion', tokenType: 'completion',
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
@@ -100,7 +101,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => {
try { try {
if (promptTokens) { if (promptTokens) {
const { input = 0, write = 0, read = 0 } = promptTokens; const { input = 0, write = 0, read = 0 } = promptTokens;
prompt = await createStructuredTransaction({ prompt = await Transaction.createStructured({
...txData, ...txData,
tokenType: 'prompt', tokenType: 'prompt',
inputTokens: -input, inputTokens: -input,
@@ -110,7 +111,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => {
} }
if (completionTokens) { if (completionTokens) {
completion = await createTransaction({ completion = await Transaction.create({
...txData, ...txData,
tokenType: 'completion', tokenType: 'completion',
rawAmount: -completionTokens, rawAmount: -completionTokens,

View File

@@ -1,9 +1,8 @@
const mongoose = require('mongoose'); const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server'); const { MongoMemoryServer } = require('mongodb-memory-server');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { createTransaction, createAutoRefillTransaction } = require('./Transaction');
require('~/db/models');
// Mock the logger to prevent console output during tests // Mock the logger to prevent console output during tests
jest.mock('~/config', () => ({ jest.mock('~/config', () => ({
@@ -20,15 +19,11 @@ jest.mock('~/server/services/Config');
describe('spendTokens', () => { describe('spendTokens', () => {
let mongoServer; let mongoServer;
let userId; let userId;
let Transaction;
let Balance;
beforeAll(async () => { beforeAll(async () => {
mongoServer = await MongoMemoryServer.create(); mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri()); const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
Transaction = mongoose.model('Transaction');
Balance = mongoose.model('Balance');
}); });
afterAll(async () => { afterAll(async () => {
@@ -202,7 +197,7 @@ describe('spendTokens', () => {
// Check that the transaction records show the adjusted values // Check that the transaction records show the adjusted values
const transactionResults = await Promise.all( const transactionResults = await Promise.all(
transactions.map((t) => transactions.map((t) =>
createTransaction({ Transaction.create({
...txData, ...txData,
tokenType: t.tokenType, tokenType: t.tokenType,
rawAmount: t.rawAmount, rawAmount: t.rawAmount,
@@ -285,7 +280,7 @@ describe('spendTokens', () => {
// Check the return values from Transaction.create directly // Check the return values from Transaction.create directly
// This is to verify that the incrementValue is not becoming positive // This is to verify that the incrementValue is not becoming positive
const directResult = await createTransaction({ const directResult = await Transaction.create({
user: userId, user: userId,
conversationId: 'test-convo-3', conversationId: 'test-convo-3',
model: 'gpt-4', model: 'gpt-4',
@@ -612,7 +607,7 @@ describe('spendTokens', () => {
const promises = []; const promises = [];
for (let i = 0; i < numberOfRefills; i++) { for (let i = 0; i < numberOfRefills; i++) {
promises.push( promises.push(
createAutoRefillTransaction({ Transaction.createAutoRefillTransaction({
user: userId, user: userId,
tokenType: 'credits', tokenType: 'credits',
context: 'concurrent-refill-test', context: 'concurrent-refill-test',

View File

@@ -1,6 +1,49 @@
const { matchModelName } = require('../utils'); const { matchModelName } = require('../utils');
const defaultRate = 6; const defaultRate = 6;
const customTokenOverrides = {};
const customCacheOverrides = {};
/**
* Allows overriding the default token multipliers.
*
* @param {Object} overrides - An object mapping model keys to their custom token multipliers.
* @param {Object} overrides.<model> - An object containing custom multipliers for the model.
* @param {number} overrides.<model>.prompt - The custom prompt multiplier for the model.
* @param {number} overrides.<model>.completion - The custom completion multiplier for the model.
*
* @example
* // Override the multipliers for "gpt-4o-mini" and "gpt-3.5":
* setCustomTokenOverrides({
* "gpt-4o-mini": { prompt: 0.2, completion: 0.5 },
* "gpt-3.5": { prompt: 1.0, completion: 2.0 }
* });
*/
const setCustomTokenOverrides = (overrides) => {
Object.assign(customTokenOverrides, overrides);
};
/**
* Allows overriding the default cache multipliers.
* The override values should be nested under a key named "Cache".
*
* @param {Object} overrides - An object mapping model keys to their custom cache multipliers.
* @param {Object} overrides.<model> - An object that must include a "Cache" property.
* @param {Object} overrides.<model>.Cache - An object containing custom cache multipliers for the model.
* @param {number} overrides.<model>.Cache.write - The custom cache write multiplier for the model.
* @param {number} overrides.<model>.Cache.read - The custom cache read multiplier for the model.
*
* @example
* // Override the cache multipliers for "gpt-4o-mini" and "gpt-3.5":
* setCustomCacheOverrides({
* "gpt-4o-mini": { cache: { write: 0.2, read: 0.5 } },
* "gpt-3.5": { cache: { write: 1.0, read: 1.5 } }
* });
*/
const setCustomCacheOverrides = (overrides) => {
Object.assign(customCacheOverrides, overrides);
};
/** /**
* AWS Bedrock pricing * AWS Bedrock pricing
* source: https://aws.amazon.com/bedrock/pricing/ * source: https://aws.amazon.com/bedrock/pricing/
@@ -78,7 +121,7 @@ const tokenValues = Object.assign(
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'o4-mini': { prompt: 1.1, completion: 4.4 }, 'o4-mini': { prompt: 1.1, completion: 4.4 },
'o3-mini': { prompt: 1.1, completion: 4.4 }, 'o3-mini': { prompt: 1.1, completion: 4.4 },
o3: { prompt: 2, completion: 8 }, o3: { prompt: 10, completion: 40 },
'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-mini': { prompt: 1.1, completion: 4.4 },
'o1-preview': { prompt: 15, completion: 60 }, 'o1-preview': { prompt: 15, completion: 60 },
o1: { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 },
@@ -100,8 +143,6 @@ const tokenValues = Object.assign(
'claude-3-5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3-5-haiku': { prompt: 0.8, completion: 4 },
'claude-3.5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3.5-haiku': { prompt: 0.8, completion: 4 },
'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, 'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
'claude-sonnet-4': { prompt: 3, completion: 15 },
'claude-opus-4': { prompt: 15, completion: 75 },
'claude-2.1': { prompt: 8, completion: 24 }, 'claude-2.1': { prompt: 8, completion: 24 },
'claude-2': { prompt: 8, completion: 24 }, 'claude-2': { prompt: 8, completion: 24 },
'claude-instant': { prompt: 0.8, completion: 2.4 }, 'claude-instant': { prompt: 0.8, completion: 2.4 },
@@ -135,11 +176,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 }, 'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 }, 'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 }, 'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.6, completion: 4 }, 'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 }, 'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 }, 'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 }, 'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-4': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 }, 'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 }, 'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 }, 'pixtral-large': { prompt: 2.0, completion: 6.0 },
@@ -165,8 +205,6 @@ const cacheTokenValues = {
'claude-3.5-haiku': { write: 1, read: 0.08 }, 'claude-3.5-haiku': { write: 1, read: 0.08 },
'claude-3-5-haiku': { write: 1, read: 0.08 }, 'claude-3-5-haiku': { write: 1, read: 0.08 },
'claude-3-haiku': { write: 0.3, read: 0.03 }, 'claude-3-haiku': { write: 0.3, read: 0.03 },
'claude-sonnet-4': { write: 3.75, read: 0.3 },
'claude-opus-4': { write: 18.75, read: 1.5 },
}; };
/** /**
@@ -288,20 +326,23 @@ const getCacheMultiplier = ({ valueKey, cacheType, model, endpoint, endpointToke
return endpointTokenConfig?.[model]?.[cacheType] ?? null; return endpointTokenConfig?.[model]?.[cacheType] ?? null;
} }
if (valueKey && cacheType) { if (!valueKey && model) {
return cacheTokenValues[valueKey]?.[cacheType] ?? null; valueKey = getValueKey(model, endpoint);
} }
if (!cacheType || !model) {
return null;
}
valueKey = getValueKey(model, endpoint);
if (!valueKey) { if (!valueKey) {
return null; return null;
} }
// If we got this far, and values[cacheType] is undefined somehow, return a rough average of default multipliers // Check for custom cache overrides under the "cache" property.
if (
customCacheOverrides[valueKey] &&
customCacheOverrides[valueKey].cache &&
customCacheOverrides[valueKey].cache[cacheType] != null
) {
return customCacheOverrides[valueKey].cache[cacheType];
}
// Fallback to the default cacheTokenValues.
return cacheTokenValues[valueKey]?.[cacheType] ?? null; return cacheTokenValues[valueKey]?.[cacheType] ?? null;
}; };
@@ -312,4 +353,6 @@ module.exports = {
getCacheMultiplier, getCacheMultiplier,
defaultRate, defaultRate,
cacheTokenValues, cacheTokenValues,
setCustomTokenOverrides,
setCustomCacheOverrides,
}; };

View File

@@ -636,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
); );
}); });
test('should return correct prompt and completion rates for Grok 4 model', () => {
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => { test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe( expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
tokenValues['grok-3'].prompt, tokenValues['grok-3'].prompt,
@@ -671,108 +662,5 @@ describe('Grok Model Tests - Pricing', () => {
tokenValues['grok-3-mini-fast'].completion, tokenValues['grok-3-mini-fast'].completion,
); );
}); });
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
});
});
describe('Claude Model Tests', () => {
it('should return correct prompt and completion rates for Claude 4 models', () => {
expect(getMultiplier({ model: 'claude-sonnet-4', tokenType: 'prompt' })).toBe(
tokenValues['claude-sonnet-4'].prompt,
);
expect(getMultiplier({ model: 'claude-sonnet-4', tokenType: 'completion' })).toBe(
tokenValues['claude-sonnet-4'].completion,
);
expect(getMultiplier({ model: 'claude-opus-4', tokenType: 'prompt' })).toBe(
tokenValues['claude-opus-4'].prompt,
);
expect(getMultiplier({ model: 'claude-opus-4', tokenType: 'completion' })).toBe(
tokenValues['claude-opus-4'].completion,
);
});
it('should handle Claude 4 model name variations with different prefixes and suffixes', () => {
const modelVariations = [
'claude-sonnet-4',
'claude-sonnet-4-20240229',
'claude-sonnet-4-latest',
'anthropic/claude-sonnet-4',
'claude-sonnet-4/anthropic',
'claude-sonnet-4-preview',
'claude-sonnet-4-20240229-preview',
'claude-opus-4',
'claude-opus-4-20240229',
'claude-opus-4-latest',
'anthropic/claude-opus-4',
'claude-opus-4/anthropic',
'claude-opus-4-preview',
'claude-opus-4-20240229-preview',
];
modelVariations.forEach((model) => {
const valueKey = getValueKey(model);
const isSonnet = model.includes('sonnet');
const expectedKey = isSonnet ? 'claude-sonnet-4' : 'claude-opus-4';
expect(valueKey).toBe(expectedKey);
expect(getMultiplier({ model, tokenType: 'prompt' })).toBe(tokenValues[expectedKey].prompt);
expect(getMultiplier({ model, tokenType: 'completion' })).toBe(
tokenValues[expectedKey].completion,
);
});
});
it('should return correct cache rates for Claude 4 models', () => {
expect(getCacheMultiplier({ model: 'claude-sonnet-4', cacheType: 'write' })).toBe(
cacheTokenValues['claude-sonnet-4'].write,
);
expect(getCacheMultiplier({ model: 'claude-sonnet-4', cacheType: 'read' })).toBe(
cacheTokenValues['claude-sonnet-4'].read,
);
expect(getCacheMultiplier({ model: 'claude-opus-4', cacheType: 'write' })).toBe(
cacheTokenValues['claude-opus-4'].write,
);
expect(getCacheMultiplier({ model: 'claude-opus-4', cacheType: 'read' })).toBe(
cacheTokenValues['claude-opus-4'].read,
);
});
it('should handle Claude 4 model cache rates with different prefixes and suffixes', () => {
const modelVariations = [
'claude-sonnet-4',
'claude-sonnet-4-20240229',
'claude-sonnet-4-latest',
'anthropic/claude-sonnet-4',
'claude-sonnet-4/anthropic',
'claude-sonnet-4-preview',
'claude-sonnet-4-20240229-preview',
'claude-opus-4',
'claude-opus-4-20240229',
'claude-opus-4-latest',
'anthropic/claude-opus-4',
'claude-opus-4/anthropic',
'claude-opus-4-preview',
'claude-opus-4-20240229-preview',
];
modelVariations.forEach((model) => {
const isSonnet = model.includes('sonnet');
const expectedKey = isSonnet ? 'claude-sonnet-4' : 'claude-opus-4';
expect(getCacheMultiplier({ model, cacheType: 'write' })).toBe(
cacheTokenValues[expectedKey].write,
);
expect(getCacheMultiplier({ model, cacheType: 'read' })).toBe(
cacheTokenValues[expectedKey].read,
);
});
}); });
}); });

View File

@@ -1,4 +1,159 @@
const bcrypt = require('bcryptjs'); const bcrypt = require('bcryptjs');
const { getBalanceConfig } = require('~/server/services/Config');
const signPayload = require('~/server/services/signPayload');
const Balance = require('./Balance');
const User = require('./User');
/**
* Retrieve a user by ID and convert the found user document to a plain object.
*
* @param {string} userId - The ID of the user to find and return as a plain object.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
*/
const getUserById = async function (userId, fieldsToSelect = null) {
const query = User.findById(userId);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
/**
* Search for a single user based on partial data and return matching user document as plain object.
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
*/
const findUser = async function (searchCriteria, fieldsToSelect = null) {
const query = User.findOne(searchCriteria);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
/**
* Update a user with new data without overwriting existing properties.
*
* @param {string} userId - The ID of the user to update.
* @param {Object} updateData - An object containing the properties to update.
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
*/
const updateUser = async function (userId, updateData) {
const updateOperation = {
$set: updateData,
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
};
return await User.findByIdAndUpdate(userId, updateOperation, {
new: true,
runValidators: true,
}).lean();
};
/**
* Creates a new user, optionally with a TTL of 1 week.
* @param {MongoUser} data - The user data to be created, must contain user_id.
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
* @param {boolean} [returnUser=false] - Whether to return the created user object.
* @returns {Promise<ObjectId|MongoUser>} A promise that resolves to the created user document ID or user object.
* @throws {Error} If a user with the same user_id already exists.
*/
const createUser = async (data, disableTTL = true, returnUser = false) => {
const balance = await getBalanceConfig();
const userData = {
...data,
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
};
if (disableTTL) {
delete userData.expiresAt;
}
const user = await User.create(userData);
// If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance
if (balance?.enabled && balance?.startBalance) {
const update = {
$inc: { tokenCredits: balance.startBalance },
};
if (
balance.autoRefillEnabled &&
balance.refillIntervalValue != null &&
balance.refillIntervalUnit != null &&
balance.refillAmount != null
) {
update.$set = {
autoRefillEnabled: true,
refillIntervalValue: balance.refillIntervalValue,
refillIntervalUnit: balance.refillIntervalUnit,
refillAmount: balance.refillAmount,
};
}
await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean();
}
if (returnUser) {
return user.toObject();
}
return user._id;
};
/**
* Count the number of user documents in the collection based on the provided filter.
*
* @param {Object} [filter={}] - The filter to apply when counting the documents.
* @returns {Promise<number>} The count of documents that match the filter.
*/
const countUsers = async function (filter = {}) {
return await User.countDocuments(filter);
};
/**
* Delete a user by their unique ID.
*
* @param {string} userId - The ID of the user to delete.
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
*/
const deleteUserById = async function (userId) {
try {
const result = await User.deleteOne({ _id: userId });
if (result.deletedCount === 0) {
return { deletedCount: 0, message: 'No user found with that ID.' };
}
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
} catch (error) {
throw new Error('Error deleting user: ' + error.message);
}
};
const { SESSION_EXPIRY } = process.env ?? {};
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
/**
* Generates a JWT token for a given user.
*
* @param {MongoUser} user - The user for whom the token is being generated.
* @returns {Promise<string>} A promise that resolves to a JWT token.
*/
const generateToken = async (user) => {
if (!user) {
throw new Error('No user provided');
}
return await signPayload({
payload: {
id: user._id,
username: user.username,
provider: user.provider,
email: user.email,
},
secret: process.env.JWT_SECRET,
expirationTime: expires / 1000,
});
};
/** /**
* Compares the provided password with the user's password. * Compares the provided password with the user's password.
@@ -12,10 +167,6 @@ const comparePassword = async (user, candidatePassword) => {
throw new Error('No user provided'); throw new Error('No user provided');
} }
if (!user.password) {
throw new Error('No password, likely an email first registered via Social/OIDC login');
}
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => { bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
if (err) { if (err) {
@@ -28,4 +179,11 @@ const comparePassword = async (user, candidatePassword) => {
module.exports = { module.exports = {
comparePassword, comparePassword,
deleteUserById,
generateToken,
getUserById,
countUsers,
createUser,
updateUser,
findUser,
}; };

View File

@@ -1,6 +1,6 @@
{ {
"name": "@librechat/backend", "name": "@librechat/backend",
"version": "v0.7.9", "version": "v0.7.8",
"description": "", "description": "",
"scripts": { "scripts": {
"start": "echo 'please run this from the root directory'", "start": "echo 'please run this from the root directory'",
@@ -34,30 +34,28 @@
}, },
"homepage": "https://librechat.ai", "homepage": "https://librechat.ai",
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.52.0", "@anthropic-ai/sdk": "^0.37.0",
"@aws-sdk/client-s3": "^3.758.0", "@aws-sdk/client-s3": "^3.758.0",
"@aws-sdk/s3-request-presigner": "^3.758.0", "@aws-sdk/s3-request-presigner": "^3.758.0",
"@azure/identity": "^4.7.0", "@azure/identity": "^4.7.0",
"@azure/search-documents": "^12.0.0", "@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.27.0", "@azure/storage-blob": "^12.27.0",
"@google/generative-ai": "^0.24.0", "@google/generative-ai": "^0.23.0",
"@googleapis/youtube": "^20.0.0", "@googleapis/youtube": "^20.0.0",
"@keyv/redis": "^4.3.3", "@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.47", "@langchain/community": "^0.3.42",
"@langchain/core": "^0.3.62", "@langchain/core": "^0.3.55",
"@langchain/google-genai": "^0.2.13", "@langchain/google-genai": "^0.2.8",
"@langchain/google-vertexai": "^0.2.13", "@langchain/google-vertexai": "^0.2.8",
"@langchain/openai": "^0.5.18",
"@langchain/textsplitters": "^0.1.0", "@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.67", "@librechat/agents": "^2.4.317",
"@librechat/api": "*",
"@librechat/data-schemas": "*", "@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1", "@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2", "axios": "^1.8.2",
"bcryptjs": "^2.4.3", "bcryptjs": "^2.4.3",
"compression": "^1.8.1", "cohere-ai": "^7.9.1",
"connect-redis": "^8.1.0", "compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.7.2", "cookie": "^0.7.2",
"cookie-parser": "^1.4.7", "cookie-parser": "^1.4.7",
"cors": "^2.8.5", "cors": "^2.8.5",
@@ -67,36 +65,34 @@
"express": "^4.21.2", "express": "^4.21.2",
"express-mongo-sanitize": "^2.2.0", "express-mongo-sanitize": "^2.2.0",
"express-rate-limit": "^7.4.1", "express-rate-limit": "^7.4.1",
"express-session": "^1.18.2", "express-session": "^1.18.1",
"express-static-gzip": "^2.2.0", "express-static-gzip": "^2.2.0",
"file-type": "^18.7.0", "file-type": "^18.7.0",
"firebase": "^11.0.2", "firebase": "^11.0.2",
"form-data": "^4.0.4",
"googleapis": "^126.0.1", "googleapis": "^126.0.1",
"handlebars": "^4.7.7", "handlebars": "^4.7.7",
"https-proxy-agent": "^7.0.6", "https-proxy-agent": "^7.0.6",
"ioredis": "^5.3.2", "ioredis": "^5.3.2",
"js-yaml": "^4.1.0", "js-yaml": "^4.1.0",
"jsonwebtoken": "^9.0.0", "jsonwebtoken": "^9.0.0",
"jwks-rsa": "^3.2.0",
"keyv": "^5.3.2", "keyv": "^5.3.2",
"keyv-file": "^5.1.2", "keyv-file": "^5.1.2",
"klona": "^2.0.6", "klona": "^2.0.6",
"librechat-data-provider": "*", "librechat-data-provider": "*",
"librechat-mcp": "*",
"lodash": "^4.17.21", "lodash": "^4.17.21",
"meilisearch": "^0.38.0", "meilisearch": "^0.38.0",
"memorystore": "^1.6.7", "memorystore": "^1.6.7",
"mime": "^3.0.0", "mime": "^3.0.0",
"module-alias": "^2.2.3", "module-alias": "^2.2.3",
"mongoose": "^8.12.1", "mongoose": "^8.12.1",
"multer": "^2.0.2", "multer": "^1.4.5-lts.1",
"nanoid": "^3.3.7", "nanoid": "^3.3.7",
"node-fetch": "^2.7.0",
"nodemailer": "^6.9.15", "nodemailer": "^6.9.15",
"ollama": "^0.5.0", "ollama": "^0.5.0",
"openai": "^5.10.1", "openai": "^4.96.2",
"openai-chat-tokens": "^0.2.8", "openai-chat-tokens": "^0.2.8",
"openid-client": "^6.5.0", "openid-client": "^5.4.2",
"passport": "^0.6.0", "passport": "^0.6.0",
"passport-apple": "^2.0.2", "passport-apple": "^2.0.2",
"passport-discord": "^0.1.4", "passport-discord": "^0.1.4",
@@ -111,9 +107,8 @@
"tiktoken": "^1.0.15", "tiktoken": "^1.0.15",
"traverse": "^0.6.7", "traverse": "^0.6.7",
"ua-parser-js": "^1.0.36", "ua-parser-js": "^1.0.36",
"undici": "^7.10.0",
"winston": "^3.11.0", "winston": "^3.11.0",
"winston-daily-rotate-file": "^5.0.0", "winston-daily-rotate-file": "^4.7.1",
"youtube-transcript": "^1.2.1", "youtube-transcript": "^1.2.1",
"zod": "^3.22.4" "zod": "^3.22.4"
}, },

Some files were not shown because too many files have changed in this diff Show More