Compare commits

..

41 Commits

Author SHA1 Message Date
Danny Avila
148052c473 chore: remove test code 2025-05-30 15:05:18 -04:00
Danny Avila
331014cc98 refactor(config): update mongoose imports to resolve path dynamically 2025-05-30 15:01:27 -04:00
Danny Avila
9c0deed34a refactor(config): update user-related imports to utilize mongoose models 2025-05-30 14:54:32 -04:00
Danny Avila
226bd90ede refactor(openidStrategy): remove unused crypto imports to clean up code 2025-05-30 14:47:32 -04:00
Danny Avila
4fea3d4274 fix(crypto): update key and IV to use environment variables for enhanced security 2025-05-30 14:45:09 -04:00
Danny Avila
ad6716f6ef refactor(PluginService): update crypto imports for better organization 2025-05-30 14:43:38 -04:00
Danny Avila
b2f7f5c904 fix(samlStrategy): update user creation to include balance configuration
- Modified the user creation process to incorporate balance configuration retrieved from the new getBalanceConfig function.
- Adjusted imports for user methods to streamline the code structure.
2025-05-30 14:42:30 -04:00
Danny Avila
4ae1d82a75 chore: remove unused mongoose imports from Message model and message routes 2025-05-30 14:39:49 -04:00
Danny Avila
494c6d2596 refactor(crypto): reorganize token hashing and signing functionality 2025-05-30 14:38:01 -04:00
Danny Avila
6f4c8ef114 refactor(token): simplify token deletion and retrieval logic
- Consolidated query conditions for token deletion and retrieval into a single array for improved readability.
- Removed redundant error handling for empty query conditions, as the logic now directly checks for provided parameters.
- Enhanced the return statement for the findToken method to streamline the code structure.
2025-05-30 14:29:21 -04:00
Danny Avila
a4c6553695 chore(session): remove commented-out code for clarity 2025-05-30 14:18:59 -04:00
Danny Avila
edb977c1bc feat(session): enhance session management with new methods and error handling
- Introduced a custom SessionError class for better error management.
- Updated session creation and querying methods to use type imports for improved type safety.
- Added updateExpiration and countActiveSessions methods to manage session lifecycle.
- Refactored deleteAllUserSessions to include logging and error handling.
- Streamlined session document creation to align with Mongoose practices.
2025-05-30 14:17:56 -04:00
Danny Avila
8ec7781672 chore: remove unused mongoose import from Role model 2025-05-30 14:06:40 -04:00
Danny Avila
99731e98dd chore: revert connectDb function to original pattern 2025-05-30 14:05:51 -04:00
Danny Avila
f57d920bd5 chore: remove unused imports 2025-05-30 14:04:31 -04:00
Danny Avila
3831ad8202 fix(models): update user and token operations to use centralized functions 2025-05-30 13:59:30 -04:00
Danny Avila
6e278f6932 fix(auth): replace mongoose model references with new function imports
- Updated AuthController, checkBan middleware, localStrategy, and openidStrategy to use new function imports for user operations.
- Removed unused mongoose imports to streamline the codebase.
- Enhanced consistency across user-related operations by utilizing the centralized methods for user management.
2025-05-30 13:46:31 -04:00
Danny Avila
90ac2b51cd feat(data-schemas): add new Mongoose models for conversationTag, key, pluginAuth, preset, project, prompt, promptGroup, sharedLink, toolCall, and transaction
- Introduced new model files for conversationTag, key, pluginAuth, preset, project, prompt, promptGroup, sharedLink, toolCall, and transaction.
- Each model includes a function to create or return the respective Mongoose model using the provided instance and schema.
- Updated the centralized models index to include these new models for better organization and accessibility.
2025-05-30 13:42:49 -04:00
Danny Avila
20ad7d52f3 refactor(db): streamline model imports and remove unused model exports
- Removed the export of models from the database connection module to simplify the structure.
- Updated various files to import models directly from the new centralized models module.
- Ensured consistency across the codebase by replacing mongoose model references with the new import paths.
2025-05-30 13:13:10 -04:00
Danny Avila
eb368fcb70 refactor(db): replace connectDb import paths and introduce new connect module
- Updated import paths for connectDb across various files to use the new centralized connect module.
- Removed the old connectDb file to streamline the database connection logic.
- Ensured all tests and models reference the new connection method for consistency.
2025-05-30 13:04:09 -04:00
Danny Avila
7cf3f98475 chore: remove Config model file to streamline codebase 2025-05-30 12:55:06 -04:00
Danny Avila
ab5450be8b WIP: first pass, massive refactor of model imports 2025-05-30 12:54:24 -04:00
Danny Avila
c682d45fb2 chore(data-schemas): update package dependencies and restructure peerDependencies
- Moved dependencies to peerDependencies in package.json for better compatibility.
- Added "peer": true to several entries in package-lock.json to indicate peer dependencies.
2025-05-30 12:23:29 -04:00
Danny Avila
5fb6b91e71 chore: remove unused file 2025-05-30 12:20:32 -04:00
Danny Avila
76e070048c refactor(data-schemas): update model and method creation for improved modularity
- Refactored model creation functions to enhance clarity and consistency across the data-schemas.
- Introduced createModels and createMethods functions to streamline the instantiation of Mongoose models and methods.
- Updated test-role.js to utilize the new createModels and createMethods for better organization.
2025-05-30 12:20:01 -04:00
Danny Avila
728d19e361 refactor(data-schemas): reintroduce mongoMeili plugin for conversation and message schemas
- Added mongoMeili plugin back to convoSchema and messageSchema for enhanced search capabilities.
- Updated import statements to use Schema directly from mongoose for consistency.
- Removed conditional checks for the plugin from model files, centralizing the logic in the schema definitions.
2025-05-30 12:13:54 -04:00
Danny Avila
2d492b932f refactor(data-schemas): enhance method organization and add librechat-data-provider dependency 2025-05-30 12:13:42 -04:00
Danny Avila
c201d54cac WIP: first pass, factory models and methods 2025-05-30 12:02:22 -04:00
Danny Avila
a2a3f5c044 experimental: npm link test 2025-05-30 11:13:34 -04:00
Danny Avila
f9c0e9853f refactor: original changes 2025-05-30 04:28:22 -04:00
Danny Avila
fa9177180f refactor(data-schemas): introduce new models and types for balance, conversation, message, and session
- Added new model files for Balance, Conversation, Message, and Session, enhancing modularity.
- Created corresponding type definitions for IBalance, IConversation, IMessage, and updated existing types.
- Refactored index files to export models from their individual files for better organization.
2025-05-30 02:13:35 -04:00
Danny Avila
f6ca8caf7e refactor(data-schemas): restructure schemas, models, and methods for improved modularity 2025-05-30 01:42:06 -04:00
Danny Avila
30b8a1c6c4 refactor(data-schemas): update tsconfig and import paths for improved module resolution
- Added baseUrl and paths configuration to tsconfig.json for better module resolution.
- Updated import statement in mongoMeili.ts to use the new path alias for the meiliLogger configuration.
2025-05-30 00:54:50 -04:00
Danny Avila
848cb6f871 refactor(data-schemas): remove legacy mongoMeili plugin and related schemas
- Deleted the mongoMeili plugin and its associated schemas (messageSchema, pluginAuthSchema) to streamline the codebase.
- Updated PluginService to import PluginAuth directly from data-schemas.
- Introduced a new meiliLogger configuration file for improved logging functionality.
2025-05-30 00:34:28 -04:00
Danny Avila
ea459749f9 refactor(data-schemas): enhance type safety in log formatting functions
- Introduced type guards to ensure message and symbol values are strings in redactFormat.
- Updated parameter types in truncateLongStrings and condenseArray for better type safety.
- Improved type handling in debugTraverse and jsonTruncateFormat to prevent runtime errors.
- Ensured proper handling of circular references and object types in logging functions.
2025-05-29 16:18:30 -04:00
Danny Avila
63c56c8dd9 refactor(data-schemas): simplify environment variable checks in winston configuration 2025-05-29 15:18:52 -04:00
Danny Avila
7caffda81a fix(data-schemas): resolve circular dependencies and add missing model registrations
- Break circular dependency by importing schemas directly from individual files
- Add missing actionSchema and pluginAuthSchema imports
- Add registerActionModel and registerPluginAuthModel functions
- Fix typo in Transaction model registration (Trasaction → Transaction)
- Include Action and PluginAuth models in registerModels return object
2025-05-29 15:16:13 -04:00
Danny Avila
0cb5ed4063 fix: change generateToken method to a static method on userSchema 2025-05-29 14:45:40 -04:00
Danny Avila
85d0688f38 chore: remove legacy TTL index cleanup from Token model 2025-05-29 14:39:30 -04:00
Danny Avila
2c14fe1e9a fix: align known working version of meilisearch @ v0.38.0 2025-05-29 14:39:30 -04:00
Cha
4049b5572c Move usermethods and models to data-schema 2025-05-29 14:39:27 -04:00
692 changed files with 17179 additions and 45687 deletions

View File

@@ -58,7 +58,7 @@ DEBUG_CONSOLE=false
# Endpoints #
#===================================================#
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic
# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
PROXY=
@@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# Vertex AI
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
@@ -349,11 +349,6 @@ REGISTRATION_VIOLATION_SCORE=1
CONCURRENT_VIOLATION_SCORE=1
MESSAGE_VIOLATION_SCORE=1
NON_BROWSER_VIOLATION_SCORE=20
TTS_VIOLATION_SCORE=0
STT_VIOLATION_SCORE=0
FORK_VIOLATION_SCORE=0
IMPORT_VIOLATION_SCORE=0
FILE_UPLOAD_VIOLATION_SCORE=0
LOGIN_MAX=7
LOGIN_WINDOW=5
@@ -458,8 +453,8 @@ OPENID_REUSE_TOKENS=
OPENID_JWKS_URL_CACHE_ENABLED=
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
# Set to true to use the OpenID Connect end session endpoint for logout
OPENID_USE_END_SESSION_ENDPOINT=
@@ -520,18 +515,6 @@ EMAIL_PASSWORD=
EMAIL_FROM_NAME=
EMAIL_FROM=noreply@librechat.ai
#========================#
# Mailgun API #
#========================#
# MAILGUN_API_KEY=your-mailgun-api-key
# MAILGUN_DOMAIN=mg.yourdomain.com
# EMAIL_FROM=noreply@yourdomain.com
# EMAIL_FROM_NAME="LibreChat"
# # Optional: For EU region
# MAILGUN_HOST=https://api.eu.mailgun.net
#========================#
# Firebase CDN #
#========================#
@@ -580,10 +563,6 @@ ALLOW_SHARED_LINKS_PUBLIC=true
# If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true
# If you have gzipped version of uploaded image images in the same folder, this will enable gzip scan and serving of these images
# Note: The images folder will be scanned on startup and a ma kept in memory. Be careful for large number of images.
# ENABLE_IMAGE_OUTPUT_GZIP_SCAN=true
#===================================================#
# UI #
#===================================================#
@@ -601,31 +580,11 @@ HELP_AND_FAQ_URL=https://librechat.ai
# REDIS Options #
#===============#
# Enable Redis for caching and session storage
# REDIS_URI=10.10.10.10:6379
# USE_REDIS=true
# Single Redis instance
# REDIS_URI=redis://127.0.0.1:6379
# Redis cluster (multiple nodes)
# REDIS_URI=redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003
# Redis with TLS/SSL encryption and CA certificate
# REDIS_URI=rediss://127.0.0.1:6380
# REDIS_CA=/path/to/ca-cert.pem
# Redis authentication (if required)
# REDIS_USERNAME=your_redis_username
# REDIS_PASSWORD=your_redis_password
# Redis key prefix configuration
# Use environment variable name for dynamic prefix (recommended for cloud deployments)
# REDIS_KEY_PREFIX_VAR=K_REVISION
# Or use static prefix directly
# REDIS_KEY_PREFIX=librechat
# Redis connection limits
# REDIS_MAX_LISTENERS=40
# USE_REDIS_CLUSTER=true
# REDIS_CA=/path/to/ca.crt
#==================================================#
# Others #
@@ -686,4 +645,4 @@ OPENWEATHER_API_KEY=
# Reranker (Required)
# JINA_API_KEY=your_jina_api_key
# or
# COHERE_API_KEY=your_cohere_api_key
# COHERE_API_KEY=your_cohere_api_key

View File

@@ -30,8 +30,8 @@ Project maintainers have the right and responsibility to remove, edit, or reject
2. Install typescript globally: `npm i -g typescript`.
3. Run `npm ci` to install dependencies.
4. Build the data provider: `npm run build:data-provider`.
5. Build data schemas: `npm run build:data-schemas`.
6. Build API methods: `npm run build:api`.
5. Build MCP: `npm run build:mcp`.
6. Build data schemas: `npm run build:data-schemas`.
7. Setup and run unit tests:
- Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`.
- Run backend unit tests: `npm run test:api`.

View File

@@ -7,7 +7,6 @@ on:
- release/*
paths:
- 'api/**'
- 'packages/api/**'
jobs:
tests_Backend:
name: Run Backend unit tests
@@ -37,12 +36,12 @@ jobs:
- name: Install Data Provider Package
run: npm run build:data-provider
- name: Install MCP Package
run: npm run build:mcp
- name: Install Data Schemas Package
run: npm run build:data-schemas
- name: Install API Package
run: npm run build:api
- name: Create empty auth.json file
run: |
mkdir -p api/data
@@ -67,8 +66,5 @@ jobs:
- name: Run librechat-data-provider unit tests
run: cd packages/data-provider && npm run test:ci
- name: Run @librechat/data-schemas unit tests
run: cd packages/data-schemas && npm run test:ci
- name: Run @librechat/api unit tests
run: cd packages/api && npm run test:ci
- name: Run librechat-mcp unit tests
run: cd packages/mcp && npm run test:ci

View File

@@ -2,7 +2,7 @@ name: Update Test Server
on:
workflow_run:
workflows: ["Docker Dev Branch Images Build"]
workflows: ["Docker Dev Images Build"]
types:
- completed
workflow_dispatch:
@@ -12,8 +12,7 @@ jobs:
runs-on: ubuntu-latest
if: |
github.repository == 'danny-avila/LibreChat' &&
(github.event_name == 'workflow_dispatch' ||
(github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'dev'))
(github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success')
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -30,17 +29,13 @@ jobs:
DO_USER: ${{ secrets.DO_USER }}
run: |
ssh -o StrictHostKeyChecking=no ${DO_USER}@${DO_HOST} << EOF
sudo -i -u danny bash << 'EEOF'
sudo -i -u danny bash << EEOF
cd ~/LibreChat && \
git fetch origin main && \
sudo npm run stop:deployed && \
sudo docker images --format "{{.Repository}}:{{.ID}}" | grep -E "lc-dev|librechat" | cut -d: -f2 | xargs -r sudo docker rmi -f || true && \
sudo npm run update:deployed && \
git checkout dev && \
git pull origin dev && \
npm run update:deployed && \
git checkout do-deploy && \
git rebase dev && \
sudo npm run start:deployed && \
git rebase main && \
npm run start:deployed && \
echo "Update completed. Application should be running now."
EEOF
EOF

View File

@@ -1,72 +0,0 @@
name: Docker Dev Branch Images Build
on:
workflow_dispatch:
push:
branches:
- dev
paths:
- 'api/**'
- 'client/**'
- 'packages/**'
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- target: api-build
file: Dockerfile.multi
image_name: lc-dev-api
- target: node
file: Dockerfile
image_name: lc-dev
steps:
# Check out the repository
- name: Checkout
uses: actions/checkout@v4
# Set up QEMU
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
# Set up Docker Buildx
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Log in to GitHub Container Registry
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# Login to Docker Hub
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Prepare the environment
- name: Prepare environment
run: |
cp .env.example .env
# Build and push Docker images for each target
- name: Build and push Docker images
uses: docker/build-push-action@v5
with:
context: .
file: ${{ matrix.file }}
push: true
tags: |
ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.sha }}
ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest
${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.sha }}
${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest
platforms: linux/amd64,linux/arm64
target: ${{ matrix.target }}

View File

@@ -5,13 +5,12 @@ on:
paths:
- "client/src/**"
- "api/**"
- "packages/data-provider/src/**"
jobs:
detect-unused-i18n-keys:
runs-on: ubuntu-latest
permissions:
pull-requests: write
pull-requests: write # Required for posting PR comments
steps:
- name: Checkout repository
uses: actions/checkout@v3

View File

@@ -98,8 +98,6 @@ jobs:
cd client
UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "")
UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "")
# Filter out false positives
UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "")
echo "CLIENT_UNUSED<<EOF" >> $GITHUB_ENV
echo "$UNUSED" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV

10
.gitignore vendored
View File

@@ -55,7 +55,6 @@ bower_components/
# AI
.clineignore
.cursor
.aider*
# Floobits
.floo
@@ -125,12 +124,3 @@ helm/**/.values.yaml
# SAML Idp cert
*.cert
# AI Assistants
/.claude/
/.cursor/
/.copilot/
/.aider/
/.openai/
/.tabnine/
/.codeium

View File

@@ -1,4 +1,4 @@
# v0.7.9-rc1
# v0.7.8
# Base node image
FROM node:20-alpine AS node

View File

@@ -1,5 +1,5 @@
# Dockerfile.multi
# v0.7.9-rc1
# v0.7.8
# Base for all builds
FROM node:20-alpine AS base-min
@@ -14,7 +14,7 @@ RUN npm config set fetch-retry-maxtimeout 600000 && \
npm config set fetch-retry-mintimeout 15000
COPY package*.json ./
COPY packages/data-provider/package*.json ./packages/data-provider/
COPY packages/api/package*.json ./packages/api/
COPY packages/mcp/package*.json ./packages/mcp/
COPY packages/data-schemas/package*.json ./packages/data-schemas/
COPY client/package*.json ./client/
COPY api/package*.json ./api/
@@ -24,27 +24,26 @@ FROM base-min AS base
WORKDIR /app
RUN npm ci
# Build `data-provider` package
# Build data-provider
FROM base AS data-provider-build
WORKDIR /app/packages/data-provider
COPY packages/data-provider ./
RUN npm run build
# Build `data-schemas` package
# Build mcp package
FROM base AS mcp-build
WORKDIR /app/packages/mcp
COPY packages/mcp ./
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
RUN npm run build
# Build data-schemas
FROM base AS data-schemas-build
WORKDIR /app/packages/data-schemas
COPY packages/data-schemas ./
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
RUN npm run build
# Build `api` package
FROM base AS api-package-build
WORKDIR /app/packages/api
COPY packages/api ./
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist
RUN npm run build
# Client build
FROM base AS client-build
WORKDIR /app/client
@@ -64,8 +63,8 @@ RUN npm ci --omit=dev
COPY api ./api
COPY config ./config
COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist
COPY --from=mcp-build /app/packages/mcp/dist ./packages/mcp/dist
COPY --from=data-schemas-build /app/packages/data-schemas/dist ./packages/data-schemas/dist
COPY --from=api-package-build /app/packages/api/dist ./packages/api/dist
COPY --from=client-build /app/client/dist ./client/dist
WORKDIR /app/api
EXPOSE 3080

View File

@@ -52,7 +52,7 @@
- 🖥️ **UI & Experience** inspired by ChatGPT with enhanced design and features
- 🤖 **AI Model Selection**:
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Responses API (incl. Azure)
- Anthropic (Claude), AWS Bedrock, OpenAI, Azure OpenAI, Google, Vertex AI, OpenAI Assistants API (incl. Azure)
- [Custom Endpoints](https://www.librechat.ai/docs/quick_start/custom_endpoints): Use any OpenAI-compatible API with LibreChat, no proxy required
- Compatible with [Local & Remote AI Providers](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):
- Ollama, groq, Cohere, Mistral AI, Apple MLX, koboldcpp, together.ai,
@@ -66,9 +66,10 @@
- 🔦 **Agents & Tools Integration**:
- **[LibreChat Agents](https://www.librechat.ai/docs/features/agents)**:
- No-Code Custom Assistants: Build specialized, AI-driven helpers without coding
- Flexible & Extensible: Use MCP Servers, tools, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, Google, Vertex AI, Responses API, and more
- Flexible & Extensible: Attach tools like DALL-E-3, file search, code execution, and more
- Compatible with Custom Endpoints, OpenAI, Azure, Anthropic, AWS Bedrock, and more
- [Model Context Protocol (MCP) Support](https://modelcontextprotocol.io/clients#librechat) for Tools
- Use LibreChat Agents and OpenAI Assistants with Files, Code Interpreter, Tools, and API Actions
- 🔍 **Web Search**:
- Search the internet and retrieve relevant information to enhance your AI context
@@ -149,8 +150,8 @@ Click on the thumbnail to open the video☝
**Other:**
- **Website:** [librechat.ai](https://librechat.ai)
- **Documentation:** [librechat.ai/docs](https://librechat.ai/docs)
- **Blog:** [librechat.ai/blog](https://librechat.ai/blog)
- **Documentation:** [docs.librechat.ai](https://docs.librechat.ai)
- **Blog:** [blog.librechat.ai](https://blog.librechat.ai)
---

View File

@@ -10,7 +10,6 @@ const {
validateVisionModel,
} = require('librechat-data-provider');
const { SplitStreamHandler: _Handler } = require('@librechat/agents');
const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api');
const {
truncateText,
formatMessage,
@@ -27,6 +26,8 @@ const {
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { createFetch, createStreamEventHandlers } = require('./generators');
const Tokenizer = require('~/server/services/Tokenizer');
const { sleep } = require('~/server/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
@@ -190,11 +191,10 @@ class AnthropicClient extends BaseClient {
reverseProxyUrl: this.options.reverseProxyUrl,
}),
apiKey: this.apiKey,
fetchOptions: {},
};
if (this.options.proxy) {
options.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy);
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
if (this.options.reverseProxyUrl) {

View File

@@ -13,6 +13,7 @@ const {
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { checkBalance } = require('~/models/balanceMethods');
const { truncateToolCallOutputs } = require('./prompts');
const { addSpaceIfNeeded } = require('~/server/utils');
const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream');
const { logger } = require('~/config');
@@ -197,10 +198,6 @@ class BaseClient {
this.currentMessages[this.currentMessages.length - 1].messageId = head;
}
if (opts.isRegenerate && responseMessageId.endsWith('_')) {
responseMessageId = crypto.randomUUID();
}
this.responseMessageId = responseMessageId;
return {
@@ -575,7 +572,7 @@ class BaseClient {
});
}
const { editedContent } = opts;
const { generation = '' } = opts;
// It's not necessary to push to currentMessages
// depending on subclass implementation of handling messages
@@ -590,21 +587,11 @@ class BaseClient {
isCreatedByUser: false,
model: this.modelOptions?.model ?? this.model,
sender: this.sender,
text: generation,
};
this.currentMessages.push(userMessage, latestMessage);
} else if (editedContent != null) {
// Handle editedContent for content parts
if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) {
const { index, text, type } = editedContent;
if (index >= 0 && index < latestMessage.content.length) {
const contentPart = latestMessage.content[index];
if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) {
contentPart[ContentTypes.THINK] = text;
} else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) {
contentPart[ContentTypes.TEXT] = text;
}
}
}
} else {
latestMessage.text = generation;
}
this.continued = true;
} else {
@@ -685,32 +672,16 @@ class BaseClient {
};
if (typeof completion === 'string') {
responseMessage.text = completion;
responseMessage.text = addSpaceIfNeeded(generation) + completion;
} else if (
Array.isArray(completion) &&
(this.clientName === EModelEndpoint.agents ||
isParamEndpoint(this.options.endpoint, this.options.endpointType))
) {
responseMessage.text = '';
if (!opts.editedContent || this.currentMessages.length === 0) {
responseMessage.content = completion;
} else {
const latestMessage = this.currentMessages[this.currentMessages.length - 1];
if (!latestMessage?.content) {
responseMessage.content = completion;
} else {
const existingContent = [...latestMessage.content];
const { type: editedType } = opts.editedContent;
responseMessage.content = this.mergeEditedContent(
existingContent,
completion,
editedType,
);
}
}
responseMessage.content = completion;
} else if (Array.isArray(completion)) {
responseMessage.text = completion.join('');
responseMessage.text = addSpaceIfNeeded(generation) + completion.join('');
}
if (
@@ -821,8 +792,7 @@ class BaseClient {
userMessage.tokenCount = userMessageTokenCount;
/*
Note: `AgentController` saves the user message if not saved here
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
*/
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
@@ -831,8 +801,7 @@ class BaseClient {
}
/*
Note: we update the user message to be sure it gets the calculated token count;
though `AgentController` saves the user message if not saved here
(noted by `savedMessageIds`), EditController does not
though `AskController` saves the user message, EditController does not
*/
await userMessagePromise;
await this.updateMessageInDatabase({
@@ -1124,50 +1093,6 @@ class BaseClient {
return numTokens;
}
/**
* Merges completion content with existing content when editing TEXT or THINK types
* @param {Array} existingContent - The existing content array
* @param {Array} newCompletion - The new completion content
* @param {string} editedType - The type of content being edited
* @returns {Array} The merged content array
*/
mergeEditedContent(existingContent, newCompletion, editedType) {
if (!newCompletion.length) {
return existingContent.concat(newCompletion);
}
if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) {
return existingContent.concat(newCompletion);
}
const lastIndex = existingContent.length - 1;
const lastExisting = existingContent[lastIndex];
const firstNew = newCompletion[0];
if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) {
return existingContent.concat(newCompletion);
}
const mergedContent = [...existingContent];
if (editedType === ContentTypes.TEXT) {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.TEXT]:
(mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''),
};
} else {
mergedContent[lastIndex] = {
...mergedContent[lastIndex],
[ContentTypes.THINK]:
(mergedContent[lastIndex][ContentTypes.THINK] || '') +
(firstNew[ContentTypes.THINK] || ''),
};
}
// Add remaining completion items
return mergedContent.concat(newCompletion.slice(1));
}
async sendPayload(payload, opts = {}) {
if (opts && typeof opts === 'object') {
this.setOptions(opts);

View File

@@ -0,0 +1,804 @@
const { Keyv } = require('keyv');
const crypto = require('crypto');
const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
ImageDetail,
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils');
const { createContextHandlers } = require('./prompts');
const { createCoherePayload } = require('./llm');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {};
class ChatGPTClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
this.conversationsCache = new Keyv(cacheOptions);
this.setOptions(options);
}
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = {
...this.options.modelOptions,
...options.modelOptions,
};
delete options.modelOptions;
// now we can merge options
this.options = {
...this.options,
...options,
};
} else {
this.options = options;
}
if (this.options.openaiApiKey) {
this.apiKey = this.options.openaiApiKey;
}
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || CHATGPT_MODEL,
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop,
};
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
const { isChatGptModel } = this;
this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
const { isUnofficialChatGptModel } = this;
// Davinci models have a max context length of 4097 tokens.
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
// I decided to reserve 1024 tokens for the response.
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
}
this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
if (isChatGptModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isUnofficialChatGptModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
});
} else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead.
this.startToken = '||>';
this.endToken = '';
try {
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch {
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
}
}
if (!this.modelOptions.stop) {
const stopTokens = [this.startToken];
if (this.endToken && this.endToken !== this.startToken) {
stopTokens.push(this.endToken);
}
stopTokens.push(`\n${this.userLabel}:`);
stopTokens.push('<|diff_marker|>');
// I chose not to do one for `chatGptLabel` because I've never seen it happen
this.modelOptions.stop = stopTokens;
}
if (this.options.reverseProxyUrl) {
this.completionsUrl = this.options.reverseProxyUrl;
} else if (isChatGptModel) {
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
} else {
this.completionsUrl = 'https://api.openai.com/v1/completions';
}
return this;
}
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
}
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
let modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = input;
} else {
modelOptions.prompt = input;
}
if (this.useOpenRouter && modelOptions.prompt) {
delete modelOptions.stop;
}
const { debug } = this.options;
let baseURL = this.completionsUrl;
if (debug) {
console.debug();
console.debug(baseURL);
console.debug(modelOptions);
console.debug();
}
const opts = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
};
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}
/** @type {TAzureConfig | undefined} */
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
const isAzure = this.azure || this.options.azure;
if (
(isAzure && this.isVisionModel && azureConfig) ||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName: modelOptions.model,
modelGroupMap,
groupMap,
});
opts.headers = resolveHeaders(headers);
this.langchainProxy = extractBaseURL(baseURL);
this.apiKey = azureOptions.azureOpenAIApiKey;
const groupName = modelGroupMap[modelOptions.model].group;
this.options.addParams = azureConfig.groupMap[groupName].addParams;
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
// Note: `forcePrompt` not re-assigned as only chat models are vision models
this.azure = !serverless && azureOptions;
this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
}
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
}
if (this.options.headers) {
opts.headers = { ...opts.headers, ...this.options.headers };
}
if (isAzure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
baseURL = this.langchainProxy
? constructAzureURL({
baseURL: this.langchainProxy,
azureOptions: this.azure,
})
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
if (this.options.forcePrompt) {
baseURL += '/completions';
} else {
baseURL += '/chat/completions';
}
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
} else if (this.apiKey) {
opts.headers.Authorization = `Bearer ${this.apiKey}`;
}
if (process.env.OPENAI_ORGANIZATION) {
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
}
if (this.useOpenRouter) {
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
opts.headers['X-Title'] = 'LibreChat';
}
/* hacky fixes for Mistral AI API:
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
- If there is only one message and it's a system message, change the role to user
*/
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
const { messages } = modelOptions;
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
if (systemMessageIndex > 0) {
const [systemMessage] = messages.splice(systemMessageIndex, 1);
messages.unshift(systemMessage);
}
modelOptions.messages = messages;
if (messages.length === 1 && messages[0].role === 'system') {
modelOptions.messages[0].role = 'user';
}
}
if (this.options.addParams && typeof this.options.addParams === 'object') {
modelOptions = {
...modelOptions,
...this.options.addParams,
};
logger.debug('[ChatGPTClient] chatCompletion: added params', {
addParams: this.options.addParams,
modelOptions,
});
}
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => {
delete modelOptions[param];
});
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
dropParams: this.options.dropParams,
modelOptions,
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
baseURL.includes('v1') &&
!baseURL.includes('/chat/completions') &&
this.isChatCompletion
) {
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
}
const BASE_URL = new URL(baseURL);
if (opts.defaultQuery) {
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
BASE_URL.searchParams.append(key, value);
});
delete opts.defaultQuery;
}
const completionsURL = BASE_URL.toString();
opts.body = JSON.stringify(modelOptions);
if (modelOptions.stream) {
return new Promise(async (resolve, reject) => {
try {
let done = false;
await fetchEventSource(completionsURL, {
...opts,
signal: abortController.signal,
async onopen(response) {
if (response.status === 200) {
return;
}
if (debug) {
console.debug(response);
}
let error;
try {
const body = await response.text();
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
error.json = JSON.parse(body);
} catch {
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
}
throw error;
},
onclose() {
if (debug) {
console.debug('Server closed the connection unexpectedly, returning...');
}
// workaround for private API not sending [DONE] event
if (!done) {
onProgress('[DONE]');
resolve();
}
},
onerror(err) {
if (debug) {
console.debug(err);
}
// rethrow to stop the operation
throw err;
},
onmessage(message) {
if (debug) {
console.debug(message);
}
if (!message.data || message.event === 'ping') {
return;
}
if (message.data === '[DONE]') {
onProgress('[DONE]');
resolve();
done = true;
return;
}
onProgress(JSON.parse(message.data));
},
});
} catch (err) {
reject(err);
}
});
}
const response = await fetch(completionsURL, {
...opts,
signal: abortController.signal,
});
if (response.status !== 200) {
const body = await response.text();
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
try {
error.json = JSON.parse(body);
} catch {
error.body = body;
}
throw error;
}
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
reply += message.text;
}
/*
Cohere API Chinese Unicode character replacement hotfix.
Should be un-commented when the following issue is resolved:
https://github.com/cohere-ai/cohere-typescript/issues/151
else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
*/
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
||>Message:
${userMessage.message}
||>Response:
${botMessage.message}
||>Title:`,
};
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
titleGenClientOptions.modelOptions = {
model: 'gpt-3.5-turbo',
temperature: 0,
presence_penalty: 0,
frequency_penalty: 0,
};
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
const result = await titleGenClient.getCompletion([instructionsPayload], null);
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
return result.choices[0].message.content
.replace(/[^a-zA-Z0-9' ]/g, '')
.replace(/\s+/g, ' ')
.trim();
}
async sendMessage(message, opts = {}) {
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
this.setOptions(opts.clientOptions);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
let conversation =
typeof opts.conversation === 'object'
? opts.conversation
: await this.conversationsCache.get(conversationId);
let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
isNewConversation = true;
}
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const userMessage = {
id: crypto.randomUUID(),
parentMessageId,
role: 'User',
message,
};
conversation.messages.push(userMessage);
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
const { prompt: payload, context } = await this.buildPrompt(
conversation.messages,
userMessage.id,
{
isChatGptModel: this.isChatGptModel,
promptPrefix: opts.promptPrefix,
},
);
if (this.options.keepNecessaryMessagesOnly) {
conversation.messages = context;
}
let reply = '';
let result = null;
if (typeof opts.onProgress === 'function') {
await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
return;
}
const token = this.isChatGptModel
? progressMessage.choices[0].delta.content
: progressMessage.choices[0].text;
// first event's delta content is always undefined
if (!token) {
return;
}
if (this.options.debug) {
console.debug(token);
}
if (token === this.endToken) {
return;
}
opts.onProgress(token);
reply += token;
},
opts.abortController || new AbortController(),
);
} else {
result = await this.getCompletion(
payload,
null,
opts.abortController || new AbortController(),
);
if (this.options.debug) {
console.debug(JSON.stringify(result));
}
if (this.isChatGptModel) {
reply = result.choices[0].message.content;
} else {
reply = result.choices[0].text.replace(this.endToken, '');
}
}
// avoids some rendering issues when using the CLI app
if (this.options.debug) {
console.debug();
}
reply = reply.trim();
const replyMessage = {
id: crypto.randomUUID(),
parentMessageId: userMessage.id,
role: 'ChatGPT',
message: reply,
};
conversation.messages.push(replyMessage);
const returnData = {
response: replyMessage.message,
conversationId,
parentMessageId: replyMessage.parentMessageId,
messageId: replyMessage.id,
details: result || {},
};
if (shouldGenerateTitle) {
conversation.title = await this.generateTitle(userMessage, replyMessage);
returnData.title = conversation.title;
}
await this.conversationsCache.set(conversationId, conversation);
if (this.options.returnConversation) {
returnData.conversation = conversation;
}
return returnData;
}
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
// Handle attachments and create augmentedPrompt
if (this.options.attachments) {
const attachments = await this.options.attachments;
const lastMessage = messages[messages.length - 1];
if (this.message_file_map) {
this.message_file_map[lastMessage.messageId] = attachments;
} else {
this.message_file_map = {
[lastMessage.messageId]: attachments,
};
}
const files = await this.addImageURLs(lastMessage, attachments);
this.options.attachments = files;
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
messages[messages.length - 1].text,
);
}
// Calculate image token cost and process embedded files
messages.forEach((message, i) => {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
messages[i].tokenCount =
(messages[i].tokenCount || 0) +
this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
}
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
let currentTokenCount;
if (isChatGptModel) {
currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
} else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
}
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
const context = [];
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && messages.length > 0) {
const message = messages.pop();
const roleLabel =
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
? this.userLabel
: this.chatGptLabel;
const messageString = `${this.startToken}${roleLabel}:\n${
message?.text ?? message?.message
}${this.endToken}\n`;
let newPromptBody;
if (promptBody || isChatGptModel) {
newPromptBody = `${messageString}${promptBody}`;
} else {
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
// like "what's the last thing I wrote?".
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}
context.unshift(message);
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = `${promptBody}${promptSuffix}`;
if (isChatGptModel) {
messagePayload.content = prompt;
// Add 3 tokens for Assistant Label priming after all messages have been counted.
currentTokenCount += 3;
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (isChatGptModel) {
return { prompt: [instructionsPayload, messagePayload], context };
}
return { prompt, context, promptTokens: currentTokenCount };
}
getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length;
}
/**
* Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
*
* @param {Object} message
*/
getTokenCountForMessage(message) {
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let tokensPerMessage = 3;
let tokensPerName = 1;
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
tokensPerMessage = 4;
tokensPerName = -1;
}
let numTokens = tokensPerMessage;
for (let [key, value] of Object.entries(message)) {
numTokens += this.getTokenCount(value);
if (key === 'name') {
numTokens += tokensPerName;
}
}
return numTokens;
}
}
module.exports = ChatGPTClient;

View File

@@ -1,7 +1,6 @@
const { google } = require('googleapis');
const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
@@ -12,14 +11,15 @@ const {
endpointSettings,
parseTextParts,
EModelEndpoint,
googleSettings,
ContentTypes,
VisionModes,
ErrorTypes,
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
const { encodeAndFormat } = require('~/server/services/Files/images');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
@@ -34,8 +34,7 @@ const BaseClient = require('./BaseClient');
const loc = process.env.GOOGLE_LOC || 'us-central1';
const publisher = 'google';
const endpointPrefix =
loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`;
const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
const settings = endpointSettings[EModelEndpoint.google];
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
@@ -166,16 +165,6 @@ class GoogleClient extends BaseClient {
);
}
// Add thinking configuration
this.modelOptions.thinkingConfig = {
thinkingBudget:
(this.modelOptions.thinking ?? googleSettings.thinking.default)
? this.modelOptions.thinkingBudget
: 0,
};
delete this.modelOptions.thinking;
delete this.modelOptions.thinkingBudget;
this.sender =
this.options.sender ??
getResponseSender({
@@ -247,11 +236,11 @@ class GoogleClient extends BaseClient {
msg.content = (
!Array.isArray(msg.content)
? [
{
type: ContentTypes.TEXT,
[ContentTypes.TEXT]: msg.content,
},
]
{
type: ContentTypes.TEXT,
[ContentTypes.TEXT]: msg.content,
},
]
: msg.content
).concat(message.image_urls);

View File

@@ -1,11 +1,10 @@
const { z } = require('zod');
const axios = require('axios');
const { Ollama } = require('ollama');
const { sleep } = require('@librechat/agents');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const { deriveBaseURL } = require('~/utils');
const { deriveBaseURL, logAxiosError } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const ollamaPayloadSchema = z.object({
mirostat: z.number().optional(),
@@ -68,7 +67,7 @@ class OllamaClient {
return models;
} catch (error) {
const logMessage =
"Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive).";
'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).';
logAxiosError({ message: logMessage, error });
return [];
}

View File

@@ -1,21 +1,13 @@
const { OllamaClient } = require('./OllamaClient');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents');
const {
isEnabled,
Tokenizer,
createFetch,
resolveHeaders,
constructAzureURL,
genAzureChatCompletion,
createStreamEventHandlers,
} = require('@librechat/api');
const {
Constants,
ImageDetail,
ContentTypes,
parseTextParts,
EModelEndpoint,
resolveHeaders,
KnownEndpoints,
openAISettings,
ImageDetailCost,
@@ -24,6 +16,13 @@ const {
validateVisionModel,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const {
extractBaseURL,
constructAzureURL,
getModelMaxTokens,
genAzureChatCompletion,
getModelMaxOutputTokens,
} = require('~/utils');
const {
truncateText,
formatMessage,
@@ -31,12 +30,14 @@ const {
titleInstruction,
createContextHandlers,
} = require('./prompts');
const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { addSpaceIfNeeded, sleep } = require('~/server/utils');
const { createFetch, createStreamEventHandlers } = require('./generators');
const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
@@ -46,6 +47,12 @@ const { logger } = require('~/config');
class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
@@ -372,12 +379,23 @@ class OpenAIClient extends BaseClient {
return files;
}
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
async buildMessages(
messages,
parentMessageId,
{ isChatCompletion = false, promptPrefix = null },
opts,
) {
let orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
summary: this.shouldSummarize,
});
if (!isChatCompletion) {
return await this.buildPrompt(orderedMessages, {
isChatGptModel: isChatCompletion,
promptPrefix,
});
}
let payload;
let instructions;
@@ -1141,7 +1159,6 @@ ${convo}
logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions });
const opts = {
baseURL,
fetchOptions: {},
};
if (this.useOpenRouter) {
@@ -1160,7 +1177,7 @@ ${convo}
}
if (this.options.proxy) {
opts.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy);
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
/** @type {TAzureConfig | undefined} */
@@ -1378,7 +1395,7 @@ ${convo}
...modelOptions,
stream: true,
};
const stream = await openai.chat.completions
const stream = await openai.beta.chat.completions
.stream(params)
.on('abort', () => {
/* Do nothing here */

View File

@@ -0,0 +1,542 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { processFileURL } = require('~/server/services/Files/process');
const { EModelEndpoint } = require('librechat-data-provider');
const { checkBalance } = require('~/models/balanceMethods');
const { formatLangChainMessages } = require('./prompts');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { logger } = require('~/config');
class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.sender = options.sender ?? 'Assistant';
this.tools = [];
this.actions = [];
this.setOptions(options);
this.openAIApiKey = this.apiKey;
this.executor = null;
}
setOptions(options) {
this.agentOptions = { ...options.agentOptions };
this.functionsAgent = this.agentOptions?.agent === 'functions';
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
super.setOptions(options);
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) {
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
}
}
getSaveOptions() {
return {
artifacts: this.options.artifacts,
chatGptLabel: this.options.chatGptLabel,
modelLabel: this.options.modelLabel,
promptPrefix: this.options.promptPrefix,
tools: this.options.tools,
...this.modelOptions,
agentOptions: this.agentOptions,
iconURL: this.options.iconURL,
greeting: this.options.greeting,
spec: this.options.spec,
};
}
saveLatestAction(action) {
this.actions.push(action);
}
getFunctionModelName(input) {
if (/-(?!0314)\d{4}/.test(input)) {
return input;
} else if (input.includes('gpt-3.5-turbo')) {
return 'gpt-3.5-turbo';
} else if (input.includes('gpt-4')) {
return 'gpt-4';
} else {
return 'gpt-3.5-turbo';
}
}
getBuildMessagesOptions(opts) {
return {
isChatCompletion: true,
promptPrefix: opts.promptPrefix,
abortController: opts.abortController,
};
}
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = {
modelName: this.agentOptions.model,
temperature: this.agentOptions.temperature,
};
const model = this.initializeLLM({
...modelOptions,
context: 'plugins',
initialMessageCount: this.currentMessages.length + 1,
});
logger.debug(
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
);
// Map Messages to Langchain format
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
userName: this.options?.name,
});
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
const memory = new BufferMemory({
llm: model,
chatHistory: new ChatMessageHistory(pastMessages),
});
const { loadedTools } = await loadTools({
user,
model,
tools: this.options.tools,
functions: this.functionsAgent,
options: {
memory,
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
fileStrategy: this.options.req.app.locals.fileStrategy,
processFileURL,
message,
},
useSpecs: true,
});
if (loadedTools.length === 0) {
return;
}
this.tools = loadedTools;
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
logger.debug(
'[PluginsClient] Loaded Tools',
this.tools.map((tool) => tool.name),
);
const handleAction = (action, runId, callback = null) => {
this.saveLatestAction(action);
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
if (typeof callback === 'function') {
callback(action, runId);
}
};
// initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
let customInstructions = (this.options.promptPrefix ?? '').trim();
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
}
this.executor = await initializer({
model,
signal,
pastMessages,
tools: this.tools,
customInstructions,
verbose: this.options.debug,
returnIntermediateSteps: true,
customName: this.options.chatGptLabel,
currentDateString: this.currentDateString,
callbackManager: CallbackManager.fromHandlers({
async handleAgentAction(action, runId) {
handleAction(action, runId, onAgentAction);
},
async handleChainEnd(action) {
if (typeof onChainEnd === 'function') {
onChainEnd(action);
}
},
}),
});
logger.debug('[PluginsClient] Loaded agent.');
}
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
let errorMessage = '';
const maxAttempts = 1;
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
const errorInput = buildErrorInput({
message,
errorMessage,
actions: this.actions,
functionsAgent: this.functionsAgent,
});
const input = attempts > 1 ? errorInput : message;
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
if (errorMessage.length > 0) {
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
}
try {
this.result = await this.executor.call({ input, signal }, [
{
async handleToolStart(...args) {
await onToolStart(...args);
},
async handleToolEnd(...args) {
await onToolEnd(...args);
},
async handleLLMEnd(output) {
const { generations } = output;
const { text } = generations[0][0];
if (text && typeof stream === 'function') {
await stream(text);
}
},
},
]);
break; // Exit the loop if the function call is successful
} catch (err) {
logger.error('[PluginsClient] executorCall error:', err);
if (attempts === maxAttempts) {
const { run } = this.runManager.getRunByConversationId(this.conversationId);
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
this.result.output = run && run.error ? run.error : defaultOutput;
this.result.errorMessage = run && run.error ? run.error : err.message;
this.result.intermediateSteps = this.actions;
break;
}
}
}
}
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
output,
errorMessage,
...result,
});
const { error } = responseMessage;
if (!error) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
}
// Record usage only when completion is skipped as it is already recorded in the agent phase.
if (!this.agentOptions.skipCompletion && !error) {
await this.recordTokenUsage(responseMessage);
}
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result, databasePromise };
}
async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);
}
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
} = await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
this.currentMessages.push(userMessage);
let {
prompt: payload,
tokenCountMap,
promptTokens,
} = await this.buildMessages(
this.currentMessages,
userMessage.messageId,
this.getBuildMessagesOptions({
promptPrefix: null,
abortController: this.abortController,
}),
);
if (tokenCountMap) {
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
}
this.handleTokenCountMap(tokenCountMap);
}
this.result = {};
if (payload) {
this.currentMessages = payload;
}
if (!this.skipSaveUserMessage) {
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise,
});
}
}
const balance = this.options.req?.app?.locals?.balance;
if (balance?.enabled) {
await checkBalance({
req: this.options.req,
res: this.options.res,
txData: {
user: this.user,
tokenType: 'prompt',
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: EModelEndpoint.openAI,
},
});
}
const responseMessage = {
endpoint: EModelEndpoint.gptPlugins,
iconURL: this.options.iconURL,
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
isCreatedByUser: false,
model: this.modelOptions.model,
sender: this.sender,
promptTokens,
};
await this.initialize({
user,
message,
onAgentAction,
onChainEnd,
signal: this.abortController.signal,
onProgress: opts.onProgress,
});
// const stream = async (text) => {
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
// };
await this.executorCall(message, {
signal: this.abortController.signal,
// stream,
onToolStart,
onToolEnd,
});
// If message was aborted mid-generation
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
responseMessage.text = 'Cancelled.';
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
// If error occurred during generation (likely token_balance)
if (this.result?.errorMessage?.length > 0) {
responseMessage.error = true;
responseMessage.text = this.result.output;
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
const partialText = opts.getPartialText();
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
responseMessage.text =
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output) {
responseMessage.text = this.result.output;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
const promptPrefix = buildPromptPrefix({
result: this.result,
message,
functionsAgent: this.functionsAgent,
});
logger.debug('[PluginsClient]', { promptPrefix });
payload = await this.buildCompletionPrompt({
messages: this.currentMessages,
promptPrefix,
});
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
responseMessage.text = await this.sendCompletion(payload, opts);
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
const orderedMessages = messages;
let promptPrefix = _promptPrefix.trim();
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
if (this.isGpt3) {
instructionsPayload.role = 'user';
messagePayload.role = 'user';
instructionsPayload.content += `\n${promptSuffix}`;
}
// testing if this works with browser endpoint
if (!this.isGpt3 && this.options.reverseProxyUrl) {
instructionsPayload.role = 'user';
}
let currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop();
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${
message.text ?? message.content ?? ''
}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setTimeout(resolve, 0));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = promptBody;
messagePayload.content = prompt;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
if (this.isGpt3 && messagePayload.content.length > 0) {
const context = 'Chat History:\n';
messagePayload.content = `${context}${prompt}`;
currentTokenCount += this.getTokenCount(context);
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (this.isGpt3) {
messagePayload.content += promptSuffix;
return [instructionsPayload, messagePayload];
}
const result = [messagePayload, instructionsPayload];
if (this.functionsAgent && !this.isGpt3) {
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
}
return result.filter((message) => message.content.length > 0);
}
}
module.exports = PluginsClient;

View File

@@ -0,0 +1,71 @@
const fetch = require('node-fetch');
const { GraphEvents } = require('@librechat/agents');
const { logger, sendEvent } = require('~/config');
const { sleep } = require('~/server/utils');
/**
* Makes a function to make HTTP request and logs the process.
* @param {Object} params
* @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint.
* @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) {
/**
* Makes an HTTP request and logs the process.
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
* @param {RequestInit} [init] - Optional init options for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
return async (_url, init) => {
let url = _url;
if (directEndpoint) {
url = reverseProxyUrl;
}
logger.debug(`Making request to ${url}`);
if (typeof Bun !== 'undefined') {
return await fetch(url, init);
}
return await fetch(url, init);
};
}
// Add this at the module level outside the class
/**
* Creates event handlers for stream events that don't capture client references
* @param {Object} res - The response object to send events to
* @returns {Object} Object containing handler functions
*/
function createStreamEventHandlers(res) {
return {
[GraphEvents.ON_RUN_STEP]: (event) => {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_MESSAGE_DELTA]: (event) => {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_REASONING_DELTA]: (event) => {
if (res) {
sendEvent(res, event);
}
},
};
}
function createHandleLLMNewToken(streamRate) {
return async () => {
if (streamRate) {
await sleep(streamRate);
}
};
}
module.exports = {
createFetch,
createHandleLLMNewToken,
createStreamEventHandlers,
};

View File

@@ -1,11 +1,15 @@
const ChatGPTClient = require('./ChatGPTClient');
const OpenAIClient = require('./OpenAIClient');
const PluginsClient = require('./PluginsClient');
const GoogleClient = require('./GoogleClient');
const TextStream = require('./TextStream');
const AnthropicClient = require('./AnthropicClient');
const toolUtils = require('./tools/util');
module.exports = {
ChatGPTClient,
OpenAIClient,
PluginsClient,
GoogleClient,
TextStream,
AnthropicClient,

View File

@@ -1,5 +1,6 @@
const { ChatOpenAI } = require('@langchain/openai');
const { isEnabled, sanitizeModelName, constructAzureURL } = require('@librechat/api');
const { sanitizeModelName, constructAzureURL } = require('~/utils');
const { isEnabled } = require('~/server/utils');
/**
* Creates a new instance of a language model (LLM) for chat interactions.

View File

@@ -1,7 +1,6 @@
const axios = require('axios');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const footer = `Use the context as your learned knowledge to better answer the user.
@@ -19,7 +18,7 @@ function createContextHandlers(req, userMessageContent) {
const queryPromises = [];
const processedFiles = [];
const processedIds = new Set();
const jwtToken = generateShortLivedToken(req.user.id);
const jwtToken = req.headers.authorization.split(' ')[1];
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
const query = async (file) => {
@@ -97,35 +96,35 @@ function createContextHandlers(req, userMessageContent) {
resolvedQueries.length === 0
? '\n\tThe semantic search did not return any results.'
: resolvedQueries
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
const generateContext = (currentContext) =>
`
const generateContext = (currentContext) =>
`
<file>
<filename>${file.filename}</filename>
<context>${currentContext}
</context>
</file>`;
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
<contextItem>
<![CDATA[${pageContent?.trim()}]]>
</contextItem>`;
})
.join('');
})
.join('');
return generateContext(contextItems);
})
.join('');
return generateContext(contextItems);
})
.join('');
if (useFullContext) {
const prompt = `${header}

View File

@@ -309,7 +309,7 @@ describe('AnthropicClient', () => {
};
client.setOptions({ modelOptions, promptCache: true });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeUndefined();
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
});
it('should not add beta header for other models', () => {
@@ -320,7 +320,7 @@ describe('AnthropicClient', () => {
},
});
const anthropicClient = client.getClient();
expect(anthropicClient._options.defaultHeaders).toBeUndefined();
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
});
});

View File

@@ -33,9 +33,7 @@ jest.mock('~/models', () => ({
const { getConvo, saveConvo } = require('~/models');
jest.mock('@librechat/agents', () => {
const { Providers } = jest.requireActual('@librechat/agents');
return {
Providers,
ChatOpenAI: jest.fn().mockImplementation(() => {
return {};
}),
@@ -422,46 +420,6 @@ describe('BaseClient', () => {
expect(response).toEqual(expectedResult);
});
test('should replace responseMessageId with new UUID when isRegenerate is true and messageId ends with underscore', async () => {
const mockCrypto = require('crypto');
const newUUID = 'new-uuid-1234';
jest.spyOn(mockCrypto, 'randomUUID').mockReturnValue(newUUID);
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe(newUUID);
expect(TestClient.responseMessageId).not.toBe('existing-message-id_');
mockCrypto.randomUUID.mockRestore();
});
test('should not replace responseMessageId when isRegenerate is false', async () => {
const opts = {
isRegenerate: false,
responseMessageId: 'existing-message-id_',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id_');
});
test('should not replace responseMessageId when it does not end with underscore', async () => {
const opts = {
isRegenerate: true,
responseMessageId: 'existing-message-id',
};
await TestClient.setMessageOptions(opts);
expect(TestClient.responseMessageId).toBe('existing-message-id');
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {

View File

@@ -531,6 +531,44 @@ describe('OpenAIClient', () => {
});
});
describe('sendMessage/getCompletion/chatCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
const model = 'text-davinci-003';
const onProgress = jest.fn().mockImplementation(() => ({}));
const testClient = new OpenAIClient('test-api-key', {
...defaultOptions,
modelOptions: { model },
});
const getCompletion = jest.spyOn(testClient, 'getCompletion');
await testClient.sendMessage('Hi mom!', { onProgress });
expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);
// Check if the first argument (url) is correct
const firstCallArgs = fetchEventSource.mock.calls[0];
const expectedURL = 'https://api.openai.com/v1/completions';
expect(firstCallArgs[0]).toBe(expectedURL);
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).toHaveProperty('model');
expect(requestBody.model).toBe(model);
});
});
describe('checkVisionRequest functionality', () => {
let client;
const attachments = [{ type: 'image/png' }];

View File

@@ -0,0 +1,314 @@
const crypto = require('crypto');
const { Constants } = require('librechat-data-provider');
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
const PluginsClient = require('../PluginsClient');
jest.mock('~/db/connect');
jest.mock('~/models/Conversation', () => {
return function () {
return {
save: jest.fn(),
deleteConvos: jest.fn(),
};
};
});
const defaultAzureOptions = {
azureOpenAIApiInstanceName: 'your-instance-name',
azureOpenAIApiDeploymentName: 'your-deployment-name',
azureOpenAIApiVersion: '2020-07-01-preview',
};
describe('PluginsClient', () => {
let TestAgent;
let options = {
tools: [],
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
};
let parentMessageId;
let conversationId;
const fakeMessages = [];
const userMessage = 'Hello, ChatGPT!';
const apiKey = 'fake-api-key';
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, options);
TestAgent.loadHistory = jest
.fn()
.mockImplementation((conversationId, parentMessageId = null) => {
if (!conversationId) {
TestAgent.currentMessages = [];
return Promise.resolve([]);
}
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
messages: fakeMessages,
parentMessageId,
});
const chatMessages = orderedMessages.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanMessage(msg.text)
: new AIMessage(msg.text),
);
TestAgent.currentMessages = orderedMessages;
return Promise.resolve(chatMessages);
});
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
if (opts && typeof opts === 'object') {
TestAgent.setOptions(opts);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
this.pastMessages = await TestAgent.loadHistory(
conversationId,
TestAgent.options?.parentMessageId,
);
const userMessage = {
text: message,
sender: 'ChatGPT',
isCreatedByUser: true,
messageId: userMessageId,
parentMessageId,
conversationId,
};
const response = {
sender: 'ChatGPT',
text: 'Hello, User!',
isCreatedByUser: false,
messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId,
conversationId,
};
fakeMessages.push(userMessage);
fakeMessages.push(response);
return response;
});
});
test('initializes PluginsClient without crashing', () => {
expect(TestAgent).toBeInstanceOf(PluginsClient);
});
test('check setOptions function', () => {
expect(TestAgent.agentIsGpt3).toBe(true);
});
describe('sendMessage', () => {
test('sendMessage should return a response message', async () => {
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: expect.any(String),
});
const response = await TestAgent.sendMessage(userMessage);
parentMessageId = response.messageId;
conversationId = response.conversationId;
expect(response).toEqual(expectedResult);
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {
conversationId,
parentMessageId,
};
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: opts.conversationId,
});
const response = await TestAgent.sendMessage(userMessage, opts);
parentMessageId = response.messageId;
expect(response.conversationId).toEqual(conversationId);
expect(response).toEqual(expectedResult);
});
test('should return chat history', async () => {
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
expect(TestAgent.currentMessages).toHaveLength(4);
expect(chatMessages[0].text).toEqual(userMessage);
});
});
describe('getFunctionModelName', () => {
let client;
beforeEach(() => {
client = new PluginsClient('dummy_api_key');
});
test('should return the input when it includes a dash followed by four digits', () => {
expect(client.getFunctionModelName('-1234')).toBe('-1234');
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
});
test('should return the input for all function-capable models (`0613` models and above)', () => {
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
});
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-4" when the input includes "gpt-4"', () => {
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
});
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
});
});
describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins
// let client;
// beforeEach(() => {
// client = new PluginsClient('dummy_api_key');
// });
test('should not call getFunctionModelName when azure options are set', () => {
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
const model = 'gpt-4-turbo';
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
const testClient = new PluginsClient('dummy_api_key', {
agentOptions: {
model,
agent: 'functions',
},
azure: defaultAzureOptions,
});
expect(spy).not.toHaveBeenCalled();
expect(testClient.agentOptions.model).toBe(model);
spy.mockRestore();
});
});
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
});

View File

@@ -0,0 +1,184 @@
require('dotenv').config();
const fs = require('fs');
const { z } = require('zod');
const path = require('path');
const yaml = require('js-yaml');
const { createOpenAPIChain } = require('langchain/chains');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { ChatPromptTemplate, HumanMessagePromptTemplate } = require('@langchain/core/prompts');
const { logger } = require('~/config');
function addLinePrefix(text, prefix = '// ') {
return text
.split('\n')
.map((line) => prefix + line)
.join('\n');
}
function createPrompt(name, functions) {
const prefix = `// The ${name} tool has the following functions. Determine the desired or most optimal function for the user's query:`;
const functionDescriptions = functions
.map((func) => `// - ${func.name}: ${func.description}`)
.join('\n');
return `${prefix}\n${functionDescriptions}
// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function.
// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`;
}
const AuthBearer = z
.object({
type: z.string().includes('service_http'),
authorization_type: z.string().includes('bearer'),
verification_tokens: z.object({
openai: z.string(),
}),
})
.catch(() => false);
const AuthDefinition = z
.object({
type: z.string(),
authorization_type: z.string(),
verification_tokens: z.object({
openai: z.string(),
}),
})
.catch(() => false);
async function readSpecFile(filePath) {
try {
const fileContents = await fs.promises.readFile(filePath, 'utf8');
if (path.extname(filePath) === '.json') {
return JSON.parse(fileContents);
}
return yaml.load(fileContents);
} catch (e) {
logger.error('[readSpecFile] error', e);
return false;
}
}
async function getSpec(url) {
const RegularUrl = z
.string()
.url()
.catch(() => false);
if (RegularUrl.parse(url) && path.extname(url) === '.json') {
const response = await fetch(url);
return await response.json();
}
const ValidSpecPath = z
.string()
.url()
.catch(async () => {
const spec = path.join(__dirname, '..', '.well-known', 'openapi', url);
if (!fs.existsSync(spec)) {
return false;
}
return await readSpecFile(spec);
});
return ValidSpecPath.parse(url);
}
async function createOpenAPIPlugin({ data, llm, user, message, memory, signal }) {
let spec;
try {
spec = await getSpec(data.api.url);
} catch (error) {
logger.error('[createOpenAPIPlugin] getSpec error', error);
return null;
}
if (!spec) {
logger.warn('[createOpenAPIPlugin] No spec found');
return null;
}
const headers = {};
const { auth, name_for_model, description_for_model, description_for_human } = data;
if (auth && AuthDefinition.parse(auth)) {
logger.debug('[createOpenAPIPlugin] auth detected', auth);
const { openai } = auth.verification_tokens;
if (AuthBearer.parse(auth)) {
headers.authorization = `Bearer ${openai}`;
logger.debug('[createOpenAPIPlugin] added auth bearer', headers);
}
}
const chainOptions = { llm };
if (data.headers && data.headers['librechat_user_id']) {
logger.debug('[createOpenAPIPlugin] id detected', headers);
headers[data.headers['librechat_user_id']] = user;
}
if (Object.keys(headers).length > 0) {
logger.debug('[createOpenAPIPlugin] headers detected', headers);
chainOptions.headers = headers;
}
if (data.params) {
logger.debug('[createOpenAPIPlugin] params detected', data.params);
chainOptions.params = data.params;
}
let history = '';
if (memory) {
logger.debug('[createOpenAPIPlugin] openAPI chain: memory detected', memory);
const { history: chat_history } = await memory.loadMemoryVariables({});
history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : '';
}
chainOptions.prompt = ChatPromptTemplate.fromMessages([
HumanMessagePromptTemplate.fromTemplate(
`# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix(
description_for_model,
)}${history}`,
),
]);
const chain = await createOpenAPIChain(spec, chainOptions);
const { functions } = chain.chains[0].lc_kwargs.llmKwargs;
return new DynamicStructuredTool({
name: name_for_model,
description_for_model: `${addLinePrefix(description_for_human)}${createPrompt(
name_for_model,
functions,
)}`,
description: `${description_for_human}`,
schema: z.object({
func: z
.string()
.describe(
`The function to invoke. The functions available are: ${functions
.map((func) => func.name)
.join(', ')}`,
),
intent: z
.string()
.describe('Describe your intent with the function and your expected result'),
}),
func: async ({ func = '', intent = '' }) => {
const filteredFunctions = functions.filter((f) => f.name === func);
chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions;
const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`;
const result = await chain.call({
query,
signal,
});
return result.response;
},
});
}
module.exports = {
getSpec,
readSpecFile,
createOpenAPIPlugin,
};

View File

@@ -0,0 +1,72 @@
const fs = require('fs');
const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin');
global.fetch = jest.fn().mockImplementationOnce(() => {
return new Promise((resolve) => {
resolve({
ok: true,
json: () => Promise.resolve({ key: 'value' }),
});
});
});
jest.mock('fs', () => ({
promises: {
readFile: jest.fn(),
},
existsSync: jest.fn(),
}));
describe('readSpecFile', () => {
it('reads JSON file correctly', async () => {
fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' }));
const result = await readSpecFile('test.json');
expect(result).toEqual({ test: 'value' });
});
it('reads YAML file correctly', async () => {
fs.promises.readFile.mockResolvedValue('test: value');
const result = await readSpecFile('test.yaml');
expect(result).toEqual({ test: 'value' });
});
it('handles error correctly', async () => {
fs.promises.readFile.mockRejectedValue(new Error('test error'));
const result = await readSpecFile('test.json');
expect(result).toBe(false);
});
});
describe('getSpec', () => {
it('fetches spec from url correctly', async () => {
const parsedJson = await getSpec('https://www.instacart.com/.well-known/ai-plugin.json');
const isObject = typeof parsedJson === 'object';
expect(isObject).toEqual(true);
});
it('reads spec from file correctly', async () => {
fs.existsSync.mockReturnValue(true);
fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' }));
const result = await getSpec('test.json');
expect(result).toEqual({ test: 'value' });
});
it('returns false when file does not exist', async () => {
fs.existsSync.mockReturnValue(false);
const result = await getSpec('test.json');
expect(result).toBe(false);
});
});
describe('createOpenAPIPlugin', () => {
it('returns null when getSpec throws an error', async () => {
const result = await createOpenAPIPlugin({ data: { api: { url: 'invalid' } } });
expect(result).toBe(null);
});
it('returns null when no spec is found', async () => {
const result = await createOpenAPIPlugin({});
expect(result).toBe(null);
});
// Add more tests here for different scenarios
});

View File

@@ -8,10 +8,10 @@ const { HttpsProxyAgent } = require('https-proxy-agent');
const { FileContext, ContentTypes } = require('librechat-data-provider');
const { getImageBasename } = require('~/server/services/Files/images');
const extractBaseURL = require('~/utils/extractBaseURL');
const logger = require('~/config/winston');
const { logger } = require('~/config');
const displayMessage =
"DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.';
class DALLE3 extends Tool {
constructor(fields = {}) {
super();

View File

@@ -4,13 +4,12 @@ const { v4 } = require('uuid');
const OpenAI = require('openai');
const FormData = require('form-data');
const { tool } = require('@langchain/core/tools');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { ContentTypes, EImageOutputType } = require('librechat-data-provider');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { extractBaseURL } = require('~/utils');
const { logAxiosError, extractBaseURL } = require('~/utils');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/** Default descriptions for image generation tool */
const DEFAULT_IMAGE_GEN_DESCRIPTION = `
@@ -65,7 +64,7 @@ const DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION = `Describe the changes, enhancement
Always base this prompt on the most recently uploaded reference images.`;
const displayMessage =
"The tool displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.";
'The tool displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.';
/**
* Replaces unwanted characters from the input string
@@ -107,12 +106,6 @@ const getImageEditPromptDescription = () => {
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
};
function createAbortHandler() {
return function () {
logger.debug('[ImageGenOAI] Image generation aborted');
};
}
/**
* Creates OpenAI Image tools (generation and editing)
* @param {Object} fields - Configuration fields
@@ -207,18 +200,10 @@ function createOpenAIImageTools(fields = {}) {
}
let resp;
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
resp = await openai.images.generate(
{
model: 'gpt-image-1',
@@ -242,10 +227,6 @@ function createOpenAIImageTools(fields = {}) {
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
Error Message: ${error.message}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
if (!resp) {
@@ -427,17 +408,10 @@ Error Message: ${error.message}`);
headers['Authorization'] = `Bearer ${apiKey}`;
}
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
/** @type {import('axios').AxiosRequestConfig} */
const axiosConfig = {
@@ -492,10 +466,6 @@ Error Message: ${error.message}`);
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
Error Message: ${error.message || 'Unknown error'}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
},
{

View File

@@ -1,29 +1,10 @@
const OpenAI = require('openai');
const DALLE3 = require('../DALLE3');
const logger = require('~/config/winston');
const { logger } = require('~/config');
jest.mock('openai');
jest.mock('@librechat/data-schemas', () => {
return {
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
};
});
jest.mock('tiktoken', () => {
return {
encoding_for_model: jest.fn().mockReturnValue({
encode: jest.fn(),
decode: jest.fn(),
}),
};
});
const processFileURL = jest.fn();
jest.mock('~/server/services/Files/images', () => ({
@@ -56,11 +37,6 @@ jest.mock('fs', () => {
return {
existsSync: jest.fn(),
mkdirSync: jest.fn(),
promises: {
writeFile: jest.fn(),
readFile: jest.fn(),
unlink: jest.fn(),
},
};
});

View File

@@ -1,35 +1,26 @@
const { z } = require('zod');
const axios = require('axios');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Tools, EToolResources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/**
*
* @param {Object} options
* @param {ServerRequest} options.req
* @param {Agent['tool_resources']} options.tool_resources
* @param {string} [options.agentId] - The agent ID for file access control
* @returns {Promise<{
* files: Array<{ file_id: string; filename: string }>,
* toolContext: string
* }>}
*/
const primeFiles = async (options) => {
const { tool_resources, req, agentId } = options;
const { tool_resources } = options;
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
const agentResourceIds = new Set(file_ids);
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
const dbFiles = (
(await getFiles(
{ file_id: { $in: file_ids } },
null,
{ text: 0 },
{ userId: req?.user?.id, agentId },
)) ?? []
).concat(resourceFiles);
const dbFiles = ((await getFiles({ file_id: { $in: file_ids } })) ?? []).concat(resourceFiles);
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
@@ -68,7 +59,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.';
}
const jwtToken = generateShortLivedToken(req.user.id);
const jwtToken = req.headers.authorization.split(' ')[1];
if (!jwtToken) {
return 'There was an error authenticating the file search request.';
}
@@ -144,7 +135,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
query: z
.string()
.describe(
"A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.",
'A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you\'re looking for. The query will be used for semantic similarity matching against the file contents.',
),
}),
},

View File

@@ -1,14 +1,14 @@
const { mcpToolPattern } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { SerpAPI } = require('@langchain/community/tools/serpapi');
const { Calculator } = require('@langchain/community/tools/calculator');
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
const {
Tools,
Constants,
EToolResources,
loadWebSearchAuth,
replaceSpecialVars,
} = require('librechat-data-provider');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const {
availableTools,
manifestToolMap,
@@ -28,10 +28,11 @@ const {
} = require('../');
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { getCachedTools } = require('~/server/services/Config');
const { createMCPTool } = require('~/server/services/MCP');
const { logger } = require('~/config');
const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`);
/**
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
@@ -92,7 +93,7 @@ const validateTools = async (user, tools = []) => {
return Array.from(validToolsSet.values());
} catch (err) {
logger.error('[validateTools] There was a problem validating tools', err);
throw new Error(err);
throw new Error('There was a problem validating tools');
}
};
@@ -235,7 +236,7 @@ const loadTools = async ({
/** @type {Record<string, string>} */
const toolContextMap = {};
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
const appTools = options.req?.app?.locals?.availableTools ?? {};
for (const tool of tools) {
if (tool === Tools.execute_code) {
@@ -245,13 +246,7 @@ const loadTools = async ({
authFields: [EnvVar.CODE_API_KEY],
});
const codeApiKey = authValues[EnvVar.CODE_API_KEY];
const { files, toolContext } = await primeCodeFiles(
{
...options,
agentId: agent?.id,
},
codeApiKey,
);
const { files, toolContext } = await primeCodeFiles(options, codeApiKey);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -266,10 +261,7 @@ const loadTools = async ({
continue;
} else if (tool === Tools.file_search) {
requestedTools[tool] = async () => {
const { files, toolContext } = await primeSearchFiles({
...options,
agentId: agent?.id,
});
const { files, toolContext } = await primeSearchFiles(options);
if (toolContext) {
toolContextMap[tool] = toolContext;
}
@@ -303,11 +295,10 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
});
};
continue;
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
} else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () =>
createMCPTool({
req: options.req,
res: options.res,
toolKey: tool,
model: agent?.model ?? model,
provider: agent?.provider ?? endpoint,

View File

@@ -1,5 +1,8 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const mockUser = {
_id: 'fakeId',
save: jest.fn(),
findByIdAndDelete: jest.fn(),
};
const mockPluginService = {
updateUserPluginAuth: jest.fn(),
@@ -7,18 +10,29 @@ const mockPluginService = {
getUserPluginAuthValue: jest.fn(),
};
const mockModels = {
User: mockUser,
};
jest.mock('~/db/connect', () => {
return {
connectDb: jest.fn(),
User: mockModels.mockUser,
};
});
jest.mock('~/models/File', () => ({
File: jest.fn(),
}));
jest.mock('~/server/services/PluginService', () => mockPluginService);
const { BaseLLM } = require('@langchain/openai');
const { Calculator } = require('@langchain/community/tools/calculator');
const { User } = require('~/db/models');
const PluginService = require('~/server/services/PluginService');
const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools');
const { StructuredSD, availableTools, DALLE3 } = require('../');
describe('Tool Handlers', () => {
let mongoServer;
let fakeUser;
const pluginKey = 'dalle';
const pluginKey2 = 'wolfram';
@@ -29,9 +43,7 @@ describe('Tool Handlers', () => {
const authConfigs = mainPlugin.authConfig;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
mockUser.save.mockResolvedValue(undefined);
const userAuthValues = {};
mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => {
@@ -46,7 +58,7 @@ describe('Tool Handlers', () => {
},
);
fakeUser = new User({
fakeUser = await mockModels.User.createUser({
name: 'Fake User',
username: 'fakeuser',
email: 'fakeuser@example.com',
@@ -72,36 +84,9 @@ describe('Tool Handlers', () => {
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
// Clear mocks but not the database since we need the user to persist
jest.clearAllMocks();
// Reset the mock implementations
const userAuthValues = {};
mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => {
return userAuthValues[`${userId}-${authField}`];
});
mockPluginService.updateUserPluginAuth.mockImplementation(
(userId, authField, _pluginKey, credential) => {
const fields = authField.split('||');
fields.forEach((field) => {
userAuthValues[`${userId}-${field}`] = credential;
});
},
);
// Re-add the auth configs for the user
await mockUser.findByIdAndDelete(fakeUser._id);
for (const authConfig of authConfigs) {
await PluginService.updateUserPluginAuth(
fakeUser._id,
authConfig.authField,
pluginKey,
mockCredential,
);
await PluginService.deleteUserPluginAuth(fakeUser._id, authConfig.authField);
}
});

View File

@@ -1,8 +1,7 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, math } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, math, removePorts } = require('~/server/utils');
const { deleteAllUserSessions } = require('~/models');
const { removePorts } = require('~/server/utils');
const getLogStores = require('./getLogStores');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};

View File

@@ -1,28 +1,69 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const banViolation = require('./banViolation');
// Mock deleteAllUserSessions since we're testing ban logic, not session deletion
jest.mock('~/models', () => ({
...jest.requireActual('~/models'),
deleteAllUserSessions: jest.fn().mockResolvedValue(true),
const mockModels = {
Session: {
deleteAllUserSessions: jest.fn(),
},
};
jest.mock('~/db/connect', () => {
return {
connectDb: jest.fn(),
get models() {
return mockModels;
},
};
});
jest.mock('~/server/utils', () => ({
isEnabled: jest.fn(() => true), // default to false, override per test if needed
math: jest.fn(() => 20), // default to false, override per test if needed
removePorts: jest.fn(),
}));
jest.mock('keyv');
// jest.mock('../models/Session');
// Mocking the getLogStores function
jest.mock('./getLogStores', () => {
return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const { CacheKeys } = require('librechat-data-provider');
const math = require('../server/utils/math');
const mockGet = jest.fn();
const mockSet = jest.fn();
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = mockGet;
set = mockSet;
}
return new KeyvMongo('', {
namespace: CacheKeys.BANS,
ttl: math(process.env.BAN_DURATION, 7200000),
});
});
});
describe('banViolation', () => {
let mongoServer;
let req, res, errorMessage;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(() => {
req = {
ip: '127.0.0.1',
@@ -35,7 +76,7 @@ describe('banViolation', () => {
};
errorMessage = {
type: 'someViolation',
user_id: new mongoose.Types.ObjectId().toString(), // Use valid ObjectId
user_id: '12345',
prev_count: 0,
violation_count: 0,
};

View File

@@ -1,33 +0,0 @@
const fs = require('fs');
const { math, isEnabled } = require('@librechat/api');
// To ensure that different deployments do not interfere with each other's cache, we use a prefix for the Redis keys.
// This prefix is usually the deployment ID, which is often passed to the container or pod as an env var.
// Set REDIS_KEY_PREFIX_VAR to the env var that contains the deployment ID.
const REDIS_KEY_PREFIX_VAR = process.env.REDIS_KEY_PREFIX_VAR;
const REDIS_KEY_PREFIX = process.env.REDIS_KEY_PREFIX;
if (REDIS_KEY_PREFIX_VAR && REDIS_KEY_PREFIX) {
throw new Error('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
}
const USE_REDIS = isEnabled(process.env.USE_REDIS);
if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
const cacheConfig = {
USE_REDIS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
REDIS_CA: process.env.REDIS_CA ? fs.readFileSync(process.env.REDIS_CA, 'utf8') : null,
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR] || REDIS_KEY_PREFIX || '',
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
CI: isEnabled(process.env.CI),
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
BAN_DURATION: math(process.env.BAN_DURATION, 7200000), // 2 hours
};
module.exports = { cacheConfig };

View File

@@ -1,108 +0,0 @@
const fs = require('fs');
describe('cacheConfig', () => {
let originalEnv;
let originalReadFileSync;
beforeEach(() => {
originalEnv = { ...process.env };
originalReadFileSync = fs.readFileSync;
// Clear all related env vars first
delete process.env.REDIS_URI;
delete process.env.REDIS_CA;
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
// Clear require cache
jest.resetModules();
});
afterEach(() => {
process.env = originalEnv;
fs.readFileSync = originalReadFileSync;
jest.resetModules();
});
describe('REDIS_KEY_PREFIX validation and resolution', () => {
test('should throw error when both REDIS_KEY_PREFIX_VAR and REDIS_KEY_PREFIX are set', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.REDIS_KEY_PREFIX = 'manual-prefix';
expect(() => {
require('./cacheConfig');
}).toThrow('Only either REDIS_KEY_PREFIX_VAR or REDIS_KEY_PREFIX can be set.');
});
test('should resolve REDIS_KEY_PREFIX from variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'DEPLOYMENT_ID';
process.env.DEPLOYMENT_ID = 'test-deployment-123';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('test-deployment-123');
});
test('should use direct REDIS_KEY_PREFIX value', () => {
process.env.REDIS_KEY_PREFIX = 'direct-prefix';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('direct-prefix');
});
test('should default to empty string when no prefix is configured', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle empty variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'EMPTY_VAR';
process.env.EMPTY_VAR = '';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
test('should handle undefined variable reference', () => {
process.env.REDIS_KEY_PREFIX_VAR = 'UNDEFINED_VAR';
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_KEY_PREFIX).toBe('');
});
});
describe('USE_REDIS and REDIS_URI validation', () => {
test('should throw error when USE_REDIS is enabled but REDIS_URI is not set', () => {
process.env.USE_REDIS = 'true';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
test('should not throw error when USE_REDIS is enabled and REDIS_URI is set', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
expect(() => {
require('./cacheConfig');
}).not.toThrow();
});
test('should handle empty REDIS_URI when USE_REDIS is enabled', () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = '';
expect(() => {
require('./cacheConfig');
}).toThrow('USE_REDIS is enabled but REDIS_URI is not set.');
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', () => {
const { cacheConfig } = require('./cacheConfig');
expect(cacheConfig.REDIS_CA).toBeNull();
});
});
});

View File

@@ -1,66 +0,0 @@
const KeyvRedis = require('@keyv/redis').default;
const { Keyv } = require('keyv');
const { cacheConfig } = require('./cacheConfig');
const { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } = require('./redisClients');
const { Time } = require('librechat-data-provider');
const ConnectRedis = require('connect-redis').default;
const MemoryStore = require('memorystore')(require('express-session'));
const { violationFile } = require('./keyvFiles');
const { RedisStore } = require('rate-limit-redis');
/**
* Creates a cache instance using Redis or a fallback store. Suitable for general caching needs.
* @param {string} namespace - The cache namespace.
* @param {number} [ttl] - Time to live for cache entries.
* @param {object} [fallbackStore] - Optional fallback store if Redis is not used.
* @returns {Keyv} Cache instance.
*/
const standardCache = (namespace, ttl = undefined, fallbackStore = undefined) => {
if (cacheConfig.USE_REDIS) {
const keyvRedis = new KeyvRedis(keyvRedisClient);
const cache = new Keyv(keyvRedis, { namespace, ttl });
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
return cache;
}
if (fallbackStore) return new Keyv({ store: fallbackStore, namespace, ttl });
return new Keyv({ namespace, ttl });
};
/**
* Creates a cache instance for storing violation data.
* Uses a file-based fallback store if Redis is not enabled.
* @param {string} namespace - The cache namespace for violations.
* @param {number} [ttl] - Time to live for cache entries.
* @returns {Keyv} Cache instance for violations.
*/
const violationCache = (namespace, ttl = undefined) => {
return standardCache(`violations:${namespace}`, ttl, violationFile);
};
/**
* Creates a session cache instance using Redis or in-memory store.
* @param {string} namespace - The session namespace.
* @param {number} [ttl] - Time to live for session entries.
* @returns {MemoryStore | ConnectRedis} Session store instance.
*/
const sessionCache = (namespace, ttl = undefined) => {
namespace = namespace.endsWith(':') ? namespace : `${namespace}:`;
if (!cacheConfig.USE_REDIS) return new MemoryStore({ ttl, checkPeriod: Time.ONE_DAY });
return new ConnectRedis({ client: ioredisClient, ttl, prefix: namespace });
};
/**
* Creates a rate limiter cache using Redis.
* @param {string} prefix - The key prefix for rate limiting.
* @returns {RedisStore|undefined} RedisStore instance or undefined if Redis is not used.
*/
const limiterCache = (prefix) => {
if (!prefix) throw new Error('prefix is required');
if (!cacheConfig.USE_REDIS) return undefined;
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
return new RedisStore({ sendCommand, prefix });
};
const sendCommand = (...args) => ioredisClient?.call(...args);
module.exports = { standardCache, sessionCache, violationCache, limiterCache };

View File

@@ -1,272 +0,0 @@
const { Time } = require('librechat-data-provider');
// Mock dependencies first
const mockKeyvRedis = {
namespace: '',
keyPrefixSeparator: '',
};
const mockKeyv = jest.fn().mockReturnValue({ mock: 'keyv' });
const mockConnectRedis = jest.fn().mockReturnValue({ mock: 'connectRedis' });
const mockMemoryStore = jest.fn().mockReturnValue({ mock: 'memoryStore' });
const mockRedisStore = jest.fn().mockReturnValue({ mock: 'redisStore' });
const mockIoredisClient = {
call: jest.fn(),
};
const mockKeyvRedisClient = {};
const mockViolationFile = {};
// Mock modules before requiring the main module
jest.mock('@keyv/redis', () => ({
default: jest.fn().mockImplementation(() => mockKeyvRedis),
}));
jest.mock('keyv', () => ({
Keyv: mockKeyv,
}));
jest.mock('./cacheConfig', () => ({
cacheConfig: {
USE_REDIS: false,
REDIS_KEY_PREFIX: 'test',
},
}));
jest.mock('./redisClients', () => ({
keyvRedisClient: mockKeyvRedisClient,
ioredisClient: mockIoredisClient,
GLOBAL_PREFIX_SEPARATOR: '::',
}));
jest.mock('./keyvFiles', () => ({
violationFile: mockViolationFile,
}));
jest.mock('connect-redis', () => ({
default: mockConnectRedis,
}));
jest.mock('memorystore', () => jest.fn(() => mockMemoryStore));
jest.mock('rate-limit-redis', () => ({
RedisStore: mockRedisStore,
}));
// Import after mocking
const { standardCache, sessionCache, violationCache, limiterCache } = require('./cacheFactory');
const { cacheConfig } = require('./cacheConfig');
describe('cacheFactory', () => {
beforeEach(() => {
jest.clearAllMocks();
// Reset cache config mock
cacheConfig.USE_REDIS = false;
cacheConfig.REDIS_KEY_PREFIX = 'test';
});
describe('redisCache', () => {
it('should create Redis cache when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(require('@keyv/redis').default).toHaveBeenCalledWith(mockKeyvRedisClient);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl });
expect(mockKeyvRedis.namespace).toBe(cacheConfig.REDIS_KEY_PREFIX);
expect(mockKeyvRedis.keyPrefixSeparator).toBe('::');
});
it('should create Redis cache with undefined ttl when not provided', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'test-namespace';
standardCache(namespace);
expect(mockKeyv).toHaveBeenCalledWith(mockKeyvRedis, { namespace, ttl: undefined });
});
it('should use fallback store when USE_REDIS is false and fallbackStore is provided', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
const fallbackStore = { some: 'store' };
standardCache(namespace, ttl, fallbackStore);
expect(mockKeyv).toHaveBeenCalledWith({ store: fallbackStore, namespace, ttl });
});
it('should create default Keyv instance when USE_REDIS is false and no fallbackStore', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'test-namespace';
const ttl = 3600;
standardCache(namespace, ttl);
expect(mockKeyv).toHaveBeenCalledWith({ namespace, ttl });
});
it('should handle namespace and ttl as undefined', () => {
cacheConfig.USE_REDIS = false;
standardCache();
expect(mockKeyv).toHaveBeenCalledWith({ namespace: undefined, ttl: undefined });
});
});
describe('violationCache', () => {
it('should create violation cache with prefixed namespace', () => {
const namespace = 'test-violations';
const ttl = 7200;
// We can't easily mock the internal redisCache call since it's in the same module
// But we can test that the function executes without throwing
expect(() => violationCache(namespace, ttl)).not.toThrow();
});
it('should create violation cache with undefined ttl', () => {
const namespace = 'test-violations';
violationCache(namespace);
// The function should call redisCache with violations: prefixed namespace
// Since we can't easily mock the internal redisCache call, we test the behavior
expect(() => violationCache(namespace)).not.toThrow();
});
it('should handle undefined namespace', () => {
expect(() => violationCache(undefined)).not.toThrow();
});
});
describe('sessionCache', () => {
it('should return MemoryStore when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockMemoryStore).toHaveBeenCalledWith({ ttl, checkPeriod: Time.ONE_DAY });
expect(result).toBe(mockMemoryStore());
});
it('should return ConnectRedis when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
const ttl = 86400;
const result = sessionCache(namespace, ttl);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl,
prefix: `${namespace}:`,
});
expect(result).toBe(mockConnectRedis());
});
it('should add colon to namespace if not present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should not add colon to namespace if already present', () => {
cacheConfig.USE_REDIS = true;
const namespace = 'sessions:';
sessionCache(namespace);
expect(mockConnectRedis).toHaveBeenCalledWith({
client: mockIoredisClient,
ttl: undefined,
prefix: 'sessions:',
});
});
it('should handle undefined ttl', () => {
cacheConfig.USE_REDIS = false;
const namespace = 'sessions';
sessionCache(namespace);
expect(mockMemoryStore).toHaveBeenCalledWith({
ttl: undefined,
checkPeriod: Time.ONE_DAY,
});
});
});
describe('limiterCache', () => {
it('should return undefined when USE_REDIS is false', () => {
cacheConfig.USE_REDIS = false;
const result = limiterCache('prefix');
expect(result).toBeUndefined();
});
it('should return RedisStore when USE_REDIS is true', () => {
cacheConfig.USE_REDIS = true;
const result = limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: `rate-limit:`,
});
expect(result).toBe(mockRedisStore());
});
it('should add colon to prefix if not present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should not add colon to prefix if already present', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit:');
expect(mockRedisStore).toHaveBeenCalledWith({
sendCommand: expect.any(Function),
prefix: 'rate-limit:',
});
});
it('should pass sendCommand function that calls ioredisClient.call', () => {
cacheConfig.USE_REDIS = true;
limiterCache('rate-limit');
const sendCommandCall = mockRedisStore.mock.calls[0][0];
const sendCommand = sendCommandCall.sendCommand;
// Test that sendCommand properly delegates to ioredisClient.call
const args = ['GET', 'test-key'];
sendCommand(...args);
expect(mockIoredisClient.call).toHaveBeenCalledWith(...args);
});
it('should handle undefined prefix', () => {
cacheConfig.USE_REDIS = true;
expect(() => limiterCache()).toThrow('prefix is required');
});
});
});

View File

@@ -1,52 +1,108 @@
const { cacheConfig } = require('./cacheConfig');
const { Keyv } = require('keyv');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile } = require('./keyvFiles');
const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo');
const { standardCache, sessionCache, violationCache } = require('./cacheFactory');
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
const duration = math(BAN_DURATION, 7200000);
const isRedisEnabled = isEnabled(USE_REDIS);
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
const createViolationInstance = (namespace) => {
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};
// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.PENDING_REQ });
const config = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const roles = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const audioRuns = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const messages = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
const flows = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.FLOWS, ttl: Time.ONE_MINUTE * 3 });
const tokenConfig = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const s3ExpiryInterval = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
const abortKeys = isRedisEnabled
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
const openIdExchangedTokensCache = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.OPENID_EXCHANGED_TOKENS, ttl: Time.TEN_MINUTES });
const namespaces = {
[ViolationTypes.GENERAL]: new Keyv({ store: logFile, namespace: 'violations' }),
[ViolationTypes.LOGINS]: violationCache(ViolationTypes.LOGINS),
[ViolationTypes.CONCURRENT]: violationCache(ViolationTypes.CONCURRENT),
[ViolationTypes.NON_BROWSER]: violationCache(ViolationTypes.NON_BROWSER),
[ViolationTypes.MESSAGE_LIMIT]: violationCache(ViolationTypes.MESSAGE_LIMIT),
[ViolationTypes.REGISTRATIONS]: violationCache(ViolationTypes.REGISTRATIONS),
[ViolationTypes.TOKEN_BALANCE]: violationCache(ViolationTypes.TOKEN_BALANCE),
[ViolationTypes.TTS_LIMIT]: violationCache(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: violationCache(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: violationCache(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: violationCache(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: violationCache(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: violationCache(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: violationCache(ViolationTypes.RESET_PASSWORD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: violationCache(ViolationTypes.ILLEGAL_MODEL_REQUEST),
[ViolationTypes.BAN]: new Keyv({
[CacheKeys.ROLES]: roles,
[CacheKeys.CONFIG_STORE]: config,
[CacheKeys.PENDING_REQ]: pending_req,
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({
store: keyvMongo,
namespace: CacheKeys.BANS,
ttl: cacheConfig.BAN_DURATION,
namespace: CacheKeys.ENCODED_DOMAINS,
ttl: 0,
}),
[CacheKeys.OPENID_SESSION]: sessionCache(CacheKeys.OPENID_SESSION),
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
[CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }),
[CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES),
[CacheKeys.TOKEN_CONFIG]: standardCache(CacheKeys.TOKEN_CONFIG, Time.THIRTY_MINUTES),
[CacheKeys.GEN_TITLE]: standardCache(CacheKeys.GEN_TITLE, Time.TWO_MINUTES),
[CacheKeys.S3_EXPIRY_INTERVAL]: standardCache(CacheKeys.S3_EXPIRY_INTERVAL, Time.THIRTY_MINUTES),
[CacheKeys.MODEL_QUERIES]: standardCache(CacheKeys.MODEL_QUERIES),
[CacheKeys.AUDIO_RUNS]: standardCache(CacheKeys.AUDIO_RUNS, Time.TEN_MINUTES),
[CacheKeys.MESSAGES]: standardCache(CacheKeys.MESSAGES, Time.ONE_MINUTE),
[CacheKeys.FLOWS]: standardCache(CacheKeys.FLOWS, Time.ONE_MINUTE * 3),
[CacheKeys.OPENID_EXCHANGED_TOKENS]: standardCache(
CacheKeys.OPENID_EXCHANGED_TOKENS,
Time.TEN_MINUTES,
general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
registrations: createViolationInstance('registrations'),
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.TOOL_CALL_LIMIT]: createViolationInstance(ViolationTypes.TOOL_CALL_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
ViolationTypes.RESET_PASSWORD_LIMIT,
),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
[CacheKeys.FLOWS]: flows,
[CacheKeys.OPENID_EXCHANGED_TOKENS]: openIdExchangedTokensCache,
};
/**
@@ -55,10 +111,7 @@ const namespaces = {
*/
function getTTLStores() {
return Object.values(namespaces).filter(
(store) =>
store instanceof Keyv &&
parseInt(store.opts?.ttl ?? '0') > 0 &&
!store.opts?.store?.constructor?.name?.includes('Redis'), // Only include non-Redis stores
(store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
);
}
@@ -94,18 +147,18 @@ async function clearExpiredFromCache(cache) {
if (data?.expires && data.expires <= expiryTime) {
const deleted = await cache.opts.store.delete(key);
if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue;
}
cleared++;
}
} catch (error) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
const deleted = await cache.opts.store.delete(key);
if (!deleted) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
continue;
}
@@ -114,7 +167,7 @@ async function clearExpiredFromCache(cache) {
}
if (cleared > 0) {
cacheConfig.DEBUG_MEMORY_CACHE &&
debugMemoryCache &&
console.log(
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
);
@@ -155,7 +208,7 @@ async function clearAllExpiredFromCache() {
}
}
if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
if (!isRedisEnabled && !isEnabled(CI)) {
/** @type {Set<NodeJS.Timeout>} */
const cleanupIntervals = new Set();
@@ -166,7 +219,7 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
cleanupIntervals.add(cleanup);
if (cacheConfig.DEBUG_MEMORY_CACHE) {
if (debugMemoryCache) {
const monitor = setInterval(() => {
const ttlStores = getTTLStores();
const memory = process.memoryUsage();
@@ -187,13 +240,13 @@ if (!cacheConfig.USE_REDIS && !cacheConfig.CI) {
}
const dispose = () => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Cleaning up and shutting down...');
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
cleanupIntervals.forEach((interval) => clearInterval(interval));
cleanupIntervals.clear();
// One final cleanup before exit
clearAllExpiredFromCache().then(() => {
cacheConfig.DEBUG_MEMORY_CACHE && console.log('[Cache] Final cleanup completed');
debugMemoryCache && console.log('[Cache] Final cleanup completed');
process.exit(0);
});
};

92
api/cache/ioredisClient.js vendored Normal file
View File

@@ -0,0 +1,92 @@
const fs = require('fs');
const Redis = require('ioredis');
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env;
/** @type {import('ioredis').Redis | import('ioredis').Cluster} */
let ioredisClient;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
ioredisClient = new Redis.Cluster(hosts, { redisOptions });
} else {
ioredisClient = new Redis(REDIS_URI, redisOptions);
}
ioredisClient.on('ready', () => {
logger.info('IoRedis connection ready');
});
ioredisClient.on('reconnecting', () => {
logger.info('IoRedis connection reconnecting');
});
ioredisClient.on('end', () => {
logger.info('IoRedis connection ended');
});
ioredisClient.on('close', () => {
logger.info('IoRedis connection closed');
});
ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err));
ioredisClient.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.',
);
} else {
logger.info('[Optional] IoRedis not initialized for rate limiters.');
}
module.exports = ioredisClient;

109
api/cache/keyvRedis.js vendored Normal file
View File

@@ -0,0 +1,109 @@
const fs = require('fs');
const ioredis = require('ioredis');
const KeyvRedis = require('@keyv/redis').default;
const { isEnabled } = require('~/server/utils');
const logger = require('~/config/winston');
const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } =
process.env;
let keyvRedis;
const redis_prefix = REDIS_KEY_PREFIX || '';
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
function mapURI(uri) {
const regex =
/^(?:(?<scheme>\w+):\/\/)?(?:(?<user>[^:@]+)(?::(?<password>[^@]+))?@)?(?<host>[\w.-]+)(?::(?<port>\d{1,5}))?$/;
const match = uri.match(regex);
if (match) {
const { scheme, user, password, host, port } = match.groups;
return {
scheme: scheme || 'none',
user: user || null,
password: password || null,
host: host || null,
port: port || null,
};
} else {
const parts = uri.split(':');
if (parts.length === 2) {
return {
scheme: 'none',
user: null,
password: null,
host: parts[0],
port: parts[1],
};
}
return {
scheme: 'none',
user: null,
password: null,
host: uri,
port: null,
};
}
}
if (REDIS_URI && isEnabled(USE_REDIS)) {
let redisOptions = null;
/** @type {import('@keyv/redis').KeyvRedisOptions} */
let keyvOpts = {
useRedisSets: false,
keyPrefix: redis_prefix,
};
if (REDIS_CA) {
const ca = fs.readFileSync(REDIS_CA);
redisOptions = { tls: { ca } };
}
if (isEnabled(USE_REDIS_CLUSTER)) {
const hosts = REDIS_URI.split(',').map((item) => {
var value = mapURI(item);
return {
host: value.host,
port: value.port,
};
});
const cluster = new ioredis.Cluster(hosts, { redisOptions });
keyvRedis = new KeyvRedis(cluster, keyvOpts);
} else {
keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts);
}
const pingInterval = setInterval(
() => {
logger.debug('KeyvRedis ping');
keyvRedis.client.ping().catch((err) => logger.error('Redis keep-alive ping failed:', err));
},
5 * 60 * 1000,
);
keyvRedis.on('ready', () => {
logger.info('KeyvRedis connection ready');
});
keyvRedis.on('reconnecting', () => {
logger.info('KeyvRedis connection reconnecting');
});
keyvRedis.on('end', () => {
logger.info('KeyvRedis connection ended');
});
keyvRedis.on('close', () => {
clearInterval(pingInterval);
logger.info('KeyvRedis connection closed');
});
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
);
} else {
logger.info('[Optional] Redis not initialized.');
}
module.exports = keyvRedis;

View File

@@ -1,5 +1,4 @@
const { isEnabled } = require('~/server/utils');
const { ViolationTypes } = require('librechat-data-provider');
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
@@ -10,14 +9,14 @@ const banViolation = require('./banViolation');
* @param {Object} res - Express response object.
* @param {string} type - The type of violation.
* @param {Object} errorMessage - The error message to log.
* @param {number | string} [score=1] - The severity of the violation. Defaults to 1
* @param {number} [score=1] - The severity of the violation. Defaults to 1
*/
const logViolation = async (req, res, type, errorMessage, score = 1) => {
const userId = req.user?.id ?? req.user?._id;
if (!userId) {
return;
}
const logs = getLogStores(ViolationTypes.GENERAL);
const logs = getLogStores('general');
const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;

View File

@@ -1,57 +0,0 @@
const IoRedis = require('ioredis');
const { cacheConfig } = require('./cacheConfig');
const { createClient, createCluster } = require('@keyv/redis');
const GLOBAL_PREFIX_SEPARATOR = '::';
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri));
const username = urls?.[0].username || cacheConfig.REDIS_USERNAME;
const password = urls?.[0].password || cacheConfig.REDIS_PASSWORD;
const ca = cacheConfig.REDIS_CA;
/** @type {import('ioredis').Redis | import('ioredis').Cluster | null} */
let ioredisClient = null;
if (cacheConfig.USE_REDIS) {
const redisOptions = {
username: username,
password: password,
tls: ca ? { ca } : undefined,
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
};
ioredisClient =
urls.length === 1
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
: new IoRedis.Cluster(cacheConfig.REDIS_URI, { redisOptions });
// Pinging the Redis server every 5 minutes to keep the connection alive
const pingInterval = setInterval(() => ioredisClient.ping(), 5 * 60 * 1000);
ioredisClient.on('close', () => clearInterval(pingInterval));
ioredisClient.on('end', () => clearInterval(pingInterval));
}
/** @type {import('@keyv/redis').RedisClient | import('@keyv/redis').RedisCluster | null} */
let keyvRedisClient = null;
if (cacheConfig.USE_REDIS) {
// ** WARNING ** Keyv Redis client does not support Prefix like ioredis above.
// The prefix feature will be handled by the Keyv-Redis store in cacheFactory.js
const redisOptions = { username, password, socket: { tls: ca != null, ca } };
keyvRedisClient =
urls.length === 1
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
: createCluster({
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
defaults: redisOptions,
});
keyvRedisClient.setMaxListeners(cacheConfig.REDIS_MAX_LISTENERS);
// Pinging the Redis server every 5 minutes to keep the connection alive
const keyvPingInterval = setInterval(() => keyvRedisClient.ping(), 5 * 60 * 1000);
keyvRedisClient.on('disconnect', () => clearInterval(keyvPingInterval));
keyvRedisClient.on('end', () => clearInterval(keyvPingInterval));
}
module.exports = { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };

View File

@@ -1,6 +1,7 @@
const axios = require('axios');
const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { Time, CacheKeys } = require('librechat-data-provider');
const { MCPManager, FlowStateManager } = require('librechat-mcp');
const logger = require('./winston');
global.EventSource = EventSource;
@@ -15,7 +16,7 @@ let flowManager = null;
*/
function getMCPManager(userId) {
if (!mcpManager) {
mcpManager = MCPManager.getInstance();
mcpManager = MCPManager.getInstance(logger);
} else {
mcpManager.checkIdleConnections(userId);
}
@@ -30,13 +31,66 @@ function getFlowStateManager(flowsCache) {
if (!flowManager) {
flowManager = new FlowStateManager(flowsCache, {
ttl: Time.ONE_MINUTE * 3,
logger,
});
}
return flowManager;
}
/**
* Sends message data in Server Sent Events format.
* @param {ServerResponse} res - The server response.
* @param {{ data: string | Record<string, unknown>, event?: string }} event - The message event.
* @param {string} event.event - The type of event.
* @param {string} event.data - The message to be sent.
*/
const sendEvent = (res, event) => {
if (typeof event.data === 'string' && event.data.length === 0) {
return;
}
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
};
/**
* Creates and configures an Axios instance with optional proxy settings.
*
* @typedef {import('axios').AxiosInstance} AxiosInstance
* @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig
*
* @returns {AxiosInstance} A configured Axios instance
* @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL
*/
function createAxiosInstance() {
const instance = axios.create();
if (process.env.proxy) {
try {
const url = new URL(process.env.proxy);
/** @type {AxiosProxyConfig} */
const proxyConfig = {
host: url.hostname.replace(/^\[|\]$/g, ''),
protocol: url.protocol.replace(':', ''),
};
if (url.port) {
proxyConfig.port = parseInt(url.port, 10);
}
instance.defaults.proxy = proxyConfig;
} catch (error) {
console.error('Error parsing proxy URL:', error);
throw new Error(`Invalid proxy URL: ${process.env.proxy}`);
}
}
return instance;
}
module.exports = {
logger,
sendEvent,
getMCPManager,
createAxiosInstance,
getFlowStateManager,
};

View File

@@ -1,6 +1,7 @@
import axios from 'axios';
import { createAxiosInstance } from './axios';
const axios = require('axios');
const { createAxiosInstance } = require('./index');
// Mock axios
jest.mock('axios', () => ({
interceptors: {
request: { use: jest.fn(), eject: jest.fn() },
@@ -19,13 +20,7 @@ jest.mock('axios', () => ({
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
reset: jest.fn().mockImplementation(function (this: {
get: jest.Mock;
post: jest.Mock;
put: jest.Mock;
delete: jest.Mock;
create: jest.Mock;
}) {
reset: jest.fn().mockImplementation(function () {
this.get.mockClear();
this.post.mockClear();
this.put.mockClear();

View File

@@ -1,11 +1,8 @@
const mongoose = require('mongoose');
const { MeiliSearch } = require('meilisearch');
const { logger } = require('@librechat/data-schemas');
const { FlowStateManager } = require('@librechat/api');
const { CacheKeys } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const Conversation = mongoose.models.Conversation;
const Message = mongoose.models.Message;
@@ -31,123 +28,43 @@ class MeiliSearchClient {
}
}
/**
* Performs the actual sync operations for messages and conversations
*/
async function performSync() {
const client = MeiliSearchClient.getInstance();
const { status } = await client.health();
if (status !== 'available') {
throw new Error('Meilisearch not available');
}
if (indexingDisabled === true) {
logger.info('[indexSync] Indexing is disabled, skipping...');
return { messagesSync: false, convosSync: false };
}
let messagesSync = false;
let convosSync = false;
// Check if we need to sync messages
const messageProgress = await Message.getSyncProgress();
if (!messageProgress.isComplete) {
logger.info(
`[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`,
);
// Check if we should do a full sync or incremental
const messageCount = await Message.countDocuments();
const messagesIndexed = messageProgress.totalProcessed;
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
if (messageCount - messagesIndexed > syncThreshold) {
logger.info('[indexSync] Starting full message sync due to large difference');
await Message.syncWithMeili();
messagesSync = true;
} else if (messageCount !== messagesIndexed) {
logger.warn('[indexSync] Messages out of sync, performing incremental sync');
await Message.syncWithMeili();
messagesSync = true;
}
} else {
logger.info(
`[indexSync] Messages are fully synced: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments}`,
);
}
// Check if we need to sync conversations
const convoProgress = await Conversation.getSyncProgress();
if (!convoProgress.isComplete) {
logger.info(
`[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`,
);
const convoCount = await Conversation.countDocuments();
const convosIndexed = convoProgress.totalProcessed;
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
if (convoCount - convosIndexed > syncThreshold) {
logger.info('[indexSync] Starting full conversation sync due to large difference');
await Conversation.syncWithMeili();
convosSync = true;
} else if (convoCount !== convosIndexed) {
logger.warn('[indexSync] Convos out of sync, performing incremental sync');
await Conversation.syncWithMeili();
convosSync = true;
}
} else {
logger.info(
`[indexSync] Conversations are fully synced: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments}`,
);
}
return { messagesSync, convosSync };
}
/**
* Main index sync function that uses FlowStateManager to prevent concurrent execution
*/
async function indexSync() {
if (!searchEnabled) {
return;
}
logger.info('[indexSync] Starting index synchronization check...');
try {
// Get or create FlowStateManager instance
const flowsCache = getLogStores(CacheKeys.FLOWS);
if (!flowsCache) {
logger.warn('[indexSync] Flows cache not available, falling back to direct sync');
return await performSync();
const client = MeiliSearchClient.getInstance();
const { status } = await client.health();
if (status !== 'available') {
throw new Error('Meilisearch not available');
}
const flowManager = new FlowStateManager(flowsCache, {
ttl: 60000 * 10, // 10 minutes TTL for sync operations
});
// Use a unique flow ID for the sync operation
const flowId = 'meili-index-sync';
const flowType = 'MEILI_SYNC';
// This will only execute the handler if no other instance is running the sync
const result = await flowManager.createFlowWithHandler(flowId, flowType, performSync);
if (result.messagesSync || result.convosSync) {
logger.info('[indexSync] Sync completed successfully');
} else {
logger.debug('[indexSync] No sync was needed');
}
return result;
} catch (err) {
if (err.message.includes('flow already exists')) {
logger.info('[indexSync] Sync already running on another instance');
if (indexingDisabled === true) {
logger.info('[indexSync] Indexing is disabled, skipping...');
return;
}
const messageCount = await Message.countDocuments();
const convoCount = await Conversation.countDocuments();
const messages = await client.index('messages').getStats();
const convos = await client.index('convos').getStats();
const messagesIndexed = messages.numberOfDocuments;
const convosIndexed = convos.numberOfDocuments;
logger.debug(`[indexSync] There are ${messageCount} messages and ${messagesIndexed} indexed`);
logger.debug(`[indexSync] There are ${convoCount} convos and ${convosIndexed} indexed`);
if (messageCount !== messagesIndexed) {
logger.debug('[indexSync] Messages out of sync, indexing');
Message.syncWithMeili();
}
if (convoCount !== convosIndexed) {
logger.debug('[indexSync] Convos out of sync, indexing');
Conversation.syncWithMeili();
}
} catch (err) {
if (err.message.includes('not found')) {
logger.debug('[indexSync] Creating indices...');
currentTimeout = setTimeout(async () => {

View File

@@ -1,4 +1,5 @@
const { Action } = require('~/db/models');
const mongoose = require('mongoose');
const Action = require('~/db/models').Action;
/**
* Update an action with new data without overwriting existing properties,

View File

@@ -11,10 +11,10 @@ const {
removeAgentIdsFromProject,
removeAgentFromAllProjects,
} = require('./Project');
const { getCachedTools } = require('~/server/services/Config');
const getLogStores = require('~/cache/getLogStores');
const { getActions } = require('./Action');
const { Agent } = require('~/db/models');
const Agent = require('~/db/models').Agent;
/**
* Create an agent with the provided data.
@@ -56,12 +56,12 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter)
* @param {string} params.agent_id
* @param {string} params.endpoint
* @param {import('@librechat/agents').ClientOptions} [params.model_parameters]
* @returns {Promise<Agent|null>} The agent document as a plain object, or null if not found.
* @returns {Agent|null} The agent document as a plain object, or null if not found.
*/
const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => {
const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => {
const { model, ...model_parameters } = _m;
/** @type {Record<string, FunctionTool>} */
const availableTools = await getCachedTools({ userId: req.user.id, includeGlobal: true });
const availableTools = req.app.locals.availableTools;
/** @type {TEphemeralAgent | null} */
const ephemeralAgent = req.body.ephemeralAgent;
const mcpServers = new Set(ephemeralAgent?.mcp);
@@ -70,9 +70,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
if (ephemeralAgent?.execute_code === true) {
tools.push(Tools.execute_code);
}
if (ephemeralAgent?.file_search === true) {
tools.push(Tools.file_search);
}
if (ephemeralAgent?.web_search === true) {
tools.push(Tools.web_search);
}
@@ -90,7 +87,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
}
const instructions = req.body.promptPrefix;
const result = {
return {
id: agent_id,
instructions,
provider: endpoint,
@@ -98,11 +95,6 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
model,
tools,
};
if (ephemeralAgent?.artifacts != null && ephemeralAgent.artifacts) {
result.artifacts = ephemeralAgent.artifacts;
}
return result;
};
/**
@@ -120,7 +112,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
return null;
}
if (agent_id === EPHEMERAL_AGENT_ID) {
return await loadEphemeralAgent({ req, agent_id, endpoint, model_parameters });
return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters });
}
const agent = await getAgent({
id: agent_id,
@@ -179,6 +171,7 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
'created_at',
'updated_at',
'__v',
'agent_ids',
'versions',
'actionsHash', // Exclude actionsHash from direct comparison
];
@@ -268,12 +261,11 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
* @param {Object} [options] - Optional configuration object.
* @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates).
* @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed.
* @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing).
* @returns {Promise<Agent>} The updated or newly created agent document as a plain object.
* @throws {Error} If the update would create a duplicate version
*/
const updateAgent = async (searchParameter, updateData, options = {}) => {
const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options;
const { updatingUserId = null, forceVersion = false } = options;
const mongoOptions = { new: true, upsert: false };
const currentAgent = await Agent.findOne(searchParameter);
@@ -310,8 +302,10 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
}
const shouldCreateVersion =
!skipVersioning &&
(forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet);
forceVersion ||
(versions &&
versions.length > 0 &&
(Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet));
if (shouldCreateVersion) {
const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash);
@@ -346,7 +340,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId);
}
if (shouldCreateVersion) {
if (shouldCreateVersion || forceVersion) {
updateData.$push = {
...($push || {}),
versions: versionEntry,
@@ -557,10 +551,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds
delete updateQuery.author;
}
const updatedAgent = await updateAgent(updateQuery, updateOps, {
updatingUserId: user.id,
skipVersioning: true,
});
const updatedAgent = await updateAgent(updateQuery, updateOps, { updatingUserId: user.id });
if (updatedAgent) {
return updatedAgent;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,5 @@
const { Assistant } = require('~/db/models');
const mongoose = require('mongoose');
const Assistant = require('~/db/models').Assistant;
/**
* Update an assistant with new data without overwriting existing properties,

View File

@@ -1,5 +1,7 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { Banner } = require('~/db/models');
const Banner = require('~/db/models').Banner;
/**
* Retrieves the current active banner.

View File

@@ -1,8 +1,8 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models');
const Conversation = require('~/db/models').Conversation;
/**
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
@@ -100,15 +100,10 @@ module.exports = {
update.conversationId = newConversationId;
}
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
update.expiredAt = null;
}
if (req.body.isTemporary) {
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
} else {
update.expiredAt = null;
}

View File

@@ -1,5 +1,8 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { ConversationTag, Conversation } = require('~/db/models');
const ConversationTag = require('~/db/models').ConversationTag;
const Conversation = require('~/db/models').Conversation;
/**
* Retrieves all conversation tags for a user.

View File

@@ -1,8 +1,8 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
const { getProjectByName } = require('./Project');
const { getAgent } = require('./Agent');
const { File } = require('~/db/models');
const { EToolResources } = require('librechat-data-provider');
const File = require('~/db/models').File;
/**
* Finds a file by its file_id with additional query options.
@@ -14,119 +14,17 @@ const findFileById = async (file_id, options = {}) => {
return await File.findOne({ file_id, ...options }).lean();
};
/**
* Checks if a user has access to multiple files through a shared agent (batch operation)
* @param {string} userId - The user ID to check access for
* @param {string[]} fileIds - Array of file IDs to check
* @param {string} agentId - The agent ID that might grant access
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
*/
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId) => {
const accessMap = new Map();
// Initialize all files as no access
fileIds.forEach((fileId) => accessMap.set(fileId, false));
try {
const agent = await getAgent({ id: agentId });
if (!agent) {
return accessMap;
}
// Check if user is the author - if so, grant access to all files
if (agent.author.toString() === userId) {
fileIds.forEach((fileId) => accessMap.set(fileId, true));
return accessMap;
}
// Check if agent is shared with the user via projects
if (!agent.projectIds || agent.projectIds.length === 0) {
return accessMap;
}
// Check if agent is in global project
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
if (
!globalProject ||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
) {
return accessMap;
}
// Agent is globally shared - check if it's collaborative
if (!agent.isCollaborative) {
return accessMap;
}
// Agent is globally shared and collaborative - check which files are actually attached
const attachedFileIds = new Set();
if (agent.tool_resources) {
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
}
}
}
// Grant access only to files that are attached to this agent
fileIds.forEach((fileId) => {
if (attachedFileIds.has(fileId)) {
accessMap.set(fileId, true);
}
});
return accessMap;
} catch (error) {
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
return accessMap;
}
};
/**
* Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field.
* @param {Object} [options] - Additional options
* @param {string} [options.userId] - User ID for access control
* @param {string} [options.agentId] - Agent ID that might grant access to files
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
*/
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const sortOptions = { updatedAt: -1, ..._sortOptions };
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
// If userId and agentId are provided, filter files based on access
if (options.userId && options.agentId) {
// Collect file IDs that need access check
const filesToCheck = [];
const ownedFiles = [];
for (const file of files) {
if (file.user && file.user.toString() === options.userId) {
ownedFiles.push(file);
} else {
filesToCheck.push(file);
}
}
if (filesToCheck.length === 0) {
return ownedFiles;
}
// Batch check access for all non-owned files
const fileIds = filesToCheck.map((f) => f.file_id);
const accessMap = await hasAccessToFilesViaAgent(options.userId, fileIds, options.agentId);
// Filter files based on access
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
return [...ownedFiles, ...accessibleFiles];
}
return files;
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
};
/**
@@ -136,19 +34,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, option
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
*/
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
if (!fileIds || !fileIds.length) {
return [];
}
try {
const filter = {
file_id: { $in: fileIds },
$or: [],
};
if (toolResourceSet.has(EToolResources.ocr)) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
if (toolResourceSet.size) {
filter.$or = [];
}
if (toolResourceSet.has(EToolResources.file_search)) {
filter.$or.push({ embedded: true });
}
@@ -280,5 +178,4 @@ module.exports = {
deleteFiles,
deleteFileByFilter,
batchUpdateFiles,
hasAccessToFilesViaAgent,
};

View File

@@ -1,264 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { fileSchema } = require('@librechat/data-schemas');
const { agentSchema } = require('@librechat/data-schemas');
const { projectSchema } = require('@librechat/data-schemas');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { getFiles, createFile } = require('./File');
const { getProjectByName } = require('./Project');
const { createAgent } = require('./Agent');
let File;
let Agent;
let Project;
describe('File Access Control', () => {
let mongoServer;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
File = mongoose.models.File || mongoose.model('File', fileSchema);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await File.deleteMany({});
await Agent.deleteMany({});
await Project.deleteMany({});
});
describe('hasAccessToFilesViaAgent', () => {
it('should efficiently check access for multiple files at once', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
// Create files
for (const fileId of fileIds) {
await createFile({
user: authorId,
file_id: fileId,
filename: `file-${fileId}.txt`,
filepath: `/uploads/${fileId}`,
});
}
// Create agent with only first two files attached
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [fileIds[0], fileIds[1]],
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for all files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have access only to the first two files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(false);
expect(accessMap.get(fileIds[3])).toBe(false);
});
it('should grant access to all files when user is the agent author', async () => {
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
// Create agent
await createAgent({
id: agentId,
name: 'Test Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
tool_resources: {
file_search: {
file_ids: [fileIds[0]], // Only one file attached
},
},
});
// Check access as the author
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
// Author should have access to all files
expect(accessMap.get(fileIds[0])).toBe(true);
expect(accessMap.get(fileIds[1])).toBe(true);
expect(accessMap.get(fileIds[2])).toBe(true);
});
it('should handle non-existent agent gracefully', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const fileIds = [uuidv4(), uuidv4()];
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
// Should have no access to any files
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
it('should deny access when agent is not collaborative', async () => {
const userId = new mongoose.Types.ObjectId().toString();
const authorId = new mongoose.Types.ObjectId().toString();
const agentId = uuidv4();
const fileIds = [uuidv4(), uuidv4()];
// Create agent with files but isCollaborative: false
await createAgent({
id: agentId,
name: 'Non-Collaborative Agent',
author: authorId,
model: 'gpt-4',
provider: 'openai',
isCollaborative: false,
tool_resources: {
file_search: {
file_ids: fileIds,
},
},
});
// Get or create global project
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
// Share agent globally
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
// Check access for files
const { hasAccessToFilesViaAgent } = require('./File');
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
// Should have no access to any files when isCollaborative is false
expect(accessMap.get(fileIds[0])).toBe(false);
expect(accessMap.get(fileIds[1])).toBe(false);
});
});
describe('getFiles with agent access control', () => {
test('should return files owned by user and files accessible through agent', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
const ownedFileId = `file_${uuidv4()}`;
const sharedFileId = `file_${uuidv4()}`;
const inaccessibleFileId = `file_${uuidv4()}`;
// Create/get global project using getProjectByName which will upsert
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
// Create agent with shared file
await createAgent({
id: agentId,
name: 'Shared Agent',
provider: 'test',
model: 'test-model',
author: authorId,
projectIds: [globalProject._id],
isCollaborative: true,
tool_resources: {
file_search: {
file_ids: [sharedFileId],
},
},
});
// Create files
await createFile({
file_id: ownedFileId,
user: userId,
filename: 'owned.txt',
filepath: '/uploads/owned.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: sharedFileId,
user: authorId,
filename: 'shared.txt',
filepath: '/uploads/shared.txt',
type: 'text/plain',
bytes: 200,
embedded: true,
});
await createFile({
file_id: inaccessibleFileId,
user: authorId,
filename: 'inaccessible.txt',
filepath: '/uploads/inaccessible.txt',
type: 'text/plain',
bytes: 300,
});
// Get files with access control
const files = await getFiles(
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
null,
{ text: 0 },
{ userId: userId.toString(), agentId },
);
expect(files).toHaveLength(2);
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
expect(files.map((f) => f.file_id)).not.toContain(inaccessibleFileId);
});
test('should return all files when no userId/agentId provided', async () => {
const userId = new mongoose.Types.ObjectId();
const fileId1 = `file_${uuidv4()}`;
const fileId2 = `file_${uuidv4()}`;
await createFile({
file_id: fileId1,
user: userId,
filename: 'file1.txt',
filepath: '/uploads/file1.txt',
type: 'text/plain',
bytes: 100,
});
await createFile({
file_id: fileId2,
user: new mongoose.Types.ObjectId(),
filename: 'file2.txt',
filepath: '/uploads/file2.txt',
type: 'text/plain',
bytes: 200,
});
const files = await getFiles({ file_id: { $in: [fileId1, fileId2] } });
expect(files).toHaveLength(2);
});
});
});

View File

@@ -1,9 +1,7 @@
const { z } = require('zod');
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { Message } = require('~/db/models');
const Message = require('~/db/models').Message;
const idSchema = z.string().uuid();
/**
@@ -56,14 +54,9 @@ async function saveMessage(req, params, metadata) {
};
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.expiredAt = null;
}
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
} else {
update.expiredAt = null;
}
@@ -260,7 +253,6 @@ async function updateMessage(req, message, metadata) {
text: updatedMessage.text,
isCreatedByUser: updatedMessage.isCreatedByUser,
tokenCount: updatedMessage.tokenCount,
feedback: updatedMessage.feedback,
};
} catch (err) {
logger.error('Error updating message:', err);

View File

@@ -1,7 +1,37 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { v4: uuidv4 } = require('uuid');
const { messageSchema } = require('@librechat/data-schemas');
jest.mock('mongoose');
const mockFindQuery = {
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockReturnThis(),
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
};
const mockSchema = {
findOneAndUpdate: jest.fn(),
updateOne: jest.fn(),
findOne: jest.fn(() => ({
lean: jest.fn(),
})),
find: jest.fn(() => mockFindQuery),
deleteMany: jest.fn(),
};
jest.mock('~/config/winston', () => ({
error: jest.fn(),
}));
const mockModels = {
Message: {
findOneAndUpdate: mockSchema.findOneAndUpdate,
updateOne: mockSchema.updateOne,
findOne: mockSchema.findOne,
find: mockSchema.find,
deleteMany: mockSchema.deleteMany,
},
};
const {
saveMessage,
@@ -10,102 +40,77 @@ const {
deleteMessages,
updateMessageText,
deleteMessagesSince,
} = require('./Message');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IMessage>}
*/
let Message;
} = require('~/models/Message');
describe('Message Operations', () => {
let mongoServer;
let mockReq;
let mockMessageData;
let mockMessage;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
// Clear database
await Message.deleteMany({});
beforeEach(() => {
jest.clearAllMocks();
mockReq = {
user: { id: 'user123' },
};
mockMessageData = {
mockMessage = {
messageId: 'msg123',
conversationId: uuidv4(),
text: 'Hello, world!',
user: 'user123',
};
mockSchema.findOneAndUpdate.mockResolvedValue({
toObject: () => mockMessage,
});
});
describe('saveMessage', () => {
it('should save a message for an authenticated user', async () => {
const result = await saveMessage(mockReq, mockMessageData);
expect(result.messageId).toBe('msg123');
expect(result.user).toBe('user123');
expect(result.text).toBe('Hello, world!');
// Verify the message was actually saved to the database
const savedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
expect(savedMessage).toBeTruthy();
expect(savedMessage.text).toBe('Hello, world!');
const result = await saveMessage(mockReq, mockMessage);
expect(result).toEqual(mockMessage);
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'msg123', user: 'user123' },
expect.objectContaining({ user: 'user123' }),
expect.any(Object),
);
});
it('should throw an error for unauthenticated user', async () => {
mockReq.user = null;
await expect(saveMessage(mockReq, mockMessageData)).rejects.toThrow('User not authenticated');
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
});
it('should handle invalid conversation ID gracefully', async () => {
mockMessageData.conversationId = 'invalid-id';
const result = await saveMessage(mockReq, mockMessageData);
expect(result).toBeUndefined();
it('should throw an error for invalid conversation ID', async () => {
mockMessage.conversationId = 'invalid-id';
await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined();
});
});
describe('updateMessageText', () => {
it('should update message text for the authenticated user', async () => {
// First save a message
await saveMessage(mockReq, mockMessageData);
// Then update it
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
// Verify the update
const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
expect(updatedMessage.text).toBe('Updated text');
expect(mockSchema.updateOne).toHaveBeenCalledWith(
{ messageId: 'msg123', user: 'user123' },
{ text: 'Updated text' },
);
});
});
describe('updateMessage', () => {
it('should update a message for the authenticated user', async () => {
// First save a message
await saveMessage(mockReq, mockMessageData);
mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(result.messageId).toBe('msg123');
expect(result.text).toBe('Updated text');
// Verify in database
const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' });
expect(updatedMessage.text).toBe('Updated text');
expect(result).toEqual(
expect.objectContaining({
messageId: 'msg123',
text: 'Hello, world!',
}),
);
});
it('should throw an error if message is not found', async () => {
mockSchema.findOneAndUpdate.mockResolvedValue(null);
await expect(
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
).rejects.toThrow('Message not found or user not authorized.');
@@ -114,45 +119,19 @@ describe('Message Operations', () => {
describe('deleteMessagesSince', () => {
it('should delete messages only for the authenticated user', async () => {
const conversationId = uuidv4();
// Create multiple messages in the same conversation
const message1 = await saveMessage(mockReq, {
messageId: 'msg1',
conversationId,
text: 'First message',
user: 'user123',
mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
const result = await deleteMessagesSince(mockReq, {
messageId: 'msg123',
conversationId: 'convo123',
});
const message2 = await saveMessage(mockReq, {
messageId: 'msg2',
conversationId,
text: 'Second message',
user: 'user123',
});
const message3 = await saveMessage(mockReq, {
messageId: 'msg3',
conversationId,
text: 'Third message',
user: 'user123',
});
// Delete messages since message2 (this should only delete messages created AFTER msg2)
await deleteMessagesSince(mockReq, {
messageId: 'msg2',
conversationId,
});
// Verify msg1 and msg2 remain, msg3 is deleted
const remainingMessages = await Message.find({ conversationId, user: 'user123' });
expect(remainingMessages).toHaveLength(2);
expect(remainingMessages.map((m) => m.messageId)).toContain('msg1');
expect(remainingMessages.map((m) => m.messageId)).toContain('msg2');
expect(remainingMessages.map((m) => m.messageId)).not.toContain('msg3');
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
expect(mockSchema.find).not.toHaveBeenCalled();
expect(result).toBeUndefined();
});
it('should return undefined if no message is found', async () => {
mockSchema.findOne().lean.mockResolvedValueOnce(null);
const result = await deleteMessagesSince(mockReq, {
messageId: 'nonexistent',
conversationId: 'convo123',
@@ -163,71 +142,29 @@ describe('Message Operations', () => {
describe('getMessages', () => {
it('should retrieve messages with the correct filter', async () => {
const conversationId = uuidv4();
// Save some messages
await saveMessage(mockReq, {
messageId: 'msg1',
conversationId,
text: 'First message',
user: 'user123',
});
await saveMessage(mockReq, {
messageId: 'msg2',
conversationId,
text: 'Second message',
user: 'user123',
});
const messages = await getMessages({ conversationId });
expect(messages).toHaveLength(2);
expect(messages[0].text).toBe('First message');
expect(messages[1].text).toBe('Second message');
const filter = { conversationId: 'convo123' };
await getMessages(filter);
expect(mockSchema.find).toHaveBeenCalledWith(filter);
expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
expect(mockFindQuery.lean).toHaveBeenCalled();
});
});
describe('deleteMessages', () => {
it('should delete messages with the correct filter', async () => {
// Save some messages for different users
await saveMessage(mockReq, mockMessageData);
await saveMessage(
{ user: { id: 'user456' } },
{
messageId: 'msg456',
conversationId: uuidv4(),
text: 'Other user message',
user: 'user456',
},
);
await deleteMessages({ user: 'user123' });
// Verify only user123's messages were deleted
const user123Messages = await Message.find({ user: 'user123' });
const user456Messages = await Message.find({ user: 'user456' });
expect(user123Messages).toHaveLength(0);
expect(user456Messages).toHaveLength(1);
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
});
});
describe('Conversation Hijacking Prevention', () => {
it("should not allow editing a message in another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4();
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
// First, save a message as the victim (but we'll try to edit as attacker)
const victimReq = { user: { id: 'victim123' } };
await saveMessage(victimReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
text: 'Victim message',
user: 'victim123',
});
mockSchema.findOneAndUpdate.mockResolvedValue(null);
// Attacker tries to edit the victim's message
await expect(
updateMessage(attackerReq, {
messageId: victimMessageId,
@@ -236,82 +173,71 @@ describe('Message Operations', () => {
}),
).rejects.toThrow('Message not found or user not authorized.');
// Verify the original message is unchanged
const originalMessage = await Message.findOne({
messageId: victimMessageId,
user: 'victim123',
});
expect(originalMessage.text).toBe('Victim message');
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: victimMessageId, user: 'attacker123' },
expect.anything(),
expect.anything(),
);
});
it("should not allow deleting messages from another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4();
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
// Save a message as the victim
const victimReq = { user: { id: 'victim123' } };
await saveMessage(victimReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
text: 'Victim message',
user: 'victim123',
});
// Attacker tries to delete from victim's conversation
mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
const result = await deleteMessagesSince(attackerReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
});
expect(result).toBeUndefined();
// Verify the victim's message still exists
const victimMessage = await Message.findOne({
expect(mockSchema.findOne).toHaveBeenCalledWith({
messageId: victimMessageId,
user: 'victim123',
user: 'attacker123',
});
expect(victimMessage).toBeTruthy();
expect(victimMessage.text).toBe('Victim message');
});
it("should not allow inserting a new message into another user's conversation", async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4();
const victimConversationId = uuidv4(); // Use a valid UUID
// Attacker tries to save a message - this should succeed but with attacker's user ID
const result = await saveMessage(attackerReq, {
conversationId: victimConversationId,
text: 'Inserted malicious message',
messageId: 'new-msg-123',
user: 'attacker123',
});
await expect(
saveMessage(attackerReq, {
conversationId: victimConversationId,
text: 'Inserted malicious message',
messageId: 'new-msg-123',
}),
).resolves.not.toThrow(); // It should not throw an error
expect(result).toBeTruthy();
expect(result.user).toBe('attacker123');
// Verify the message was saved with the attacker's user ID, not as an anonymous message
const savedMessage = await Message.findOne({ messageId: 'new-msg-123' });
expect(savedMessage.user).toBe('attacker123');
expect(savedMessage.conversationId).toBe(victimConversationId);
// Check that the message was saved with the attacker's user ID
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'new-msg-123', user: 'attacker123' },
expect.objectContaining({
user: 'attacker123',
conversationId: victimConversationId,
}),
expect.anything(),
);
});
it('should allow retrieving messages from any conversation', async () => {
const victimConversationId = uuidv4();
const victimConversationId = 'victim-convo-123';
// Save a message in the victim's conversation
const victimReq = { user: { id: 'victim123' } };
await saveMessage(victimReq, {
messageId: 'victim-msg',
await getMessages({ conversationId: victimConversationId });
expect(mockSchema.find).toHaveBeenCalledWith({
conversationId: victimConversationId,
text: 'Victim message',
user: 'victim123',
});
// Anyone should be able to retrieve messages by conversation ID
const messages = await getMessages({ conversationId: victimConversationId });
expect(messages).toHaveLength(1);
expect(messages[0].text).toBe('Victim message');
mockSchema.find.mockReturnValueOnce({
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
});
const result = await getMessages({ conversationId: victimConversationId });
expect(result).toEqual([{ text: 'Test message' }]);
});
});
});

View File

@@ -1,5 +1,7 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { Preset } = require('~/db/models');
const Preset = require('~/db/models').Preset;
const getPreset = async (user, presetId) => {
try {

View File

@@ -1,5 +1,7 @@
const mongoose = require('mongoose');
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
const { Project } = require('~/db/models');
const Project = require('~/db/models').Project;
/**
* Retrieve a project by ID and convert the found project document to a plain object.

View File

@@ -1,3 +1,4 @@
const mongoose = require('mongoose');
const { ObjectId } = require('mongodb');
const { logger } = require('@librechat/data-schemas');
const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider');
@@ -7,9 +8,11 @@ const {
removeGroupIdsFromProject,
removeGroupFromAllProjects,
} = require('./Project');
const { PromptGroup, Prompt } = require('~/db/models');
const { escapeRegExp } = require('~/server/utils');
const PromptGroup = require('~/db/models').PromptGroup;
const Prompt = require('~/db/models').Prompt;
/**
* Create a pipeline for the aggregation to get prompt groups
* @param {Object} query

View File

@@ -8,7 +8,8 @@ const {
} = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const getLogStores = require('~/cache/getLogStores');
const { Role } = require('~/db/models');
const Role = require('~/db/models').Role;
/**
* Retrieve a role by name and convert the found role document to a plain object.
@@ -170,6 +171,35 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
}
}
/**
* Initialize default roles in the system.
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
* Updates existing roles with new permission types if they're missing.
*
* @returns {Promise<void>}
*/
const initializeRoles = async function () {
for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) {
let role = await Role.findOne({ name: roleName });
const defaultPerms = roleDefaults[roleName].permissions;
if (!role) {
// Create new role if it doesn't exist.
role = new Role(roleDefaults[roleName]);
} else {
// Ensure role.permissions is defined.
role.permissions = role.permissions || {};
// For each permission type in defaults, add it if missing.
for (const permType of Object.keys(defaultPerms)) {
if (role.permissions[permType] == null) {
role.permissions[permType] = defaultPerms[permType];
}
}
}
await role.save();
}
};
/**
* Migrates roles from old schema to new schema structure.
* This can be called directly to fix existing roles.
@@ -251,7 +281,8 @@ const migrateRoleSchema = async function (roleName) {
module.exports = {
getRoleByName,
initializeRoles,
updateRoleByName,
migrateRoleSchema,
updateAccessPermissions,
migrateRoleSchema,
};

View File

@@ -6,10 +6,10 @@ const {
roleDefaults,
PermissionTypes,
} = require('librechat-data-provider');
const { getRoleByName, updateAccessPermissions } = require('~/models/Role');
const { getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role');
const getLogStores = require('~/cache/getLogStores');
const { initializeRoles } = require('~/models');
const { Role } = require('~/db/models');
const Role = require('~/db/models').Role;
// Mock the cache
jest.mock('~/cache/getLogStores', () =>

349
api/models/Share.js Normal file
View File

@@ -0,0 +1,349 @@
const { nanoid } = require('nanoid');
const mongoose = require('mongoose');
const { Constants } = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const { getMessages } = require('./Message');
const Conversation = require('~/db/models').Conversation;
const SharedLink = require('~/db/models').SharedLink;
class ShareServiceError extends Error {
constructor(message, code) {
super(message);
this.name = 'ShareServiceError';
this.code = code;
}
}
const memoizedAnonymizeId = (prefix) => {
const memo = new Map();
return (id) => {
if (!memo.has(id)) {
memo.set(id, `${prefix}_${nanoid()}`);
}
return memo.get(id);
};
};
const anonymizeConvoId = memoizedAnonymizeId('convo');
const anonymizeAssistantId = memoizedAnonymizeId('a');
const anonymizeMessageId = (id) =>
id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id);
function anonymizeConvo(conversation) {
if (!conversation) {
return null;
}
const newConvo = { ...conversation };
if (newConvo.assistant_id) {
newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id);
}
return newConvo;
}
function anonymizeMessages(messages, newConvoId) {
if (!Array.isArray(messages)) {
return [];
}
const idMap = new Map();
return messages.map((message) => {
const newMessageId = anonymizeMessageId(message.messageId);
idMap.set(message.messageId, newMessageId);
const anonymizedAttachments = message.attachments?.map((attachment) => {
return {
...attachment,
messageId: newMessageId,
conversationId: newConvoId,
};
});
return {
...message,
messageId: newMessageId,
parentMessageId:
idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId),
conversationId: newConvoId,
model: message.model?.startsWith('asst_')
? anonymizeAssistantId(message.model)
: message.model,
attachments: anonymizedAttachments,
};
});
}
async function getSharedMessages(shareId) {
try {
const share = await SharedLink.findOne({ shareId, isPublic: true })
.populate({
path: 'messages',
select: '-_id -__v -user',
})
.select('-_id -__v -user')
.lean();
if (!share?.conversationId || !share.isPublic) {
return null;
}
const newConvoId = anonymizeConvoId(share.conversationId);
const result = {
...share,
conversationId: newConvoId,
messages: anonymizeMessages(share.messages, newConvoId),
};
return result;
} catch (error) {
logger.error('[getShare] Error getting share link', {
error: error.message,
shareId,
});
throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR');
}
}
async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortDirection, search) {
try {
const query = { user, isPublic };
if (pageParam) {
if (sortDirection === 'desc') {
query[sortBy] = { $lt: pageParam };
} else {
query[sortBy] = { $gt: pageParam };
}
}
if (search && search.trim()) {
try {
const searchResults = await Conversation.meiliSearch(search);
if (!searchResults?.hits?.length) {
return {
links: [],
nextCursor: undefined,
hasNextPage: false,
};
}
const conversationIds = searchResults.hits.map((hit) => hit.conversationId);
query['conversationId'] = { $in: conversationIds };
} catch (searchError) {
logger.error('[getSharedLinks] Meilisearch error', {
error: searchError.message,
user,
});
return {
links: [],
nextCursor: undefined,
hasNextPage: false,
};
}
}
const sort = {};
sort[sortBy] = sortDirection === 'desc' ? -1 : 1;
if (Array.isArray(query.conversationId)) {
query.conversationId = { $in: query.conversationId };
}
const sharedLinks = await SharedLink.find(query)
.sort(sort)
.limit(pageSize + 1)
.select('-__v -user')
.lean();
const hasNextPage = sharedLinks.length > pageSize;
const links = sharedLinks.slice(0, pageSize);
const nextCursor = hasNextPage ? links[links.length - 1][sortBy] : undefined;
return {
links: links.map((link) => ({
shareId: link.shareId,
title: link?.title || 'Untitled',
isPublic: link.isPublic,
createdAt: link.createdAt,
conversationId: link.conversationId,
})),
nextCursor,
hasNextPage,
};
} catch (error) {
logger.error('[getSharedLinks] Error getting shares', {
error: error.message,
user,
});
throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR');
}
}
async function deleteAllSharedLinks(user) {
try {
const result = await SharedLink.deleteMany({ user });
return {
message: 'All shared links deleted successfully',
deletedCount: result.deletedCount,
};
} catch (error) {
logger.error('[deleteAllSharedLinks] Error deleting shared links', {
error: error.message,
user,
});
throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR');
}
}
async function createSharedLink(user, conversationId) {
if (!user || !conversationId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const [existingShare, conversationMessages] = await Promise.all([
SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(),
getMessages({ conversationId }),
]);
if (existingShare && existingShare.isPublic) {
throw new ShareServiceError('Share already exists', 'SHARE_EXISTS');
} else if (existingShare) {
await SharedLink.deleteOne({ conversationId });
}
const conversation = await Conversation.findOne({ conversationId }).lean();
const title = conversation?.title || 'Untitled';
const shareId = nanoid();
await SharedLink.create({
shareId,
conversationId,
messages: conversationMessages,
title,
user,
});
return { shareId, conversationId };
} catch (error) {
logger.error('[createSharedLink] Error creating shared link', {
error: error.message,
user,
conversationId,
});
throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR');
}
}
async function getSharedLink(user, conversationId) {
if (!user || !conversationId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const share = await SharedLink.findOne({ conversationId, user, isPublic: true })
.select('shareId -_id')
.lean();
if (!share) {
return { shareId: null, success: false };
}
return { shareId: share.shareId, success: true };
} catch (error) {
logger.error('[getSharedLink] Error getting shared link', {
error: error.message,
user,
conversationId,
});
throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR');
}
}
async function updateSharedLink(user, shareId) {
if (!user || !shareId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean();
if (!share) {
throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND');
}
const [updatedMessages] = await Promise.all([
getMessages({ conversationId: share.conversationId }),
]);
const newShareId = nanoid();
const update = {
messages: updatedMessages,
user,
shareId: newShareId,
};
const updatedShare = await SharedLink.findOneAndUpdate({ shareId, user }, update, {
new: true,
upsert: false,
runValidators: true,
}).lean();
if (!updatedShare) {
throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR');
}
anonymizeConvo(updatedShare);
return { shareId: newShareId, conversationId: updatedShare.conversationId };
} catch (error) {
logger.error('[updateSharedLink] Error updating shared link', {
error: error.message,
user,
shareId,
});
throw new ShareServiceError(
error.code === 'SHARE_UPDATE_ERROR' ? error.message : 'Error updating shared link',
error.code || 'SHARE_UPDATE_ERROR',
);
}
}
async function deleteSharedLink(user, shareId) {
if (!user || !shareId) {
throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS');
}
try {
const result = await SharedLink.findOneAndDelete({ shareId, user }).lean();
if (!result) {
return null;
}
return {
success: true,
shareId,
message: 'Share deleted successfully',
};
} catch (error) {
logger.error('[deleteSharedLink] Error deleting shared link', {
error: error.message,
user,
shareId,
});
throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR');
}
}
module.exports = {
getSharedLink,
getSharedLinks,
createSharedLink,
updateSharedLink,
deleteSharedLink,
getSharedMessages,
deleteAllSharedLinks,
};

42
api/models/Token.js Normal file
View File

@@ -0,0 +1,42 @@
const { findToken, updateToken, createToken } = require('~/models');
const { encryptV2 } = require('~/server/utils/crypto');
/**
* Handles the OAuth token by creating or updating the token.
* @param {object} fields
* @param {string} fields.userId - The user's ID.
* @param {string} fields.token - The full token to store.
* @param {string} fields.identifier - Unique, alternative identifier for the token.
* @param {number} fields.expiresIn - The number of seconds until the token expires.
* @param {object} fields.metadata - Additional metadata to store with the token.
* @param {string} [fields.type="oauth"] - The type of token. Default is 'oauth'.
*/
async function handleOAuthToken({
token,
userId,
identifier,
expiresIn,
metadata,
type = 'oauth',
}) {
const encrypedToken = await encryptV2(token);
const tokenData = {
type,
userId,
metadata,
identifier,
token: encrypedToken,
expiresIn: parseInt(expiresIn, 10) || 3600,
};
const existingToken = await findToken({ userId, identifier });
if (existingToken) {
return await updateToken({ identifier }, tokenData);
} else {
return await createToken(tokenData);
}
}
module.exports = {
handleOAuthToken,
};

View File

@@ -1,4 +1,6 @@
const { ToolCall } = require('~/db/models');
const mongoose = require('mongoose');
const ToolCall = require('~/db/models').ToolCall;
/**
* Create a new tool call

View File

@@ -1,7 +1,10 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { Transaction, Balance } = require('~/db/models');
const Transaction = require('~/db/models').Transaction;
const Balance = require('~/db/models').Balance;
const cancelRate = 1.15;

View File

@@ -4,7 +4,7 @@ const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { createTransaction } = require('./Transaction');
const { Balance } = require('~/db/models');
const Balance = require('~/db/models').Balance;
// Mock the custom config module so we can control the balance flag.
jest.mock('~/server/services/Config');

View File

@@ -1,9 +1,11 @@
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const { ViolationTypes } = require('librechat-data-provider');
const { createAutoRefillTransaction } = require('./Transaction');
const { logViolation } = require('~/cache');
const { getMultiplier } = require('./tx');
const { Balance } = require('~/db/models');
const Balance = require('~/db/models').Balance;
function isInvalidDate(date) {
return isNaN(date);

View File

@@ -1,7 +1,8 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { getMessages, bulkSaveMessages } = require('./Message');
const { Message } = require('~/db/models');
const Message = require('~/db/models').Message;
// Original version of buildTree function
function buildTree({ messages, fileMap }) {

View File

@@ -1,6 +1,6 @@
const mongoose = require('mongoose');
const { getRandomValues } = require('@librechat/api');
const { logger, hashToken } = require('@librechat/data-schemas');
const { getRandomValues } = require('~/server/utils/crypto');
const { createToken, findToken } = require('~/models');
/**

View File

@@ -2,8 +2,8 @@ const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { createTransaction, createAutoRefillTransaction } = require('./Transaction');
require('~/db/models');
const Transaction = require('~/db/models').Transaction;
const Balance = require('~/db/models').Balance;
// Mock the logger to prevent console output during tests
jest.mock('~/config', () => ({
@@ -20,15 +20,10 @@ jest.mock('~/server/services/Config');
describe('spendTokens', () => {
let mongoServer;
let userId;
let Transaction;
let Balance;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
Transaction = mongoose.model('Transaction');
Balance = mongoose.model('Balance');
});
afterAll(async () => {

View File

@@ -78,7 +78,7 @@ const tokenValues = Object.assign(
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'o4-mini': { prompt: 1.1, completion: 4.4 },
'o3-mini': { prompt: 1.1, completion: 4.4 },
o3: { prompt: 2, completion: 8 },
o3: { prompt: 10, completion: 40 },
'o1-mini': { prompt: 1.1, completion: 4.4 },
'o1-preview': { prompt: 15, completion: 60 },
o1: { prompt: 15, completion: 60 },
@@ -135,11 +135,10 @@ const tokenValues = Object.assign(
'grok-2-1212': { prompt: 2.0, completion: 10.0 },
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
'grok-3-mini-fast': { prompt: 0.4, completion: 4 },
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
'grok-3-fast': { prompt: 5.0, completion: 25.0 },
'grok-3': { prompt: 3.0, completion: 15.0 },
'grok-4': { prompt: 3.0, completion: 15.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 },

View File

@@ -636,15 +636,6 @@ describe('Grok Model Tests - Pricing', () => {
);
});
test('should return correct prompt and completion rates for Grok 4 model', () => {
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
tokenValues['grok-3'].prompt,
@@ -671,15 +662,6 @@ describe('Grok Model Tests - Pricing', () => {
tokenValues['grok-3-mini-fast'].completion,
);
});
test('should return correct prompt and completion rates for Grok 4 model with prefixes', () => {
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'prompt' })).toBe(
tokenValues['grok-4'].prompt,
);
expect(getMultiplier({ model: 'xai/grok-4-0709', tokenType: 'completion' })).toBe(
tokenValues['grok-4'].completion,
);
});
});
});

View File

@@ -12,10 +12,6 @@ const comparePassword = async (user, candidatePassword) => {
throw new Error('No user provided');
}
if (!user.password) {
throw new Error('No password, likely an email first registered via Social/OIDC login');
}
return new Promise((resolve, reject) => {
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
if (err) {

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "v0.7.9-rc1",
"version": "v0.7.8",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -34,27 +34,27 @@
},
"homepage": "https://librechat.ai",
"dependencies": {
"@anthropic-ai/sdk": "^0.52.0",
"@anthropic-ai/sdk": "^0.37.0",
"@aws-sdk/client-s3": "^3.758.0",
"@aws-sdk/s3-request-presigner": "^3.758.0",
"@azure/identity": "^4.7.0",
"@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.27.0",
"@google/generative-ai": "^0.24.0",
"@google/generative-ai": "^0.23.0",
"@googleapis/youtube": "^20.0.0",
"@keyv/redis": "^4.3.3",
"@langchain/community": "^0.3.47",
"@langchain/core": "^0.3.60",
"@langchain/google-genai": "^0.2.13",
"@langchain/google-vertexai": "^0.2.13",
"@langchain/community": "^0.3.44",
"@langchain/core": "^0.3.57",
"@langchain/google-genai": "^0.2.9",
"@langchain/google-vertexai": "^0.2.9",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.60",
"@librechat/api": "*",
"@librechat/agents": "^2.4.37",
"@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"cohere-ai": "^7.9.1",
"compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.7.2",
@@ -81,15 +81,15 @@
"keyv-file": "^5.1.2",
"klona": "^2.0.6",
"librechat-data-provider": "*",
"librechat-mcp": "*",
"lodash": "^4.17.21",
"meilisearch": "^0.38.0",
"memorystore": "^1.6.7",
"mime": "^3.0.0",
"module-alias": "^2.2.3",
"mongoose": "^8.12.1",
"multer": "^2.0.1",
"multer": "^2.0.0",
"nanoid": "^3.3.7",
"node-fetch": "^2.7.0",
"nodemailer": "^6.9.15",
"ollama": "^0.5.0",
"openai": "^4.96.2",
@@ -109,9 +109,8 @@
"tiktoken": "^1.0.15",
"traverse": "^0.6.7",
"ua-parser-js": "^1.0.36",
"undici": "^7.10.0",
"winston": "^3.11.0",
"winston-daily-rotate-file": "^5.0.0",
"winston-daily-rotate-file": "^4.7.1",
"youtube-transcript": "^1.2.1",
"zod": "^3.22.4"
},

View File

@@ -169,6 +169,9 @@ function disposeClient(client) {
client.isGenerativeModel = null;
}
// Properties specific to OpenAIClient
if (client.ChatGPTClient) {
client.ChatGPTClient = null;
}
if (client.completionsUrl) {
client.completionsUrl = null;
}
@@ -217,9 +220,6 @@ function disposeClient(client) {
if (client.maxResponseTokens) {
client.maxResponseTokens = null;
}
if (client.processMemory) {
client.processMemory = null;
}
if (client.run) {
// Break circular references in run
if (client.run.Graph) {

View File

@@ -0,0 +1,282 @@
const { getResponseSender, Constants } = require('librechat-data-provider');
const {
handleAbortError,
createAbortController,
cleanupAbortController,
} = require('~/server/middleware');
const {
disposeClient,
processReqData,
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
endpointOption,
conversationId,
modelDisplayLabel,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
let client = null;
let abortKey = null;
let cleanupHandlers = [];
let clientRef = null;
logger.debug('[AskController]', {
text,
conversationId,
...endpointOption,
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
});
let userMessage = null;
let userMessagePromise = null;
let promptTokens = null;
let userMessageId = null;
let responseMessageId = null;
let getAbortData = null;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
modelDisplayLabel,
});
const initialConversationId = conversationId;
const newConvo = !initialConversationId;
const userId = req.user.id;
let reqDataContext = {
userMessage,
userMessagePromise,
responseMessageId,
promptTokens,
conversationId,
userMessageId,
};
const updateReqData = (data = {}) => {
reqDataContext = processReqData(data, reqDataContext);
abortKey = reqDataContext.abortKey;
userMessage = reqDataContext.userMessage;
userMessagePromise = reqDataContext.userMessagePromise;
responseMessageId = reqDataContext.responseMessageId;
promptTokens = reqDataContext.promptTokens;
conversationId = reqDataContext.conversationId;
userMessageId = reqDataContext.userMessageId;
};
let { onProgress: progressCallback, getPartialText } = createOnProgress();
const performCleanup = () => {
logger.debug('[AskController] Performing cleanup');
if (Array.isArray(cleanupHandlers)) {
for (const handler of cleanupHandlers) {
try {
if (typeof handler === 'function') {
handler();
}
} catch (e) {
// Ignore
}
}
}
if (abortKey) {
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
if (client) {
disposeClient(client);
client = null;
}
reqDataContext = null;
userMessage = null;
userMessagePromise = null;
promptTokens = null;
getAbortData = null;
progressCallback = null;
endpointOption = null;
cleanupHandlers = null;
addTitle = null;
if (requestDataMap.has(req)) {
requestDataMap.delete(req);
}
logger.debug('[AskController] Cleanup completed');
};
try {
({ client } = await initializeClient({ req, res, endpointOption }));
if (clientRegistry && client) {
clientRegistry.register(client, { userId }, client);
}
if (client) {
requestDataMap.set(req, { client });
}
clientRef = new WeakRef(client);
getAbortData = () => {
const currentClient = clientRef?.deref();
const currentText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
return {
sender,
conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: currentText,
userMessage: userMessage,
userMessagePromise: userMessagePromise,
promptTokens: reqDataContext.promptTokens,
};
};
const { onStart, abortController } = createAbortController(
req,
res,
getAbortData,
updateReqData,
);
const closeHandler = () => {
logger.debug('[AskController] Request closed');
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
return;
}
abortController.abort();
logger.debug('[AskController] Request aborted on close');
};
res.on('close', closeHandler);
cleanupHandlers.push(() => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
// Ignore
}
});
const messageOptions = {
user: userId,
parentMessageId,
conversationId: reqDataContext.conversationId,
overrideParentMessageId,
getReqData: updateReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
res,
},
};
/** @type {TMessage} */
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;
const databasePromise = response.databasePromise;
delete response.databasePromise;
const { conversation: convoData = {} } = await databasePromise;
const conversation = { ...convoData };
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
const latestUserMessage = reqDataContext.userMessage;
if (client?.options?.attachments && latestUserMessage) {
latestUserMessage.files = client.options.attachments;
if (endpointOption?.modelOptions?.model) {
conversation.model = endpointOption.modelOptions.model;
}
delete latestUserMessage.image_urls;
}
if (!abortController.signal.aborted) {
const finalResponseMessage = { ...response };
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
requestMessage: latestUserMessage,
responseMessage: finalResponseMessage,
});
res.end();
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
await saveMessage(
req,
{ ...finalResponseMessage, user: userId },
{ context: 'api/server/controllers/AskController.js - response end' },
);
}
}
if (!client?.skipSaveUserMessage && latestUserMessage) {
await saveMessage(req, latestUserMessage, {
context: "api/server/controllers/AskController.js - don't skip saving user message",
});
}
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
text,
response: { ...response },
client,
})
.then(() => {
logger.debug('[AskController] Title generation started');
})
.catch((err) => {
logger.error('[AskController] Error in title generation', err);
})
.finally(() => {
logger.debug('[AskController] Title generation completed');
performCleanup();
});
} else {
performCleanup();
}
} catch (error) {
logger.error('[AskController] Error handling request', error);
let partialText = '';
try {
const currentClient = clientRef?.deref();
partialText =
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
} catch (getTextError) {
logger.error('[AskController] Error calling getText() during error handling', getTextError);
}
handleAbortError(res, req, error, {
sender,
partialText,
conversationId: reqDataContext.conversationId,
messageId: reqDataContext.responseMessageId,
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
userMessageId: reqDataContext.userMessageId,
})
.catch((err) => {
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
})
.finally(() => {
performCleanup();
});
}
};
module.exports = AskController;

View File

@@ -1,17 +1,17 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const openIdClient = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
requestPasswordReset,
setOpenIDAuthTokens,
registerUser,
resetPassword,
setAuthTokens,
registerUser,
requestPasswordReset,
setOpenIDAuthTokens,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getOpenIdConfig } = require('~/strategies');
const { isEnabled } = require('~/server/utils');
const registrationController = async (req, res) => {
try {

View File

@@ -1,4 +1,6 @@
const { Balance } = require('~/db/models');
const mongoose = require('mongoose');
const Balance = require('~/db/models').Balance;
async function balanceController(req, res) {
const balanceData = await Balance.findOne(

View File

@@ -1,5 +1,3 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { getResponseSender } = require('librechat-data-provider');
const {
handleAbortError,
@@ -12,8 +10,9 @@ const {
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { createOnProgress } = require('~/server/utils');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
let {
@@ -85,7 +84,7 @@ const EditController = async (req, res, next, initializeClient) => {
}
if (abortKey) {
logger.debug('[EditController] Cleaning up abort controller');
logger.debug('[AskController] Cleaning up abort controller');
cleanupAbortController(abortKey);
abortKey = null;
}
@@ -199,7 +198,7 @@ const EditController = async (req, res, next, initializeClient) => {
const finalUserMessage = reqDataContext.userMessage;
const finalResponseMessage = { ...response };
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -24,23 +24,17 @@ const handleValidationError = (err, res) => {
}
};
module.exports = (err, _req, res, _next) => {
// eslint-disable-next-line no-unused-vars
module.exports = (err, req, res, next) => {
try {
if (err.name === 'ValidationError') {
return handleValidationError(err, res);
return (err = handleValidationError(err, res));
}
if (err.code && err.code == 11000) {
return handleDuplicateKeyError(err, res);
return (err = handleDuplicateKeyError(err, res));
}
// Special handling for errors like SyntaxError
if (err.statusCode && err.body) {
return res.status(err.statusCode).send(err.body);
}
logger.error('ErrorController => error', err);
return res.status(500).send('An unknown error occurred.');
} catch (err) {
logger.error('ErrorController => processing error', err);
return res.status(500).send('Processing error in ErrorController.');
logger.error('ErrorController => error', err);
res.status(500).send('An unknown error occurred.');
}
};

View File

@@ -1,241 +0,0 @@
const errorController = require('./ErrorController');
const { logger } = require('~/config');
// Mock the logger
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
},
}));
describe('ErrorController', () => {
let mockReq, mockRes, mockNext;
beforeEach(() => {
mockReq = {};
mockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
mockNext = jest.fn();
logger.error.mockClear();
});
describe('ValidationError handling', () => {
it('should handle ValidationError with single error', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '["Email is required"]',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with multiple errors', () => {
const validationError = {
name: 'ValidationError',
errors: {
email: { message: 'Email is required', path: 'email' },
password: { message: 'Password is required', path: 'password' },
},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '"Email is required Password is required"',
fields: '["email","password"]',
});
expect(logger.error).toHaveBeenCalledWith('Validation error:', validationError.errors);
});
it('should handle ValidationError with empty errors object', () => {
const validationError = {
name: 'ValidationError',
errors: {},
};
errorController(validationError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith({
messages: '[]',
fields: '[]',
});
});
});
describe('Duplicate key error handling', () => {
it('should handle duplicate key error (code 11000)', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle duplicate key error with multiple fields', () => {
const duplicateKeyError = {
code: 11000,
keyValue: { email: 'test@example.com', username: 'testuser' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email","username"] already exists.',
fields: '["email","username"]',
});
expect(logger.error).toHaveBeenCalledWith('Duplicate key error:', duplicateKeyError.keyValue);
});
it('should handle error with code 11000 as string', () => {
const duplicateKeyError = {
code: '11000',
keyValue: { email: 'test@example.com' },
};
errorController(duplicateKeyError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(409);
expect(mockRes.send).toHaveBeenCalledWith({
messages: 'An document with that ["email"] already exists.',
fields: '["email"]',
});
});
});
describe('SyntaxError handling', () => {
it('should handle errors with statusCode and body', () => {
const syntaxError = {
statusCode: 400,
body: 'Invalid JSON syntax',
};
errorController(syntaxError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.send).toHaveBeenCalledWith('Invalid JSON syntax');
});
it('should handle errors with different statusCode and body', () => {
const customError = {
statusCode: 422,
body: { error: 'Unprocessable entity' },
};
errorController(customError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(422);
expect(mockRes.send).toHaveBeenCalledWith({ error: 'Unprocessable entity' });
});
it('should handle error with statusCode but no body', () => {
const partialError = {
statusCode: 400,
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
it('should handle error with body but no statusCode', () => {
const partialError = {
body: 'Some error message',
};
errorController(partialError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
});
});
describe('Unknown error handling', () => {
it('should handle unknown errors', () => {
const unknownError = new Error('Some unknown error');
errorController(unknownError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', unknownError);
});
it('should handle errors with code other than 11000', () => {
const mongoError = {
code: 11100,
message: 'Some MongoDB error',
};
errorController(mongoError, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('An unknown error occurred.');
expect(logger.error).toHaveBeenCalledWith('ErrorController => error', mongoError);
});
it('should handle null/undefined errors', () => {
errorController(null, mockReq, mockRes, mockNext);
expect(mockRes.status).toHaveBeenCalledWith(500);
expect(mockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledWith(
'ErrorController => processing error',
expect.any(Error),
);
});
});
describe('Catch block handling', () => {
beforeEach(() => {
// Restore logger mock to normal behavior for these tests
logger.error.mockRestore();
logger.error = jest.fn();
});
it('should handle errors when logger.error throws', () => {
// Create fresh mocks for this test
const freshMockRes = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
// Mock logger to throw on the first call, succeed on the second
logger.error
.mockImplementationOnce(() => {
throw new Error('Logger error');
})
.mockImplementation(() => {});
const testError = new Error('Test error');
errorController(testError, mockReq, freshMockRes, mockNext);
expect(freshMockRes.status).toHaveBeenCalledWith(500);
expect(freshMockRes.send).toHaveBeenCalledWith('Processing error in ErrorController.');
expect(logger.error).toHaveBeenCalledTimes(2);
});
});
});

View File

@@ -1,9 +1,8 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, AuthType, Constants } = require('librechat-data-provider');
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
const { CacheKeys, AuthType } = require('librechat-data-provider');
const { getToolkitKey } = require('~/server/services/ToolService');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCustomConfig } = require('~/server/services/Config');
const { availableTools } = require('~/app/clients/tools');
const { getMCPManager } = require('~/config');
const { getLogStores } = require('~/cache');
/**
@@ -85,45 +84,6 @@ const getAvailablePluginsController = async (req, res) => {
}
};
function createServerToolsCallback() {
/**
* @param {string} serverName
* @param {TPlugin[] | null} serverTools
*/
return async function (serverName, serverTools) {
try {
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
if (!serverName || !mcpToolsCache) {
return;
}
await mcpToolsCache.set(serverName, serverTools);
logger.warn(`MCP tools for ${serverName} added to cache.`);
} catch (error) {
logger.error('Error retrieving MCP tools from cache:', error);
}
};
}
function createGetServerTools() {
/**
* Retrieves cached server tools
* @param {string} serverName
* @returns {Promise<TPlugin[] | null>}
*/
return async function (serverName) {
try {
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
if (!mcpToolsCache) {
return null;
}
return await mcpToolsCache.get(serverName);
} catch (error) {
logger.error('Error retrieving MCP tools from cache:', error);
return null;
}
};
}
/**
* Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file.
*
@@ -138,33 +98,18 @@ function createGetServerTools() {
*/
const getAvailableTools = async (req, res) => {
try {
const userId = req.user?.id;
const customConfig = await getCustomConfig();
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
const cachedUserTools = await getCachedTools({ userId });
const userPlugins = await convertMCPToolsToPlugins(cachedUserTools, customConfig, userId);
if (cachedToolsArray && userPlugins) {
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
res.status(200).json(dedupedTools);
const cachedTools = await cache.get(CacheKeys.TOOLS);
if (cachedTools) {
res.status(200).json(cachedTools);
return;
}
// If not in cache, build from manifest
let pluginManifest = availableTools;
const customConfig = await getCustomConfig();
if (customConfig?.mcpServers != null) {
const mcpManager = getMCPManager();
const flowsCache = getLogStores(CacheKeys.FLOWS);
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
const serverToolsCallback = createServerToolsCallback();
const getServerTools = createGetServerTools();
const mcpTools = await mcpManager.loadManifestTools({
flowManager,
serverToolsCallback,
getServerTools,
});
pluginManifest = [...mcpTools, ...pluginManifest];
pluginManifest = await mcpManager.loadManifestTools(pluginManifest);
}
/** @type {TPlugin[]} */
@@ -178,212 +123,21 @@ const getAvailableTools = async (req, res) => {
}
});
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
const toolDefinitions = req.app.locals.availableTools;
const tools = authenticatedPlugins.filter(
(plugin) =>
toolDefinitions[plugin.pluginKey] !== undefined ||
(plugin.toolkit === true &&
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)),
);
const toolsOutput = [];
for (const plugin of authenticatedPlugins) {
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
const isToolkit =
plugin.toolkit === true &&
Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey);
if (!isToolDefined && !isToolkit) {
continue;
}
const toolToAdd = { ...plugin };
if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) {
toolsOutput.push(toolToAdd);
continue;
}
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
const serverName = parts[parts.length - 1];
const serverConfig = customConfig?.mcpServers?.[serverName];
logger.warn(
`[getAvailableTools] Processing MCP tool:`,
JSON.stringify({
pluginKey: plugin.pluginKey,
serverName,
hasServerConfig: !!serverConfig,
hasCustomUserVars: !!serverConfig?.customUserVars,
}),
);
if (!serverConfig) {
logger.warn(
`[getAvailableTools] No server config found for ${serverName}, skipping auth check`,
);
toolsOutput.push(toolToAdd);
continue;
}
// Handle MCP servers with customUserVars (user-level auth required)
if (serverConfig.customUserVars) {
logger.warn(`[getAvailableTools] Processing user-level MCP server: ${serverName}`);
const customVarKeys = Object.keys(serverConfig.customUserVars);
// Build authConfig for MCP tools
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
// Check actual connection status for MCP tools with auth requirements
if (userId) {
try {
const mcpManager = getMCPManager(userId);
const connectionStatus = await mcpManager.getUserConnectionStatus(userId, serverName);
toolToAdd.authenticated = connectionStatus.connected;
logger.warn(`[getAvailableTools] User-level connection status for ${serverName}:`, {
connected: connectionStatus.connected,
hasConnection: connectionStatus.hasConnection,
});
} catch (error) {
logger.error(
`[getAvailableTools] Error checking connection status for ${serverName}:`,
error,
);
toolToAdd.authenticated = false;
}
} else {
// For non-authenticated requests, default to false
toolToAdd.authenticated = false;
}
} else {
// Handle app-level MCP servers (no auth required)
logger.warn(`[getAvailableTools] Processing app-level MCP server: ${serverName}`);
toolToAdd.authConfig = [];
// Check if the app-level connection is active
try {
const mcpManager = getMCPManager();
const allConnections = mcpManager.getAllConnections();
logger.warn(`[getAvailableTools] All app-level connections:`, {
connectionNames: Array.from(allConnections.keys()),
serverName,
});
const appConnection = mcpManager.getConnection(serverName);
logger.warn(`[getAvailableTools] Checking app-level connection for ${serverName}:`, {
hasConnection: !!appConnection,
connectionState: appConnection?.getConnectionState?.(),
});
if (appConnection) {
const connectionState = appConnection.getConnectionState();
logger.warn(`[getAvailableTools] App-level connection status for ${serverName}:`, {
connectionState,
hasConnection: !!appConnection,
});
// For app-level connections, consider them authenticated if they're in 'connected' state
// This is more reliable than isConnected() which does network calls
toolToAdd.authenticated = connectionState === 'connected';
logger.warn(`[getAvailableTools] Final authenticated status for ${serverName}:`, {
authenticated: toolToAdd.authenticated,
connectionState,
});
} else {
logger.warn(`[getAvailableTools] No app-level connection found for ${serverName}`);
toolToAdd.authenticated = false;
}
} catch (error) {
logger.error(
`[getAvailableTools] Error checking app-level connection status for ${serverName}:`,
error,
);
toolToAdd.authenticated = false;
}
}
toolsOutput.push(toolToAdd);
}
const finalTools = filterUniquePlugins(toolsOutput);
await cache.set(CacheKeys.TOOLS, finalTools);
const dedupedTools = filterUniquePlugins([...userPlugins, ...finalTools]);
res.status(200).json(dedupedTools);
await cache.set(CacheKeys.TOOLS, tools);
res.status(200).json(tools);
} catch (error) {
logger.error('[getAvailableTools]', error);
res.status(500).json({ message: error.message });
}
};
/**
* Converts MCP function format tools to plugin format
* @param {Object} functionTools - Object with function format tools
* @param {Object} customConfig - Custom configuration for MCP servers
* @returns {Array} Array of plugin objects
*/
async function convertMCPToolsToPlugins(functionTools, customConfig, userId = null) {
const plugins = [];
for (const [toolKey, toolData] of Object.entries(functionTools)) {
if (!toolData.function || !toolKey.includes(Constants.mcp_delimiter)) {
continue;
}
const functionData = toolData.function;
const parts = toolKey.split(Constants.mcp_delimiter);
const serverName = parts[parts.length - 1];
const plugin = {
name: parts[0], // Use the tool name without server suffix
pluginKey: toolKey,
description: functionData.description || '',
authenticated: false, // Default to false, will be updated based on connection status
icon: undefined,
};
// Build authConfig for MCP tools
const serverConfig = customConfig?.mcpServers?.[serverName];
if (!serverConfig?.customUserVars) {
plugin.authConfig = [];
plugin.authenticated = true; // No auth required
plugins.push(plugin);
continue;
}
const customVarKeys = Object.keys(serverConfig.customUserVars);
if (customVarKeys.length === 0) {
plugin.authConfig = [];
plugin.authenticated = true; // No auth required
} else {
plugin.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
authField: key,
label: value.title || key,
description: value.description || '',
}));
// Check actual connection status for MCP tools with auth requirements
if (userId) {
try {
const mcpManager = getMCPManager(userId);
const connectionStatus = await mcpManager.getUserConnectionStatus(userId, serverName);
plugin.authenticated = connectionStatus.connected;
} catch (error) {
logger.error(
`[convertMCPToolsToPlugins] Error checking connection status for ${serverName}:`,
error,
);
plugin.authenticated = false;
}
} else {
plugin.authenticated = false;
}
}
plugins.push(plugin);
}
return plugins;
}
module.exports = {
getAvailableTools,
getAvailablePluginsController,

View File

@@ -1,4 +1,3 @@
const { encryptV3 } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
verifyTOTP,
@@ -8,6 +7,7 @@ const {
generateBackupCodes,
} = require('~/server/services/twoFactorService');
const { getUserById, updateUser } = require('~/models');
const { encryptV3 } = require('~/server/utils/crypto');
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');

View File

@@ -1,6 +1,5 @@
const {
Tools,
Constants,
FileSources,
webSearchKeys,
extractWebSearchEnvVars,
@@ -21,10 +20,12 @@ const { updateUserPluginsService, deleteUserKey } = require('~/server/services/U
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { Transaction, Balance, User } = require('~/db/models');
const { deleteAllSharedLinks } = require('~/models/Share');
const { deleteToolCalls } = require('~/models/ToolCall');
const { deleteAllSharedLinks } = require('~/models');
const { getMCPManager } = require('~/config');
const Transaction = require('~/db/models').Transaction;
const Balance = require('~/db/models').Balance;
const User = require('~/db/models').User;
const getUserController = async (req, res) => {
/** @type {MongoUser} */
@@ -104,22 +105,10 @@ const updateUserPluginsController = async (req, res) => {
}
let keys = Object.keys(auth);
const values = Object.values(auth); // Used in 'install' block
const isMCPTool = pluginKey.startsWith('mcp_') || pluginKey.includes(Constants.mcp_delimiter);
// Early exit condition:
// If keys are empty (meaning auth: {} was likely sent for uninstall, or auth was empty for install)
// AND it's not web_search (which has special key handling to populate `keys` for uninstall)
// AND it's NOT (an uninstall action FOR an MCP tool - we need to proceed for this case to clear all its auth)
// THEN return.
if (
keys.length === 0 &&
pluginKey !== Tools.web_search &&
!(action === 'uninstall' && isMCPTool)
) {
if (keys.length === 0 && pluginKey !== Tools.web_search) {
return res.status(200).send();
}
const values = Object.values(auth);
/** @type {number} */
let status = 200;
@@ -146,57 +135,16 @@ const updateUserPluginsController = async (req, res) => {
}
}
} else if (action === 'uninstall') {
// const isMCPTool was defined earlier
if (isMCPTool && keys.length === 0) {
// This handles the case where auth: {} is sent for an MCP tool uninstall.
// It means "delete all credentials associated with this MCP pluginKey".
authService = await deleteUserPluginAuth(user.id, null, true, pluginKey);
for (let i = 0; i < keys.length; i++) {
authService = await deleteUserPluginAuth(user.id, keys[i]);
if (authService instanceof Error) {
logger.error(
`[authService] Error deleting all auth for MCP tool ${pluginKey}:`,
authService,
);
logger.error('[authService]', authService);
({ status, message } = authService);
}
} else {
// This handles:
// 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}).
// 2. Other tools uninstall (if keys were provided).
// 3. MCP tool uninstall if specific keys were provided in `auth` (not current frontend behavior).
// If keys is empty for non-MCP tools (and not web_search), this loop won't run, and nothing is deleted.
for (let i = 0; i < keys.length; i++) {
authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name
if (authService instanceof Error) {
logger.error('[authService] Error deleting specific auth key:', authService);
({ status, message } = authService);
}
}
}
}
if (status === 200) {
// If auth was updated successfully, disconnect MCP sessions as they might use these credentials
if (pluginKey.startsWith(Constants.mcp_prefix)) {
try {
const mcpManager = getMCPManager(user.id);
if (mcpManager) {
// Extract server name from pluginKey (e.g., "mcp_myserver" -> "myserver")
const serverName = pluginKey.replace(Constants.mcp_prefix, '');
logger.info(
`[updateUserPluginsController] Disconnecting MCP connection for user ${user.id} and server ${serverName} after plugin auth update for ${pluginKey}.`,
);
// COMMENTED OUT: Don't kill the server connection on revoke
// await mcpManager.disconnectUserConnection(user.id, serverName);
}
} catch (disconnectError) {
logger.error(
`[updateUserPluginsController] Error disconnecting MCP connection for user ${user.id} after plugin auth update:`,
disconnectError,
);
// Do not fail the request for this, but log it.
}
}
return res.status(status).send();
}
@@ -218,11 +166,7 @@ const deleteUserController = async (req, res) => {
await Balance.deleteMany({ user: user._id }); // delete user balances
await deletePresets(user.id); // delete user presets
/* TODO: Delete Assistant Threads */
try {
await deleteConvos(user.id); // delete user convos
} catch (error) {
logger.error('[deleteUserController] Error deleting user convos, likely no convos', error);
}
await deleteConvos(user.id); // delete user convos
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
await deleteUserById(user.id); // delete user
await deleteAllSharedLinks(user.id); // delete user shared links

View File

@@ -1,195 +0,0 @@
const { duplicateAgent } = require('../v1');
const { getAgent, createAgent } = require('~/models/Agent');
const { getActions } = require('~/models/Action');
const { nanoid } = require('nanoid');
jest.mock('~/models/Agent');
jest.mock('~/models/Action');
jest.mock('nanoid');
describe('duplicateAgent', () => {
let req, res;
beforeEach(() => {
req = {
params: { id: 'agent_123' },
user: { id: 'user_456' },
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
jest.clearAllMocks();
});
it('should duplicate an agent successfully', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_789',
versions: [{ name: 'Test Agent', version: 1 }],
__v: 0,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
author: 'user_456',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' });
expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
id: 'agent_new_123',
author: 'user_456',
name: expect.stringContaining('Test Agent ('),
description: 'Test Description',
instructions: 'Test Instructions',
provider: 'openai',
model: 'gpt-4',
tools: ['file_search'],
actions: [],
}),
);
expect(createAgent).toHaveBeenCalledWith(
expect.not.objectContaining({
versions: expect.anything(),
__v: expect.anything(),
}),
);
expect(res.status).toHaveBeenCalledWith(201);
expect(res.json).toHaveBeenCalledWith({
agent: mockNewAgent,
actions: [],
});
});
it('should ensure duplicated agent has clean versions array without nested fields', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
description: 'Test Description',
versions: [
{
name: 'Test Agent',
versions: [{ name: 'Nested' }],
__v: 1,
},
],
__v: 2,
};
const mockNewAgent = {
id: 'agent_new_123',
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
versions: [
{
name: 'Test Agent (1/2/23, 12:34)',
description: 'Test Description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue(mockNewAgent);
await duplicateAgent(req, res);
expect(mockNewAgent.versions).toHaveLength(1);
const firstVersion = mockNewAgent.versions[0];
expect(firstVersion).not.toHaveProperty('versions');
expect(firstVersion).not.toHaveProperty('__v');
expect(mockNewAgent).not.toHaveProperty('__v');
expect(res.status).toHaveBeenCalledWith(201);
});
it('should return 404 if agent not found', async () => {
getAgent.mockResolvedValue(null);
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Agent not found',
status: 'error',
});
});
it('should handle tool_resources.ocr correctly', async () => {
const mockAgent = {
id: 'agent_123',
name: 'Test Agent',
tool_resources: {
ocr: { enabled: true, config: 'test' },
other: { should: 'not be copied' },
},
};
getAgent.mockResolvedValue(mockAgent);
getActions.mockResolvedValue([]);
nanoid.mockReturnValue('new_123');
createAgent.mockResolvedValue({ id: 'agent_new_123' });
await duplicateAgent(req, res);
expect(createAgent).toHaveBeenCalledWith(
expect.objectContaining({
tool_resources: {
ocr: { enabled: true, config: 'test' },
},
}),
);
});
it('should handle errors gracefully', async () => {
getAgent.mockRejectedValue(new Error('Database error'));
await duplicateAgent(req, res);
expect(res.status).toHaveBeenCalledWith(500);
expect(res.json).toHaveBeenCalledWith({ error: 'Database error' });
});
});

View File

@@ -1,6 +1,4 @@
const { nanoid } = require('nanoid');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Tools, StepTypes, FileContext } = require('librechat-data-provider');
const {
EnvVar,
@@ -14,6 +12,7 @@ const {
const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { saveBase64Image } = require('~/server/services/Files/process');
const { logger, sendEvent } = require('~/config');
class ModelEndHandler {
/**
@@ -241,7 +240,9 @@ function createToolEndCallback({ req, res, artifactPromises }) {
if (output.artifact[Tools.web_search]) {
artifactPromises.push(
(async () => {
const name = `${output.name}_${output.tool_call_id}_${nanoid()}`;
const attachment = {
name,
type: Tools.web_search,
messageId: metadata.run_id,
toolCallId: output.tool_call_id,

View File

@@ -1,18 +1,15 @@
// const { HttpsProxyAgent } = require('https-proxy-agent');
// const {
// Constants,
// ImageDetail,
// EModelEndpoint,
// resolveHeaders,
// validateVisionModel,
// mapModelToAzureConfig,
// } = require('librechat-data-provider');
require('events').EventEmitter.defaultMaxListeners = 100;
const { logger } = require('@librechat/data-schemas');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const {
sendEvent,
createRun,
Tokenizer,
checkAccess,
memoryInstructions,
createMemoryProcessor,
} = require('@librechat/api');
const {
Callback,
Providers,
GraphEvents,
formatMessage,
formatAgentMessages,
@@ -22,45 +19,25 @@ const {
} = require('@librechat/agents');
const {
Constants,
Permissions,
VisionModes,
ContentTypes,
EModelEndpoint,
KnownEndpoints,
PermissionTypes,
isAgentsEndpoint,
AgentCapabilities,
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const {
findPluginAuthsByKeys,
getFormattedMemories,
deleteMemory,
setMemory,
} = require('~/models');
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getProviderConfig } = require('~/server/services/Endpoints');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const Tokenizer = require('~/server/services/Tokenizer');
const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent');
const { getMCPManager } = require('~/config');
const omitTitleOptions = new Set([
'stream',
'thinking',
'streaming',
'clientOptions',
'thinkingConfig',
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
'additionalModelRequestFields',
]);
const { logger, sendEvent } = require('~/config');
const { createRun } = require('./run');
/**
* @param {ServerRequest} req
@@ -80,8 +57,12 @@ const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deep
const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi];
// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory');
// const { getFormattedMemories } = require('~/models/Memory');
// const { getCurrentDateTime } = require('~/utils');
function createTokenCounter(encoding) {
return function (message) {
return (message) => {
const countTokens = (text) => Tokenizer.getTokenCount(text, encoding);
return getTokenCountForMessage(message, countTokens);
};
@@ -142,8 +123,6 @@ class AgentClient extends BaseClient {
this.usage;
/** @type {Record<string, number>} */
this.indexTokenCountMap = {};
/** @type {(messages: BaseMessage[]) => Promise<void>} */
this.processMemory;
}
/**
@@ -158,10 +137,55 @@ class AgentClient extends BaseClient {
}
/**
* `AgentClient` is not opinionated about vision requests, so we don't do anything here
*
* Checks if the model is a vision model based on request attachments and sets the appropriate options:
* - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request.
* - Sets `this.isVisionModel` to `true` if vision request.
* - Deletes `this.modelOptions.stop` if vision request.
* @param {MongoFile[]} attachments
*/
checkVisionRequest() {}
checkVisionRequest(attachments) {
// if (!attachments) {
// return;
// }
// const availableModels = this.options.modelsConfig?.[this.options.endpoint];
// if (!availableModels) {
// return;
// }
// let visionRequestDetected = false;
// for (const file of attachments) {
// if (file?.type?.includes('image')) {
// visionRequestDetected = true;
// break;
// }
// }
// if (!visionRequestDetected) {
// return;
// }
// this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });
// if (this.isVisionModel) {
// delete this.modelOptions.stop;
// return;
// }
// for (const model of availableModels) {
// if (!validateVisionModel({ model, availableModels })) {
// continue;
// }
// this.modelOptions.model = model;
// this.isVisionModel = true;
// delete this.modelOptions.stop;
// return;
// }
// if (!availableModels.includes(this.defaultVisionModel)) {
// return;
// }
// if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) {
// return;
// }
// this.modelOptions.model = this.defaultVisionModel;
// this.isVisionModel = true;
// delete this.modelOptions.stop;
}
getSaveOptions() {
// TODO:
@@ -245,6 +269,24 @@ class AgentClient extends BaseClient {
.filter(Boolean)
.join('\n')
.trim();
// this.systemMessage = getCurrentDateTime();
// const { withKeys, withoutKeys } = await getFormattedMemories({
// userId: this.options.req.user.id,
// });
// processMemory({
// userId: this.options.req.user.id,
// message: this.options.req.body.text,
// parentMessageId,
// memory: withKeys,
// thread_id: this.conversationId,
// }).catch((error) => {
// logger.error('Memory Agent failed to process memory', error);
// });
// this.systemMessage += '\n\n' + memoryInstructions;
// if (withoutKeys) {
// this.systemMessage += `\n\n# Existing memory about the user:\n${withoutKeys}`;
// }
if (this.options.attachments) {
const attachments = await this.options.attachments;
@@ -328,37 +370,6 @@ class AgentClient extends BaseClient {
systemContent = this.augmentedPrompt + systemContent;
}
// Inject MCP server instructions if available
const ephemeralAgent = this.options.req.body.ephemeralAgent;
let mcpServers = [];
// Check for ephemeral agent MCP servers
if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) {
mcpServers = ephemeralAgent.mcp;
}
// Check for regular agent MCP tools
else if (this.options.agent && this.options.agent.tools) {
mcpServers = this.options.agent.tools
.filter(
(tool) =>
tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter),
)
.map((tool) => tool.name.split(Constants.mcp_delimiter).pop())
.filter(Boolean);
}
if (mcpServers.length > 0) {
try {
const mcpInstructions = getMCPManager().formatInstructionsForContext(mcpServers);
if (mcpInstructions) {
systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n');
logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers);
}
} catch (error) {
logger.error('[AgentClient] Failed to inject MCP instructions:', error);
}
}
if (systemContent) {
this.options.agent.instructions = systemContent;
}
@@ -388,164 +399,9 @@ class AgentClient extends BaseClient {
opts.getReqData({ promptTokens });
}
const withoutKeys = await this.useMemory();
if (withoutKeys) {
systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`;
}
if (systemContent) {
this.options.agent.instructions = systemContent;
}
return result;
}
/**
* @returns {Promise<string | undefined>}
*/
async useMemory() {
const user = this.options.req.user;
if (user.personalization?.memories === false) {
return;
}
const hasAccess = await checkAccess({
user,
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE],
getRoleByName,
});
if (!hasAccess) {
logger.debug(
`[api/server/controllers/agents/client.js #useMemory] User ${user.id} does not have USE permission for memories`,
);
return;
}
/** @type {TCustomConfig['memory']} */
const memoryConfig = this.options.req?.app?.locals?.memory;
if (!memoryConfig || memoryConfig.disabled === true) {
return;
}
/** @type {Agent} */
let prelimAgent;
const allowedProviders = new Set(
this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders,
);
try {
if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) {
prelimAgent = await loadAgent({
req: this.options.req,
agent_id: memoryConfig.agent.id,
endpoint: EModelEndpoint.agents,
});
} else if (
memoryConfig.agent?.id == null &&
memoryConfig.agent?.model != null &&
memoryConfig.agent?.provider != null
) {
prelimAgent = { id: Constants.EPHEMERAL_AGENT_ID, ...memoryConfig.agent };
}
} catch (error) {
logger.error(
'[api/server/controllers/agents/client.js #useMemory] Error loading agent for memory',
error,
);
}
const agent = await initializeAgent({
req: this.options.req,
res: this.options.res,
agent: prelimAgent,
allowedProviders,
endpointOption: {
endpoint:
prelimAgent.id !== Constants.EPHEMERAL_AGENT_ID
? EModelEndpoint.agents
: memoryConfig.agent?.provider,
},
});
if (!agent) {
logger.warn(
'[api/server/controllers/agents/client.js #useMemory] No agent found for memory',
memoryConfig,
);
return;
}
const llmConfig = Object.assign(
{
provider: agent.provider,
model: agent.model,
},
agent.model_parameters,
);
/** @type {import('@librechat/api').MemoryConfig} */
const config = {
validKeys: memoryConfig.validKeys,
instructions: agent.instructions,
llmConfig,
tokenLimit: memoryConfig.tokenLimit,
};
const userId = this.options.req.user.id + '';
const messageId = this.responseMessageId + '';
const conversationId = this.conversationId + '';
const [withoutKeys, processMemory] = await createMemoryProcessor({
userId,
config,
messageId,
conversationId,
memoryMethods: {
setMemory,
deleteMemory,
getFormattedMemories,
},
res: this.options.res,
});
this.processMemory = processMemory;
return withoutKeys;
}
/**
* @param {BaseMessage[]} messages
* @returns {Promise<void | (TAttachment | null)[]>}
*/
async runMemory(messages) {
try {
if (this.processMemory == null) {
return;
}
/** @type {TCustomConfig['memory']} */
const memoryConfig = this.options.req?.app?.locals?.memory;
const messageWindowSize = memoryConfig?.messageWindowSize ?? 5;
let messagesToProcess = [...messages];
if (messages.length > messageWindowSize) {
for (let i = messages.length - messageWindowSize; i >= 0; i--) {
const potentialWindow = messages.slice(i, i + messageWindowSize);
if (potentialWindow[0]?.role === 'user') {
messagesToProcess = [...potentialWindow];
break;
}
}
if (messagesToProcess.length === messages.length) {
messagesToProcess = [...messages.slice(-messageWindowSize)];
}
}
const bufferString = getBufferString(messagesToProcess);
const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`);
return await this.processMemory([bufferMessage]);
} catch (error) {
logger.error('Memory Agent failed to process memory', error);
}
}
/** @type {sendCompletion} */
async sendCompletion(payload, opts = {}) {
await this.chatCompletion({
@@ -688,13 +544,100 @@ class AgentClient extends BaseClient {
let config;
/** @type {ReturnType<createRun>} */
let run;
/** @type {Promise<(TAttachment | null)[] | undefined>} */
let memoryPromise;
try {
if (!abortController) {
abortController = new AbortController();
}
// if (this.options.headers) {
// opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
// }
// if (this.options.proxy) {
// opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
// }
// if (this.isVisionModel) {
// modelOptions.max_tokens = 4000;
// }
// /** @type {TAzureConfig | undefined} */
// const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
// if (
// (this.azure && this.isVisionModel && azureConfig) ||
// (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
// ) {
// const { modelGroupMap, groupMap } = azureConfig;
// const {
// azureOptions,
// baseURL,
// headers = {},
// serverless,
// } = mapModelToAzureConfig({
// modelName: modelOptions.model,
// modelGroupMap,
// groupMap,
// });
// opts.defaultHeaders = resolveHeaders(headers);
// this.langchainProxy = extractBaseURL(baseURL);
// this.apiKey = azureOptions.azureOpenAIApiKey;
// const groupName = modelGroupMap[modelOptions.model].group;
// this.options.addParams = azureConfig.groupMap[groupName].addParams;
// this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
// // Note: `forcePrompt` not re-assigned as only chat models are vision models
// this.azure = !serverless && azureOptions;
// this.azureEndpoint =
// !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
// }
// if (this.azure || this.options.azure) {
// /* Azure Bug, extremely short default `max_tokens` response */
// if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
// modelOptions.max_tokens = 4000;
// }
// /* Azure does not accept `model` in the body, so we need to remove it. */
// delete modelOptions.model;
// opts.baseURL = this.langchainProxy
// ? constructAzureURL({
// baseURL: this.langchainProxy,
// azureOptions: this.azure,
// })
// : this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
// opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
// opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey };
// }
// if (process.env.OPENAI_ORGANIZATION) {
// opts.organization = process.env.OPENAI_ORGANIZATION;
// }
// if (this.options.addParams && typeof this.options.addParams === 'object') {
// modelOptions = {
// ...modelOptions,
// ...this.options.addParams,
// };
// logger.debug('[api/server/controllers/agents/client.js #chatCompletion] added params', {
// addParams: this.options.addParams,
// modelOptions,
// });
// }
// if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
// this.options.dropParams.forEach((param) => {
// delete modelOptions[param];
// });
// logger.debug('[api/server/controllers/agents/client.js #chatCompletion] dropped params', {
// dropParams: this.options.dropParams,
// modelOptions,
// });
// }
/** @type {TCustomConfig['endpoints']['agents']} */
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
@@ -704,9 +647,8 @@ class AgentClient extends BaseClient {
last_agent_index: this.agentConfigs?.size ?? 0,
user_id: this.user ?? this.options.req.user?.id,
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
user: this.options.req.user,
},
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
recursionLimit: agentsEConfig?.recursionLimit,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',
@@ -792,10 +734,6 @@ class AgentClient extends BaseClient {
messages = addCacheControl(messages);
}
if (i === 0) {
memoryPromise = this.runMemory(messages);
}
run = await createRun({
agent,
req: this.options.req,
@@ -831,24 +769,10 @@ class AgentClient extends BaseClient {
run.Graph.contentData = contentData;
}
try {
if (await hasCustomUserVars()) {
config.configurable.userMCPAuthMap = await getMCPAuthMap({
tools: agent.tools,
userId: this.options.req.user.id,
findPluginAuthsByKeys,
});
}
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
err,
);
}
const encoding = this.getEncoding();
await run.processStream({ messages }, config, {
keepContent: i !== 0,
tokenCounter: createTokenCounter(this.getEncoding()),
tokenCounter: createTokenCounter(encoding),
indexTokenCountMap: currentIndexCountMap,
maxContextTokens: agent.maxContextTokens,
callbacks: {
@@ -963,12 +887,6 @@ class AgentClient extends BaseClient {
});
try {
if (memoryPromise) {
const attachments = await memoryPromise;
if (attachments && attachments.length > 0) {
this.artifactPromises.push(...attachments);
}
}
await this.recordCollectedUsage({ context: 'message' });
} catch (err) {
logger.error(
@@ -977,12 +895,6 @@ class AgentClient extends BaseClient {
);
}
} catch (err) {
if (memoryPromise) {
const attachments = await memoryPromise;
if (attachments && attachments.length > 0) {
this.artifactPromises.push(...attachments);
}
}
logger.error(
'[api/server/controllers/agents/client.js #sendCompletion] Operation aborted',
err,
@@ -1011,26 +923,23 @@ class AgentClient extends BaseClient {
throw new Error('Run not initialized');
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const { req, res, agent } = this.options;
const endpoint = agent.endpoint;
const endpoint = this.options.agent.endpoint;
const { req, res } = this.options;
/** @type {import('@librechat/agents').ClientOptions} */
let clientOptions = {
maxTokens: 75,
model: agent.model_parameters.model,
};
const { getOptions, overrideProvider, customEndpointConfig } =
await getProviderConfig(endpoint);
/** @type {TEndpoint | undefined} */
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
let endpointConfig = req.app.locals[endpoint];
if (!endpointConfig) {
logger.warn(
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
);
try {
endpointConfig = await getCustomEndpointConfig(endpoint);
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
err,
);
}
}
if (
endpointConfig &&
endpointConfig.titleModel &&
@@ -1038,56 +947,30 @@ class AgentClient extends BaseClient {
) {
clientOptions.model = endpointConfig.titleModel;
}
const options = await getOptions({
req,
res,
optionsOnly: true,
overrideEndpoint: endpoint,
overrideModel: clientOptions.model,
endpointOption: { model_parameters: clientOptions },
});
let provider = options.provider ?? overrideProvider ?? agent.provider;
if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName == null
clientOptions.model &&
this.options.agent.model_parameters.model !== clientOptions.model
) {
provider = Providers.OPENAI;
} else if (
endpoint === EModelEndpoint.azureOpenAI &&
options.llmConfig?.azureOpenAIApiInstanceName != null &&
provider !== Providers.AZURE
) {
provider = Providers.AZURE;
clientOptions =
(
await initOpenAI({
req,
res,
optionsOnly: true,
overrideModel: clientOptions.model,
overrideEndpoint: endpoint,
endpointOption: {
model_parameters: clientOptions,
},
})
)?.llmConfig ?? clientOptions;
}
/** @type {import('@librechat/agents').ClientOptions} */
clientOptions = { ...options.llmConfig };
if (options.configOptions) {
clientOptions.configuration = options.configOptions;
}
// Ensure maxTokens is set for non-o1 models
if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) {
clientOptions.maxTokens = 75;
} else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens;
}
clientOptions = Object.assign(
Object.fromEntries(
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
),
);
if (provider === Providers.GOOGLE) {
clientOptions.json = true;
}
try {
const titleResult = await this.run.generateTitle({
provider,
inputText: text,
contentParts: this.contentParts,
clientOptions,
@@ -1105,10 +988,8 @@ class AgentClient extends BaseClient {
let input_tokens, output_tokens;
if (item.usage) {
input_tokens =
item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens;
output_tokens =
item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens;
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
} else if (item.tokenUsage) {
input_tokens = item.tokenUsage.promptTokens;
output_tokens = item.tokenUsage.completionTokens;

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { sendResponse } = require('~/server/middleware/error');
const { recordUsage } = require('~/server/services/Threads');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendResponse } = require('~/server/utils');
/**
* @typedef {Object} ErrorHandlerContext
@@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -0,0 +1,106 @@
const { HttpsProxyAgent } = require('https-proxy-agent');
const { resolveHeaders } = require('librechat-data-provider');
const { createLLM } = require('~/app/clients/llm');
/**
* Initializes and returns a Language Learning Model (LLM) instance.
*
* @param {Object} options - Configuration options for the LLM.
* @param {string} options.model - The model identifier.
* @param {string} options.modelName - The specific name of the model.
* @param {number} options.temperature - The temperature setting for the model.
* @param {number} options.presence_penalty - The presence penalty for the model.
* @param {number} options.frequency_penalty - The frequency penalty for the model.
* @param {number} options.max_tokens - The maximum number of tokens for the model output.
* @param {boolean} options.streaming - Whether to use streaming for the model output.
* @param {Object} options.context - The context for the conversation.
* @param {number} options.tokenBuffer - The token buffer size.
* @param {number} options.initialMessageCount - The initial message count.
* @param {string} options.conversationId - The ID of the conversation.
* @param {string} options.user - The user identifier.
* @param {string} options.langchainProxy - The langchain proxy URL.
* @param {boolean} options.useOpenRouter - Whether to use OpenRouter.
* @param {Object} options.options - Additional options.
* @param {Object} options.options.headers - Custom headers for the request.
* @param {string} options.options.proxy - Proxy URL.
* @param {Object} options.options.req - The request object.
* @param {Object} options.options.res - The response object.
* @param {boolean} options.options.debug - Whether to enable debug mode.
* @param {string} options.apiKey - The API key for authentication.
* @param {Object} options.azure - Azure-specific configuration.
* @param {Object} options.abortController - The AbortController instance.
* @returns {Object} The initialized LLM instance.
*/
function initializeLLM(options) {
const {
model,
modelName,
temperature,
presence_penalty,
frequency_penalty,
max_tokens,
streaming,
user,
langchainProxy,
useOpenRouter,
options: { headers, proxy },
apiKey,
azure,
} = options;
const modelOptions = {
modelName: modelName || model,
temperature,
presence_penalty,
frequency_penalty,
user,
};
if (max_tokens) {
modelOptions.max_tokens = max_tokens;
}
const configOptions = {};
if (langchainProxy) {
configOptions.basePath = langchainProxy;
}
if (useOpenRouter) {
configOptions.basePath = 'https://openrouter.ai/api/v1';
configOptions.baseOptions = {
headers: {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
};
}
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
configOptions.baseOptions = {
headers: resolveHeaders({
...headers,
...configOptions?.baseOptions?.headers,
}),
};
}
if (proxy) {
configOptions.httpAgent = new HttpsProxyAgent(proxy);
configOptions.httpsAgent = new HttpsProxyAgent(proxy);
}
const llm = createLLM({
modelOptions,
configOptions,
openAIApiKey: apiKey,
azure,
streaming,
});
return llm;
}
module.exports = {
initializeLLM,
};

View File

@@ -1,5 +1,3 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
handleAbortError,
@@ -7,19 +5,17 @@ const {
cleanupAbortController,
} = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
isRegenerate,
endpointOption,
conversationId,
isContinued = false,
editedContent = null,
parentMessageId = null,
overrideParentMessageId = null,
responseMessageId: editedResponseMessageId = null,
} = req.body;
let sender;
@@ -71,7 +67,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
handler();
}
} catch (e) {
logger.error('[AgentController] Error in cleanup handler', e);
// Ignore cleanup errors
}
}
}
@@ -159,7 +155,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
try {
res.removeListener('close', closeHandler);
} catch (e) {
logger.error('[AgentController] Error removing close listener', e);
// Ignore
}
});
@@ -167,15 +163,10 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
user: userId,
onStart,
getReqData,
isContinued,
isRegenerate,
editedContent,
conversationId,
parentMessageId,
abortController,
overrideParentMessageId,
isEdited: !!editedContent,
responseMessageId: editedResponseMessageId,
progressOptions: {
res,
},
@@ -215,7 +206,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Create a new response object with minimal copies
const finalResponse = { ...response };
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
title: conversation.title,
@@ -237,7 +228,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Save user message if needed
if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, {
context: "api/server/controllers/agents/request.js - don't skip saving user message",
context: 'api/server/controllers/agents/request.js - don\'t skip saving user message',
});
}

View File

@@ -0,0 +1,94 @@
const { Run, Providers } = require('@librechat/agents');
const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider');
/**
* @typedef {import('@librechat/agents').t} t
* @typedef {import('@librechat/agents').StandardGraphConfig} StandardGraphConfig
* @typedef {import('@librechat/agents').StreamEventData} StreamEventData
* @typedef {import('@librechat/agents').EventHandler} EventHandler
* @typedef {import('@librechat/agents').GraphEvents} GraphEvents
* @typedef {import('@librechat/agents').LLMConfig} LLMConfig
* @typedef {import('@librechat/agents').IState} IState
*/
const customProviders = new Set([
Providers.XAI,
Providers.OLLAMA,
Providers.DEEPSEEK,
Providers.OPENROUTER,
]);
/**
* Creates a new Run instance with custom handlers and configuration.
*
* @param {Object} options - The options for creating the Run instance.
* @param {ServerRequest} [options.req] - The server request.
* @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated.
* @param {Agent} options.agent - The agent for this run.
* @param {AbortSignal} options.signal - The signal for this run.
* @param {Record<GraphEvents, EventHandler> | undefined} [options.customHandlers] - Custom event handlers.
* @param {boolean} [options.streaming=true] - Whether to use streaming.
* @param {boolean} [options.streamUsage=true] - Whether to stream usage information.
* @returns {Promise<Run<IState>>} A promise that resolves to a new Run instance.
*/
async function createRun({
runId,
agent,
signal,
customHandlers,
streaming = true,
streamUsage = true,
}) {
const provider = providerEndpointMap[agent.provider] ?? agent.provider;
/** @type {LLMConfig} */
const llmConfig = Object.assign(
{
provider,
streaming,
streamUsage,
},
agent.model_parameters,
);
/** Resolves issues with new OpenAI usage field */
if (
customProviders.has(agent.provider) ||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
) {
llmConfig.streamUsage = false;
llmConfig.usage = true;
}
/** @type {'reasoning_content' | 'reasoning'} */
let reasoningKey;
if (
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
) {
reasoningKey = 'reasoning';
}
/** @type {StandardGraphConfig} */
const graphConfig = {
signal,
llmConfig,
reasoningKey,
tools: agent.tools,
instructions: agent.instructions,
additional_instructions: agent.additional_instructions,
// toolEnd: agent.end_after_tools,
};
// TEMPORARY FOR TESTING
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
graphConfig.streamBuffer = 2000;
}
return Run.create({
runId,
graphConfig,
customHandlers,
});
}
module.exports = { createRun };

View File

@@ -1,16 +1,13 @@
const { z } = require('zod');
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
const {
Tools,
Constants,
FileContext,
FileSources,
SystemRoles,
EToolResources,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const {
getAgent,
@@ -19,21 +16,19 @@ const {
deleteAgent,
getListAgents,
} = require('~/models/Agent');
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
const { filterFile } = require('~/server/services/Files/process');
const { updateAction, getActions } = require('~/models/Action');
const { getCachedTools } = require('~/server/services/Config');
const { updateAgentProjects } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { revertAgentVersion } = require('~/models/Agent');
const { deleteFileByFilter } = require('~/models/File');
const { revertAgentVersion } = require('~/models/Agent');
const { logger } = require('~/config');
const systemTools = {
[Tools.execute_code]: true,
[Tools.file_search]: true,
[Tools.web_search]: true,
};
/**
@@ -46,18 +41,13 @@ const systemTools = {
*/
const createAgentHandler = async (req, res) => {
try {
const validatedData = agentCreateSchema.parse(req.body);
const { tools = [], ...agentData } = removeNullishValues(validatedData);
const { tools = [], provider, name, description, instructions, model, ...agentData } = req.body;
const { id: userId } = req.user;
agentData.id = `agent_${nanoid()}`;
agentData.author = userId;
agentData.tools = [];
const availableTools = await getCachedTools({ includeGlobal: true });
for (const tool of tools) {
if (availableTools[tool]) {
if (req.app.locals.availableTools[tool]) {
agentData.tools.push(tool);
}
@@ -66,13 +56,19 @@ const createAgentHandler = async (req, res) => {
}
}
Object.assign(agentData, {
author: userId,
name,
description,
instructions,
provider,
model,
});
agentData.id = `agent_${nanoid()}`;
const agent = await createAgent(agentData);
res.status(201).json(agent);
} catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents] Error creating agent', error);
res.status(500).json({ error: error.message });
}
@@ -156,16 +152,14 @@ const getAgentHandler = async (req, res) => {
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const validatedData = agentUpdateSchema.parse(req.body);
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
const { projectIds, removeProjectIds, ...updateData } = req.body;
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id });
const isAuthor = existingAgent.author.toString() === req.user.id;
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
@@ -174,18 +168,12 @@ const updateAgentHandler = async (req, res) => {
});
}
/** @type {boolean} */
const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0;
let updatedAgent =
Object.keys(updateData).length > 0
? await updateAgent({ id }, updateData, {
updatingUserId: req.user.id,
skipVersioning: isProjectUpdate,
})
? await updateAgent({ id }, updateData, { updatingUserId: req.user.id })
: existingAgent;
if (isProjectUpdate) {
if (projectIds || removeProjectIds) {
updatedAgent = await updateAgentProjects({
user: req.user,
agentId: id,
@@ -204,11 +192,6 @@ const updateAgentHandler = async (req, res) => {
return res.json(updatedAgent);
} catch (error) {
if (error instanceof z.ZodError) {
logger.error('[/Agents/:id] Validation error', error.errors);
return res.status(400).json({ error: 'Invalid request data', details: error.errors });
}
logger.error('[/Agents/:id] Error updating Agent', error);
if (error.statusCode === 409) {
@@ -251,8 +234,6 @@ const duplicateAgentHandler = async (req, res) => {
createdAt: _createdAt,
updatedAt: _updatedAt,
tool_resources: _tool_resources = {},
versions: _versions,
__v: _v,
...cloneData
} = agent;
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
@@ -391,45 +372,21 @@ const uploadAgentAvatarHandler = async (req, res) => {
return res.status(400).json({ message: 'Agent ID is required' });
}
const isAdmin = req.user.role === SystemRoles.ADMIN;
const existingAgent = await getAgent({ id: agent_id });
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const isAuthor = existingAgent.author.toString() === req.user.id;
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
return res.status(403).json({
error: 'You do not have permission to modify this non-collaborative agent',
});
}
const buffer = await fs.readFile(req.file.path);
const fileStrategy = req.app.locals.fileStrategy;
const resizedBuffer = await resizeAvatar({
userId: req.user.id,
input: buffer,
const image = await uploadImageBuffer({
req,
context: FileContext.avatar,
metadata: { buffer },
});
const { processAvatar } = getStrategyFunctions(fileStrategy);
const avatarUrl = await processAvatar({
buffer: resizedBuffer,
userId: req.user.id,
manual: 'false',
agentId: agent_id,
});
const image = {
filepath: avatarUrl,
source: fileStrategy,
};
let _avatar = existingAgent.avatar;
let _avatar;
try {
const agent = await getAgent({ id: agent_id });
_avatar = agent.avatar;
} catch (error) {
logger.error('[/:agent_id/avatar] Error fetching agent', error);
_avatar = {};
}
if (_avatar && _avatar.source) {
const { deleteFile } = getStrategyFunctions(_avatar.source);
@@ -446,12 +403,12 @@ const uploadAgentAvatarHandler = async (req, res) => {
const data = {
avatar: {
filepath: image.filepath,
source: image.source,
source: req.app.locals.fileStrategy,
},
};
promises.push(
await updateAgent({ id: agent_id }, data, {
await updateAgent({ id: agent_id, author: req.user.id }, data, {
updatingUserId: req.user.id,
}),
);
@@ -466,7 +423,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
try {
await fs.unlink(req.file.path);
logger.debug('[/:agent_id/avatar] Temp. image upload file deleted');
} catch {
} catch (error) {
logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted');
}
}

View File

@@ -1,659 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { agentSchema } = require('@librechat/data-schemas');
// Only mock the dependencies that are not database-related
jest.mock('~/server/services/Config', () => ({
getCachedTools: jest.fn().mockResolvedValue({
web_search: true,
execute_code: true,
file_search: true,
}),
}));
jest.mock('~/models/Project', () => ({
getProjectByName: jest.fn().mockResolvedValue(null),
}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/services/Files/images/avatar', () => ({
resizeAvatar: jest.fn(),
}));
jest.mock('~/server/services/Files/S3/crud', () => ({
refreshS3Url: jest.fn(),
}));
jest.mock('~/server/services/Files/process', () => ({
filterFile: jest.fn(),
}));
jest.mock('~/models/Action', () => ({
updateAction: jest.fn(),
getActions: jest.fn().mockResolvedValue([]),
}));
jest.mock('~/models/File', () => ({
deleteFileByFilter: jest.fn(),
}));
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
/**
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
*/
let Agent;
describe('Agent Controllers - Mass Assignment Protection', () => {
let mongoServer;
let mockReq;
let mockRes;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
}, 20000);
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await Agent.deleteMany({});
// Reset all mocks
jest.clearAllMocks();
// Setup mock request and response objects
mockReq = {
user: {
id: new mongoose.Types.ObjectId().toString(),
role: 'USER',
},
body: {},
params: {},
app: {
locals: {
fileStrategy: 'local',
},
},
};
mockRes = {
status: jest.fn().mockReturnThis(),
json: jest.fn().mockReturnThis(),
};
});
describe('createAgentHandler', () => {
test('should create agent with allowed fields only', async () => {
const validData = {
name: 'Test Agent',
description: 'A test agent',
instructions: 'Be helpful',
provider: 'openai',
model: 'gpt-4',
tools: ['web_search'],
model_parameters: { temperature: 0.7 },
tool_resources: {
file_search: { file_ids: ['file1', 'file2'] },
},
};
mockReq.body = validData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
expect(mockRes.json).toHaveBeenCalled();
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.name).toBe('Test Agent');
expect(createdAgent.description).toBe('A test agent');
expect(createdAgent.provider).toBe('openai');
expect(createdAgent.model).toBe('gpt-4');
expect(createdAgent.author.toString()).toBe(mockReq.user.id);
expect(createdAgent.tools).toContain('web_search');
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb).toBeDefined();
expect(agentInDb.name).toBe('Test Agent');
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
});
test('should reject creation with unauthorized fields (mass assignment protection)', async () => {
const maliciousData = {
// Required fields
provider: 'openai',
model: 'gpt-4',
name: 'Malicious Agent',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to set author
authorName: 'Hacker', // Should be stripped
isCollaborative: true, // Should be stripped on creation
versions: [], // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
id: 'custom_agent_id', // Should be overridden
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
mockReq.body = maliciousData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not set
expect(createdAgent.author.toString()).toBe(mockReq.user.id); // Should be the request user, not the malicious value
expect(createdAgent.authorName).toBeUndefined();
expect(createdAgent.isCollaborative).toBeFalsy();
expect(createdAgent.versions).toHaveLength(1); // Should have exactly 1 version from creation
expect(createdAgent.id).not.toBe('custom_agent_id'); // Should have generated ID
expect(createdAgent.id).toMatch(/^agent_/); // Should have proper prefix
// Verify timestamps are recent (not the malicious dates)
const createdTime = new Date(createdAgent.createdAt).getTime();
const now = Date.now();
expect(now - createdTime).toBeLessThan(5000); // Created within last 5 seconds
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(mockReq.user.id);
expect(agentInDb.authorName).toBeUndefined();
});
test('should validate required fields', async () => {
const invalidData = {
name: 'Missing Required Fields',
// Missing provider and model
};
mockReq.body = invalidData;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
// Verify nothing was created in database
const count = await Agent.countDocuments();
expect(count).toBe(0);
});
test('should handle tool_resources validation', async () => {
const dataWithInvalidToolResources = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Tool Resources',
tool_resources: {
// Valid resources
file_search: {
file_ids: ['file1', 'file2'],
vector_store_ids: ['vs1'],
},
execute_code: {
file_ids: ['file3'],
},
// Invalid resource (should be stripped by schema)
invalid_resource: {
file_ids: ['file4'],
},
},
};
mockReq.body = dataWithInvalidToolResources;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.tool_resources).toBeDefined();
expect(createdAgent.tool_resources.file_search).toBeDefined();
expect(createdAgent.tool_resources.execute_code).toBeDefined();
expect(createdAgent.tool_resources.invalid_resource).toBeUndefined(); // Should be stripped
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
});
test('should handle avatar validation', async () => {
const dataWithAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Avatar',
avatar: {
filepath: 'https://example.com/avatar.png',
source: 's3',
},
};
mockReq.body = dataWithAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
expect(createdAgent.avatar).toEqual({
filepath: 'https://example.com/avatar.png',
source: 's3',
});
});
test('should handle invalid avatar format', async () => {
const dataWithInvalidAvatar = {
provider: 'openai',
model: 'gpt-4',
name: 'Agent with Invalid Avatar',
avatar: 'just-a-string', // Invalid format
};
mockReq.body = dataWithInvalidAvatar;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
}),
);
});
});
describe('updateAgentHandler', () => {
let existingAgentId;
let existingAgentAuthorId;
beforeEach(async () => {
// Create an existing agent for update tests
existingAgentAuthorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
author: existingAgentAuthorId,
description: 'Original description',
isCollaborative: false,
versions: [
{
name: 'Original Agent',
provider: 'openai',
model: 'gpt-3.5-turbo',
description: 'Original description',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
existingAgentId = agent.id;
});
test('should update agent with allowed fields only', async () => {
mockReq.user.id = existingAgentAuthorId.toString(); // Set as author
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Agent',
description: 'Updated description',
model: 'gpt-4',
isCollaborative: true, // This IS allowed in updates
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(400);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Updated Agent');
expect(updatedAgent.description).toBe('Updated description');
expect(updatedAgent.model).toBe('gpt-4');
expect(updatedAgent.isCollaborative).toBe(true);
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Updated Agent');
expect(agentInDb.isCollaborative).toBe(true);
});
test('should reject update with unauthorized fields (mass assignment protection)', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Updated Name',
// Unauthorized fields that should be stripped
author: new mongoose.Types.ObjectId().toString(), // Should not be able to change author
authorName: 'Hacker', // Should be stripped
id: 'different_agent_id', // Should be stripped
_id: new mongoose.Types.ObjectId(), // Should be stripped
versions: [], // Should be stripped
createdAt: new Date('2020-01-01'), // Should be stripped
updatedAt: new Date('2020-01-01'), // Should be stripped
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
// Verify unauthorized fields were not changed
expect(updatedAgent.author).toBe(existingAgentAuthorId.toString()); // Should not have changed
expect(updatedAgent.authorName).toBeUndefined();
expect(updatedAgent.id).toBe(existingAgentId); // Should not have changed
expect(updatedAgent.name).toBe('Updated Name'); // Only this should have changed
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.author.toString()).toBe(existingAgentAuthorId.toString());
expect(agentInDb.id).toBe(existingAgentId);
});
test('should reject update from non-author when not collaborative', async () => {
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Unauthorized Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalledWith({
error: 'You do not have permission to modify this non-collaborative agent',
});
// Verify agent was not modified in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Original Agent');
});
test('should allow update from non-author when collaborative', async () => {
// First make the agent collaborative
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
const differentUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = differentUserId; // Different user
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Collaborative Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Collaborative Update');
// Author field should be removed for non-author
expect(updatedAgent.author).toBeUndefined();
// Verify in database
const agentInDb = await Agent.findOne({ id: existingAgentId });
expect(agentInDb.name).toBe('Collaborative Update');
});
test('should allow admin to update any agent', async () => {
const adminUserId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = adminUserId;
mockReq.user.role = 'ADMIN'; // Set as admin
mockReq.params.id = existingAgentId;
mockReq.body = {
name: 'Admin Update',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).not.toHaveBeenCalledWith(403);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.name).toBe('Admin Update');
});
test('should handle projectIds updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
const projectId1 = new mongoose.Types.ObjectId().toString();
const projectId2 = new mongoose.Types.ObjectId().toString();
mockReq.body = {
projectIds: [projectId1, projectId2],
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent).toBeDefined();
// Note: updateAgentProjects requires more setup, so we just verify the handler doesn't crash
});
test('should validate tool_resources in updates', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
tool_resources: {
ocr: {
file_ids: ['ocr1', 'ocr2'],
},
execute_code: {
file_ids: ['img1'],
},
// Invalid tool resource
invalid_tool: {
file_ids: ['invalid'],
},
},
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.json).toHaveBeenCalled();
const updatedAgent = mockRes.json.mock.calls[0][0];
expect(updatedAgent.tool_resources).toBeDefined();
expect(updatedAgent.tool_resources.ocr).toBeDefined();
expect(updatedAgent.tool_resources.execute_code).toBeDefined();
expect(updatedAgent.tool_resources.invalid_tool).toBeUndefined();
});
test('should return 404 for non-existent agent', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = `agent_${uuidv4()}`; // Non-existent ID
mockReq.body = {
name: 'Update Non-existent',
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(404);
expect(mockRes.json).toHaveBeenCalledWith({ error: 'Agent not found' });
});
test('should handle validation errors properly', async () => {
mockReq.user.id = existingAgentAuthorId.toString();
mockReq.params.id = existingAgentId;
mockReq.body = {
model_parameters: 'invalid-not-an-object', // Should be an object
};
await updateAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(400);
expect(mockRes.json).toHaveBeenCalledWith(
expect.objectContaining({
error: 'Invalid request data',
details: expect.any(Array),
}),
);
});
});
describe('Mass Assignment Attack Scenarios', () => {
test('should prevent setting system fields during creation', async () => {
const systemFields = {
provider: 'openai',
model: 'gpt-4',
name: 'System Fields Test',
// System fields that should never be settable by users
__v: 99,
_id: new mongoose.Types.ObjectId(),
versions: [
{
name: 'Fake Version',
provider: 'fake',
model: 'fake-model',
},
],
};
mockReq.body = systemFields;
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify system fields were not affected
expect(createdAgent.__v).not.toBe(99);
expect(createdAgent.versions).toHaveLength(1); // Should only have the auto-created version
expect(createdAgent.versions[0].name).toBe('System Fields Test'); // From actual creation
expect(createdAgent.versions[0].provider).toBe('openai'); // From actual creation
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.__v).not.toBe(99);
});
test('should prevent privilege escalation through isCollaborative', async () => {
// Create a non-collaborative agent
const authorId = new mongoose.Types.ObjectId();
const agent = await Agent.create({
id: `agent_${uuidv4()}`,
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
author: authorId,
isCollaborative: false,
versions: [
{
name: 'Private Agent',
provider: 'openai',
model: 'gpt-4',
createdAt: new Date(),
updatedAt: new Date(),
},
],
});
// Try to make it collaborative as a different user
const attackerId = new mongoose.Types.ObjectId().toString();
mockReq.user.id = attackerId;
mockReq.params.id = agent.id;
mockReq.body = {
isCollaborative: true, // Trying to escalate privileges
};
await updateAgentHandler(mockReq, mockRes);
// Should be rejected
expect(mockRes.status).toHaveBeenCalledWith(403);
// Verify in database that it's still not collaborative
const agentInDb = await Agent.findOne({ id: agent.id });
expect(agentInDb.isCollaborative).toBe(false);
});
test('should prevent author hijacking', async () => {
const originalAuthorId = new mongoose.Types.ObjectId();
const attackerId = new mongoose.Types.ObjectId();
// Admin creates an agent
mockReq.user.id = originalAuthorId.toString();
mockReq.user.role = 'ADMIN';
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Admin Agent',
author: attackerId.toString(), // Trying to set different author
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Author should be the actual user, not the attempted value
expect(createdAgent.author.toString()).toBe(originalAuthorId.toString());
expect(createdAgent.author.toString()).not.toBe(attackerId.toString());
// Verify in database
const agentInDb = await Agent.findOne({ id: createdAgent.id });
expect(agentInDb.author.toString()).toBe(originalAuthorId.toString());
});
test('should strip unknown fields to prevent future vulnerabilities', async () => {
mockReq.body = {
provider: 'openai',
model: 'gpt-4',
name: 'Future Proof Test',
// Unknown fields that might be added in future
superAdminAccess: true,
bypassAllChecks: true,
internalFlag: 'secret',
futureFeature: 'exploit',
};
await createAgentHandler(mockReq, mockRes);
expect(mockRes.status).toHaveBeenCalledWith(201);
const createdAgent = mockRes.json.mock.calls[0][0];
// Verify unknown fields were stripped
expect(createdAgent.superAdminAccess).toBeUndefined();
expect(createdAgent.bypassAllChecks).toBeUndefined();
expect(createdAgent.internalFlag).toBeUndefined();
expect(createdAgent.futureFeature).toBeUndefined();
// Also check in database
const agentInDb = await Agent.findOne({ id: createdAgent.id }).lean();
expect(agentInDb.superAdminAccess).toBeUndefined();
expect(agentInDb.bypassAllChecks).toBeUndefined();
expect(agentInDb.internalFlag).toBeUndefined();
expect(agentInDb.futureFeature).toBeUndefined();
});
});
});

View File

@@ -1,7 +1,4 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -22,20 +19,20 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { createRunBody } = require('~/server/services/createRunBody');
const { sendResponse } = require('~/server/middleware/error');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -474,7 +471,7 @@ const chatV1 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendEvent(res, {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -590,7 +587,7 @@ const chatV1 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,7 +1,4 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -25,14 +22,15 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { sendMessage, sleep, countTokens } = require('~/server/utils');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -311,7 +309,7 @@ const chatV2 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendEvent(res, {
sendMessage(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -434,7 +432,7 @@ const chatV2 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendEvent(res, {
sendMessage(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
const { sendResponse } = require('~/server/middleware/error');
const { getConvo } = require('~/models/Conversation');
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
/**
* @typedef {Object} ErrorHandlerContext
@@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -1,5 +1,4 @@
const fs = require('fs').promises;
const { logger } = require('@librechat/data-schemas');
const { FileContext } = require('librechat-data-provider');
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
@@ -7,9 +6,9 @@ const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { deleteAssistantActions } = require('~/server/services/ActionService');
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
const { getOpenAIClient, fetchAssistants } = require('./helpers');
const { getCachedTools } = require('~/server/services/Config');
const { manifestToolMap } = require('~/app/clients/tools');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
/**
* Create an assistant.
@@ -31,20 +30,21 @@ const createAssistant = async (req, res) => {
delete assistantData.conversation_starters;
delete assistantData.append_current_datetime;
const toolDefinitions = await getCachedTools({ includeGlobal: true });
assistantData.tools = tools
.map((tool) => {
if (typeof tool !== 'string') {
return tool;
}
const toolDefinitions = req.app.locals.availableTools;
const toolDef = toolDefinitions[tool];
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
return Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
.map(([_, val]) => val);
return (
Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
// eslint-disable-next-line no-unused-vars
.map(([_, val]) => val)
);
}
return toolDef;
@@ -135,21 +135,21 @@ const patchAssistant = async (req, res) => {
append_current_datetime,
...updateData
} = req.body;
const toolDefinitions = await getCachedTools({ includeGlobal: true });
updateData.tools = (updateData.tools ?? [])
.map((tool) => {
if (typeof tool !== 'string') {
return tool;
}
const toolDefinitions = req.app.locals.availableTools;
const toolDef = toolDefinitions[tool];
if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) {
return Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
.map(([_, val]) => val);
return (
Object.entries(toolDefinitions)
.filter(([key]) => key.startsWith(`${tool}_`))
// eslint-disable-next-line no-unused-vars
.map(([_, val]) => val)
);
}
return toolDef;

Some files were not shown because too many files have changed in this diff Show More