Compare commits

..

1 Commits

Author SHA1 Message Date
Danny Avila
d69465ea3d refactor(useHandleKeyUp): enter from beginning only 2024-07-11 13:16:42 -04:00
360 changed files with 6535 additions and 16246 deletions

View File

@@ -87,6 +87,7 @@ ANTHROPIC_API_KEY=user_provided
# Azure #
#============#
# Note: these variables are DEPRECATED
# Use the `librechat.yaml` configuration for `azureOpenAI` instead
# You may also continue to use them if you opt out of using the `librechat.yaml` configuration
@@ -116,7 +117,7 @@ BINGAI_TOKEN=user_provided
GOOGLE_KEY=user_provided
# GOOGLE_REVERSE_PROXY=
# Gemini API (AI Studio)
# Gemini API
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
# Vertex AI
@@ -124,30 +125,26 @@ GOOGLE_KEY=user_provided
# GOOGLE_TITLE_MODEL=gemini-pro
# Google Safety Settings
# NOTE: These settings apply to both Vertex AI and Gemini API (AI Studio)
# Google Gemini Safety Settings
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
# To use this restricted HarmBlockThreshold setting, you will need to either:
#
# For Vertex AI:
# To use the BLOCK_NONE setting, you need either:
# (a) Access through an allowlist via your Google account team, or
# (b) Switch to monthly invoiced billing: https://cloud.google.com/billing/docs/how-to/invoiced-billing
#
# For Gemini API (AI Studio):
# BLOCK_NONE is available by default, no special account requirements.
#
# Available options: BLOCK_NONE, BLOCK_ONLY_HIGH, BLOCK_MEDIUM_AND_ABOVE, BLOCK_LOW_AND_ABOVE
# (a) Get access through an allowlist via your Google account team
# (b) Switch your account type to monthly invoiced billing following this instruction:
# https://cloud.google.com/billing/docs/how-to/invoiced-billing
#
# GOOGLE_SAFETY_SEXUALLY_EXPLICIT=BLOCK_ONLY_HIGH
# GOOGLE_SAFETY_HATE_SPEECH=BLOCK_ONLY_HIGH
# GOOGLE_SAFETY_HARASSMENT=BLOCK_ONLY_HIGH
# GOOGLE_SAFETY_DANGEROUS_CONTENT=BLOCK_ONLY_HIGH
#============#
# OpenAI #
#============#
OPENAI_API_KEY=user_provided
# OPENAI_MODELS=gpt-4o,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k
# OPENAI_MODELS=gpt-4o,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k
DEBUG_OPENAI=false
@@ -169,7 +166,7 @@ DEBUG_OPENAI=false
ASSISTANTS_API_KEY=user_provided
# ASSISTANTS_BASE_URL=
# ASSISTANTS_MODELS=gpt-4o,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
# ASSISTANTS_MODELS=gpt-4o,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
#==========================#
# Azure Assistants API #
@@ -191,7 +188,7 @@ ASSISTANTS_API_KEY=user_provided
# Plugins #
#============#
# PLUGIN_MODELS=gpt-4o,gpt-4o-mini,gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613
# PLUGIN_MODELS=gpt-4o,gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613
DEBUG_PLUGINS=true
@@ -264,6 +261,7 @@ MEILI_NO_ANALYTICS=true
MEILI_HOST=http://0.0.0.0:7700
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
#==================================================#
# Speech to Text & Text to Speech #
#==================================================#
@@ -271,16 +269,6 @@ MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
STT_API_KEY=
TTS_API_KEY=
#==================================================#
# RAG #
#==================================================#
# More info: https://www.librechat.ai/docs/configuration/rag_api
# RAG_OPENAI_BASEURL=
# RAG_OPENAI_API_KEY=
# EMBEDDINGS_PROVIDER=openai
# EMBEDDINGS_MODEL=text-embedding-3-small
#===================================================#
# User System #
#===================================================#
@@ -386,8 +374,6 @@ LDAP_BIND_CREDENTIALS=
LDAP_USER_SEARCH_BASE=
LDAP_SEARCH_FILTER=mail={{username}}
LDAP_CA_CERT_PATH=
# LDAP_TLS_REJECT_UNAUTHORIZED=
# LDAP_LOGIN_USES_USERNAME=true
# LDAP_ID=
# LDAP_USERNAME=
# LDAP_FULL_NAME=
@@ -425,18 +411,6 @@ FIREBASE_APP_ID=
ALLOW_SHARED_LINKS=true
ALLOW_SHARED_LINKS_PUBLIC=true
#==============================#
# Static File Cache Control #
#==============================#
# Leave commented out to use default of 1 month for max-age and 1 week for s-maxage
# NODE_ENV must be set to production for these to take effect
# STATIC_CACHE_MAX_AGE=604800
# STATIC_CACHE_S_MAX_AGE=259200
# If you have another service in front of your LibreChat doing compression, disable express based compression here
# DISABLE_COMPRESSION=true
#===================================================#
# UI #
#===================================================#

View File

@@ -12,7 +12,6 @@ module.exports = {
'plugin:react-hooks/recommended',
'plugin:jest/recommended',
'prettier',
'plugin:jsx-a11y/recommended',
],
ignorePatterns: [
'client/dist/**/*',
@@ -33,7 +32,7 @@ module.exports = {
jsx: true,
},
},
plugins: ['react', 'react-hooks', '@typescript-eslint', 'import', 'jsx-a11y'],
plugins: ['react', 'react-hooks', '@typescript-eslint', 'import'],
rules: {
'react/react-in-jsx-scope': 'off',
'@typescript-eslint/ban-ts-comment': ['error', { 'ts-ignore': 'allow' }],
@@ -66,7 +65,6 @@ module.exports = {
'no-restricted-syntax': 'off',
'react/prop-types': ['off'],
'react/display-name': ['off'],
'no-nested-ternary': 'error',
'no-unused-vars': ['error', { varsIgnorePattern: '^_' }],
quotes: ['error', 'single'],
},
@@ -120,8 +118,6 @@ module.exports = {
],
rules: {
'@typescript-eslint/no-explicit-any': 'error',
'@typescript-eslint/no-unnecessary-condition': 'warn',
'@typescript-eslint/strict-boolean-expressions': 'warn',
},
},
{

View File

@@ -1,17 +0,0 @@
name: Lint for accessibility issues
on:
pull_request:
paths:
- 'client/src/**'
jobs:
axe-linter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dequelabs/axe-linter-action@v1
with:
api_key: ${{ secrets.AXE_LINTER_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,41 +0,0 @@
name: Update Test Server
on:
workflow_run:
workflows: ["Docker Dev Images Build"]
types:
- completed
workflow_dispatch:
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.repository == 'danny-avila/LibreChat' &&
(github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success')
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install SSH Key
uses: shimataro/ssh-key-action@v2
with:
key: ${{ secrets.DO_SSH_PRIVATE_KEY }}
known_hosts: ${{ secrets.DO_KNOWN_HOSTS }}
- name: Run update script on DigitalOcean Droplet
env:
DO_HOST: ${{ secrets.DO_HOST }}
DO_USER: ${{ secrets.DO_USER }}
run: |
ssh -o StrictHostKeyChecking=no ${DO_USER}@${DO_HOST} << EOF
sudo -i -u danny bash << EEOF
cd ~/LibreChat && \
git fetch origin main && \
npm run update:deployed && \
git checkout do-deploy && \
git rebase main && \
npm run start:deployed && \
echo "Update completed. Application should be running now."
EEOF
EOF

View File

@@ -1,35 +0,0 @@
name: Build Helm Charts on Tag
# The workflow is triggered when a tag is pushed
on:
push:
tags:
- "*"
jobs:
release:
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Configure Git
run: |
git config user.name "$GITHUB_ACTOR"
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
- name: Install Helm
uses: azure/setup-helm@v4
env:
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
- name: Run chart-releaser
uses: helm/chart-releaser-action@v1.6.0
with:
charts_dir: helmchart
env:
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@@ -1,4 +1,4 @@
# v0.7.4
# v0.7.3
# Base node image
FROM node:20-alpine AS node

View File

@@ -1,4 +1,4 @@
# v0.7.4
# v0.7.3
# Build API, Client and Data Provider
FROM node:20-alpine AS base
@@ -38,6 +38,6 @@ ENV HOST=0.0.0.0
CMD ["node", "server/index.js"]
# Nginx setup
FROM nginx:1.27.0-alpine AS prod-stage
FROM nginx:1.21.1-alpine AS prod-stage
COPY ./client/nginx.conf /etc/nginx/conf.d/default.conf
CMD ["nginx", "-g", "daemon off;"]

View File

@@ -50,7 +50,7 @@
- 🔄 Edit, Resubmit, and Continue Messages with Conversation branching
- 🌿 Fork Messages & Conversations for Advanced Context control
- 💬 Multimodal Chat:
- Upload and analyze images with Claude 3, GPT-4 (including `gpt-4o` and `gpt-4o-mini`), and Gemini Vision 📸
- Upload and analyze images with Claude 3, GPT-4 (including `gpt-4o`), and Gemini Vision 📸
- Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, & Google. 🗃️
- Advanced Agents with Files, Code Interpreter, Tools, and API Actions 🔦
- Available through the [OpenAI Assistants API](https://platform.openai.com/docs/assistants/overview) 🌤️

View File

@@ -2,10 +2,8 @@ const Anthropic = require('@anthropic-ai/sdk');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
Constants,
EModelEndpoint,
anthropicSettings,
getResponseSender,
EModelEndpoint,
validateVisionModel,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
@@ -18,7 +16,6 @@ const {
} = require('./prompts');
const spendTokens = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
@@ -32,8 +29,6 @@ function delayBeforeRetry(attempts, baseDelay = 1000) {
return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts));
}
const { legacy } = anthropicSettings;
class AnthropicClient extends BaseClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
@@ -66,20 +61,15 @@ class AnthropicClient extends BaseClient {
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
model: modelOptions.model || anthropicSettings.model.default,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || 'claude-1',
temperature: typeof modelOptions.temperature === 'undefined' ? 1 : modelOptions.temperature, // 0 - 1, 1 is default
topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
stop: modelOptions.stop, // no stop method for now
};
this.isClaude3 = this.modelOptions.model.includes('claude-3');
this.isLegacyOutput = !this.modelOptions.model.includes('claude-3-5-sonnet');
if (
this.isLegacyOutput &&
this.modelOptions.maxOutputTokens &&
this.modelOptions.maxOutputTokens > legacy.maxOutputTokens.default
) {
this.modelOptions.maxOutputTokens = legacy.maxOutputTokens.default;
}
this.useMessages = this.isClaude3 || !!this.options.attachments;
this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229';
@@ -129,11 +119,10 @@ class AnthropicClient extends BaseClient {
/**
* Get the initialized Anthropic client.
* @param {Partial<Anthropic.ClientOptions>} requestOptions - The options for the client.
* @returns {Anthropic} The Anthropic client instance.
*/
getClient(requestOptions) {
/** @type {Anthropic.ClientOptions} */
getClient() {
/** @type {Anthropic.default.RequestOptions} */
const options = {
fetch: this.fetch,
apiKey: this.apiKey,
@@ -147,12 +136,6 @@ class AnthropicClient extends BaseClient {
options.baseURL = this.options.reverseProxyUrl;
}
if (requestOptions?.model && requestOptions.model.includes('claude-3-5-sonnet')) {
options.defaultHeaders = {
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
};
}
return new Anthropic(options);
}
@@ -573,6 +556,8 @@ class AnthropicClient extends BaseClient {
}
logger.debug('modelOptions', { modelOptions });
const client = this.getClient();
const metadata = {
user_id: this.user,
};
@@ -600,7 +585,7 @@ class AnthropicClient extends BaseClient {
if (this.useMessages) {
requestOptions.messages = payload;
requestOptions.max_tokens = maxOutputTokens || legacy.maxOutputTokens.default;
requestOptions.max_tokens = maxOutputTokens || 1500;
} else {
requestOptions.prompt = payload;
requestOptions.max_tokens_to_sample = maxOutputTokens || 1500;
@@ -620,14 +605,12 @@ class AnthropicClient extends BaseClient {
};
const maxRetries = 3;
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
async function processResponse() {
let attempts = 0;
while (attempts < maxRetries) {
let response;
try {
const client = this.getClient(requestOptions);
response = await this.createResponse(client, requestOptions);
signal.addEventListener('abort', () => {
@@ -644,8 +627,6 @@ class AnthropicClient extends BaseClient {
} else if (completion.completion) {
handleChunk(completion.completion);
}
await sleep(streamRate);
}
// Successful processing, exit loop
@@ -756,11 +737,7 @@ class AnthropicClient extends BaseClient {
};
try {
const response = await this.createResponse(
this.getClient(requestOptions),
requestOptions,
true,
);
const response = await this.createResponse(this.getClient(), requestOptions, true);
let promptTokens = response?.usage?.input_tokens;
let completionTokens = response?.usage?.output_tokens;
if (!promptTokens) {

View File

@@ -1,11 +1,10 @@
const crypto = require('crypto');
const fetch = require('node-fetch');
const { supportsBalanceCheck, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache');
const TextStream = require('./TextStream');
const { logger } = require('~/config');
@@ -540,19 +539,7 @@ class BaseClient {
const completionTokens = this.getTokenCount(completion);
await this.recordTokenUsage({ promptTokens, completionTokens });
}
if (this.userMessagePromise) {
await this.userMessagePromise;
}
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
delete responseMessage.tokenCount;
return responseMessage;
}
@@ -611,41 +598,28 @@ class BaseClient {
* @param {string | null} user
*/
async saveMessageToDatabase(message, endpointOptions, user = null) {
if (this.user && user !== this.user) {
throw new Error('User mismatch.');
}
const savedMessage = await saveMessage(
this.options.req,
{
...message,
endpoint: this.options.endpoint,
unfinished: false,
user,
},
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' },
);
const savedMessage = await saveMessage({
...message,
endpoint: this.options.endpoint,
unfinished: false,
user,
});
if (this.skipSaveConvo) {
return { message: savedMessage };
}
const conversation = await saveConvo(
this.options.req,
{
conversationId: message.conversationId,
endpoint: this.options.endpoint,
endpointType: this.options.endpointType,
...endpointOptions,
},
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo' },
);
const conversation = await saveConvo(user, {
conversationId: message.conversationId,
endpoint: this.options.endpoint,
endpointType: this.options.endpointType,
...endpointOptions,
});
return { message: savedMessage, conversation };
}
async updateMessageInDatabase(message) {
await updateMessage(this.options.req, message);
await updateMessage(message);
}
/**

View File

@@ -13,12 +13,10 @@ const {
endpointSettings,
EModelEndpoint,
VisionModes,
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { getModelMaxTokens } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const {
formatMessage,
@@ -622,9 +620,8 @@ class GoogleClient extends BaseClient {
}
async getCompletion(_payload, options = {}) {
const { parameters, instances } = _payload;
const { onProgress, abortController } = options;
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
const { parameters, instances } = _payload;
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
let examples;
@@ -677,6 +674,7 @@ class GoogleClient extends BaseClient {
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
@@ -692,7 +690,8 @@ class GoogleClient extends BaseClient {
};
}
requestOptions.safetySettings = _payload.safetySettings;
const safetySettings = _payload.safetySettings;
requestOptions.safetySettings = safetySettings;
const delay = modelName.includes('flash') ? 8 : 14;
const result = await client.generateContentStream(requestOptions);
@@ -702,28 +701,21 @@ class GoogleClient extends BaseClient {
delay,
});
reply += chunkText;
await sleep(streamRate);
}
return reply;
}
const safetySettings = _payload.safetySettings;
const stream = await model.stream(messages, {
signal: abortController.signal,
timeout: 7000,
safetySettings: _payload.safetySettings,
safetySettings: safetySettings,
});
let delay = this.options.streamRate || 8;
if (!this.options.streamRate) {
if (this.isGenerativeModel) {
delay = 12;
}
if (modelName.includes('flash')) {
delay = 5;
}
let delay = this.isGenerativeModel ? 12 : 8;
if (modelName.includes('flash')) {
delay = 5;
}
for await (const chunk of stream) {
const chunkText = chunk?.content ?? chunk;
await this.generateTextStream(chunkText, onProgress, {
@@ -868,36 +860,38 @@ class GoogleClient extends BaseClient {
}
async sendCompletion(payload, opts = {}) {
payload.safetySettings = this.getSafetySettings();
const modelName = payload.parameters?.model;
if (modelName && modelName.toLowerCase().includes('gemini')) {
const safetySettings = [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold:
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold:
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
];
payload.safetySettings = safetySettings;
}
let reply = '';
reply = await this.getCompletion(payload, opts);
return reply.trim();
}
getSafetySettings() {
return [
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold:
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold:
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
},
];
}
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {

View File

@@ -1,9 +1,7 @@
const { z } = require('zod');
const axios = require('axios');
const { Ollama } = require('ollama');
const { Constants } = require('librechat-data-provider');
const { deriveBaseURL } = require('~/utils');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
const ollamaPayloadSchema = z.object({
@@ -42,7 +40,6 @@ const getValidBase64 = (imageUrl) => {
class OllamaClient {
constructor(options = {}) {
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
/** @type {Ollama} */
this.client = new Ollama({ host });
}
@@ -139,8 +136,6 @@ class OllamaClient {
stream.controller.abort();
break;
}
await sleep(this.streamRate);
}
}
// TODO: regular completion

View File

@@ -1182,10 +1182,8 @@ ${convo}
});
}
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
if (this.message_file_map && this.isOllama) {
const ollamaClient = new OllamaClient({ baseURL, streamRate });
const ollamaClient = new OllamaClient({ baseURL });
return await ollamaClient.chatCompletion({
payload: modelOptions,
onProgress,
@@ -1223,6 +1221,8 @@ ${convo}
}
});
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || '';
intermediateReply += token;
@@ -1232,7 +1232,9 @@ ${convo}
break;
}
await sleep(streamRate);
if (this.azure) {
await sleep(azureDelay);
}
}
if (!UnexpectedRoleError) {

View File

@@ -1,6 +1,5 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('langchain/callbacks');
const { CacheKeys, Time } = require('librechat-data-provider');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
@@ -12,7 +11,6 @@ const { SelfReflectionTool } = require('./tools');
const { isEnabled } = require('~/server/utils');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
class PluginsClient extends OpenAIClient {
@@ -222,13 +220,6 @@ class PluginsClient extends OpenAIClient {
}
}
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
@@ -248,15 +239,6 @@ class PluginsClient extends OpenAIClient {
}
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}

View File

@@ -60,10 +60,10 @@ function addImages(intermediateSteps, responseMessage) {
if (!observation || !observation.includes('![')) {
return;
}
const observedImagePath = observation.match(/!\[[^(]*\]\([^)]*\)/g);
const observedImagePath = observation.match(/!\[.*\]\([^)]*\)/g);
if (observedImagePath && !responseMessage.text.includes(observedImagePath[0])) {
responseMessage.text += '\n' + observedImagePath[0];
logger.debug('[addImages] added image from intermediateSteps:', observedImagePath[0]);
responseMessage.text += '\n' + observation;
logger.debug('[addImages] added image from intermediateSteps:', observation);
}
});
}

View File

@@ -81,62 +81,4 @@ describe('addImages', () => {
addImages(intermediateSteps, responseMessage);
expect(responseMessage.text).toBe(`${originalText}\n${imageMarkdown}`);
});
it('should extract only image markdowns when there is text between them', () => {
const markdownWithTextBetweenImages = `
![image1](/images/image1.png)
Some text between images that should not be included.
![image2](/images/image2.png)
More text that should be ignored.
![image3](/images/image3.png)
`;
intermediateSteps.push({ observation: markdownWithTextBetweenImages });
addImages(intermediateSteps, responseMessage);
expect(responseMessage.text).toBe('\n![image1](/images/image1.png)');
});
it('should only return the first image when multiple images are present', () => {
const markdownWithMultipleImages = `
![image1](/images/image1.png)
![image2](/images/image2.png)
![image3](/images/image3.png)
`;
intermediateSteps.push({ observation: markdownWithMultipleImages });
addImages(intermediateSteps, responseMessage);
expect(responseMessage.text).toBe('\n![image1](/images/image1.png)');
});
it('should not include any text or metadata surrounding the image markdown', () => {
const markdownWithMetadata = `
Title: Test Document
Author: John Doe
![image1](/images/image1.png)
Some content after the image.
Vector values: [0.1, 0.2, 0.3]
`;
intermediateSteps.push({ observation: markdownWithMetadata });
addImages(intermediateSteps, responseMessage);
expect(responseMessage.text).toBe('\n![image1](/images/image1.png)');
});
it('should handle complex markdown with multiple images and only return the first one', () => {
const complexMarkdown = `
# Document Title
## Section 1
Here's some text with an embedded image:
![image1](/images/image1.png)
## Section 2
More text here...
![image2](/images/image2.png)
### Subsection
Even more content
![image3](/images/image3.png)
`;
intermediateSteps.push({ observation: complexMarkdown });
addImages(intermediateSteps, responseMessage);
expect(responseMessage.text).toBe('\n![image1](/images/image1.png)');
});
});

View File

@@ -1,6 +1,4 @@
const { anthropicSettings } = require('librechat-data-provider');
const AnthropicClient = require('~/app/clients/AnthropicClient');
const AnthropicClient = require('../AnthropicClient');
const HUMAN_PROMPT = '\n\nHuman:';
const AI_PROMPT = '\n\nAssistant:';
@@ -24,7 +22,7 @@ describe('AnthropicClient', () => {
const options = {
modelOptions: {
model,
temperature: anthropicSettings.temperature.default,
temperature: 0.7,
},
};
client = new AnthropicClient('test-api-key');
@@ -35,42 +33,7 @@ describe('AnthropicClient', () => {
it('should set the options correctly', () => {
expect(client.apiKey).toBe('test-api-key');
expect(client.modelOptions.model).toBe(model);
expect(client.modelOptions.temperature).toBe(anthropicSettings.temperature.default);
});
it('should set legacy maxOutputTokens for non-Claude-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-2',
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
},
});
expect(client.modelOptions.maxOutputTokens).toBe(
anthropicSettings.legacy.maxOutputTokens.default,
);
});
it('should not set maxOutputTokens if not provided', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-3',
},
});
expect(client.modelOptions.maxOutputTokens).toBeUndefined();
});
it('should not set legacy maxOutputTokens for Claude-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-3-opus-20240229',
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
},
});
expect(client.modelOptions.maxOutputTokens).toBe(
anthropicSettings.legacy.maxOutputTokens.default,
);
expect(client.modelOptions.temperature).toBe(0.7);
});
});
@@ -173,57 +136,4 @@ describe('AnthropicClient', () => {
expect(prompt).toContain('You are Claude-2');
});
});
describe('getClient', () => {
it('should set legacy maxOutputTokens for non-Claude-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-2',
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
},
});
expect(client.modelOptions.maxOutputTokens).toBe(
anthropicSettings.legacy.maxOutputTokens.default,
);
});
it('should not set legacy maxOutputTokens for Claude-3 models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-3-opus-20240229',
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
},
});
expect(client.modelOptions.maxOutputTokens).toBe(
anthropicSettings.legacy.maxOutputTokens.default,
);
});
it('should add beta header for claude-3-5-sonnet model', () => {
const client = new AnthropicClient('test-api-key');
const modelOptions = {
model: 'claude-3-5-sonnet-20240307',
};
client.setOptions({ modelOptions });
const anthropicClient = client.getClient(modelOptions);
expect(anthropicClient._options.defaultHeaders).toBeDefined();
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
'max-tokens-3-5-sonnet-2024-07-15',
);
});
it('should not add beta header for other models', () => {
const client = new AnthropicClient('test-api-key');
client.setOptions({
modelOptions: {
model: 'claude-2',
},
});
const anthropicClient = client.getClient();
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
});
});
});

View File

@@ -1,7 +1,7 @@
const { Constants } = require('librechat-data-provider');
const { initializeFakeClient } = require('./FakeClient');
jest.mock('~/lib/db/connectDb');
jest.mock('../../../lib/db/connectDb');
jest.mock('~/models', () => ({
User: jest.fn(),
Key: jest.fn(),
@@ -631,32 +631,5 @@ describe('BaseClient', () => {
}),
);
});
test('userMessagePromise is awaited before saving response message', async () => {
// Mock the saveMessageToDatabase method
TestClient.saveMessageToDatabase = jest.fn().mockImplementation(() => {
return new Promise((resolve) => setTimeout(resolve, 100)); // Simulate a delay
});
// Send a message
const messagePromise = TestClient.sendMessage('Hello, world!');
// Wait a short time to ensure the user message save has started
await new Promise((resolve) => setTimeout(resolve, 50));
// Check that saveMessageToDatabase has been called once (for the user message)
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledTimes(1);
// Wait for the message to be fully processed
await messagePromise;
// Check that saveMessageToDatabase has been called twice (once for user message, once for response)
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledTimes(2);
// Check the order of calls
const calls = TestClient.saveMessageToDatabase.mock.calls;
expect(calls[0][0].isCreatedByUser).toBe(true); // First call should be for user message
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
});
});
});

View File

@@ -38,12 +38,7 @@ const run = async () => {
"On the other hand, we denounce with righteous indignation and dislike men who are so beguiled and demoralized by the charms of pleasure of the moment, so blinded by desire, that they cannot foresee the pain and trouble that are bound to ensue; and equal blame belongs to those who fail in their duty through weakness of will, which is the same as saying through shrinking from toil and pain. These cases are perfectly simple and easy to distinguish. In a free hour, when our power of choice is untrammelled and when nothing prevents our being able to do what we like best, every pleasure is to be welcomed and every pain avoided. But in certain circumstances and owing to the claims of duty or the obligations of business it will frequently occur that pleasures have to be repudiated and annoyances accepted. The wise man therefore always holds in these matters to this principle of selection: he rejects pleasures to secure other greater pleasures, or else he endures pains to avoid worse pains."
`;
const model = 'gpt-3.5-turbo';
let maxContextTokens = 4095;
if (model === 'gpt-4') {
maxContextTokens = 8191;
} else if (model === 'gpt-4-32k') {
maxContextTokens = 32767;
}
const maxContextTokens = model === 'gpt-4' ? 8191 : model === 'gpt-4-32k' ? 32767 : 4095; // 1 less than maximum
const clientOptions = {
reverseProxyUrl: process.env.OPENAI_REVERSE_PROXY || null,
maxContextTokens,

View File

@@ -12,15 +12,9 @@ class GoogleSearchResults extends Tool {
this.envVarApiKey = 'GOOGLE_SEARCH_API_KEY';
this.envVarSearchEngineId = 'GOOGLE_CSE_ID';
this.override = fields.override ?? false;
this.apiKey = fields[this.envVarApiKey] ?? getEnvironmentVariable(this.envVarApiKey);
this.apiKey = fields.apiKey ?? getEnvironmentVariable(this.envVarApiKey);
this.searchEngineId =
fields[this.envVarSearchEngineId] ?? getEnvironmentVariable(this.envVarSearchEngineId);
if (!this.override && (!this.apiKey || !this.searchEngineId)) {
throw new Error(
`Missing ${this.envVarApiKey} or ${this.envVarSearchEngineId} environment variable.`,
);
}
fields.searchEngineId ?? getEnvironmentVariable(this.envVarSearchEngineId);
this.kwargs = fields?.kwargs ?? {};
this.name = 'google';

View File

@@ -12,7 +12,7 @@ class TavilySearchResults extends Tool {
this.envVar = 'TAVILY_API_KEY';
/* Used to initialize the Tool without necessary variables. */
this.override = fields.override ?? false;
this.apiKey = fields[this.envVar] ?? this.getApiKey();
this.apiKey = fields.apiKey ?? this.getApiKey();
this.kwargs = fields?.kwargs ?? {};
this.name = 'tavily_search_results_json';
@@ -82,9 +82,7 @@ class TavilySearchResults extends Tool {
const json = await response.json();
if (!response.ok) {
throw new Error(
`Request failed with status ${response.status}: ${json?.detail?.error || json?.error}`,
);
throw new Error(`Request failed with status ${response.status}: ${json.error}`);
}
return JSON.stringify(json);

View File

@@ -1,50 +0,0 @@
const GoogleSearch = require('../GoogleSearch');
jest.mock('node-fetch');
jest.mock('@langchain/core/utils/env');
describe('GoogleSearch', () => {
let originalEnv;
const mockApiKey = 'mock_api';
const mockSearchEngineId = 'mock_search_engine_id';
beforeAll(() => {
originalEnv = { ...process.env };
});
beforeEach(() => {
jest.resetModules();
process.env = {
...originalEnv,
GOOGLE_SEARCH_API_KEY: mockApiKey,
GOOGLE_CSE_ID: mockSearchEngineId,
};
});
afterEach(() => {
jest.clearAllMocks();
process.env = originalEnv;
});
it('should use mockApiKey and mockSearchEngineId when environment variables are not set', () => {
const instance = new GoogleSearch({
GOOGLE_SEARCH_API_KEY: mockApiKey,
GOOGLE_CSE_ID: mockSearchEngineId,
});
expect(instance.apiKey).toBe(mockApiKey);
expect(instance.searchEngineId).toBe(mockSearchEngineId);
});
it('should throw an error if GOOGLE_SEARCH_API_KEY or GOOGLE_CSE_ID is missing', () => {
delete process.env.GOOGLE_SEARCH_API_KEY;
expect(() => new GoogleSearch()).toThrow(
'Missing GOOGLE_SEARCH_API_KEY or GOOGLE_CSE_ID environment variable.',
);
process.env.GOOGLE_SEARCH_API_KEY = mockApiKey;
delete process.env.GOOGLE_CSE_ID;
expect(() => new GoogleSearch()).toThrow(
'Missing GOOGLE_SEARCH_API_KEY or GOOGLE_CSE_ID environment variable.',
);
});
});

View File

@@ -1,38 +0,0 @@
const TavilySearchResults = require('../TavilySearchResults');
jest.mock('node-fetch');
jest.mock('@langchain/core/utils/env');
describe('TavilySearchResults', () => {
let originalEnv;
const mockApiKey = 'mock_api_key';
beforeAll(() => {
originalEnv = { ...process.env };
});
beforeEach(() => {
jest.resetModules();
process.env = {
...originalEnv,
TAVILY_API_KEY: mockApiKey,
};
});
afterEach(() => {
jest.clearAllMocks();
process.env = originalEnv;
});
it('should throw an error if TAVILY_API_KEY is missing', () => {
delete process.env.TAVILY_API_KEY;
expect(() => new TavilySearchResults()).toThrow('Missing TAVILY_API_KEY environment variable.');
});
it('should use mockApiKey when TAVILY_API_KEY is not set in the environment', () => {
const instance = new TavilySearchResults({
TAVILY_API_KEY: mockApiKey,
});
expect(instance.apiKey).toBe(mockApiKey);
});
});

View File

@@ -1,11 +1,13 @@
const Keyv = require('keyv');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo');
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
const THIRTY_MINUTES = 1800000;
const TEN_MINUTES = 600000;
const duration = math(BAN_DURATION, 7200000);
@@ -27,21 +29,17 @@ const roles = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const audioRuns = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
const messages = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES })
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES });
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
const tokenConfig = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
const genTitle = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
? new Keyv({ store: keyvRedis, ttl: 120000 })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
@@ -49,7 +47,7 @@ const modelQueries = isEnabled(process.env.USE_REDIS)
const abortKeys = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
const namespaces = {
[CacheKeys.ROLES]: roles,
@@ -69,7 +67,6 @@ const namespaces = {
registrations: createViolationInstance('registrations'),
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[ViolationTypes.CONVO_ACCESS]: createViolationInstance(ViolationTypes.CONVO_ACCESS),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
@@ -84,7 +81,6 @@ const namespaces = {
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,
};
/**

View File

@@ -109,14 +109,6 @@ const condenseArray = (item) => {
* @returns {string} - The formatted log message.
*/
const debugTraverse = winston.format.printf(({ level, message, timestamp, ...metadata }) => {
if (!message) {
return `${timestamp} ${level}`;
}
if (!message?.trim || typeof message !== 'string') {
return `${timestamp} ${level}: ${JSON.stringify(message)}`;
}
let msg = `${timestamp} ${level}: ${truncateLongStrings(message?.trim(), 150)}`;
try {
if (level !== 'debug') {

View File

@@ -2,20 +2,6 @@ const Conversation = require('./schema/convoSchema');
const { getMessages, deleteMessages } = require('./Message');
const logger = require('~/config/winston');
/**
* Searches for a conversation by conversationId and returns a lean document with only conversationId and user.
* @param {string} conversationId - The conversation's ID.
* @returns {Promise<{conversationId: string, user: string} | null>} The conversation object with selected fields or null if not found.
*/
const searchConversation = async (conversationId) => {
try {
return await Conversation.findOne({ conversationId }, 'conversationId user').lean();
} catch (error) {
logger.error('[searchConversation] Error searching conversation', error);
throw new Error('Error searching conversation');
}
};
/**
* Retrieves a single conversation for a given user and conversation ID.
* @param {string} user - The user's ID.
@@ -33,40 +19,22 @@ const getConvo = async (user, conversationId) => {
module.exports = {
Conversation,
searchConversation,
/**
* Saves a conversation to the database.
* @param {Object} req - The request object.
* @param {string} conversationId - The conversation's ID.
* @param {Object} metadata - Additional metadata to log for operation.
* @returns {Promise<TConversation>} The conversation object.
*/
saveConvo: async (req, { conversationId, newConversationId, ...convo }, metadata) => {
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
try {
if (metadata && metadata?.context) {
logger.debug(`[saveConvo] ${metadata.context}`);
}
const messages = await getMessages({ conversationId }, '_id');
const update = { ...convo, messages, user: req.user.id };
const update = { ...convo, messages, user };
if (newConversationId) {
update.conversationId = newConversationId;
}
const conversation = await Conversation.findOneAndUpdate(
{ conversationId, user: req.user.id },
update,
{
new: true,
upsert: true,
},
);
const conversation = await Conversation.findOneAndUpdate({ conversationId, user }, update, {
new: true,
upsert: true,
});
return conversation.toObject();
} catch (error) {
logger.error('[saveConvo] Error saving conversation', error);
if (metadata && metadata?.context) {
logger.info(`[saveConvo] ${metadata.context}`);
}
return { message: 'Error saving conversation' };
}
},
@@ -88,16 +56,13 @@ module.exports = {
throw new Error('Failed to save conversations in bulk.');
}
},
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false, tags) => {
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false) => {
const query = { user };
if (isArchived) {
query.isArchived = true;
} else {
query.$or = [{ isArchived: false }, { isArchived: { $exists: false } }];
}
if (Array.isArray(tags) && tags.length > 0) {
query.tags = { $in: tags };
}
try {
const totalConvos = (await Conversation.countDocuments(query)) || 1;
const totalPages = Math.ceil(totalConvos / pageSize);

View File

@@ -1,242 +0,0 @@
const ConversationTag = require('./schema/conversationTagSchema');
const Conversation = require('./schema/convoSchema');
const logger = require('~/config/winston');
/**
* Retrieves all conversation tags for a user.
* @param {string} user - The user ID.
* @returns {Promise<Array>} An array of conversation tags.
*/
const getConversationTags = async (user) => {
try {
return await ConversationTag.find({ user }).sort({ position: 1 }).lean();
} catch (error) {
logger.error('[getConversationTags] Error getting conversation tags', error);
throw new Error('Error getting conversation tags');
}
};
/**
* Creates a new conversation tag.
* @param {string} user - The user ID.
* @param {Object} data - The tag data.
* @param {string} data.tag - The tag name.
* @param {string} [data.description] - The tag description.
* @param {boolean} [data.addToConversation] - Whether to add the tag to a conversation.
* @param {string} [data.conversationId] - The conversation ID to add the tag to.
* @returns {Promise<Object>} The created tag.
*/
const createConversationTag = async (user, data) => {
try {
const { tag, description, addToConversation, conversationId } = data;
const existingTag = await ConversationTag.findOne({ user, tag }).lean();
if (existingTag) {
return existingTag;
}
const maxPosition = await ConversationTag.findOne({ user }).sort('-position').lean();
const position = (maxPosition?.position || 0) + 1;
const newTag = await ConversationTag.findOneAndUpdate(
{ tag, user },
{
tag,
user,
count: addToConversation ? 1 : 0,
position,
description,
$setOnInsert: { createdAt: new Date() },
},
{
new: true,
upsert: true,
lean: true,
},
);
if (addToConversation && conversationId) {
await Conversation.findOneAndUpdate(
{ user, conversationId },
{ $addToSet: { tags: tag } },
{ new: true },
);
}
return newTag;
} catch (error) {
logger.error('[createConversationTag] Error creating conversation tag', error);
throw new Error('Error creating conversation tag');
}
};
/**
* Updates an existing conversation tag.
* @param {string} user - The user ID.
* @param {string} oldTag - The current tag name.
* @param {Object} data - The updated tag data.
* @param {string} [data.tag] - The new tag name.
* @param {string} [data.description] - The updated description.
* @param {number} [data.position] - The new position.
* @returns {Promise<Object>} The updated tag.
*/
const updateConversationTag = async (user, oldTag, data) => {
try {
const { tag: newTag, description, position } = data;
const existingTag = await ConversationTag.findOne({ user, tag: oldTag }).lean();
if (!existingTag) {
return null;
}
if (newTag && newTag !== oldTag) {
const tagAlreadyExists = await ConversationTag.findOne({ user, tag: newTag }).lean();
if (tagAlreadyExists) {
throw new Error('Tag already exists');
}
await Conversation.updateMany({ user, tags: oldTag }, { $set: { 'tags.$': newTag } });
}
const updateData = {};
if (newTag) {
updateData.tag = newTag;
}
if (description !== undefined) {
updateData.description = description;
}
if (position !== undefined) {
await adjustPositions(user, existingTag.position, position);
updateData.position = position;
}
return await ConversationTag.findOneAndUpdate({ user, tag: oldTag }, updateData, {
new: true,
lean: true,
});
} catch (error) {
logger.error('[updateConversationTag] Error updating conversation tag', error);
throw new Error('Error updating conversation tag');
}
};
/**
* Adjusts positions of tags when a tag's position is changed.
* @param {string} user - The user ID.
* @param {number} oldPosition - The old position of the tag.
* @param {number} newPosition - The new position of the tag.
* @returns {Promise<void>}
*/
const adjustPositions = async (user, oldPosition, newPosition) => {
if (oldPosition === newPosition) {
return;
}
const update = oldPosition < newPosition ? { $inc: { position: -1 } } : { $inc: { position: 1 } };
await ConversationTag.updateMany(
{
user,
position: {
$gt: Math.min(oldPosition, newPosition),
$lte: Math.max(oldPosition, newPosition),
},
},
update,
);
};
/**
* Deletes a conversation tag.
* @param {string} user - The user ID.
* @param {string} tag - The tag to delete.
* @returns {Promise<Object>} The deleted tag.
*/
const deleteConversationTag = async (user, tag) => {
try {
const deletedTag = await ConversationTag.findOneAndDelete({ user, tag }).lean();
if (!deletedTag) {
return null;
}
await Conversation.updateMany({ user, tags: tag }, { $pull: { tags: tag } });
await ConversationTag.updateMany(
{ user, position: { $gt: deletedTag.position } },
{ $inc: { position: -1 } },
);
return deletedTag;
} catch (error) {
logger.error('[deleteConversationTag] Error deleting conversation tag', error);
throw new Error('Error deleting conversation tag');
}
};
/**
* Updates tags for a specific conversation.
* @param {string} user - The user ID.
* @param {string} conversationId - The conversation ID.
* @param {string[]} tags - The new set of tags for the conversation.
* @returns {Promise<string[]>} The updated list of tags for the conversation.
*/
const updateTagsForConversation = async (user, conversationId, tags) => {
try {
const conversation = await Conversation.findOne({ user, conversationId }).lean();
if (!conversation) {
throw new Error('Conversation not found');
}
const oldTags = new Set(conversation.tags);
const newTags = new Set(tags);
const addedTags = [...newTags].filter((tag) => !oldTags.has(tag));
const removedTags = [...oldTags].filter((tag) => !newTags.has(tag));
const bulkOps = [];
for (const tag of addedTags) {
bulkOps.push({
updateOne: {
filter: { user, tag },
update: { $inc: { count: 1 } },
upsert: true,
},
});
}
for (const tag of removedTags) {
bulkOps.push({
updateOne: {
filter: { user, tag },
update: { $inc: { count: -1 } },
},
});
}
if (bulkOps.length > 0) {
await ConversationTag.bulkWrite(bulkOps);
}
const updatedConversation = (
await Conversation.findOneAndUpdate(
{ user, conversationId },
{ $set: { tags: [...newTags] } },
{ new: true },
)
).toObject();
return updatedConversation.tags;
} catch (error) {
logger.error('[updateTagsForConversation] Error updating tags', error);
throw new Error('Error updating tags for conversation');
}
};
module.exports = {
getConversationTags,
createConversationTag,
updateConversationTag,
deleteConversationTag,
updateTagsForConversation,
};

View File

@@ -1,342 +1,204 @@
const { z } = require('zod');
const Message = require('./schema/messageSchema');
const { logger } = require('~/config');
const logger = require('~/config/winston');
const idSchema = z.string().uuid();
/**
* Saves a message in the database.
*
* @async
* @function saveMessage
* @param {Express.Request} req - The request object containing user information.
* @param {Object} params - The message data object.
* @param {string} params.endpoint - The endpoint where the message originated.
* @param {string} params.iconURL - The URL of the sender's icon.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.newMessageId - The new unique identifier for the message (if applicable).
* @param {string} params.conversationId - The identifier of the conversation.
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
* @param {string} params.sender - The identifier of the sender.
* @param {string} params.text - The text content of the message.
* @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user.
* @param {string} [params.error] - Any error associated with the message.
* @param {boolean} [params.unfinished] - Indicates if the message is unfinished.
* @param {Object[]} [params.files] - An array of files associated with the message.
* @param {boolean} [params.isEdited] - Indicates if the message was edited.
* @param {string} [params.finish_reason] - Reason for finishing the message.
* @param {number} [params.tokenCount] - The number of tokens in the message.
* @param {string} [params.plugin] - Plugin associated with the message.
* @param {string[]} [params.plugins] - An array of plugins associated with the message.
* @param {string} [params.model] - The model used to generate the message.
* @param {Object} [metadata] - Additional metadata for this operation
* @param {string} [metadata.context] - The context of the operation
* @returns {Promise<TMessage>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
async function saveMessage(req, params, metadata) {
try {
if (!req || !req.user || !req.user.id) {
throw new Error('User not authenticated');
}
const {
text,
error,
model,
files,
plugin,
sender,
plugins,
iconURL,
endpoint,
isEdited,
messageId,
unfinished,
tokenCount,
newMessageId,
finish_reason,
conversationId,
parentMessageId,
isCreatedByUser,
} = params;
const validConvoId = idSchema.safeParse(conversationId);
if (!validConvoId.success) {
logger.warn(`Invalid conversation ID: ${conversationId}`);
if (metadata && metadata?.context) {
logger.info(`---\`saveMessage\` context: ${metadata.context}`);
}
logger.info(`---Invalid conversation ID Params:
${JSON.stringify(params, null, 2)}
`);
return;
}
const update = {
user: req.user.id,
iconURL,
endpoint,
messageId: newMessageId || messageId,
conversationId,
parentMessageId,
sender,
text,
isCreatedByUser,
isEdited,
finish_reason,
error,
unfinished,
tokenCount,
plugin,
plugins,
model,
};
if (files) {
update.files = files;
}
const message = await Message.findOneAndUpdate({ messageId, user: req.user.id }, update, {
upsert: true,
new: true,
});
return message.toObject();
} catch (err) {
logger.error('Error saving message:', err);
if (metadata && metadata?.context) {
logger.info(`---\`saveMessage\` context: ${metadata.context}`);
}
throw err;
}
}
/**
* Saves multiple messages in the database in bulk.
*
* @async
* @function bulkSaveMessages
* @param {Object[]} messages - An array of message objects to save.
* @returns {Promise<Object>} The result of the bulk write operation.
* @throws {Error} If there is an error in saving messages in bulk.
*/
async function bulkSaveMessages(messages) {
try {
const bulkOps = messages.map((message) => ({
updateOne: {
filter: { messageId: message.messageId },
update: message,
upsert: true,
},
}));
const result = await Message.bulkWrite(bulkOps);
return result;
} catch (err) {
logger.error('Error saving messages in bulk:', err);
throw err;
}
}
/**
* Records a message in the database.
*
* @async
* @function recordMessage
* @param {Object} params - The message data object.
* @param {string} params.user - The identifier of the user.
* @param {string} params.endpoint - The endpoint where the message originated.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.conversationId - The identifier of the conversation.
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
* @returns {Promise<Object>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
async function recordMessage({
user,
endpoint,
messageId,
conversationId,
parentMessageId,
...rest
}) {
try {
// No parsing of convoId as may use threadId
const message = {
user,
endpoint,
messageId,
conversationId,
parentMessageId,
...rest,
};
return await Message.findOneAndUpdate({ user, messageId }, message, {
upsert: true,
new: true,
});
} catch (err) {
logger.error('Error recording message:', err);
throw err;
}
}
/**
* Updates the text of a message.
*
* @async
* @function updateMessageText
* @param {Object} params - The update data object.
* @param {Object} req - The request object.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.text - The new text content of the message.
* @returns {Promise<void>}
* @throws {Error} If there is an error in updating the message text.
*/
async function updateMessageText(req, { messageId, text }) {
try {
await Message.updateOne({ messageId, user: req.user.id }, { text });
} catch (err) {
logger.error('Error updating message text:', err);
throw err;
}
}
/**
* Updates a message.
*
* @async
* @function updateMessage
* @param {Object} message - The message object containing update data.
* @param {Object} req - The request object.
* @param {string} message.messageId - The unique identifier for the message.
* @param {string} [message.text] - The new text content of the message.
* @param {Object[]} [message.files] - The files associated with the message.
* @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user.
* @param {string} [message.sender] - The identifier of the sender.
* @param {number} [message.tokenCount] - The number of tokens in the message.
* @param {Object} [metadata] - The operation metadata
* @param {string} [metadata.context] - The operation metadata
* @returns {Promise<TMessage>} The updated message document.
* @throws {Error} If there is an error in updating the message or if the message is not found.
*/
async function updateMessage(req, message, metadata) {
try {
const { messageId, ...update } = message;
update.isEdited = true;
const updatedMessage = await Message.findOneAndUpdate(
{ messageId, user: req.user.id },
update,
{
new: true,
},
);
if (!updatedMessage) {
throw new Error('Message not found or user not authorized.');
}
return {
messageId: updatedMessage.messageId,
conversationId: updatedMessage.conversationId,
parentMessageId: updatedMessage.parentMessageId,
sender: updatedMessage.sender,
text: updatedMessage.text,
isCreatedByUser: updatedMessage.isCreatedByUser,
tokenCount: updatedMessage.tokenCount,
isEdited: true,
};
} catch (err) {
logger.error('Error updating message:', err);
if (metadata && metadata?.context) {
logger.info(`---\`updateMessage\` context: ${metadata.context}`);
}
throw err;
}
}
/**
* Deletes messages in a conversation since a specific message.
*
* @async
* @function deleteMessagesSince
* @param {Object} params - The parameters object.
* @param {Object} req - The request object.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.conversationId - The identifier of the conversation.
* @returns {Promise<Number>} The number of deleted messages.
* @throws {Error} If there is an error in deleting messages.
*/
async function deleteMessagesSince(req, { messageId, conversationId }) {
try {
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
if (message) {
const query = Message.find({ conversationId, user: req.user.id });
return await query.deleteMany({
createdAt: { $gt: message.createdAt },
});
}
return undefined;
} catch (err) {
logger.error('Error deleting messages:', err);
throw err;
}
}
/**
* Retrieves messages from the database.
* @async
* @function getMessages
* @param {Record<string, unknown>} filter - The filter criteria.
* @param {string | undefined} [select] - The fields to select.
* @returns {Promise<TMessage[]>} The messages that match the filter criteria.
* @throws {Error} If there is an error in retrieving messages.
*/
async function getMessages(filter, select) {
try {
if (select) {
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
}
return await Message.find(filter).sort({ createdAt: 1 }).lean();
} catch (err) {
logger.error('Error getting messages:', err);
throw err;
}
}
/**
* Deletes messages from the database.
*
* @async
* @function deleteMessages
* @param {Object} filter - The filter criteria to find messages to delete.
* @returns {Promise<Object>} The metadata with count of deleted messages.
* @throws {Error} If there is an error in deleting messages.
*/
async function deleteMessages(filter) {
try {
return await Message.deleteMany(filter);
} catch (err) {
logger.error('Error deleting messages:', err);
throw err;
}
}
module.exports = {
Message,
saveMessage,
bulkSaveMessages,
recordMessage,
updateMessageText,
updateMessage,
deleteMessagesSince,
getMessages,
deleteMessages,
async saveMessage({
user,
endpoint,
iconURL,
messageId,
newMessageId,
conversationId,
parentMessageId,
sender,
text,
isCreatedByUser,
error,
unfinished,
files,
isEdited,
finish_reason,
tokenCount,
plugin,
plugins,
model,
}) {
try {
const validConvoId = idSchema.safeParse(conversationId);
if (!validConvoId.success) {
return;
}
const update = {
user,
iconURL,
endpoint,
messageId: newMessageId || messageId,
conversationId,
parentMessageId,
sender,
text,
isCreatedByUser,
isEdited,
finish_reason,
error,
unfinished,
tokenCount,
plugin,
plugins,
model,
};
if (files) {
update.files = files;
}
const message = await Message.findOneAndUpdate({ messageId }, update, {
upsert: true,
new: true,
});
return message.toObject();
} catch (err) {
logger.error('Error saving message:', err);
throw new Error('Failed to save message.');
}
},
async bulkSaveMessages(messages) {
try {
const bulkOps = messages.map((message) => ({
updateOne: {
filter: { messageId: message.messageId },
update: message,
upsert: true,
},
}));
const result = await Message.bulkWrite(bulkOps);
return result;
} catch (err) {
logger.error('Error saving messages in bulk:', err);
throw new Error('Failed to save messages in bulk.');
}
},
/**
* Records a message in the database.
*
* @async
* @function recordMessage
* @param {Object} params - The message data object.
* @param {string} params.user - The identifier of the user.
* @param {string} params.endpoint - The endpoint where the message originated.
* @param {string} params.messageId - The unique identifier for the message.
* @param {string} params.conversationId - The identifier of the conversation.
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
* @returns {Promise<Object>} The updated or newly inserted message document.
* @throws {Error} If there is an error in saving the message.
*/
async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) {
try {
// No parsing of convoId as may use threadId
const message = {
user,
endpoint,
messageId,
conversationId,
parentMessageId,
...rest,
};
return await Message.findOneAndUpdate({ user, messageId }, message, {
upsert: true,
new: true,
});
} catch (err) {
logger.error('Error saving message:', err);
throw new Error('Failed to save message.');
}
},
async updateMessageText({ messageId, text }) {
try {
await Message.updateOne({ messageId }, { text });
} catch (err) {
logger.error('Error updating message text:', err);
throw new Error('Failed to update message text.');
}
},
async updateMessage(message) {
try {
const { messageId, ...update } = message;
update.isEdited = true;
const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, {
new: true,
});
if (!updatedMessage) {
throw new Error('Message not found.');
}
return {
messageId: updatedMessage.messageId,
conversationId: updatedMessage.conversationId,
parentMessageId: updatedMessage.parentMessageId,
sender: updatedMessage.sender,
text: updatedMessage.text,
isCreatedByUser: updatedMessage.isCreatedByUser,
tokenCount: updatedMessage.tokenCount,
isEdited: true,
};
} catch (err) {
logger.error('Error updating message:', err);
throw new Error('Failed to update message.');
}
},
async deleteMessagesSince({ messageId, conversationId }) {
try {
const message = await Message.findOne({ messageId }).lean();
if (message) {
return await Message.find({ conversationId }).deleteMany({
createdAt: { $gt: message.createdAt },
});
}
} catch (err) {
logger.error('Error deleting messages:', err);
throw new Error('Failed to delete messages.');
}
},
/**
* Retrieves messages from the database.
* @param {Record<string, unknown>} filter
* @param {string | undefined} [select]
* @returns
*/
async getMessages(filter, select) {
try {
if (select) {
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
}
return await Message.find(filter).sort({ createdAt: 1 }).lean();
} catch (err) {
logger.error('Error getting messages:', err);
throw new Error('Failed to get messages.');
}
},
async deleteMessages(filter) {
try {
return await Message.deleteMany(filter);
} catch (err) {
logger.error('Error deleting messages:', err);
throw new Error('Failed to delete messages.');
}
},
};

View File

@@ -1,239 +0,0 @@
const mongoose = require('mongoose');
const { v4: uuidv4 } = require('uuid');
jest.mock('mongoose');
const mockFindQuery = {
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockReturnThis(),
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
};
const mockSchema = {
findOneAndUpdate: jest.fn(),
updateOne: jest.fn(),
findOne: jest.fn(() => ({
lean: jest.fn(),
})),
find: jest.fn(() => mockFindQuery),
deleteMany: jest.fn(),
};
mongoose.model.mockReturnValue(mockSchema);
jest.mock('~/models/schema/messageSchema', () => mockSchema);
jest.mock('~/config/winston', () => ({
error: jest.fn(),
}));
const {
saveMessage,
getMessages,
updateMessage,
deleteMessages,
updateMessageText,
deleteMessagesSince,
} = require('~/models/Message');
describe('Message Operations', () => {
let mockReq;
let mockMessage;
beforeEach(() => {
jest.clearAllMocks();
mockReq = {
user: { id: 'user123' },
};
mockMessage = {
messageId: 'msg123',
conversationId: uuidv4(),
text: 'Hello, world!',
user: 'user123',
};
mockSchema.findOneAndUpdate.mockResolvedValue({
toObject: () => mockMessage,
});
});
describe('saveMessage', () => {
it('should save a message for an authenticated user', async () => {
const result = await saveMessage(mockReq, mockMessage);
expect(result).toEqual(mockMessage);
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'msg123', user: 'user123' },
expect.objectContaining({ user: 'user123' }),
expect.any(Object),
);
});
it('should throw an error for unauthenticated user', async () => {
mockReq.user = null;
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
});
it('should throw an error for invalid conversation ID', async () => {
mockMessage.conversationId = 'invalid-id';
await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined();
});
});
describe('updateMessageText', () => {
it('should update message text for the authenticated user', async () => {
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(mockSchema.updateOne).toHaveBeenCalledWith(
{ messageId: 'msg123', user: 'user123' },
{ text: 'Updated text' },
);
});
});
describe('updateMessage', () => {
it('should update a message for the authenticated user', async () => {
mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
expect(result).toEqual(
expect.objectContaining({
messageId: 'msg123',
text: 'Hello, world!',
isEdited: true,
}),
);
});
it('should throw an error if message is not found', async () => {
mockSchema.findOneAndUpdate.mockResolvedValue(null);
await expect(
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
).rejects.toThrow('Message not found or user not authorized.');
});
});
describe('deleteMessagesSince', () => {
it('should delete messages only for the authenticated user', async () => {
mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
const result = await deleteMessagesSince(mockReq, {
messageId: 'msg123',
conversationId: 'convo123',
});
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
expect(mockSchema.find).not.toHaveBeenCalled();
expect(result).toBeUndefined();
});
it('should return undefined if no message is found', async () => {
mockSchema.findOne().lean.mockResolvedValueOnce(null);
const result = await deleteMessagesSince(mockReq, {
messageId: 'nonexistent',
conversationId: 'convo123',
});
expect(result).toBeUndefined();
});
});
describe('getMessages', () => {
it('should retrieve messages with the correct filter', async () => {
const filter = { conversationId: 'convo123' };
await getMessages(filter);
expect(mockSchema.find).toHaveBeenCalledWith(filter);
expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
expect(mockFindQuery.lean).toHaveBeenCalled();
});
});
describe('deleteMessages', () => {
it('should delete messages with the correct filter', async () => {
await deleteMessages({ user: 'user123' });
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
});
});
describe('Conversation Hijacking Prevention', () => {
it('should not allow editing a message in another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
mockSchema.findOneAndUpdate.mockResolvedValue(null);
await expect(
updateMessage(attackerReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
text: 'Hacked message',
}),
).rejects.toThrow('Message not found or user not authorized.');
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: victimMessageId, user: 'attacker123' },
expect.anything(),
expect.anything(),
);
});
it('should not allow deleting messages from another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = 'victim-convo-123';
const victimMessageId = 'victim-msg-123';
mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
const result = await deleteMessagesSince(attackerReq, {
messageId: victimMessageId,
conversationId: victimConversationId,
});
expect(result).toBeUndefined();
expect(mockSchema.findOne).toHaveBeenCalledWith({
messageId: victimMessageId,
user: 'attacker123',
});
});
it('should not allow inserting a new message into another user\'s conversation', async () => {
const attackerReq = { user: { id: 'attacker123' } };
const victimConversationId = uuidv4(); // Use a valid UUID
await expect(
saveMessage(attackerReq, {
conversationId: victimConversationId,
text: 'Inserted malicious message',
messageId: 'new-msg-123',
}),
).resolves.not.toThrow(); // It should not throw an error
// Check that the message was saved with the attacker's user ID
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
{ messageId: 'new-msg-123', user: 'attacker123' },
expect.objectContaining({
user: 'attacker123',
conversationId: victimConversationId,
}),
expect.anything(),
);
});
it('should allow retrieving messages from any conversation', async () => {
const victimConversationId = 'victim-convo-123';
await getMessages({ conversationId: victimConversationId });
expect(mockSchema.find).toHaveBeenCalledWith({
conversationId: victimConversationId,
});
mockSchema.find.mockReturnValueOnce({
select: jest.fn().mockReturnThis(),
sort: jest.fn().mockReturnThis(),
lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
});
const result = await getMessages({ conversationId: victimConversationId });
expect(result).toEqual([{ text: 'Test message' }]);
});
});
});

View File

@@ -1,6 +1,6 @@
const crypto = require('crypto');
const mongoose = require('mongoose');
const signPayload = require('~/server/services/signPayload');
const { hashToken } = require('~/server/utils/crypto');
const { logger } = require('~/config');
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
@@ -39,7 +39,8 @@ sessionSchema.methods.generateRefreshToken = async function () {
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
});
this.refreshTokenHash = await hashToken(refreshToken);
const hash = crypto.createHash('sha256');
this.refreshTokenHash = hash.update(refreshToken).digest('hex');
await this.save();

View File

@@ -1,252 +1,117 @@
const { nanoid } = require('nanoid');
const { Constants } = require('librechat-data-provider');
const SharedLink = require('./schema/shareSchema');
const crypto = require('crypto');
const { getMessages } = require('./Message');
const SharedLink = require('./schema/shareSchema');
const logger = require('~/config/winston');
/**
* Anonymizes a conversation ID
* @returns {string} The anonymized conversation ID
*/
function anonymizeConvoId() {
return `convo_${nanoid()}`;
}
module.exports = {
SharedLink,
getSharedMessages: async (shareId) => {
try {
const share = await SharedLink.findOne({ shareId })
.populate({
path: 'messages',
select: '-_id -__v -user',
})
.select('-_id -__v -user')
.lean();
/**
* Anonymizes an assistant ID
* @returns {string} The anonymized assistant ID
*/
function anonymizeAssistantId() {
return `a_${nanoid()}`;
}
if (!share || !share.conversationId || !share.isPublic) {
return null;
}
/**
* Anonymizes a message ID
* @param {string} id - The original message ID
* @returns {string} The anonymized message ID
*/
function anonymizeMessageId(id) {
return id === Constants.NO_PARENT ? id : `msg_${nanoid()}`;
}
/**
* Anonymizes a conversation object
* @param {object} conversation - The conversation object
* @returns {object} The anonymized conversation object
*/
function anonymizeConvo(conversation) {
const newConvo = { ...conversation };
if (newConvo.assistant_id) {
newConvo.assistant_id = anonymizeAssistantId();
}
return newConvo;
}
/**
* Anonymizes messages in a conversation
* @param {TMessage[]} messages - The original messages
* @param {string} newConvoId - The new conversation ID
* @returns {TMessage[]} The anonymized messages
*/
function anonymizeMessages(messages, newConvoId) {
const idMap = new Map();
return messages.map((message) => {
const newMessageId = anonymizeMessageId(message.messageId);
idMap.set(message.messageId, newMessageId);
const anonymizedMessage = Object.assign(message, {
messageId: newMessageId,
parentMessageId:
idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId),
conversationId: newConvoId,
});
if (anonymizedMessage.model && anonymizedMessage.model.startsWith('asst_')) {
anonymizedMessage.model = anonymizeAssistantId();
return share;
} catch (error) {
logger.error('[getShare] Error getting share link', error);
throw new Error('Error getting share link');
}
},
return anonymizedMessage;
});
}
/**
* Retrieves shared messages for a given share ID
* @param {string} shareId - The share ID
* @returns {Promise<object|null>} The shared conversation data or null if not found
*/
async function getSharedMessages(shareId) {
try {
const share = await SharedLink.findOne({ shareId })
.populate({
path: 'messages',
select: '-_id -__v -user',
})
.select('-_id -__v -user')
.lean();
if (!share || !share.conversationId || !share.isPublic) {
return null;
}
const newConvoId = anonymizeConvoId();
return Object.assign(share, {
conversationId: newConvoId,
messages: anonymizeMessages(share.messages, newConvoId),
});
} catch (error) {
logger.error('[getShare] Error getting share link', error);
throw new Error('Error getting share link');
}
}
/**
* Retrieves shared links for a user
* @param {string} user - The user ID
* @param {number} [pageNumber=1] - The page number
* @param {number} [pageSize=25] - The page size
* @param {boolean} [isPublic=true] - Whether to retrieve public links only
* @returns {Promise<object>} The shared links and pagination data
*/
async function getSharedLinks(user, pageNumber = 1, pageSize = 25, isPublic = true) {
const query = { user, isPublic };
try {
const [totalConvos, sharedLinks] = await Promise.all([
SharedLink.countDocuments(query),
SharedLink.find(query)
getSharedLinks: async (user, pageNumber = 1, pageSize = 25, isPublic = true) => {
const query = { user, isPublic };
try {
const totalConvos = (await SharedLink.countDocuments(query)) || 1;
const totalPages = Math.ceil(totalConvos / pageSize);
const shares = await SharedLink.find(query)
.sort({ updatedAt: -1 })
.skip((pageNumber - 1) * pageSize)
.limit(pageSize)
.select('-_id -__v -user')
.lean(),
]);
.lean();
const totalPages = Math.ceil((totalConvos || 1) / pageSize);
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
} catch (error) {
logger.error('[getShareByPage] Error getting shares', error);
throw new Error('Error getting shares');
}
},
return {
sharedLinks,
pages: totalPages,
pageNumber,
pageSize,
};
} catch (error) {
logger.error('[getShareByPage] Error getting shares', error);
throw new Error('Error getting shares');
}
}
createSharedLink: async (user, { conversationId, ...shareData }) => {
try {
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (share) {
return share;
}
/**
* Creates a new shared link
* @param {string} user - The user ID
* @param {object} shareData - The share data
* @param {string} shareData.conversationId - The conversation ID
* @returns {Promise<object>} The created shared link
*/
async function createSharedLink(user, { conversationId, ...shareData }) {
try {
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (share) {
const newConvoId = anonymizeConvoId();
const sharedConvo = anonymizeConvo(share);
return Object.assign(sharedConvo, {
conversationId: newConvoId,
messages: anonymizeMessages(share.messages, newConvoId),
const shareId = crypto.randomUUID();
const messages = await getMessages({ conversationId });
const update = { ...shareData, shareId, messages, user };
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: true,
});
} catch (error) {
logger.error('[createSharedLink] Error creating shared link', error);
throw new Error('Error creating shared link');
}
},
const shareId = nanoid();
const messages = await getMessages({ conversationId });
const update = { ...shareData, shareId, messages, user };
const newShare = await SharedLink.findOneAndUpdate({ conversationId, user }, update, {
new: true,
upsert: true,
}).lean();
updateSharedLink: async (user, { conversationId, ...shareData }) => {
try {
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (!share) {
return { message: 'Share not found' };
}
const newConvoId = anonymizeConvoId();
const sharedConvo = anonymizeConvo(newShare);
return Object.assign(sharedConvo, {
conversationId: newConvoId,
messages: anonymizeMessages(newShare.messages, newConvoId),
});
} catch (error) {
logger.error('[createSharedLink] Error creating shared link', error);
throw new Error('Error creating shared link');
}
}
/**
* Updates an existing shared link
* @param {string} user - The user ID
* @param {object} shareData - The share data to update
* @param {string} shareData.conversationId - The conversation ID
* @returns {Promise<object>} The updated shared link
*/
async function updateSharedLink(user, { conversationId, ...shareData }) {
try {
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (!share) {
return { message: 'Share not found' };
// update messages to the latest
const messages = await getMessages({ conversationId });
const update = { ...shareData, messages, user };
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: false,
});
} catch (error) {
logger.error('[updateSharedLink] Error updating shared link', error);
throw new Error('Error updating shared link');
}
},
const messages = await getMessages({ conversationId });
const update = { ...shareData, messages, user };
const updatedShare = await SharedLink.findOneAndUpdate({ conversationId, user }, update, {
new: true,
upsert: false,
}).lean();
const newConvoId = anonymizeConvoId();
const sharedConvo = anonymizeConvo(updatedShare);
return Object.assign(sharedConvo, {
conversationId: newConvoId,
messages: anonymizeMessages(updatedShare.messages, newConvoId),
});
} catch (error) {
logger.error('[updateSharedLink] Error updating shared link', error);
throw new Error('Error updating shared link');
}
}
/**
* Deletes a shared link
* @param {string} user - The user ID
* @param {object} params - The deletion parameters
* @param {string} params.shareId - The share ID to delete
* @returns {Promise<object>} The result of the deletion
*/
async function deleteSharedLink(user, { shareId }) {
try {
const result = await SharedLink.findOneAndDelete({ shareId, user });
return result ? { message: 'Share deleted successfully' } : { message: 'Share not found' };
} catch (error) {
logger.error('[deleteSharedLink] Error deleting shared link', error);
throw new Error('Error deleting shared link');
}
}
/**
* Deletes all shared links for a specific user
* @param {string} user - The user ID
* @returns {Promise<object>} The result of the deletion
*/
async function deleteAllSharedLinks(user) {
try {
const result = await SharedLink.deleteMany({ user });
return {
message: 'All shared links have been deleted successfully',
deletedCount: result.deletedCount,
};
} catch (error) {
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
throw new Error('Error deleting shared links');
}
}
module.exports = {
SharedLink,
getSharedLinks,
createSharedLink,
updateSharedLink,
deleteSharedLink,
getSharedMessages,
deleteAllSharedLinks,
deleteSharedLink: async (user, { shareId }) => {
try {
const share = await SharedLink.findOne({ shareId, user });
if (!share) {
return { message: 'Share not found' };
}
return await SharedLink.findOneAndDelete({ shareId, user });
} catch (error) {
logger.error('[deleteSharedLink] Error deleting shared link', error);
throw new Error('Error deleting shared link');
}
},
/**
* Deletes all shared links for a specific user.
* @param {string} user - The user ID.
* @returns {Promise<{ message: string, deletedCount?: number }>} A result object indicating success or error message.
*/
deleteAllSharedLinks: async (user) => {
try {
const result = await SharedLink.deleteMany({ user });
return {
message: 'All shared links have been deleted successfully',
deletedCount: result.deletedCount,
};
} catch (error) {
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
throw new Error('Error deleting shared links');
}
},
};

View File

@@ -1,31 +0,0 @@
const mongoose = require('mongoose');
const conversationTagSchema = mongoose.Schema(
{
tag: {
type: String,
index: true,
},
user: {
type: String,
index: true,
},
description: {
type: String,
index: true,
},
count: {
type: Number,
default: 0,
},
position: {
type: Number,
default: 0,
},
},
{ timestamps: true },
);
conversationTagSchema.index({ tag: 1, user: 1 }, { unique: true });
module.exports = mongoose.model('ConversationTag', conversationTagSchema);

View File

@@ -42,11 +42,6 @@ const convoSchema = mongoose.Schema(
invocationId: {
type: Number,
},
tags: {
type: [String],
default: [],
meiliIndex: true,
},
},
{ timestamps: true },
);
@@ -61,7 +56,6 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
}
convoSchema.index({ createdAt: 1, updatedAt: 1 });
convoSchema.index({ conversationId: 1, user: 1 }, { unique: true });
const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema);

View File

@@ -103,10 +103,6 @@ const conversationPreset = {
spec: {
type: String,
},
tags: {
type: [String],
default: [],
},
tools: { type: [{ type: String }], default: undefined },
maxContextTokens: {
type: Number,

View File

@@ -129,9 +129,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
}
messageSchema.index({ createdAt: 1 });
messageSchema.index({ messageId: 1, user: 1 }, { unique: true });
/** @type {mongoose.Model<TMessage>} */
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
module.exports = Message;

View File

@@ -1,74 +1,37 @@
const { matchModelName } = require('../utils');
const defaultRate = 6;
/** AWS Bedrock pricing */
const bedrockValues = {
'anthropic.claude-3-haiku-20240307-v1:0': { prompt: 0.25, completion: 1.25 },
'anthropic.claude-3-sonnet-20240229-v1:0': { prompt: 3.0, completion: 15.0 },
'anthropic.claude-3-opus-20240229-v1:0': { prompt: 15.0, completion: 75.0 },
'anthropic.claude-3-5-sonnet-20240620-v1:0': { prompt: 3.0, completion: 15.0 },
'anthropic.claude-v2:1': { prompt: 8.0, completion: 24.0 },
'anthropic.claude-instant-v1': { prompt: 0.8, completion: 2.4 },
'meta.llama2-13b-chat-v1': { prompt: 0.75, completion: 1.0 },
'meta.llama2-70b-chat-v1': { prompt: 1.95, completion: 2.56 },
'meta.llama3-8b-instruct-v1:0': { prompt: 0.3, completion: 0.6 },
'meta.llama3-70b-instruct-v1:0': { prompt: 2.65, completion: 3.5 },
'meta.llama3-1-8b-instruct-v1:0': { prompt: 0.3, completion: 0.6 },
'meta.llama3-1-70b-instruct-v1:0': { prompt: 2.65, completion: 3.5 },
'meta.llama3-1-405b-instruct-v1:0': { prompt: 5.32, completion: 16.0 },
'mistral.mistral-7b-instruct-v0:2': { prompt: 0.15, completion: 0.2 },
'mistral.mistral-small-2402-v1:0': { prompt: 0.15, completion: 0.2 },
'mistral.mixtral-8x7b-instruct-v0:1': { prompt: 0.45, completion: 0.7 },
'mistral.mistral-large-2402-v1:0': { prompt: 4.0, completion: 12.0 },
'mistral.mistral-large-2407-v1:0': { prompt: 3.0, completion: 9.0 },
'cohere.command-text-v14': { prompt: 1.5, completion: 2.0 },
'cohere.command-light-text-v14': { prompt: 0.3, completion: 0.6 },
'cohere.command-r-v1:0': { prompt: 0.5, completion: 1.5 },
'cohere.command-r-plus-v1:0': { prompt: 3.0, completion: 15.0 },
'ai21.j2-mid-v1': { prompt: 12.5, completion: 12.5 },
'ai21.j2-ultra-v1': { prompt: 18.8, completion: 18.8 },
'amazon.titan-text-lite-v1': { prompt: 0.15, completion: 0.2 },
'amazon.titan-text-express-v1': { prompt: 0.2, completion: 0.6 },
};
for (const [key, value] of Object.entries(bedrockValues)) {
bedrockValues[`bedrock/${key}`] = value;
}
/**
* Mapping of model token sizes to their respective multipliers for prompt and completion.
* The rates are 1 USD per 1M tokens.
* @type {Object.<string, {prompt: number, completion: number}>}
*/
const tokenValues = Object.assign(
{
'8k': { prompt: 30, completion: 60 },
'32k': { prompt: 60, completion: 120 },
'4k': { prompt: 1.5, completion: 2 },
'16k': { prompt: 3, completion: 4 },
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'gpt-4o-2024-08-06': { prompt: 2.5, completion: 10 },
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
'gpt-4o': { prompt: 5, completion: 15 },
'gpt-4-1106': { prompt: 10, completion: 30 },
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
'claude-3-opus': { prompt: 15, completion: 75 },
'claude-3-sonnet': { prompt: 3, completion: 15 },
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
'claude-2.1': { prompt: 8, completion: 24 },
'claude-2': { prompt: 8, completion: 24 },
'claude-': { prompt: 0.8, completion: 2.4 },
'command-r-plus': { prompt: 3, completion: 15 },
'command-r': { prompt: 0.5, completion: 1.5 },
/* cohere doesn't have rates for the older command models,
const tokenValues = {
'8k': { prompt: 30, completion: 60 },
'32k': { prompt: 60, completion: 120 },
'4k': { prompt: 1.5, completion: 2 },
'16k': { prompt: 3, completion: 4 },
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
'gpt-4o': { prompt: 5, completion: 15 },
'gpt-4-1106': { prompt: 10, completion: 30 },
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
'claude-3-opus': { prompt: 15, completion: 75 },
'claude-3-sonnet': { prompt: 3, completion: 15 },
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
'claude-2.1': { prompt: 8, completion: 24 },
'claude-2': { prompt: 8, completion: 24 },
'claude-': { prompt: 0.8, completion: 2.4 },
'command-r-plus': { prompt: 3, completion: 15 },
'command-r': { prompt: 0.5, completion: 1.5 },
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
'gemini-1.5': { prompt: 7, completion: 21 }, // May 2nd, 2024 pricing
gemini: { prompt: 0.5, completion: 1.5 }, // May 2nd, 2024 pricing
},
bedrockValues,
);
command: { prompt: 0.38, completion: 0.38 },
// 'gemini-1.5': { prompt: 7, completion: 21 }, // May 2nd, 2024 pricing
// 'gemini': { prompt: 0.5, completion: 1.5 }, // May 2nd, 2024 pricing
'gemini-1.5': { prompt: 0, completion: 0 }, // currently free
gemini: { prompt: 0, completion: 0 }, // currently free
};
/**
* Retrieves the key associated with a given model name.
@@ -91,10 +54,6 @@ const getValueKey = (model, endpoint) => {
return 'gpt-3.5-turbo-1106';
} else if (modelName.includes('gpt-3.5')) {
return '4k';
} else if (modelName.includes('gpt-4o-2024-08-06')) {
return 'gpt-4o-2024-08-06';
} else if (modelName.includes('gpt-4o-mini')) {
return 'gpt-4o-mini';
} else if (modelName.includes('gpt-4o')) {
return 'gpt-4o';
} else if (modelName.includes('gpt-4-vision')) {

View File

@@ -49,20 +49,6 @@ describe('getValueKey', () => {
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
});
it('should return "gpt-4o-mini" for model type of "gpt-4o-mini"', () => {
expect(getValueKey('gpt-4o-mini-2024-07-18')).toBe('gpt-4o-mini');
expect(getValueKey('openai/gpt-4o-mini')).toBe('gpt-4o-mini');
expect(getValueKey('gpt-4o-mini-0718')).toBe('gpt-4o-mini');
expect(getValueKey('gpt-4o-2024-08-06-0718')).not.toBe('gpt-4o');
});
it('should return "gpt-4o-2024-08-06" for model type of "gpt-4o-2024-08-06"', () => {
expect(getValueKey('gpt-4o-2024-08-06-2024-07-18')).toBe('gpt-4o-2024-08-06');
expect(getValueKey('openai/gpt-4o-2024-08-06')).toBe('gpt-4o-2024-08-06');
expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o-2024-08-06');
expect(getValueKey('gpt-4o-2024-08-06-0718')).not.toBe('gpt-4o');
});
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
@@ -123,19 +109,6 @@ describe('getMultiplier', () => {
);
});
it('should return the correct multiplier for gpt-4o-mini', () => {
const valueKey = getValueKey('gpt-4o-mini-2024-07-18');
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(
tokenValues['gpt-4o-mini'].prompt,
);
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
tokenValues['gpt-4o-mini'].completion,
);
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
tokenValues['gpt-4-1106'].completion,
);
});
it('should derive the valueKey from the model if not provided for new models', () => {
expect(
getMultiplier({ tokenType: 'prompt', model: 'gpt-3.5-turbo-1106-some-other-info' }),
@@ -160,68 +133,3 @@ describe('getMultiplier', () => {
);
});
});
describe('AWS Bedrock Model Tests', () => {
const awsModels = [
'anthropic.claude-3-haiku-20240307-v1:0',
'anthropic.claude-3-sonnet-20240229-v1:0',
'anthropic.claude-3-opus-20240229-v1:0',
'anthropic.claude-3-5-sonnet-20240620-v1:0',
'anthropic.claude-v2:1',
'anthropic.claude-instant-v1',
'meta.llama2-13b-chat-v1',
'meta.llama2-70b-chat-v1',
'meta.llama3-8b-instruct-v1:0',
'meta.llama3-70b-instruct-v1:0',
'meta.llama3-1-8b-instruct-v1:0',
'meta.llama3-1-70b-instruct-v1:0',
'meta.llama3-1-405b-instruct-v1:0',
'mistral.mistral-7b-instruct-v0:2',
'mistral.mistral-small-2402-v1:0',
'mistral.mixtral-8x7b-instruct-v0:1',
'mistral.mistral-large-2402-v1:0',
'mistral.mistral-large-2407-v1:0',
'cohere.command-text-v14',
'cohere.command-light-text-v14',
'cohere.command-r-v1:0',
'cohere.command-r-plus-v1:0',
'ai21.j2-mid-v1',
'ai21.j2-ultra-v1',
'amazon.titan-text-lite-v1',
'amazon.titan-text-express-v1',
];
it('should return the correct prompt multipliers for all models', () => {
const results = awsModels.map((model) => {
const multiplier = getMultiplier({ valueKey: model, tokenType: 'prompt' });
return multiplier === tokenValues[model].prompt;
});
expect(results.every(Boolean)).toBe(true);
});
it('should return the correct completion multipliers for all models', () => {
const results = awsModels.map((model) => {
const multiplier = getMultiplier({ valueKey: model, tokenType: 'completion' });
return multiplier === tokenValues[model].completion;
});
expect(results.every(Boolean)).toBe(true);
});
it('should return the correct prompt multipliers for all models with Bedrock prefix', () => {
const results = awsModels.map((model) => {
const modelName = `bedrock/${model}`;
const multiplier = getMultiplier({ valueKey: modelName, tokenType: 'prompt' });
return multiplier === tokenValues[model].prompt;
});
expect(results.every(Boolean)).toBe(true);
});
it('should return the correct completion multipliers for all models with Bedrock prefix', () => {
const results = awsModels.map((model) => {
const modelName = `bedrock/${model}`;
const multiplier = getMultiplier({ valueKey: modelName, tokenType: 'completion' });
return multiplier === tokenValues[model].completion;
});
expect(results.every(Boolean)).toBe(true);
});
});

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "0.7.4",
"version": "0.7.4-rc1",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -45,7 +45,6 @@
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
"cohere-ai": "^7.9.1",
"compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.5.0",
"cors": "^2.8.5",
@@ -73,7 +72,6 @@
"module-alias": "^2.2.3",
"mongoose": "^7.1.1",
"multer": "^1.4.5-lts.1",
"nanoid": "^3.3.7",
"nodejs-gpt": "^1.37.4",
"nodemailer": "^6.9.4",
"ollama": "^0.5.0",

View File

@@ -1,8 +1,7 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@@ -52,13 +51,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
try {
const { client } = await initializeClient({ req, res, endpointOption });
const messageCache = getLogStores(CacheKeys.MESSAGES);
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
messageCache.set(responseMessageId, {
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
@@ -68,10 +65,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
unfinished,
error: false,
user,
}, Time.FIVE_MINUTES);
*/
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
});
},
3000,
{ trailing: false },
@@ -150,17 +144,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
});
res.end();
await saveMessage(
req,
{ ...response, user },
{ context: 'api/server/controllers/AskController.js - response end' },
);
await saveMessage({ ...response, user });
}
if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, {
context: 'api/server/controllers/AskController.js - don\'t skip saving user message',
});
await saveMessage(userMessage);
}
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {

View File

@@ -1,3 +1,4 @@
const crypto = require('crypto');
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const {
@@ -6,7 +7,6 @@ const {
setAuthTokens,
requestPasswordReset,
} = require('~/server/services/AuthService');
const { hashToken } = require('~/server/utils/crypto');
const { Session, getUserById } = require('~/models');
const { logger } = require('~/config');
@@ -74,7 +74,8 @@ const refreshController = async (req, res) => {
}
// Hash the refresh token
const hashedToken = await hashToken(refreshToken);
const hash = crypto.createHash('sha256');
const hashedToken = hash.update(refreshToken).digest('hex');
// Find the session with the hashed refresh token
const session = await Session.findOne({ user: userId, refreshTokenHash: hashedToken });

View File

@@ -1,8 +1,7 @@
const throttle = require('lodash/throttle');
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
@@ -52,14 +51,12 @@ const EditController = async (req, res, next, initializeClient) => {
}
};
const messageCache = getLogStores(CacheKeys.MESSAGES);
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: throttle(
({ text: partialText }) => {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
{
saveMessage({
messageId: responseMessageId,
sender,
conversationId,
@@ -70,8 +67,7 @@ const EditController = async (req, res, next, initializeClient) => {
isEdited: true,
error: false,
user,
} */
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
});
},
3000,
{ trailing: false },
@@ -145,11 +141,7 @@ const EditController = async (req, res, next, initializeClient) => {
});
res.end();
await saveMessage(
req,
{ ...response, user },
{ context: 'api/server/controllers/EditController.js - response end' },
);
await saveMessage({ ...response, user });
}
} catch (error) {
const partialText = getPartialText();

View File

@@ -120,22 +120,21 @@ const chatV1 = async (req, res) => {
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
return sendResponse(res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(req, res, messageData, error.message);
return sendResponse(res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(req, res, messageData, defaultErrorMessage);
return sendResponse(res, messageData, defaultErrorMessage);
}
await sleep(2000);
@@ -222,10 +221,10 @@ const chatV1 = async (req, res) => {
};
} catch (error) {
logger.error('[/assistants/chat/] Error finalizing error process', error);
return sendResponse(req, res, messageData, 'The Assistant run failed');
return sendResponse(res, messageData, 'The Assistant run failed');
}
return sendResponse(req, res, finalEvent);
return sendResponse(res, finalEvent);
};
try {
@@ -383,9 +382,6 @@ const chatV1 = async (req, res) => {
return files;
};
/** @type {Promise<Run>|undefined} */
let userMessagePromise;
const initializeThread = async () => {
/** @type {[ undefined | MongoFile[]]}*/
const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]);
@@ -442,7 +438,7 @@ const chatV1 = async (req, res) => {
previousMessages.push(requestMessage);
/* asynchronous */
userMessagePromise = saveUserMessage(req, { ...requestMessage, model });
saveUserMessage({ ...requestMessage, model });
conversation = {
conversationId,
@@ -586,10 +582,7 @@ const chatV1 = async (req, res) => {
});
res.end();
if (userMessagePromise) {
await userMessagePromise;
}
await saveAssistantMessage(req, { ...responseMessage, model });
await saveAssistantMessage({ ...responseMessage, model });
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
addTitle(req, {

View File

@@ -1,12 +1,12 @@
const { v4 } = require('uuid');
const {
Time,
Constants,
RunStatus,
CacheKeys,
ContentTypes,
ToolCallTypes,
EModelEndpoint,
ViolationTypes,
retrievalMimeTypes,
AssistantStreamEvents,
} = require('librechat-data-provider');
@@ -14,12 +14,12 @@ const {
initThread,
recordUsage,
saveUserMessage,
checkMessageGaps,
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
@@ -44,7 +44,7 @@ const ten_minutes = 1000 * 60 * 10;
const chatV2 = async (req, res) => {
logger.debug('[/assistants/chat/] req.body', req.body);
/** @type {{files: MongoFile[]}} */
/** @type {{ files: MongoFile[]}} */
const {
text,
model,
@@ -90,20 +90,139 @@ const chatV2 = async (req, res) => {
/** @type {Run | undefined} - The completed run, undefined if incomplete */
let completedRun;
const getContext = () => ({
openai,
run_id,
endpoint,
cacheKey,
thread_id,
completedRun,
assistant_id,
conversationId,
parentMessageId,
responseMessageId,
});
const handleError = async (error) => {
const defaultErrorMessage =
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
const messageData = {
thread_id,
assistant_id,
conversationId,
parentMessageId,
sender: 'System',
user: req.user.id,
shouldSaveMessage: false,
messageId: responseMessageId,
endpoint,
};
const handleError = createErrorHandler({ req, res, getContext });
if (error.message === 'Run cancelled') {
return res.end();
} else if (error.message === 'Request closed' && completedRun) {
return;
} else if (error.message === 'Request closed') {
logger.debug('[/assistants/chat/] Request aborted on close');
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === EModelEndpoint.azureAssistants
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(res, messageData, error.message);
} else {
logger.error('[/assistants/chat/]', error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(res, messageData, defaultErrorMessage);
}
await sleep(2000);
try {
const status = await cache.get(cacheKey);
if (status === 'cancelled') {
logger.debug('[/assistants/chat/] Run already cancelled');
return res.end();
}
await cache.delete(cacheKey);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
} catch (error) {
logger.error('[/assistants/chat/] Error cancelling run', error);
}
await sleep(2000);
let run;
try {
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
await recordUsage({
...run.usage,
model: run.model,
user: req.user.id,
conversationId,
});
} catch (error) {
logger.error('[/assistants/chat/] Error fetching or processing run', error);
}
let finalEvent;
try {
const runMessages = await checkMessageGaps({
openai,
run_id,
endpoint,
thread_id,
conversationId,
latestMessageId: responseMessageId,
});
const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
final: true,
conversation: await getConvo(req.user.id, conversationId),
runMessages,
};
} catch (error) {
logger.error('[/assistants/chat/] Error finalizing error process', error);
return sendResponse(res, messageData, 'The Assistant run failed');
}
return sendResponse(res, finalEvent);
};
try {
res.on('close', async () => {
@@ -246,9 +365,6 @@ const chatV2 = async (req, res) => {
}
};
/** @type {Promise<Run>|undefined} */
let userMessagePromise;
const initializeThread = async () => {
await getRequestFileIds();
@@ -291,7 +407,7 @@ const chatV2 = async (req, res) => {
previousMessages.push(requestMessage);
/* asynchronous */
userMessagePromise = saveUserMessage(req, { ...requestMessage, model });
saveUserMessage({ ...requestMessage, model });
conversation = {
conversationId,
@@ -373,11 +489,6 @@ const chatV2 = async (req, res) => {
},
};
/** @type {undefined | TAssistantEndpoint} */
const config = req.app.locals[endpoint] ?? {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
const streamRunManager = new StreamRunManager({
req,
res,
@@ -387,7 +498,6 @@ const chatV2 = async (req, res) => {
attachedFileIds,
parentMessageId: userMessageId,
responseMessage: openai.responseMessage,
streamRate: allConfig?.streamRate ?? config.streamRate,
// streamOptions: {
// },
@@ -400,16 +510,6 @@ const chatV2 = async (req, res) => {
response = streamRunManager;
response.text = streamRunManager.intermediateText;
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
complete: true,
text: response.text,
},
Time.FIVE_MINUTES,
);
};
await processRun();
@@ -452,10 +552,7 @@ const chatV2 = async (req, res) => {
});
res.end();
if (userMessagePromise) {
await userMessagePromise;
}
await saveAssistantMessage(req, { ...responseMessage, model });
await saveAssistantMessage({ ...responseMessage, model });
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
addTitle(req, {

View File

@@ -1,193 +0,0 @@
// errorHandler.js
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
/**
* @typedef {Object} ErrorHandlerContext
* @property {OpenAIClient} openai - The OpenAI client
* @property {string} thread_id - The thread ID
* @property {string} run_id - The run ID
* @property {boolean} completedRun - Whether the run has completed
* @property {string} assistant_id - The assistant ID
* @property {string} conversationId - The conversation ID
* @property {string} parentMessageId - The parent message ID
* @property {string} responseMessageId - The response message ID
* @property {string} endpoint - The endpoint being used
* @property {string} cacheKey - The cache key for the current request
*/
/**
* @typedef {Object} ErrorHandlerDependencies
* @property {Express.Request} req - The Express request object
* @property {Express.Response} res - The Express response object
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
* @property {string} [originPath] - The origin path for the error handler
*/
/**
* Creates an error handler function with the given dependencies
* @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler
* @returns {(error: Error) => Promise<void>} The error handler function
*/
const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => {
const cache = getLogStores(CacheKeys.ABORT_KEYS);
/**
* Handles errors that occur during the chat process
* @param {Error} error - The error that occurred
* @returns {Promise<void>}
*/
return async (error) => {
const {
openai,
run_id,
endpoint,
cacheKey,
thread_id,
completedRun,
assistant_id,
conversationId,
parentMessageId,
responseMessageId,
} = getContext();
const defaultErrorMessage =
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
const messageData = {
thread_id,
assistant_id,
conversationId,
parentMessageId,
sender: 'System',
user: req.user.id,
shouldSaveMessage: false,
messageId: responseMessageId,
endpoint,
};
if (error.message === 'Run cancelled') {
return res.end();
} else if (error.message === 'Request closed' && completedRun) {
return;
} else if (error.message === 'Request closed') {
logger.debug(`[${originPath}] Request aborted on close`);
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);
} else if (error?.message?.includes('string too long')) {
return sendResponse(
req,
res,
messageData,
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
);
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
return sendResponse(req, res, messageData, error.message);
} else {
logger.error(`[${originPath}]`, error);
}
if (!openai || !thread_id || !run_id) {
return sendResponse(req, res, messageData, defaultErrorMessage);
}
await new Promise((resolve) => setTimeout(resolve, 2000));
try {
const status = await cache.get(cacheKey);
if (status === 'cancelled') {
logger.debug(`[${originPath}] Run already cancelled`);
return res.end();
}
await cache.delete(cacheKey);
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
} catch (error) {
logger.error(`[${originPath}] Error cancelling run`, error);
}
await new Promise((resolve) => setTimeout(resolve, 2000));
let run;
try {
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
await recordUsage({
...run.usage,
model: run.model,
user: req.user.id,
conversationId,
});
} catch (error) {
logger.error(`[${originPath}] Error fetching or processing run`, error);
}
let finalEvent;
try {
const runMessages = await checkMessageGaps({
openai,
run_id,
endpoint,
thread_id,
conversationId,
latestMessageId: responseMessageId,
});
const errorContentPart = {
text: {
value:
error?.message ?? 'There was an error processing your request. Please try again later.',
},
type: ContentTypes.ERROR,
};
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
runMessages[runMessages.length - 1].content = [errorContentPart];
} else {
const contentParts = runMessages[runMessages.length - 1].content;
for (let i = 0; i < contentParts.length; i++) {
const currentPart = contentParts[i];
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
if (
toolCall &&
toolCall?.function &&
!(toolCall?.function?.output || toolCall?.function?.output?.length)
) {
contentParts[i] = {
...currentPart,
[ContentTypes.TOOL_CALL]: {
...toolCall,
function: {
...toolCall.function,
output: 'error processing tool',
},
},
};
}
}
runMessages[runMessages.length - 1].content.push(errorContentPart);
}
finalEvent = {
final: true,
conversation: await getConvo(req.user.id, conversationId),
runMessages,
};
} catch (error) {
logger.error(`[${originPath}] Error finalizing error process`, error);
return sendResponse(req, res, messageData, 'The Assistant run failed');
}
return sendResponse(req, res, finalEvent);
};
};
module.exports = { createErrorHandler };

View File

@@ -4,7 +4,6 @@ require('module-alias')({ base: path.resolve(__dirname, '..') });
const cors = require('cors');
const axios = require('axios');
const express = require('express');
const compression = require('compression');
const passport = require('passport');
const mongoSanitize = require('express-mongo-sanitize');
const { jwtLogin, passportLogin } = require('~/strategies');
@@ -18,9 +17,8 @@ const configureSocialLogins = require('./socialLogins');
const AppService = require('./services/AppService');
const noIndex = require('./middleware/noIndex');
const routes = require('./routes');
const staticCache = require('./utils/staticCache');
const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION } = process.env ?? {};
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
const port = Number(PORT) || 3080;
const host = HOST || 'localhost';
@@ -45,16 +43,12 @@ const startServer = async () => {
app.use(express.json({ limit: '3mb' }));
app.use(mongoSanitize());
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
app.use(staticCache(app.locals.paths.dist));
app.use(staticCache(app.locals.paths.fonts));
app.use(staticCache(app.locals.paths.assets));
app.use(express.static(app.locals.paths.dist));
app.use(express.static(app.locals.paths.fonts));
app.use(express.static(app.locals.paths.assets));
app.set('trust proxy', 1); // trust first proxy
app.use(cors());
if (DISABLE_COMPRESSION !== 'true') {
app.use(compression());
}
if (!ALLOW_SOCIAL_LOGIN) {
console.warn(
'Social logins are disabled. Set Environment Variable "ALLOW_SOCIAL_LOGIN" to true to enable them.',
@@ -100,7 +94,6 @@ const startServer = async () => {
app.use('/api/share', routes.share);
app.use('/api/roles', routes.roles);
app.use('/api/tags', routes.tags);
app.use((req, res) => {
res.sendFile(path.join(app.locals.paths.dist, 'index.html'));
});

View File

@@ -30,10 +30,7 @@ async function abortMessage(req, res) {
return res.status(204).send({ message: 'Request not found' });
}
const finalEvent = await abortController.abortCompletion();
logger.debug(
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
JSON.stringify({ abortKey }),
);
logger.info('[abortMessage] Aborted request', { abortKey });
abortControllers.delete(abortKey);
if (res.headersSent && finalEvent) {
@@ -119,11 +116,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
{ promptTokens, completionTokens },
);
saveMessage(
req,
{ ...responseMessage, user },
{ context: 'api/server/middleware/abortMiddleware.js' },
);
saveMessage({ ...responseMessage, user });
let conversation;
if (userMessagePromise) {
@@ -197,7 +190,7 @@ const handleAbortError = async (res, req, error, data) => {
}
};
await sendError(req, res, options, callback);
await sendError(res, options, callback);
};
if (partialText && partialText.length > 5) {

View File

@@ -19,8 +19,7 @@ const validateAssistant = async (req, res, next) => {
}
const { supportedIds, excludedIds } = assistantsConfig;
const error = { message: 'validateAssistant: Assistant not supported' };
const error = { message: 'Assistant not supported' };
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',

View File

@@ -1,7 +1,5 @@
const { Time } = require('librechat-data-provider');
const clearPendingReq = require('~/cache/clearPendingReq');
const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils');
const clearPendingReq = require('../../cache/clearPendingReq');
const { logViolation, getLogStores } = require('../../cache');
const denyRequest = require('./denyRequest');
const {
@@ -9,6 +7,7 @@ const {
CONCURRENT_MESSAGE_MAX = 1,
CONCURRENT_VIOLATION_SCORE: score,
} = process.env ?? {};
const ttl = 1000 * 60 * 1;
/**
* Middleware to limit concurrent requests for a user.
@@ -39,7 +38,7 @@ const concurrentLimiter = async (req, res, next) => {
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
const type = 'concurrent';
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}`;
const key = `${USE_REDIS ? namespace : ''}:${userId}`;
const pendingRequests = +((await cache.get(key)) ?? 0);
if (pendingRequests >= limit) {
@@ -52,7 +51,7 @@ const concurrentLimiter = async (req, res, next) => {
await logViolation(req, res, type, errorMessage, score);
return await denyRequest(req, res, errorMessage);
} else {
await cache.set(key, pendingRequests + 1, Time.ONE_MINUTE);
await cache.set(key, pendingRequests + 1, ttl);
}
// Ensure the requests are removed from the store once the request is done

View File

@@ -41,14 +41,10 @@ const denyRequest = async (req, res, errorMessage) => {
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
if (shouldSaveMessage) {
await saveMessage(
req,
{ ...userMessage, user: req.user.id },
{ context: `api/server/middleware/denyRequest.js - ${responseText}` },
);
await saveMessage({ ...userMessage, user: req.user.id });
}
return await sendError(req, res, {
return await sendError(res, {
sender: getResponseSender(req.body),
messageId: crypto.randomUUID(),
conversationId,

View File

@@ -14,7 +14,6 @@ const requireJwtAuth = require('./requireJwtAuth');
const validateModel = require('./validateModel');
const moderateText = require('./moderateText');
const setHeaders = require('./setHeaders');
const validate = require('./validate');
const limiters = require('./limiters');
const uaParser = require('./uaParser');
const checkBan = require('./checkBan');
@@ -23,7 +22,6 @@ const roles = require('./roles');
module.exports = {
...abortMiddleware,
...validate,
...limiters,
...roles,
noIndex,

View File

@@ -1,112 +0,0 @@
const jwt = require('jsonwebtoken');
const validateImageRequest = require('~/server/middleware/validateImageRequest');
describe('validateImageRequest middleware', () => {
let req, res, next;
beforeEach(() => {
req = {
app: { locals: { secureImageLinks: true } },
headers: {},
originalUrl: '',
};
res = {
status: jest.fn().mockReturnThis(),
send: jest.fn(),
};
next = jest.fn();
process.env.JWT_REFRESH_SECRET = 'test-secret';
});
afterEach(() => {
jest.clearAllMocks();
});
test('should call next() if secureImageLinks is false', () => {
req.app.locals.secureImageLinks = false;
validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 401 if refresh token is not provided', () => {
validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(401);
expect(res.send).toHaveBeenCalledWith('Unauthorized');
});
test('should return 403 if refresh token is invalid', () => {
req.headers.cookie = 'refreshToken=invalid-token';
validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should return 403 if refresh token is expired', () => {
const expiredToken = jwt.sign(
{ id: '123', exp: Math.floor(Date.now() / 1000) - 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${expiredToken}`;
validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
test('should call next() for valid image path', () => {
const validToken = jwt.sign(
{ id: '123', exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/123/example.jpg';
validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should return 403 for invalid image path', () => {
const validToken = jwt.sign(
{ id: '123', exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/456/example.jpg';
validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
});
// File traversal tests
test('should prevent file traversal attempts', () => {
const validToken = jwt.sign(
{ id: '123', exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
const traversalAttempts = [
'/images/123/../../../etc/passwd',
'/images/123/..%2F..%2F..%2Fetc%2Fpasswd',
'/images/123/image.jpg/../../../etc/passwd',
'/images/123/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd',
];
traversalAttempts.forEach((attempt) => {
req.originalUrl = attempt;
validateImageRequest(req, res, next);
expect(res.status).toHaveBeenCalledWith(403);
expect(res.send).toHaveBeenCalledWith('Access Denied');
jest.clearAllMocks();
});
});
test('should handle URL encoded characters in valid paths', () => {
const validToken = jwt.sign(
{ id: '123', exp: Math.floor(Date.now() / 1000) + 3600 },
process.env.JWT_REFRESH_SECRET,
);
req.headers.cookie = `refreshToken=${validToken}`;
req.originalUrl = '/images/123/image%20with%20spaces.jpg';
validateImageRequest(req, res, next);
expect(next).toHaveBeenCalled();
});
});

View File

@@ -1,73 +0,0 @@
const { Constants, ViolationTypes, Time } = require('librechat-data-provider');
const { searchConversation } = require('~/models/Conversation');
const denyRequest = require('~/server/middleware/denyRequest');
const { logViolation, getLogStores } = require('~/cache');
const { isEnabled } = require('~/server/utils');
const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {};
/**
* Middleware to validate user's authorization for a conversation.
*
* This middleware checks if a user has the right to access a specific conversation.
* If the user doesn't have access, an error is returned. If the conversation doesn't exist,
* a not found error is returned. If the access is valid, the middleware allows the request to proceed.
* If the `cache` store is not available, the middleware will skip its logic.
*
* @function
* @param {Express.Request} req - Express request object containing user information.
* @param {Express.Response} res - Express response object.
* @param {function} next - Express next middleware function.
* @throws {Error} Throws an error if the user doesn't have access to the conversation.
*/
const validateConvoAccess = async (req, res, next) => {
const namespace = ViolationTypes.CONVO_ACCESS;
const cache = getLogStores(namespace);
const conversationId = req.body.conversationId;
if (!conversationId || conversationId === Constants.NEW_CONVO) {
return next();
}
const userId = req.user?.id ?? req.user?._id ?? '';
const type = ViolationTypes.CONVO_ACCESS;
const key = `${isEnabled(USE_REDIS) ? namespace : ''}:${userId}:${conversationId}`;
try {
if (cache) {
const cachedAccess = await cache.get(key);
if (cachedAccess === 'authorized') {
return next();
}
}
const conversation = await searchConversation(conversationId);
if (!conversation) {
return next();
}
if (conversation.user !== userId) {
const errorMessage = {
type,
error: 'User not authorized for this conversation',
};
if (cache) {
await logViolation(req, res, type, errorMessage, score);
}
return await denyRequest(req, res, errorMessage);
}
if (cache) {
await cache.set(key, 'authorized', Time.TEN_MINUTES);
}
next();
} catch (error) {
console.error('Error validating conversation access:', error);
res.status(500).json({ error: 'Internal server error' });
}
};
module.exports = validateConvoAccess;

View File

@@ -1,4 +0,0 @@
const validateConvoAccess = require('./convoAccess');
module.exports = {
validateConvoAccess,
};

View File

@@ -31,14 +31,10 @@ function validateImageRequest(req, res, next) {
return res.status(403).send('Access Denied');
}
const fullPath = decodeURIComponent(req.originalUrl);
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
if (pathPattern.test(fullPath)) {
if (req.path.includes(payload.id)) {
logger.debug('[validateImageRequest] Image request validated');
next();
} else {
logger.warn('[validateImageRequest] Invalid image path');
res.status(403).send('Access Denied');
}
}

View File

@@ -76,9 +76,7 @@ describe.skip('GET /', () => {
openidLoginEnabled: true,
openidLabel: 'Test OpenID',
openidImageUrl: 'http://test-server.com',
ldap: {
enabled: true,
},
ldapLoginEnabled: true,
serverDomain: 'http://test-server.com',
emailLoginEnabled: 'true',
registrationEnabled: 'true',

View File

@@ -1,55 +0,0 @@
const request = require('supertest');
const express = require('express');
const { getLdapConfig } = require('~/server/services/Config/ldap');
const { isEnabled } = require('~/server/utils');
jest.mock('~/server/services/Config/ldap');
jest.mock('~/server/utils');
const app = express();
// Mock the route handler
app.get('/api/config', (req, res) => {
const ldapConfig = getLdapConfig();
res.json({ ldap: ldapConfig });
});
describe('LDAP Config Tests', () => {
afterEach(() => {
jest.resetAllMocks();
});
it('should return LDAP config with username property when LDAP_LOGIN_USES_USERNAME is enabled', async () => {
getLdapConfig.mockReturnValue({ enabled: true, username: true });
isEnabled.mockReturnValue(true);
const response = await request(app).get('/api/config');
expect(response.statusCode).toBe(200);
expect(response.body.ldap).toEqual({
enabled: true,
username: true,
});
});
it('should return LDAP config without username property when LDAP_LOGIN_USES_USERNAME is not enabled', async () => {
getLdapConfig.mockReturnValue({ enabled: true });
isEnabled.mockReturnValue(false);
const response = await request(app).get('/api/config');
expect(response.statusCode).toBe(200);
expect(response.body.ldap).toEqual({
enabled: true,
});
});
it('should not return LDAP config when LDAP is not enabled', async () => {
getLdapConfig.mockReturnValue(undefined);
const response = await request(app).get('/api/config');
expect(response.statusCode).toBe(200);
expect(response.body.ldap).toBeUndefined();
});
});

View File

@@ -51,8 +51,8 @@ router.post('/', setHeaders, async (req, res) => {
});
if (!overrideParentMessageId) {
await saveMessage(req, { ...userMessage, user: req.user.id });
await saveConvo(req, {
await saveMessage({ ...userMessage, user: req.user.id });
await saveConvo(req.user.id, {
...userMessage,
...endpointOption,
conversationId,
@@ -93,7 +93,7 @@ const ask = async ({
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage(req, {
saveMessage({
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
@@ -159,7 +159,7 @@ const ask = async ({
isCreatedByUser: false,
};
await saveMessage(req, { ...responseMessage, user });
await saveMessage({ ...responseMessage, user });
responseMessage.messageId = newResponseMessageId;
// STEP2 update the conversation
@@ -183,7 +183,7 @@ const ask = async ({
}
}
await saveConvo(req, conversationUpdate);
await saveConvo(user, conversationUpdate);
conversationId = newConversationId;
// STEP3 update the user message
@@ -192,7 +192,7 @@ const ask = async ({
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
if (!overrideParentMessageId) {
await saveMessage(req, {
await saveMessage({
...userMessage,
user,
messageId: userMessageId,
@@ -213,7 +213,7 @@ const ask = async ({
if (userParentMessageId == Constants.NO_PARENT) {
// const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
const title = await response.details.title;
await saveConvo(req, {
await saveConvo(user, {
conversationId: conversationId,
title,
});
@@ -229,7 +229,7 @@ const ask = async ({
isCreatedByUser: false,
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
};
await saveMessage(req, { ...errorMessage, user });
await saveMessage({ ...errorMessage, user });
handleError(res, errorMessage);
}
};

View File

@@ -70,8 +70,8 @@ router.post('/', setHeaders, async (req, res) => {
});
if (!overrideParentMessageId) {
await saveMessage(req, { ...userMessage, user: req.user.id });
await saveConvo(req, {
await saveMessage({ ...userMessage, user: req.user.id });
await saveConvo(req.user.id, {
...userMessage,
...endpointOption,
conversationId,
@@ -118,7 +118,7 @@ const ask = async ({
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage(req, {
saveMessage({
messageId: responseMessageId,
sender: model,
conversationId,
@@ -197,7 +197,7 @@ const ask = async ({
isCreatedByUser: false,
};
await saveMessage(req, { ...responseMessage, user });
await saveMessage({ ...responseMessage, user });
responseMessage.messageId = newResponseMessageId;
let conversationUpdate = {
@@ -216,12 +216,12 @@ const ask = async ({
conversationUpdate.invocationId = response.invocationId;
}
await saveConvo(req, conversationUpdate);
await saveConvo(user, conversationUpdate);
userMessage.messageId = newUserMessageId;
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
if (!overrideParentMessageId) {
await saveMessage(req, {
await saveMessage({
...userMessage,
user,
messageId: userMessageId,
@@ -245,7 +245,7 @@ const ask = async ({
response: responseMessage,
});
await saveConvo(req, {
await saveConvo(user, {
conversationId: conversationId,
title,
});
@@ -266,7 +266,7 @@ const ask = async ({
isCreatedByUser: false,
};
saveMessage(req, { ...responseMessage, user });
saveMessage({ ...responseMessage, user });
return {
title: await getConvoTitle(user, conversationId),
@@ -288,7 +288,7 @@ const ask = async ({
model,
isCreatedByUser: false,
};
await saveMessage(req, { ...errorMessage, user });
await saveMessage({ ...errorMessage, user });
handleError(res, errorMessage);
}
}

View File

@@ -1,11 +1,10 @@
const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage, updateMessage } = require('~/models');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const {
handleAbort,
createAbortController,
@@ -72,15 +71,7 @@ router.post(
}
};
const messageCache = getLogStores(CacheKeys.MESSAGES);
const throttledCacheSet = throttle(
(text) => {
messageCache.set(responseMessageId, text, Time.FIVE_MINUTES);
},
3000,
{ trailing: false },
);
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
let streaming = null;
let timer = null;
@@ -94,7 +85,18 @@ router.post(
clearTimeout(timer);
}
throttledCacheSet(partialText);
throttledSaveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
error: false,
plugins,
user,
});
streaming = new Promise((resolve) => {
timer = setTimeout(() => {
@@ -168,11 +170,7 @@ router.post(
const onChainEnd = () => {
if (!client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugins,
@@ -209,6 +207,9 @@ router.post(
logger.debug('[/ask/gptPlugins]', response);
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
@@ -229,15 +230,6 @@ router.post(
client,
});
}
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
if (response.plugins?.length > 0) {
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/ask/gptPlugins.js - save plugins used' },
);
}
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {

View File

@@ -12,10 +12,9 @@ const {
uaParser,
checkBan,
requireJwtAuth,
messageIpLimiter,
concurrentLimiter,
messageIpLimiter,
messageUserLimiter,
validateConvoAccess,
} = require('~/server/middleware');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
@@ -38,8 +37,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.chatGPTBrowser}`, askChatGPTBrowser);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);

View File

@@ -42,7 +42,7 @@ router.post('/:assistant_id', async (req, res) => {
return res.status(400).json({ message: 'No functions provided' });
}
let metadata = await encryptMetadata(_metadata);
let metadata = encryptMetadata(_metadata);
let { domain } = metadata;
domain = await domainParser(req, domain, true);

View File

@@ -8,7 +8,6 @@ const {
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const validateConvoAccess = require('~/server/middleware/validate/convoAccess');
const validateAssistant = require('~/server/middleware/assistants/validate');
const chatController = require('~/server/controllers/assistants/chatV1');
@@ -22,14 +21,6 @@ router.post('/abort', handleAbort());
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post(
'/',
validateModel,
buildEndpointOption,
validateAssistant,
validateConvoAccess,
setHeaders,
chatController,
);
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
module.exports = router;

View File

@@ -8,7 +8,6 @@ const {
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const validateConvoAccess = require('~/server/middleware/validate/convoAccess');
const validateAssistant = require('~/server/middleware/assistants/validate');
const chatController = require('~/server/controllers/assistants/chatV2');
@@ -22,14 +21,6 @@ router.post('/abort', handleAbort());
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post(
'/',
validateModel,
buildEndpointOption,
validateAssistant,
validateConvoAccess,
setHeaders,
chatController,
);
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
module.exports = router;

View File

@@ -1,6 +1,13 @@
const express = require('express');
const router = express.Router();
const { uaParser, checkBan, requireJwtAuth } = require('~/server/middleware');
const {
uaParser,
checkBan,
requireJwtAuth,
// concurrentLimiter,
// messageIpLimiter,
// messageUserLimiter,
} = require('~/server/middleware');
const v1 = require('./v1');
const chatV1 = require('./chatV1');

View File

@@ -1,6 +1,5 @@
const express = require('express');
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
const { getLdapConfig } = require('~/server/services/Config/ldap');
const { getProjectByName } = require('~/models/Project');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
@@ -34,8 +33,7 @@ router.get('/', async function (req, res) {
const instanceProject = await getProjectByName('instance', '_id');
const ldap = getLdapConfig();
const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
try {
/** @type {TStartupConfig} */
const payload = {
@@ -53,9 +51,10 @@ router.get('/', async function (req, res) {
!!process.env.OPENID_SESSION_SECRET,
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
openidImageUrl: process.env.OPENID_IMAGE_URL,
ldapLoginEnabled,
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
emailLoginEnabled,
registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION),
registrationEnabled: !ldapLoginEnabled && isEnabled(process.env.ALLOW_REGISTRATION),
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
emailEnabled:
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
@@ -77,10 +76,6 @@ router.get('/', async function (req, res) {
instanceProjectId: instanceProject._id.toString(),
};
if (ldap) {
payload.ldap = ldap;
}
if (typeof process.env.CUSTOM_FOOTER === 'string') {
payload.customFooter = process.env.CUSTOM_FOOTER;
}

View File

@@ -8,7 +8,6 @@ const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { forkConversation } = require('~/server/utils/import/fork');
const { importConversations } = require('~/server/utils/import');
const { createImportLimiters } = require('~/server/middleware');
const { updateTagsForConversation } = require('~/models/ConversationTag');
const getLogStores = require('~/cache/getLogStores');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
@@ -31,14 +30,8 @@ router.get('/', async (req, res) => {
return res.status(400).json({ error: 'Invalid page size' });
}
const isArchived = req.query.isArchived === 'true';
let tags;
if (req.query.tags) {
tags = Array.isArray(req.query.tags) ? req.query.tags : [req.query.tags];
} else {
tags = undefined;
}
res.status(200).send(await getConvosByPage(req.user.id, pageNumber, pageSize, isArchived, tags));
res.status(200).send(await getConvosByPage(req.user.id, pageNumber, pageSize, isArchived));
});
router.get('/:conversationId', async (req, res) => {
@@ -111,7 +104,7 @@ router.post('/update', async (req, res) => {
const update = req.body.arg;
try {
const dbResponse = await saveConvo(req, update, { context: 'POST /api/convos/update' });
const dbResponse = await saveConvo(req.user.id, update);
res.status(201).json(dbResponse);
} catch (error) {
logger.error('Error updating conversation', error);
@@ -174,18 +167,4 @@ router.post('/fork', async (req, res) => {
}
});
router.put('/tags/:conversationId', async (req, res) => {
try {
const conversationTags = await updateTagsForConversation(
req.user.id,
req.params.conversationId,
req.body.tags,
);
res.status(200).json(conversationTags);
} catch (error) {
logger.error('Error updating conversation tags', error);
res.status(500).send('Error updating conversation tags');
}
});
module.exports = router;

View File

@@ -1,20 +1,19 @@
const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
const { getResponseSender } = require('librechat-data-provider');
const {
setHeaders,
handleAbort,
moderateText,
validateModel,
createAbortController,
handleAbortError,
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
createAbortController,
moderateText,
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, updateMessage } = require('~/models');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
@@ -80,15 +79,7 @@ router.post(
}
};
const messageCache = getLogStores(CacheKeys.MESSAGES);
const throttledCacheSet = throttle(
(text) => {
messageCache.set(responseMessageId, text, Time.FIVE_MINUTES);
},
3000,
{ trailing: false },
);
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
const {
onProgress: progressCallback,
sendIntermediateMessage,
@@ -99,7 +90,19 @@ router.post(
if (plugin.loading === true) {
plugin.loading = false;
}
throttledCacheSet(partialText);
throttledSaveMessage({
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
isEdited: true,
error: false,
user,
});
},
});
@@ -107,11 +110,7 @@ router.post(
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
@@ -142,11 +141,7 @@ router.post(
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' },
);
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugin,
@@ -184,6 +179,8 @@ router.post(
}
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
@@ -197,13 +194,6 @@ router.post(
responseMessage: response,
});
res.end();
response.plugin = { ...plugin, loading: false };
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/edit/gptPlugins.js' },
);
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {

View File

@@ -13,7 +13,6 @@ const {
messageIpLimiter,
concurrentLimiter,
messageUserLimiter,
validateConvoAccess,
} = require('~/server/middleware');
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
@@ -36,8 +35,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(messageUserLimiter);
}
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);

View File

@@ -21,7 +21,6 @@ const staticRoute = require('./static');
const share = require('./share');
const categories = require('./categories');
const roles = require('./roles');
const tags = require('./tags');
module.exports = {
search,
@@ -47,5 +46,4 @@ module.exports = {
share,
categories,
roles,
tags,
};

View File

@@ -1,79 +1,49 @@
const express = require('express');
const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { countTokens } = require('~/server/utils');
const { logger } = require('~/config');
const router = express.Router();
const {
getMessages,
updateMessage,
saveConvo,
saveMessage,
deleteMessages,
} = require('../../models');
const { countTokens } = require('../utils');
const { requireJwtAuth, validateMessageReq } = require('../middleware/');
router.use(requireJwtAuth);
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
router.get('/:conversationId', validateMessageReq, async (req, res) => {
try {
const { conversationId } = req.params;
const messages = await getMessages({ conversationId }, '-_id -__v -user');
res.status(200).json(messages);
} catch (error) {
logger.error('Error fetching messages:', error);
res.status(500).json({ error: 'Internal server error' });
}
const { conversationId } = req.params;
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
});
// CREATE
router.post('/:conversationId', validateMessageReq, async (req, res) => {
try {
const message = req.body;
const savedMessage = await saveMessage(
req,
{ ...message, user: req.user.id },
{ context: 'POST /api/messages/:conversationId' },
);
if (!savedMessage) {
return res.status(400).json({ error: 'Message not saved' });
}
await saveConvo(req, savedMessage, { context: 'POST /api/messages/:conversationId' });
res.status(201).json(savedMessage);
} catch (error) {
logger.error('Error saving message:', error);
res.status(500).json({ error: 'Internal server error' });
}
const message = req.body;
const savedMessage = await saveMessage({ ...message, user: req.user.id });
await saveConvo(req.user.id, savedMessage);
res.status(201).send(savedMessage);
});
// READ
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
try {
const { conversationId, messageId } = req.params;
const message = await getMessages({ conversationId, messageId }, '-_id -__v -user');
if (!message) {
return res.status(404).json({ error: 'Message not found' });
}
res.status(200).json(message);
} catch (error) {
logger.error('Error fetching message:', error);
res.status(500).json({ error: 'Internal server error' });
}
const { conversationId, messageId } = req.params;
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
});
// UPDATE
router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
try {
const { messageId, model } = req.params;
const { text } = req.body;
const tokenCount = await countTokens(text, model);
const result = await updateMessage(req, { messageId, text, tokenCount });
res.status(200).json(result);
} catch (error) {
logger.error('Error updating message:', error);
res.status(500).json({ error: 'Internal server error' });
}
const { messageId, model } = req.params;
const { text } = req.body;
const tokenCount = await countTokens(text, model);
res.status(201).json(await updateMessage({ messageId, text, tokenCount }));
});
// DELETE
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
try {
const { messageId } = req.params;
await deleteMessages({ messageId });
res.status(204).send();
} catch (error) {
logger.error('Error deleting message:', error);
res.status(500).json({ error: 'Internal server error' });
}
const { messageId } = req.params;
await deleteMessages({ messageId });
res.status(204).send();
});
module.exports = router;

View File

@@ -1,8 +1,7 @@
const express = require('express');
const staticCache = require('../utils/staticCache');
const paths = require('~/config/paths');
const router = express.Router();
router.use(staticCache(paths.imageOutput));
router.use(express.static(paths.imageOutput));
module.exports = router;

View File

@@ -1,88 +0,0 @@
const express = require('express');
const {
getConversationTags,
updateConversationTag,
createConversationTag,
deleteConversationTag,
} = require('~/models/ConversationTag');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const router = express.Router();
router.use(requireJwtAuth);
/**
* GET /
* Retrieves all conversation tags for the authenticated user.
* @param {Object} req - Express request object
* @param {Object} res - Express response object
*/
router.get('/', async (req, res) => {
try {
const tags = await getConversationTags(req.user.id);
if (tags) {
res.status(200).json(tags);
} else {
res.status(404).end();
}
} catch (error) {
console.error('Error getting conversation tags:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* POST /
* Creates a new conversation tag for the authenticated user.
* @param {Object} req - Express request object
* @param {Object} res - Express response object
*/
router.post('/', async (req, res) => {
try {
const tag = await createConversationTag(req.user.id, req.body);
res.status(200).json(tag);
} catch (error) {
console.error('Error creating conversation tag:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* PUT /:tag
* Updates an existing conversation tag for the authenticated user.
* @param {Object} req - Express request object
* @param {Object} res - Express response object
*/
router.put('/:tag', async (req, res) => {
try {
const tag = await updateConversationTag(req.user.id, req.params.tag, req.body);
if (tag) {
res.status(200).json(tag);
} else {
res.status(404).json({ error: 'Tag not found' });
}
} catch (error) {
console.error('Error updating conversation tag:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
/**
* DELETE /:tag
* Deletes a conversation tag for the authenticated user.
* @param {Object} req - Express request object
* @param {Object} res - Express response object
*/
router.delete('/:tag', async (req, res) => {
try {
const tag = await deleteConversationTag(req.user.id, req.params.tag);
if (tag) {
res.status(200).json(tag);
} else {
res.status(404).json({ error: 'Tag not found' });
}
} catch (error) {
console.error('Error deleting conversation tag:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
module.exports = router;

View File

@@ -116,8 +116,8 @@ async function loadActionSets(searchParams) {
* @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call.
* @returns { { _call: (toolInput: Object) => unknown} } An object with `_call` method to execute the tool input.
*/
async function createActionTool({ action, requestBuilder }) {
action.metadata = await decryptMetadata(action.metadata);
function createActionTool({ action, requestBuilder }) {
action.metadata = decryptMetadata(action.metadata);
const _call = async (toolInput) => {
try {
requestBuilder.setParams(toolInput);
@@ -153,23 +153,23 @@ async function createActionTool({ action, requestBuilder }) {
* @param {ActionMetadata} metadata - The action metadata to encrypt.
* @returns {ActionMetadata} The updated action metadata with encrypted values.
*/
async function encryptMetadata(metadata) {
function encryptMetadata(metadata) {
const encryptedMetadata = { ...metadata };
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
encryptedMetadata.api_key = await encryptV2(metadata.api_key);
encryptedMetadata.api_key = encryptV2(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id);
encryptedMetadata.oauth_client_id = encryptV2(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret);
encryptedMetadata.oauth_client_secret = encryptV2(metadata.oauth_client_secret);
}
}
@@ -182,23 +182,23 @@ async function encryptMetadata(metadata) {
* @param {ActionMetadata} metadata - The action metadata to decrypt.
* @returns {ActionMetadata} The updated action metadata with decrypted values.
*/
async function decryptMetadata(metadata) {
function decryptMetadata(metadata) {
const decryptedMetadata = { ...metadata };
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
decryptedMetadata.api_key = await decryptV2(metadata.api_key);
decryptedMetadata.api_key = decryptV2(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id);
decryptedMetadata.oauth_client_id = decryptV2(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret);
decryptedMetadata.oauth_client_secret = decryptV2(metadata.oauth_client_secret);
}
}

View File

@@ -67,18 +67,17 @@ const AppService = async (app) => {
handleRateLimits(config?.rateLimits);
const endpointLocals = {};
const endpoints = config?.endpoints;
if (endpoints?.[EModelEndpoint.azureOpenAI]) {
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) {
endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config);
checkAzureVariables();
}
if (endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
}
if (endpoints?.[EModelEndpoint.azureAssistants]) {
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) {
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
config,
EModelEndpoint.azureAssistants,
@@ -86,7 +85,7 @@ const AppService = async (app) => {
);
}
if (endpoints?.[EModelEndpoint.assistants]) {
if (config?.endpoints?.[EModelEndpoint.assistants]) {
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
config,
EModelEndpoint.assistants,
@@ -94,19 +93,6 @@ const AppService = async (app) => {
);
}
if (endpoints?.[EModelEndpoint.openAI]) {
endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI];
}
if (endpoints?.[EModelEndpoint.google]) {
endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google];
}
if (endpoints?.[EModelEndpoint.anthropic]) {
endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic];
}
if (endpoints?.[EModelEndpoint.gptPlugins]) {
endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins];
}
app.locals = {
...defaultLocals,
modelSpecs: config.modelSpecs,

View File

@@ -1,5 +1,5 @@
const crypto = require('crypto');
const bcrypt = require('bcryptjs');
const { webcrypto } = require('node:crypto');
const { SystemRoles, errorsToString } = require('librechat-data-provider');
const {
findUser,
@@ -12,7 +12,6 @@ const {
} = require('~/models/userMethods');
const { sendEmail, checkEmailConfig } = require('~/server/utils');
const { registerSchema } = require('~/strategies/validators');
const { hashToken } = require('~/server/utils/crypto');
const isDomainAllowed = require('./isDomainAllowed');
const Token = require('~/models/schema/tokenSchema');
const Session = require('~/models/Session');
@@ -35,7 +34,7 @@ const genericVerificationMessage = 'Please check your email to verify your email
*/
const logoutUser = async (userId, refreshToken) => {
try {
const hash = await hashToken(refreshToken);
const hash = crypto.createHash('sha256').update(refreshToken).digest('hex');
// Find the session with the matching user and refreshTokenHash
const session = await Session.findOne({ user: userId, refreshTokenHash: hash });
@@ -54,23 +53,14 @@ const logoutUser = async (userId, refreshToken) => {
}
};
/**
* Creates Token and corresponding Hash for verification
* @returns {[string, string]}
*/
const createTokenHash = () => {
const token = Buffer.from(webcrypto.getRandomValues(new Uint8Array(32))).toString('hex');
const hash = bcrypt.hashSync(token, 10);
return [token, hash];
};
/**
* Send Verification Email
* @param {Partial<MongoUser> & { _id: ObjectId, email: string, name: string}} user
* @returns {Promise<void>}
*/
const sendVerificationEmail = async (user) => {
const [verifyToken, hash] = createTokenHash();
let verifyToken = crypto.randomBytes(32).toString('hex');
const hash = bcrypt.hashSync(verifyToken, 10);
const verificationLink = `${
domains.client
@@ -236,7 +226,8 @@ const requestPasswordReset = async (req) => {
await token.deleteOne();
}
const [resetToken, hash] = createTokenHash();
let resetToken = crypto.randomBytes(32).toString('hex');
const hash = bcrypt.hashSync(resetToken, 10);
await new Token({
userId: user._id,
@@ -374,7 +365,8 @@ const resendVerificationEmail = async (req) => {
return { status: 200, message: genericVerificationMessage };
}
const [verifyToken, hash] = createTokenHash();
let verifyToken = crypto.randomBytes(32).toString('hex');
const hash = bcrypt.hashSync(verifyToken, 10);
const verificationLink = `${
domains.client

View File

@@ -1,24 +0,0 @@
const { isEnabled } = require('~/server/utils');
/** @returns {TStartupConfig['ldap'] | undefined} */
const getLdapConfig = () => {
const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
const ldap = {
enabled: ldapLoginEnabled,
};
const ldapLoginUsesUsername = isEnabled(process.env.LDAP_LOGIN_USES_USERNAME);
if (!ldapLoginEnabled) {
return ldap;
}
if (ldapLoginUsesUsername) {
ldap.username = true;
}
return ldap;
};
module.exports = {
getLdapConfig,
};

View File

@@ -23,14 +23,10 @@ const addTitle = async (req, { text, response, client }) => {
const title = await client.titleConvo({ text, responseText: response?.text });
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/anthropic/addTitle.js' },
);
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};
module.exports = addTitle;

View File

@@ -19,27 +19,11 @@ const initializeClient = async ({ req, res, endpointOption }) => {
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
}
const clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
if (anthropicConfig) {
clientOptions.streamRate = anthropicConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new AnthropicClient(anthropicApiKey, {
req,
res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
...clientOptions,
...endpointOption,
});

View File

@@ -19,14 +19,10 @@ const addTitle = async (req, { text, responseText, conversationId, client }) =>
const title = await client.titleConvo({ text, conversationId, responseText });
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId,
title,
},
{ context: 'api/server/services/Endpoints/assistants/addTitle.js' },
);
await saveConvo(req.user.id, {
conversationId,
title,
});
};
module.exports = addTitle;

View File

@@ -114,16 +114,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
directEndpoint: endpointConfig.directEndpoint,
titleMessageRole: endpointConfig.titleMessageRole,
streamRate: endpointConfig.streamRate,
endpointTokenConfig,
};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
customOptions.streamRate = allConfig.streamRate;
}
const clientOptions = {
reverseProxyUrl: baseURL ?? null,
proxy: PROXY ?? null,

View File

@@ -49,14 +49,10 @@ const addTitle = async (req, { text, response, client }) => {
const title = await titleClient.titleConvo({ text, responseText: response?.text });
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/google/addTitle.js' },
);
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};
module.exports = addTitle;

View File

@@ -27,27 +27,11 @@ const initializeClient = async ({ req, res, endpointOption }) => {
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
};
const clientOptions = {};
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
/** @type {undefined | TBaseEndpoint} */
const googleConfig = req.app.locals[EModelEndpoint.google];
if (googleConfig) {
clientOptions.streamRate = googleConfig.streamRate;
}
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
const client = new GoogleClient(credentials, {
req,
res,
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
...clientOptions,
...endpointOption,
});

View File

@@ -8,8 +8,6 @@ jest.mock('~/server/services/UserService', () => ({
getUserKey: jest.fn().mockImplementation(() => ({})),
}));
const app = { locals: {} };
describe('google/initializeClient', () => {
afterEach(() => {
jest.clearAllMocks();
@@ -25,7 +23,6 @@ describe('google/initializeClient', () => {
const req = {
body: { key: expiresAt },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
@@ -47,7 +44,6 @@ describe('google/initializeClient', () => {
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
@@ -70,7 +66,6 @@ describe('google/initializeClient', () => {
const req = {
body: { key: expiresAt },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };

View File

@@ -86,9 +86,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.titleModel = azureConfig.titleModel;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
const groupName = modelGroupMap[modelName].group;
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
@@ -101,19 +98,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
/** @type {undefined | TBaseEndpoint} */
const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins];
if (!useAzure && pluginsConfig) {
clientOptions.streamRate = pluginsConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (!apiKey) {
throw new Error(`${endpoint} API key not provided. Please provide it again.`);
}

View File

@@ -23,14 +23,10 @@ const addTitle = async (req, { text, response, client }) => {
const title = await client.titleConvo({ text, responseText: response?.text });
await titleCache.set(key, title, 120000);
await saveConvo(
req,
{
conversationId: response.conversationId,
title,
},
{ context: 'api/server/services/Endpoints/openAI/addTitle.js' },
);
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};
module.exports = addTitle;

View File

@@ -76,10 +76,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
clientOptions.titleConvo = azureConfig.titleConvo;
clientOptions.titleModel = azureConfig.titleModel;
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const groupName = modelGroupMap[modelName].group;
@@ -94,19 +90,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
/** @type {undefined | TBaseEndpoint} */
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
if (!isAzureOpenAI && openAIConfig) {
clientOptions.streamRate = openAIConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (userProvidesKey & !apiKey) {
throw new Error(
JSON.stringify({

View File

@@ -1,248 +0,0 @@
const { Readable } = require('stream');
const axios = require('axios');
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
/**
* Service class for handling Speech-to-Text (STT) operations.
* @class
*/
class STTService {
/**
* Creates an instance of STTService.
* @param {Object} customConfig - The custom configuration object.
*/
constructor(customConfig) {
this.customConfig = customConfig;
this.providerStrategies = {
[STTProviders.OPENAI]: this.openAIProvider,
[STTProviders.AZURE_OPENAI]: this.azureOpenAIProvider,
};
}
/**
* Creates a singleton instance of STTService.
* @static
* @async
* @returns {Promise<STTService>} The STTService instance.
* @throws {Error} If the custom config is not found.
*/
static async getInstance() {
const customConfig = await getCustomConfig();
if (!customConfig) {
throw new Error('Custom config not found');
}
return new STTService(customConfig);
}
/**
* Retrieves the configured STT provider and its schema.
* @returns {Promise<[string, Object]>} A promise that resolves to an array containing the provider name and its schema.
* @throws {Error} If no STT schema is set, multiple providers are set, or no provider is set.
*/
async getProviderSchema() {
const sttSchema = this.customConfig.speech.stt;
if (!sttSchema) {
throw new Error(
'No STT schema is set. Did you configure STT in the custom config (librechat.yaml)?',
);
}
const providers = Object.entries(sttSchema).filter(
([, value]) => Object.keys(value).length > 0,
);
if (providers.length !== 1) {
throw new Error(
providers.length > 1
? 'Multiple providers are set. Please set only one provider.'
: 'No provider is set. Please set a provider.',
);
}
const [provider, schema] = providers[0];
return [provider, schema];
}
/**
* Recursively removes undefined properties from an object.
* @param {Object} obj - The object to clean.
* @returns {void}
*/
removeUndefined(obj) {
Object.keys(obj).forEach((key) => {
if (obj[key] && typeof obj[key] === 'object') {
this.removeUndefined(obj[key]);
if (Object.keys(obj[key]).length === 0) {
delete obj[key];
}
} else if (obj[key] === undefined) {
delete obj[key];
}
});
}
/**
* Prepares the request for the OpenAI STT provider.
* @param {Object} sttSchema - The STT schema for OpenAI.
* @param {Stream} audioReadStream - The audio data to be transcribed.
* @returns {Array} An array containing the URL, data, and headers for the request.
*/
openAIProvider(sttSchema, audioReadStream) {
const url = sttSchema?.url || 'https://api.openai.com/v1/audio/transcriptions';
const apiKey = extractEnvVariable(sttSchema.apiKey) || '';
const data = {
file: audioReadStream,
model: sttSchema.model,
};
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { Authorization: `Bearer ${apiKey}` }),
};
[headers].forEach(this.removeUndefined);
return [url, data, headers];
}
/**
* Prepares the request for the Azure OpenAI STT provider.
* @param {Object} sttSchema - The STT schema for Azure OpenAI.
* @param {Buffer} audioBuffer - The audio data to be transcribed.
* @param {Object} audioFile - The audio file object containing originalname, mimetype, and size.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the audio file size exceeds 25MB or the audio file format is not accepted.
*/
azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
const url = `${genAzureEndpoint({
azureOpenAIApiInstanceName: sttSchema?.instanceName,
azureOpenAIApiDeploymentName: sttSchema?.deploymentName,
})}/audio/transcriptions?api-version=${sttSchema?.apiVersion}`;
const apiKey = sttSchema.apiKey ? extractEnvVariable(sttSchema.apiKey) : '';
if (audioBuffer.byteLength > 25 * 1024 * 1024) {
throw new Error('The audio file size exceeds the limit of 25MB');
}
const acceptedFormats = ['flac', 'mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'ogg', 'wav', 'webm'];
const fileFormat = audioFile.mimetype.split('/')[1];
if (!acceptedFormats.includes(fileFormat)) {
throw new Error(`The audio file format ${fileFormat} is not accepted`);
}
const formData = new FormData();
const audioBlob = new Blob([audioBuffer], { type: audioFile.mimetype });
formData.append('file', audioBlob, audioFile.originalname);
const headers = {
'Content-Type': 'multipart/form-data',
...(apiKey && { 'api-key': apiKey }),
};
[headers].forEach(this.removeUndefined);
return [url, formData, headers];
}
/**
* Sends an STT request to the specified provider.
* @async
* @param {string} provider - The STT provider to use.
* @param {Object} sttSchema - The STT schema for the provider.
* @param {Object} requestData - The data required for the STT request.
* @param {Buffer} requestData.audioBuffer - The audio data to be transcribed.
* @param {Object} requestData.audioFile - The audio file object containing originalname, mimetype, and size.
* @returns {Promise<string>} A promise that resolves to the transcribed text.
* @throws {Error} If the provider is invalid, the response status is not 200, or the response data is missing.
*/
async sttRequest(provider, sttSchema, { audioBuffer, audioFile }) {
const strategy = this.providerStrategies[provider];
if (!strategy) {
throw new Error('Invalid provider');
}
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = 'audio.wav';
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);
if (!Readable.from && data instanceof FormData) {
const audioBlob = new Blob([audioBuffer], { type: audioFile.mimetype });
data.set('file', audioBlob, audioFile.originalname);
}
try {
const response = await axios.post(url, data, { headers });
if (response.status !== 200) {
throw new Error('Invalid response from the STT API');
}
if (!response.data || !response.data.text) {
throw new Error('Missing data in response from the STT API');
}
return response.data.text.trim();
} catch (error) {
logger.error(`STT request failed for provider ${provider}:`, error);
throw error;
}
}
/**
* Processes a speech-to-text request.
* @async
* @param {Object} req - The request object.
* @param {Object} res - The response object.
* @returns {Promise<void>}
*/
async processTextToSpeech(req, res) {
if (!req.file || !req.file.buffer) {
return res.status(400).json({ message: 'No audio file provided in the FormData' });
}
const audioBuffer = req.file.buffer;
const audioFile = {
originalname: req.file.originalname,
mimetype: req.file.mimetype,
size: req.file.size,
};
try {
const [provider, sttSchema] = await this.getProviderSchema();
const text = await this.sttRequest(provider, sttSchema, { audioBuffer, audioFile });
res.json({ text });
} catch (error) {
logger.error('An error occurred while processing the audio:', error);
res.sendStatus(500);
}
}
}
/**
* Factory function to create an STTService instance.
* @async
* @returns {Promise<STTService>} A promise that resolves to an STTService instance.
*/
async function createSTTService() {
return STTService.getInstance();
}
/**
* Wrapper function for speech-to-text processing.
* @async
* @param {Object} req - The request object.
* @param {Object} res - The response object.
* @returns {Promise<void>}
*/
async function speechToText(req, res) {
const sttService = await createSTTService();
await sttService.processTextToSpeech(req, res);
}
module.exports = { speechToText };

View File

@@ -1,477 +0,0 @@
const axios = require('axios');
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
const { logger } = require('~/config');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { genAzureEndpoint } = require('~/utils');
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
/**
* Service class for handling Text-to-Speech (TTS) operations.
* @class
*/
class TTSService {
/**
* Creates an instance of TTSService.
* @param {Object} customConfig - The custom configuration object.
*/
constructor(customConfig) {
this.customConfig = customConfig;
this.providerStrategies = {
[TTSProviders.OPENAI]: this.openAIProvider.bind(this),
[TTSProviders.AZURE_OPENAI]: this.azureOpenAIProvider.bind(this),
[TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this),
[TTSProviders.LOCALAI]: this.localAIProvider.bind(this),
};
}
/**
* Creates a singleton instance of TTSService.
* @static
* @async
* @returns {Promise<TTSService>} The TTSService instance.
* @throws {Error} If the custom config is not found.
*/
static async getInstance() {
const customConfig = await getCustomConfig();
if (!customConfig) {
throw new Error('Custom config not found');
}
return new TTSService(customConfig);
}
/**
* Retrieves the configured TTS provider.
* @returns {string} The name of the configured provider.
* @throws {Error} If no provider is set or multiple providers are set.
*/
getProvider() {
const ttsSchema = this.customConfig.speech.tts;
if (!ttsSchema) {
throw new Error(
'No TTS schema is set. Did you configure TTS in the custom config (librechat.yaml)?',
);
}
const providers = Object.entries(ttsSchema).filter(
([, value]) => Object.keys(value).length > 0,
);
if (providers.length !== 1) {
throw new Error(
providers.length > 1
? 'Multiple providers are set. Please set only one provider.'
: 'No provider is set. Please set a provider.',
);
}
return providers[0][0];
}
/**
* Selects a voice for TTS based on provider schema and request.
* @async
* @param {Object} providerSchema - The schema for the selected provider.
* @param {string} requestVoice - The requested voice.
* @returns {Promise<string>} The selected voice.
*/
async getVoice(providerSchema, requestVoice) {
const voices = providerSchema.voices.filter((voice) => voice && voice.toUpperCase() !== 'ALL');
let voice = requestVoice;
if (!voice || !voices.includes(voice) || (voice.toUpperCase() === 'ALL' && voices.length > 1)) {
voice = getRandomVoiceId(voices);
}
return voice;
}
/**
* Recursively removes undefined properties from an object.
* @param {Object} obj - The object to clean.
*/
removeUndefined(obj) {
Object.keys(obj).forEach((key) => {
if (obj[key] && typeof obj[key] === 'object') {
this.removeUndefined(obj[key]);
if (Object.keys(obj[key]).length === 0) {
delete obj[key];
}
} else if (obj[key] === undefined) {
delete obj[key];
}
});
}
/**
* Prepares the request for OpenAI TTS provider.
* @param {Object} ttsSchema - The TTS schema for OpenAI.
* @param {string} input - The input text.
* @param {string} voice - The selected voice.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the selected voice is not available.
*/
openAIProvider(ttsSchema, input, voice) {
const url = ttsSchema?.url || 'https://api.openai.com/v1/audio/speech';
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
input,
model: ttsSchema?.model,
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
backend: ttsSchema?.backend,
};
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${extractEnvVariable(ttsSchema?.apiKey)}`,
};
return [url, data, headers];
}
/**
* Prepares the request for Azure OpenAI TTS provider.
* @param {Object} ttsSchema - The TTS schema for Azure OpenAI.
* @param {string} input - The input text.
* @param {string} voice - The selected voice.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the selected voice is not available.
*/
azureOpenAIProvider(ttsSchema, input, voice) {
const url = `${genAzureEndpoint({
azureOpenAIApiInstanceName: ttsSchema?.instanceName,
azureOpenAIApiDeploymentName: ttsSchema?.deploymentName,
})}/audio/speech?api-version=${ttsSchema?.apiVersion}`;
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
model: ttsSchema?.model,
input,
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
};
const headers = {
'Content-Type': 'application/json',
'api-key': ttsSchema.apiKey ? extractEnvVariable(ttsSchema.apiKey) : '',
};
return [url, data, headers];
}
/**
* Prepares the request for ElevenLabs TTS provider.
* @param {Object} ttsSchema - The TTS schema for ElevenLabs.
* @param {string} input - The input text.
* @param {string} voice - The selected voice.
* @param {boolean} stream - Whether to use streaming.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the selected voice is not available.
*/
elevenLabsProvider(ttsSchema, input, voice, stream) {
let url =
ttsSchema?.url ||
`https://api.elevenlabs.io/v1/text-to-speech/${voice}${stream ? '/stream' : ''}`;
if (!ttsSchema?.voices.includes(voice) && !ttsSchema?.voices.includes('ALL')) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
model_id: ttsSchema?.model,
text: input,
voice_settings: {
similarity_boost: ttsSchema?.voice_settings?.similarity_boost,
stability: ttsSchema?.voice_settings?.stability,
style: ttsSchema?.voice_settings?.style,
use_speaker_boost: ttsSchema?.voice_settings?.use_speaker_boost,
},
pronunciation_dictionary_locators: ttsSchema?.pronunciation_dictionary_locators,
};
const headers = {
'Content-Type': 'application/json',
'xi-api-key': extractEnvVariable(ttsSchema?.apiKey),
Accept: 'audio/mpeg',
};
return [url, data, headers];
}
/**
* Prepares the request for LocalAI TTS provider.
* @param {Object} ttsSchema - The TTS schema for LocalAI.
* @param {string} input - The input text.
* @param {string} voice - The selected voice.
* @returns {Array} An array containing the URL, data, and headers for the request.
* @throws {Error} If the selected voice is not available.
*/
localAIProvider(ttsSchema, input, voice) {
const url = ttsSchema?.url;
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
input,
model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
backend: ttsSchema?.backend,
};
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${extractEnvVariable(ttsSchema?.apiKey)}`,
};
if (extractEnvVariable(ttsSchema.apiKey) === '') {
delete headers.Authorization;
}
return [url, data, headers];
}
/**
* Sends a TTS request to the specified provider.
* @async
* @param {string} provider - The TTS provider to use.
* @param {Object} ttsSchema - The TTS schema for the provider.
* @param {Object} options - The options for the TTS request.
* @param {string} options.input - The input text.
* @param {string} options.voice - The voice to use.
* @param {boolean} [options.stream=true] - Whether to use streaming.
* @returns {Promise<Object>} The axios response object.
* @throws {Error} If the provider is invalid or the request fails.
*/
async ttsRequest(provider, ttsSchema, { input, voice, stream = true }) {
const strategy = this.providerStrategies[provider];
if (!strategy) {
throw new Error('Invalid provider');
}
const [url, data, headers] = strategy.call(this, ttsSchema, input, voice, stream);
[data, headers].forEach(this.removeUndefined.bind(this));
const options = { headers, responseType: stream ? 'stream' : 'arraybuffer' };
try {
return await axios.post(url, data, options);
} catch (error) {
logger.error(`TTS request failed for provider ${provider}:`, error);
throw error;
}
}
/**
* Processes a text-to-speech request.
* @async
* @param {Object} req - The request object.
* @param {Object} res - The response object.
* @returns {Promise<void>}
*/
async processTextToSpeech(req, res) {
const { input, voice: requestVoice } = req.body;
if (!input) {
return res.status(400).send('Missing text in request body');
}
try {
res.setHeader('Content-Type', 'audio/mpeg');
const provider = this.getProvider();
const ttsSchema = this.customConfig.speech.tts[provider];
const voice = await this.getVoice(ttsSchema, requestVoice);
if (input.length < 4096) {
const response = await this.ttsRequest(provider, ttsSchema, { input, voice });
response.data.pipe(res);
return;
}
const textChunks = splitTextIntoChunks(input, 1000);
for (const chunk of textChunks) {
try {
const response = await this.ttsRequest(provider, ttsSchema, {
voice,
input: chunk.text,
stream: true,
});
logger.debug(`[textToSpeech] user: ${req?.user?.id} | writing audio stream`);
await new Promise((resolve) => {
response.data.pipe(res, { end: chunk.isFinished });
response.data.on('end', resolve);
});
if (chunk.isFinished) {
break;
}
} catch (innerError) {
logger.error('Error processing manual update:', chunk, innerError);
if (!res.headersSent) {
return res.status(500).end();
}
return;
}
}
if (!res.headersSent) {
res.end();
}
} catch (error) {
logger.error('Error creating the audio stream:', error);
if (!res.headersSent) {
return res.status(500).send('An error occurred');
}
}
}
/**
* Streams audio data from the TTS provider.
* @async
* @param {Object} req - The request object.
* @param {Object} res - The response object.
* @returns {Promise<void>}
*/
async streamAudio(req, res) {
res.setHeader('Content-Type', 'audio/mpeg');
const provider = this.getProvider();
const ttsSchema = this.customConfig.speech.tts[provider];
const voice = await this.getVoice(ttsSchema, req.body.voice);
let shouldContinue = true;
req.on('close', () => {
logger.warn('[streamAudio] Audio Stream Request closed by client');
shouldContinue = false;
});
const processChunks = createChunkProcessor(req.body.messageId);
try {
while (shouldContinue) {
const updates = await processChunks();
if (typeof updates === 'string') {
logger.error(`Error processing audio stream updates: ${updates}`);
return res.status(500).end();
}
if (updates.length === 0) {
await new Promise((resolve) => setTimeout(resolve, 1250));
continue;
}
for (const update of updates) {
try {
const response = await this.ttsRequest(provider, ttsSchema, {
voice,
input: update.text,
stream: true,
});
if (!shouldContinue) {
break;
}
logger.debug(`[streamAudio] user: ${req?.user?.id} | writing audio stream`);
await new Promise((resolve) => {
response.data.pipe(res, { end: update.isFinished });
response.data.on('end', resolve);
});
if (update.isFinished) {
shouldContinue = false;
break;
}
} catch (innerError) {
logger.error('Error processing audio stream update:', update, innerError);
if (!res.headersSent) {
return res.status(500).end();
}
return;
}
}
if (!shouldContinue) {
break;
}
}
if (!res.headersSent) {
res.end();
}
} catch (error) {
logger.error('Failed to fetch audio:', error);
if (!res.headersSent) {
res.status(500).end();
}
}
}
}
/**
* Factory function to create a TTSService instance.
* @async
* @returns {Promise<TTSService>} A promise that resolves to a TTSService instance.
*/
async function createTTSService() {
return TTSService.getInstance();
}
/**
* Wrapper function for text-to-speech processing.
* @async
* @param {Object} req - The request object.
* @param {Object} res - The response object.
* @returns {Promise<void>}
*/
async function textToSpeech(req, res) {
const ttsService = await createTTSService();
await ttsService.processTextToSpeech(req, res);
}
/**
* Wrapper function for audio streaming.
* @async
* @param {Object} req - The request object.
* @param {Object} res - The response object.
* @returns {Promise<void>}
*/
async function streamAudio(req, res) {
const ttsService = await createTTSService();
await ttsService.streamAudio(req, res);
}
/**
* Wrapper function to get the configured TTS provider.
* @async
* @returns {Promise<string>} A promise that resolves to the name of the configured provider.
*/
async function getProvider() {
const ttsService = await createTTSService();
return ttsService.getProvider();
}
module.exports = {
textToSpeech,
streamAudio,
getProvider,
};

View File

@@ -15,43 +15,37 @@ const getCustomConfig = require('~/server/services/Config/getCustomConfig');
async function getCustomConfigSpeech(req, res) {
try {
const customConfig = await getCustomConfig();
const sttExternal = !!customConfig.speech?.stt;
const ttsExternal = !!customConfig.speech?.tts;
let settings = {
sttExternal,
ttsExternal,
};
if (!customConfig || !customConfig.speech?.speechTab) {
return res.status(200).send(settings);
throw new Error('Configuration or speechTab schema is missing');
}
const speechTab = customConfig.speech.speechTab;
const ttsSchema = customConfig.speech?.speechTab;
let settings = {};
if (speechTab.advancedMode !== undefined) {
settings.advancedMode = speechTab.advancedMode;
if (ttsSchema.advancedMode !== undefined) {
settings.advancedMode = ttsSchema.advancedMode;
}
if (speechTab.speechToText) {
for (const key in speechTab.speechToText) {
if (speechTab.speechToText[key] !== undefined) {
settings[key] = speechTab.speechToText[key];
if (ttsSchema.speechToText) {
for (const key in ttsSchema.speechToText) {
if (ttsSchema.speechToText[key] !== undefined) {
settings[key] = ttsSchema.speechToText[key];
}
}
}
if (speechTab.textToSpeech) {
for (const key in speechTab.textToSpeech) {
if (speechTab.textToSpeech[key] !== undefined) {
settings[key] = speechTab.textToSpeech[key];
if (ttsSchema.textToSpeech) {
for (const key in ttsSchema.textToSpeech) {
if (ttsSchema.textToSpeech[key] !== undefined) {
settings[key] = ttsSchema.textToSpeech[key];
}
}
}
return res.status(200).send(settings);
} catch (error) {
console.error('Failed to get custom config speech settings:', error);
res.status(500).send('Internal Server Error');
res.status(200).send();
}
}

View File

@@ -1,6 +1,6 @@
const { TTSProviders } = require('librechat-data-provider');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { getProvider } = require('./TTSService');
const { getProvider } = require('./textToSpeech');
/**
* This function retrieves the available voices for the current TTS provider
@@ -21,7 +21,7 @@ async function getVoices(req, res) {
}
const ttsSchema = customConfig?.speech?.tts;
const provider = await getProvider(ttsSchema);
const provider = getProvider(ttsSchema);
let voices;
switch (provider) {

View File

@@ -1,11 +1,11 @@
const getVoices = require('./getVoices');
const getCustomConfigSpeech = require('./getCustomConfigSpeech');
const TTSService = require('./TTSService');
const STTService = require('./STTService');
const textToSpeech = require('./textToSpeech');
const speechToText = require('./speechToText');
module.exports = {
getVoices,
getCustomConfigSpeech,
...STTService,
...TTSService,
speechToText,
...textToSpeech,
};

View File

@@ -0,0 +1,225 @@
const { Readable } = require('stream');
const axios = require('axios');
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
/**
* Handle the response from the STT API
* @param {Object} response - The response from the STT API
*
* @returns {string} The text from the response data
*
* @throws Will throw an error if the response status is not 200 or the response data is missing
*/
async function handleResponse(response) {
if (response.status !== 200) {
throw new Error('Invalid response from the STT API');
}
if (!response.data || !response.data.text) {
throw new Error('Missing data in response from the STT API');
}
return response.data.text.trim();
}
/**
* getProviderSchema function
* This function takes the customConfig object and returns the name of the provider and its schema
* If more than one provider is set or no provider is set, it throws an error
*
* @param {Object} customConfig - The custom configuration containing the STT schema
* @returns {Promise<[string, Object]>} The name of the provider and its schema
* @throws {Error} Throws an error if multiple providers are set or no provider is set
*/
async function getProviderSchema(customConfig) {
const sttSchema = customConfig.speech.stt;
if (!sttSchema) {
throw new Error(`No STT schema is set. Did you configure STT in the custom config (librechat.yaml)?
https://www.librechat.ai/docs/configuration/stt_tts#stt`);
}
const providers = Object.entries(sttSchema).filter(([, value]) => Object.keys(value).length > 0);
if (providers.length > 1) {
throw new Error('Multiple providers are set. Please set only one provider.');
} else if (providers.length === 0) {
throw new Error('No provider is set. Please set a provider.');
} else {
const provider = providers[0][0];
return [provider, sttSchema[provider]];
}
}
function removeUndefined(obj) {
Object.keys(obj).forEach((key) => {
if (obj[key] && typeof obj[key] === 'object') {
removeUndefined(obj[key]);
if (Object.keys(obj[key]).length === 0) {
delete obj[key];
}
} else if (obj[key] === undefined) {
delete obj[key];
}
});
}
/**
* This function prepares the necessary data and headers for making a request to the OpenAI API
* It uses the provided speech-to-text schema and audio stream to create the request
*
* @param {Object} sttSchema - The speech-to-text schema containing the OpenAI configuration
* @param {Stream} audioReadStream - The audio data to be transcribed
*
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
* If an error occurs, it returns an array with three null values and logs the error with logger
*/
function openAIProvider(sttSchema, audioReadStream) {
try {
const url = sttSchema.openai?.url || 'https://api.openai.com/v1/audio/transcriptions';
const apiKey = sttSchema.openai.apiKey ? extractEnvVariable(sttSchema.openai.apiKey) : '';
let data = {
file: audioReadStream,
model: sttSchema.openai.model,
};
let headers = {
'Content-Type': 'multipart/form-data',
};
[headers].forEach(removeUndefined);
if (apiKey) {
headers.Authorization = 'Bearer ' + apiKey;
}
return [url, data, headers];
} catch (error) {
logger.error('An error occurred while preparing the OpenAI API STT request: ', error);
return [null, null, null];
}
}
/**
* Prepares the necessary data and headers for making a request to the Azure API.
* It uses the provided Speech-to-Text (STT) schema and audio file to create the request.
*
* @param {Object} sttSchema - The STT schema object, which should contain instanceName, deploymentName, apiVersion, and apiKey.
* @param {Buffer} audioBuffer - The audio data to be transcribed
* @param {Object} audioFile - The audio file object, which should contain originalname, mimetype, and size.
*
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request.
* If an error occurs, it logs the error with logger and returns an array with three null values.
*/
function azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
try {
const instanceName = sttSchema?.instanceName;
const deploymentName = sttSchema?.deploymentName;
const apiVersion = sttSchema?.apiVersion;
const url =
genAzureEndpoint({
azureOpenAIApiInstanceName: instanceName,
azureOpenAIApiDeploymentName: deploymentName,
}) +
'/audio/transcriptions?api-version=' +
apiVersion;
const apiKey = sttSchema.apiKey ? extractEnvVariable(sttSchema.apiKey) : '';
if (audioBuffer.byteLength > 25 * 1024 * 1024) {
throw new Error('The audio file size exceeds the limit of 25MB');
}
const acceptedFormats = ['flac', 'mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'ogg', 'wav', 'webm'];
const fileFormat = audioFile.mimetype.split('/')[1];
if (!acceptedFormats.includes(fileFormat)) {
throw new Error(`The audio file format ${fileFormat} is not accepted`);
}
const formData = new FormData();
const audioBlob = new Blob([audioBuffer], { type: audioFile.mimetype });
formData.append('file', audioBlob, audioFile.originalname);
let data = formData;
let headers = {
'Content-Type': 'multipart/form-data',
};
[headers].forEach(removeUndefined);
if (apiKey) {
headers['api-key'] = apiKey;
}
return [url, data, headers];
} catch (error) {
logger.error('An error occurred while preparing the Azure OpenAI API STT request: ', error);
throw error;
}
}
/**
* Convert speech to text
* @param {Object} req - The request object
* @param {Object} res - The response object
*
* @returns {Object} The response object with the text from the STT API
*
* @throws Will throw an error if an error occurs while processing the audio
*/
async function speechToText(req, res) {
const customConfig = await getCustomConfig();
if (!customConfig) {
return res.status(500).send('Custom config not found');
}
if (!req.file || !req.file.buffer) {
return res.status(400).json({ message: 'No audio file provided in the FormData' });
}
const audioBuffer = req.file.buffer;
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = 'audio.wav';
const [provider, sttSchema] = await getProviderSchema(customConfig);
let [url, data, headers] = [];
switch (provider) {
case STTProviders.OPENAI:
[url, data, headers] = openAIProvider(sttSchema, audioReadStream);
break;
case STTProviders.AZURE_OPENAI:
[url, data, headers] = azureOpenAIProvider(sttSchema, audioBuffer, req.file);
break;
default:
throw new Error('Invalid provider');
}
if (!Readable.from) {
const audioBlob = new Blob([audioBuffer], { type: req.file.mimetype });
delete data['file'];
data['file'] = audioBlob;
}
try {
const response = await axios.post(url, data, { headers: headers });
const text = await handleResponse(response);
res.json({ text });
} catch (error) {
logger.error('An error occurred while processing the audio:', error);
res.sendStatus(500);
}
}
module.exports = speechToText;

View File

@@ -1,6 +1,5 @@
const WebSocket = require('ws');
const { CacheKeys } = require('librechat-data-provider');
const { getLogStores } = require('~/cache');
const { Message } = require('~/models/Message');
/**
* @param {string[]} voiceIds - Array of voice IDs
@@ -105,8 +104,6 @@ function createChunkProcessor(messageId) {
throw new Error('Message ID is required');
}
const messageCache = getLogStores(CacheKeys.MESSAGES);
/**
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
*/
@@ -119,17 +116,14 @@ function createChunkProcessor(messageId) {
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
}
/** @type { string | { text: string; complete: boolean } } */
const message = await messageCache.get(messageId);
const message = await Message.findOne({ messageId }, 'text unfinished').lean();
if (!message) {
if (!message || !message.text) {
notFoundCount++;
return [];
}
const text = typeof message === 'string' ? message : message.text;
const complete = typeof message === 'string' ? false : message.complete;
const { text, unfinished } = message;
if (text === processedText) {
noChangeCount++;
}
@@ -137,7 +131,7 @@ function createChunkProcessor(messageId) {
const remainingText = text.slice(processedText.length);
const chunks = [];
if (!complete && remainingText.length >= 20) {
if (unfinished && remainingText.length >= 20) {
const separatorIndex = findLastSeparatorIndex(remainingText);
if (separatorIndex !== -1) {
const chunkText = remainingText.slice(0, separatorIndex + 1);
@@ -147,7 +141,7 @@ function createChunkProcessor(messageId) {
chunks.push({ text: remainingText, isFinished: false });
processedText = text;
}
} else if (complete && remainingText.trim().length > 0) {
} else if (!unfinished && remainingText.trim().length > 0) {
chunks.push({ text: remainingText.trim(), isFinished: true });
processedText = text;
}

View File

@@ -1,145 +1,89 @@
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
const { Message } = require('~/models/Message');
jest.mock('keyv');
const globalCache = {};
jest.mock('~/cache/getLogStores', () => {
return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const { CacheKeys } = require('librechat-data-provider');
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = async (key) => {
return new Promise((resolve) => {
resolve(globalCache[key] || null);
});
};
set = async (key, value) => {
return new Promise((resolve) => {
globalCache[key] = value;
resolve(true);
});
};
}
return new KeyvMongo('', {
namespace: CacheKeys.MESSAGES,
ttl: 0,
});
});
});
jest.mock('~/models/Message', () => ({
Message: {
findOne: jest.fn().mockReturnValue({
lean: jest.fn(),
}),
},
}));
describe('processChunks', () => {
let processChunks;
let mockMessageCache;
beforeEach(() => {
jest.resetAllMocks();
mockMessageCache = {
get: jest.fn(),
};
require('~/cache/getLogStores').mockReturnValue(mockMessageCache);
processChunks = createChunkProcessor('message-id');
Message.findOne.mockClear();
Message.findOne().lean.mockClear();
});
it('should return an empty array when the message is not found', async () => {
mockMessageCache.get.mockResolvedValueOnce(null);
Message.findOne().lean.mockResolvedValueOnce(null);
const result = await processChunks();
expect(result).toEqual([]);
expect(mockMessageCache.get).toHaveBeenCalledWith('message-id');
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return an error message after MAX_NOT_FOUND_COUNT attempts', async () => {
mockMessageCache.get.mockResolvedValue(null);
it('should return an empty array when the message does not have a text property', async () => {
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true });
for (let i = 0; i < 6; i++) {
await processChunks();
}
const result = await processChunks();
expect(result).toBe('Message not found after 6 attempts');
expect(result).toEqual([]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return chunks for an incomplete message with separators', async () => {
it('should return chunks for an unfinished message with separators', async () => {
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
const result = await processChunks();
expect(result).toEqual([
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return chunks for an incomplete message without separators', async () => {
it('should return chunks for an unfinished message without separators', async () => {
const messageText = 'This is a long message without separators hello there my friend';
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: false }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return the remaining text as a chunk for a complete message', async () => {
it('should return the remaining text as a chunk for a finished message', async () => {
const messageText = 'This is a finished message.';
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: true }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalled();
});
it('should return an empty array for a complete message with no remaining text', async () => {
it('should return an empty array for a finished message with no remaining text', async () => {
const messageText = 'This is a finished message.';
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
await processChunks();
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
const result = await processChunks();
expect(result).toEqual([]);
});
it('should return an error message after MAX_NO_CHANGE_COUNT attempts with no change', async () => {
const messageText = 'This is a message that does not change.';
mockMessageCache.get.mockResolvedValue({ text: messageText, complete: false });
for (let i = 0; i < 11; i++) {
await processChunks();
}
const result = await processChunks();
expect(result).toBe('No change in message after 10 attempts');
});
it('should handle string messages as incomplete', async () => {
const messageText = 'This is a message as a string.';
mockMessageCache.get.mockResolvedValueOnce(messageText);
const result = await processChunks();
expect(result).toEqual([{ text: messageText, isFinished: false }]);
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
});
});

View File

@@ -0,0 +1,473 @@
const axios = require('axios');
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
const { logger } = require('~/config');
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
const { genAzureEndpoint } = require('~/utils');
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
/**
* getProvider function
* This function takes the ttsSchema object and returns the name of the provider
* If more than one provider is set or no provider is set, it throws an error
*
* @param {Object} ttsSchema - The TTS schema containing the provider configuration
* @returns {string} The name of the provider
* @throws {Error} Throws an error if multiple providers are set or no provider is set
*/
function getProvider(ttsSchema) {
if (!ttsSchema) {
throw new Error(`No TTS schema is set. Did you configure TTS in the custom config (librechat.yaml)?
https://www.librechat.ai/docs/configuration/stt_tts#tts`);
}
const providers = Object.entries(ttsSchema).filter(([, value]) => Object.keys(value).length > 0);
if (providers.length > 1) {
throw new Error('Multiple providers are set. Please set only one provider.');
} else if (providers.length === 0) {
throw new Error('No provider is set. Please set a provider.');
} else {
return providers[0][0];
}
}
/**
* removeUndefined function
* This function takes an object and removes all keys with undefined values
* It also removes keys with empty objects as values
*
* @param {Object} obj - The object to be cleaned
* @returns {void} This function does not return a value. It modifies the input object directly
*/
function removeUndefined(obj) {
Object.keys(obj).forEach((key) => {
if (obj[key] && typeof obj[key] === 'object') {
removeUndefined(obj[key]);
if (Object.keys(obj[key]).length === 0) {
delete obj[key];
}
} else if (obj[key] === undefined) {
delete obj[key];
}
});
}
/**
* This function prepares the necessary data and headers for making a request to the OpenAI TTS
* It uses the provided TTS schema, input text, and voice to create the request
*
* @param {TCustomConfig['tts']['openai']} ttsSchema - The TTS schema containing the OpenAI configuration
* @param {string} input - The text to be converted to speech
* @param {string} voice - The voice to be used for the speech
*
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
* If an error occurs, it throws an error with a message indicating that the selected voice is not available
*/
function openAIProvider(ttsSchema, input, voice) {
const url = ttsSchema?.url || 'https://api.openai.com/v1/audio/speech';
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
let data = {
input,
model: ttsSchema?.model,
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
backend: ttsSchema?.backend,
};
let headers = {
'Content-Type': 'application/json',
Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey),
};
[data, headers].forEach(removeUndefined);
return [url, data, headers];
}
/**
* Generates the necessary parameters for making a request to Azure's OpenAI Text-to-Speech API.
*
* @param {TCustomConfig['tts']['azureOpenAI']} ttsSchema - The TTS schema containing the AzureOpenAI configuration
* @param {string} input - The text to be converted to speech
* @param {string} voice - The voice to be used for the speech
*
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
* If an error occurs, it throws an error with a message indicating that the selected voice is not available
*/
function azureOpenAIProvider(ttsSchema, input, voice) {
const instanceName = ttsSchema?.instanceName;
const deploymentName = ttsSchema?.deploymentName;
const apiVersion = ttsSchema?.apiVersion;
const url =
genAzureEndpoint({
azureOpenAIApiInstanceName: instanceName,
azureOpenAIApiDeploymentName: deploymentName,
}) +
'/audio/speech?api-version=' +
apiVersion;
const apiKey = ttsSchema.apiKey ? extractEnvVariable(ttsSchema.apiKey) : '';
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
let data = {
model: ttsSchema?.model,
input,
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
};
let headers = {
'Content-Type': 'application/json',
};
[data, headers].forEach(removeUndefined);
if (apiKey) {
headers['api-key'] = apiKey;
}
return [url, data, headers];
}
/**
* elevenLabsProvider function
* This function prepares the necessary data and headers for making a request to the Eleven Labs TTS
* It uses the provided TTS schema, input text, and voice to create the request
*
* @param {TCustomConfig['tts']['elevenLabs']} ttsSchema - The TTS schema containing the Eleven Labs configuration
* @param {string} input - The text to be converted to speech
* @param {string} voice - The voice to be used for the speech
* @param {boolean} stream - Whether to stream the audio or not
*
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
* @throws {Error} Throws an error if the selected voice is not available
*/
function elevenLabsProvider(ttsSchema, input, voice, stream) {
let url =
ttsSchema?.url ||
`https://api.elevenlabs.io/v1/text-to-speech/{voice_id}${stream ? '/stream' : ''}`;
if (!ttsSchema?.voices.includes(voice) && !ttsSchema?.voices.includes('ALL')) {
throw new Error(`Voice ${voice} is not available.`);
}
url = url.replace('{voice_id}', voice);
let data = {
model_id: ttsSchema?.model,
text: input,
// voice_id: voice,
voice_settings: {
similarity_boost: ttsSchema?.voice_settings?.similarity_boost,
stability: ttsSchema?.voice_settings?.stability,
style: ttsSchema?.voice_settings?.style,
use_speaker_boost: ttsSchema?.voice_settings?.use_speaker_boost || undefined,
},
pronunciation_dictionary_locators: ttsSchema?.pronunciation_dictionary_locators,
};
let headers = {
'Content-Type': 'application/json',
'xi-api-key': extractEnvVariable(ttsSchema?.apiKey),
Accept: 'audio/mpeg',
};
[data, headers].forEach(removeUndefined);
return [url, data, headers];
}
/**
* localAIProvider function
* This function prepares the necessary data and headers for making a request to the LocalAI TTS
* It uses the provided TTS schema, input text, and voice to create the request
*
* @param {TCustomConfig['tts']['localai']} ttsSchema - The TTS schema containing the LocalAI configuration
* @param {string} input - The text to be converted to speech
* @param {string} voice - The voice to be used for the speech
*
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
* @throws {Error} Throws an error if the selected voice is not available
*/
function localAIProvider(ttsSchema, input, voice) {
let url = ttsSchema?.url;
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
throw new Error(`Voice ${voice} is not available.`);
}
let data = {
input,
model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
backend: ttsSchema?.backend,
};
let headers = {
'Content-Type': 'application/json',
Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey),
};
[data, headers].forEach(removeUndefined);
if (extractEnvVariable(ttsSchema.apiKey) === '') {
delete headers.Authorization;
}
return [url, data, headers];
}
/**
*
* Returns provider and its schema for use with TTS requests
* @param {TCustomConfig} customConfig
* @param {string} _voice
* @returns {Promise<[string, TProviderSchema]>}
*/
async function getProviderSchema(customConfig) {
const provider = getProvider(customConfig.speech.tts);
return [provider, customConfig.speech.tts[provider]];
}
/**
*
* Returns a tuple of the TTS schema as well as the voice for the TTS request
* @param {TProviderSchema} providerSchema
* @param {string} requestVoice
* @returns {Promise<string>}
*/
async function getVoice(providerSchema, requestVoice) {
const voices = providerSchema.voices.filter((voice) => voice && voice.toUpperCase() !== 'ALL');
let voice = requestVoice;
if (!voice || !voices.includes(voice) || (voice.toUpperCase() === 'ALL' && voices.length > 1)) {
voice = getRandomVoiceId(voices);
}
return voice;
}
/**
*
* @param {string} provider
* @param {TProviderSchema} ttsSchema
* @param {object} params
* @param {string} params.voice
* @param {string} params.input
* @param {boolean} [params.stream]
* @returns {Promise<ArrayBuffer>}
*/
async function ttsRequest(provider, ttsSchema, { input, voice, stream = true } = { stream: true }) {
let [url, data, headers] = [];
switch (provider) {
case TTSProviders.OPENAI:
[url, data, headers] = openAIProvider(ttsSchema, input, voice);
break;
case TTSProviders.AZURE_OPENAI:
[url, data, headers] = azureOpenAIProvider(ttsSchema, input, voice);
break;
case TTSProviders.ELEVENLABS:
[url, data, headers] = elevenLabsProvider(ttsSchema, input, voice, stream);
break;
case TTSProviders.LOCALAI:
[url, data, headers] = localAIProvider(ttsSchema, input, voice);
break;
default:
throw new Error('Invalid provider');
}
if (stream) {
return await axios.post(url, data, { headers, responseType: 'stream' });
}
return await axios.post(url, data, { headers, responseType: 'arraybuffer' });
}
/**
* Handles a text-to-speech request. Extracts input and voice from the request, retrieves the TTS configuration,
* and sends a request to the appropriate provider. The resulting audio data is sent in the response
*
* @param {Object} req - The request object, which should contain the input text and voice in its body
* @param {Object} res - The response object, used to send the audio data or an error message
*
* @returns {Promise<void>} This function does not return a value. It sends the audio data or an error message in the response
*
* @throws {Error} Throws an error if the provider is invalid
*/
async function textToSpeech(req, res) {
const { input } = req.body;
if (!input) {
return res.status(400).send('Missing text in request body');
}
const customConfig = await getCustomConfig();
if (!customConfig) {
res.status(500).send('Custom config not found');
}
try {
res.setHeader('Content-Type', 'audio/mpeg');
const [provider, ttsSchema] = await getProviderSchema(customConfig);
const voice = await getVoice(ttsSchema, req.body.voice);
if (input.length < 4096) {
const response = await ttsRequest(provider, ttsSchema, { input, voice });
response.data.pipe(res);
return;
}
const textChunks = splitTextIntoChunks(input, 1000);
for (const chunk of textChunks) {
try {
const response = await ttsRequest(provider, ttsSchema, {
voice,
input: chunk.text,
stream: true,
});
logger.debug(`[textToSpeech] user: ${req?.user?.id} | writing audio stream`);
await new Promise((resolve) => {
response.data.pipe(res, { end: chunk.isFinished });
response.data.on('end', () => {
resolve();
});
});
if (chunk.isFinished) {
break;
}
} catch (innerError) {
logger.error('Error processing manual update:', chunk, innerError);
if (!res.headersSent) {
res.status(500).end();
}
return;
}
}
if (!res.headersSent) {
res.end();
}
} catch (error) {
logger.error(
'Error creating the audio stream. Suggestion: check your provider quota. Error:',
error,
);
res.status(500).send('An error occurred');
}
}
async function streamAudio(req, res) {
res.setHeader('Content-Type', 'audio/mpeg');
const customConfig = await getCustomConfig();
if (!customConfig) {
return res.status(500).send('Custom config not found');
}
const [provider, ttsSchema] = await getProviderSchema(customConfig);
const voice = await getVoice(ttsSchema, req.body.voice);
try {
let shouldContinue = true;
req.on('close', () => {
logger.warn('[streamAudio] Audio Stream Request closed by client');
shouldContinue = false;
});
const processChunks = createChunkProcessor(req.body.messageId);
while (shouldContinue) {
// example updates
// const updates = [
// { text: 'This is a test.', isFinished: false },
// { text: 'This is only a test.', isFinished: false },
// { text: 'Your voice is like a combination of Fergie and Jesus!', isFinished: true },
// ];
const updates = await processChunks();
if (typeof updates === 'string') {
logger.error(`Error processing audio stream updates: ${JSON.stringify(updates)}`);
res.status(500).end();
return;
}
if (updates.length === 0) {
await new Promise((resolve) => setTimeout(resolve, 1250));
continue;
}
for (const update of updates) {
try {
const response = await ttsRequest(provider, ttsSchema, {
voice,
input: update.text,
stream: true,
});
if (!shouldContinue) {
break;
}
logger.debug(`[streamAudio] user: ${req?.user?.id} | writing audio stream`);
await new Promise((resolve) => {
response.data.pipe(res, { end: update.isFinished });
response.data.on('end', () => {
resolve();
});
});
if (update.isFinished) {
shouldContinue = false;
break;
}
} catch (innerError) {
logger.error('Error processing update:', update, innerError);
if (!res.headersSent) {
res.status(500).end();
}
return;
}
}
if (!shouldContinue) {
break;
}
}
if (!res.headersSent) {
res.end();
}
} catch (error) {
logger.error('Failed to fetch audio:', error);
if (!res.headersSent) {
res.status(500).end();
}
}
}
module.exports = {
textToSpeech,
getProvider,
streamAudio,
};

View File

@@ -29,7 +29,7 @@ const getUserPluginAuthValue = async (userId, authField) => {
throw new Error(`No plugin auth ${authField} found for user ${userId}`);
}
const decryptedValue = await decrypt(pluginAuth.value);
const decryptedValue = decrypt(pluginAuth.value);
return decryptedValue;
} catch (err) {
logger.error('[getUserPluginAuthValue]', err);
@@ -64,7 +64,7 @@ const getUserPluginAuthValue = async (userId, authField) => {
const updateUserPluginAuth = async (userId, authField, pluginKey, value) => {
try {
const encryptedValue = await encrypt(value);
const encryptedValue = encrypt(value);
const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean();
if (pluginAuth) {
const pluginAuth = await PluginAuth.updateOne(

View File

@@ -1,19 +1,17 @@
const throttle = require('lodash/throttle');
const {
Time,
CacheKeys,
StepTypes,
ContentTypes,
ToolCallTypes,
// StepStatus,
MessageContentTypes,
AssistantStreamEvents,
Constants,
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService');
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
const { saveMessage, updateMessageText } = require('~/models/Message');
const { createOnProgress, sendMessage } = require('~/server/utils');
const { processMessages } = require('~/server/services/Threads');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
/**
@@ -70,8 +68,8 @@ class StreamRunManager {
this.attachedFileIds = fields.attachedFileIds;
/** @type {undefined | Promise<ChatCompletion>} */
this.visionPromise = fields.visionPromise;
/** @type {number} */
this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE;
/** @type {boolean} */
this.savedInitialMessage = false;
/**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
@@ -141,11 +139,11 @@ class StreamRunManager {
return this.intermediateText;
}
/** Returns the current, intermediate message
* @returns {TMessage}
/** Saves the initial intermediate message
* @returns {Promise<void>}
*/
getIntermediateMessage() {
return {
async saveInitialMessage() {
return saveMessage({
conversationId: this.finalMessage.conversationId,
messageId: this.finalMessage.messageId,
parentMessageId: this.parentMessageId,
@@ -157,7 +155,7 @@ class StreamRunManager {
sender: 'Assistant',
unfinished: true,
error: false,
};
});
}
/* <------------------ Main Event Handlers ------------------> */
@@ -349,8 +347,6 @@ class StreamRunManager {
type: ContentTypes.TOOL_CALL,
index,
});
await sleep(this.streamRate);
}
};
@@ -448,7 +444,6 @@ class StreamRunManager {
if (content && content.type === MessageContentTypes.TEXT) {
this.intermediateText += content.text.value;
onProgress(content.text.value);
await sleep(this.streamRate);
}
}
@@ -594,14 +589,21 @@ class StreamRunManager {
const index = this.getStepIndex(stepKey);
this.orderedRunSteps.set(index, message_creation);
const messageCache = getLogStores(CacheKeys.MESSAGES);
// Create the Factory Function to stream the message
const { onProgress: progressCallback } = createOnProgress({
onProgress: throttle(
() => {
messageCache.set(this.finalMessage.messageId, this.getText(), Time.FIVE_MINUTES);
if (!this.savedInitialMessage) {
this.saveInitialMessage();
this.savedInitialMessage = true;
} else {
updateMessageText({
messageId: this.finalMessage.messageId,
text: this.getText(),
});
}
},
3000,
2000,
{ trailing: false },
),
});

View File

@@ -11,6 +11,7 @@ const { recordMessage, getMessages } = require('~/models/Message');
const { saveConvo } = require('~/models/Conversation');
const spendTokens = require('~/models/spendTokens');
const { countTokens } = require('~/server/utils');
const { logger } = require('~/config');
/**
* Initializes a new thread or adds messages to an existing thread.
@@ -40,7 +41,6 @@ async function initThread({ openai, body, thread_id: _thread_id }) {
/**
* Saves a user message to the DB in the Assistants endpoint format.
*
* @param {Object} req - The request object.
* @param {Object} params - The parameters of the user message
* @param {string} params.user - The user's ID.
* @param {string} params.text - The user's prompt.
@@ -59,7 +59,7 @@ async function initThread({ openai, body, thread_id: _thread_id }) {
* @param {string[]} [params.file_ids] - Optional. List of File IDs attached to the userMessage.
* @return {Promise<Run>} A promise that resolves to the created run object.
*/
async function saveUserMessage(req, params) {
async function saveUserMessage(params) {
const tokenCount = await countTokens(params.text);
// todo: do this on the frontend
@@ -110,16 +110,14 @@ async function saveUserMessage(req, params) {
}
const message = await recordMessage(userMessage);
await saveConvo(req, convo, {
context: 'api/server/services/Threads/manage.js #saveUserMessage',
});
await saveConvo(params.user, convo);
return message;
}
/**
* Saves an Assistant message to the DB in the Assistants endpoint format.
*
* @param {Object} req - The request object.
* @param {Object} params - The parameters of the Assistant message
* @param {string} params.user - The user's ID.
* @param {string} params.messageId - The message Id.
@@ -136,7 +134,7 @@ async function saveUserMessage(req, params) {
* @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field.
* @return {Promise<Run>} A promise that resolves to the created run object.
*/
async function saveAssistantMessage(req, params) {
async function saveAssistantMessage(params) {
// const tokenCount = // TODO: need to count each content part
const message = await recordMessage({
@@ -156,18 +154,14 @@ async function saveAssistantMessage(req, params) {
// tokenCount,
});
await saveConvo(
req,
{
endpoint: params.endpoint,
conversationId: params.conversationId,
promptPrefix: params.promptPrefix,
instructions: params.instructions,
assistant_id: params.assistant_id,
model: params.model,
},
{ context: 'api/server/services/Threads/manage.js #saveAssistantMessage' },
);
await saveConvo(params.user, {
endpoint: params.endpoint,
conversationId: params.conversationId,
promptPrefix: params.promptPrefix,
instructions: params.instructions,
assistant_id: params.assistant_id,
model: params.model,
});
return message;
}
@@ -344,14 +338,10 @@ async function syncMessages({
await Promise.all(modifyPromises);
await Promise.all(recordPromises);
await saveConvo(
openai.req,
{
conversationId,
file_ids: attached_file_ids,
},
{ context: 'api/server/services/Threads/manage.js #syncMessages' },
);
await saveConvo(openai.req.user.id, {
conversationId,
file_ids: attached_file_ids,
});
return result;
}
@@ -515,34 +505,80 @@ const recordUsage = async ({
);
};
const uniqueCitationStart = '^====||===';
const uniqueCitationEnd = '==|||||^';
/** Helper function to escape special characters in regex
* @param {string} string - The string to escape.
* @returns {string} The escaped string.
/**
* Creates a replaceAnnotation function with internal state for tracking the index offset.
*
* @returns {function} The replaceAnnotation function with closure for index offset.
*/
function escapeRegExp(string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
function createReplaceAnnotation() {
let indexOffset = 0;
/**
* Safely replaces the annotated text within the specified range denoted by start_index and end_index,
* after verifying that the text within that range matches the given annotation text.
* Proceeds with the replacement even if a mismatch is found, but logs a warning.
*
* @param {object} params The original text content.
* @param {string} params.currentText The current text content, with/without replacements.
* @param {number} params.start_index The starting index where replacement should begin.
* @param {number} params.end_index The ending index where replacement should end.
* @param {string} params.expectedText The text expected to be found in the specified range.
* @param {string} params.replacementText The text to insert in place of the existing content.
* @returns {string} The text with the replacement applied, regardless of text match.
*/
function replaceAnnotation({
currentText,
start_index,
end_index,
expectedText,
replacementText,
}) {
const adjustedStartIndex = start_index + indexOffset;
const adjustedEndIndex = end_index + indexOffset;
if (
adjustedStartIndex < 0 ||
adjustedEndIndex > currentText.length ||
adjustedStartIndex > adjustedEndIndex
) {
logger.warn(`Invalid range specified for annotation replacement.
Attempting replacement with \`replace\` method instead...
length: ${currentText.length}
start_index: ${adjustedStartIndex}
end_index: ${adjustedEndIndex}`);
return currentText.replace(expectedText, replacementText);
}
if (currentText.substring(adjustedStartIndex, adjustedEndIndex) !== expectedText) {
return currentText.replace(expectedText, replacementText);
}
indexOffset += replacementText.length - (adjustedEndIndex - adjustedStartIndex);
return (
currentText.slice(0, adjustedStartIndex) +
replacementText +
currentText.slice(adjustedEndIndex)
);
}
return replaceAnnotation;
}
/**
* Sorts, processes, and flattens messages to a single string.
*
* @param {object} params - The parameters for processing messages.
* @param {object} params - The OpenAI client instance.
* @param {OpenAIClient} params.openai - The OpenAI client instance.
* @param {RunClient} params.client - The LibreChat client that manages the run: either refers to `OpenAI` or `StreamRunManager`.
* @param {ThreadMessage[]} params.messages - An array of messages.
* @returns {Promise<{messages: ThreadMessage[], text: string, edited: boolean}>} The sorted messages, the flattened text, and whether it was edited.
* @returns {Promise<{messages: ThreadMessage[], text: string}>} The sorted messages and the flattened text.
*/
async function processMessages({ openai, client, messages = [] }) {
const sorted = messages.sort((a, b) => a.created_at - b.created_at);
let text = '';
let edited = false;
const sources = new Map();
const fileRetrievalPromises = [];
const sources = [];
for (const message of sorted) {
message.files = [];
for (const content of message.content) {
@@ -551,21 +587,15 @@ async function processMessages({ openai, client, messages = [] }) {
const currentFileId = contentType?.file_id;
if (type === ContentTypes.IMAGE_FILE && !client.processedFileIds.has(currentFileId)) {
fileRetrievalPromises.push(
retrieveAndProcessFile({
openai,
client,
file_id: currentFileId,
basename: `${currentFileId}.png`,
})
.then((file) => {
client.processedFileIds.add(currentFileId);
message.files.push(file);
})
.catch((error) => {
console.error(`Failed to retrieve file: ${error.message}`);
}),
);
const file = await retrieveAndProcessFile({
openai,
client,
file_id: currentFileId,
basename: `${currentFileId}.png`,
});
client.processedFileIds.add(currentFileId);
message.files.push(file);
continue;
}
@@ -574,110 +604,78 @@ async function processMessages({ openai, client, messages = [] }) {
/** @type {{ annotations: Annotation[] }} */
const { annotations } = contentType ?? {};
// Process annotations if they exist
if (!annotations?.length) {
text += currentText;
text += currentText + ' ';
continue;
}
const replacements = [];
const annotationPromises = annotations.map(async (annotation) => {
const originalText = currentText;
text += originalText;
const replaceAnnotation = createReplaceAnnotation();
logger.debug('[processMessages] Processing annotations:', annotations);
for (const annotation of annotations) {
let file;
const type = annotation.type;
const annotationType = annotation[type];
const file_id = annotationType?.file_id;
const alreadyProcessed = client.processedFileIds.has(file_id);
let file;
let replacementText = '';
const replaceCurrentAnnotation = (replacementText = '') => {
const { start_index, end_index, text: expectedText } = annotation;
currentText = replaceAnnotation({
originalText,
currentText,
start_index,
end_index,
expectedText,
replacementText,
});
edited = true;
};
try {
if (alreadyProcessed) {
file = await retrieveAndProcessFile({ openai, client, file_id, unknownType: true });
} else if (type === AnnotationTypes.FILE_PATH) {
const basename = path.basename(annotation.text);
file = await retrieveAndProcessFile({
openai,
client,
file_id,
basename,
});
replacementText = file.filepath;
} else if (type === AnnotationTypes.FILE_CITATION && file_id) {
file = await retrieveAndProcessFile({
openai,
client,
file_id,
unknownType: true,
});
if (file && file.filename) {
if (!sources.has(file.filename)) {
sources.set(file.filename, sources.size + 1);
}
replacementText = `${uniqueCitationStart}${sources.get(
file.filename,
)}${uniqueCitationEnd}`;
}
}
if (file && replacementText) {
replacements.push({
start: annotation.start_index,
end: annotation.end_index,
text: replacementText,
});
edited = true;
if (!alreadyProcessed) {
client.processedFileIds.add(file_id);
message.files.push(file);
}
}
} catch (error) {
console.error(`Failed to process annotation: ${error.message}`);
if (alreadyProcessed) {
const { file_id } = annotationType || {};
file = await retrieveAndProcessFile({ openai, client, file_id, unknownType: true });
} else if (type === AnnotationTypes.FILE_PATH) {
const basename = path.basename(annotation.text);
file = await retrieveAndProcessFile({
openai,
client,
file_id,
basename,
});
replaceCurrentAnnotation(file.filepath);
} else if (type === AnnotationTypes.FILE_CITATION) {
file = await retrieveAndProcessFile({
openai,
client,
file_id,
unknownType: true,
});
sources.push(file.filename);
replaceCurrentAnnotation(`^${sources.length}^`);
}
});
await Promise.all(annotationPromises);
text = currentText;
// Apply replacements in reverse order
replacements.sort((a, b) => b.start - a.start);
for (const { start, end, text: replacementText } of replacements) {
currentText = currentText.slice(0, start) + replacementText + currentText.slice(end);
if (!file) {
continue;
}
client.processedFileIds.add(file_id);
message.files.push(file);
}
text += currentText;
}
}
await Promise.all(fileRetrievalPromises);
// Handle adjacent identical citations with the unique format
const adjacentCitationRegex = new RegExp(
`${escapeRegExp(uniqueCitationStart)}(\\d+)${escapeRegExp(
uniqueCitationEnd,
)}(\\s*)${escapeRegExp(uniqueCitationStart)}(\\d+)${escapeRegExp(uniqueCitationEnd)}`,
'g',
);
text = text.replace(adjacentCitationRegex, (match, num1, space, num2) => {
return num1 === num2
? `${uniqueCitationStart}${num1}${uniqueCitationEnd}`
: `${uniqueCitationStart}${num1}${uniqueCitationEnd}${space}${uniqueCitationStart}${num2}${uniqueCitationEnd}`;
});
// Remove any remaining adjacent identical citations
const remainingAdjacentRegex = new RegExp(
`(${escapeRegExp(uniqueCitationStart)}(\\d+)${escapeRegExp(uniqueCitationEnd)})\\s*\\1+`,
'g',
);
text = text.replace(remainingAdjacentRegex, '$1');
// Replace the unique citation format with the final format
text = text.replace(new RegExp(escapeRegExp(uniqueCitationStart), 'g'), '^');
text = text.replace(new RegExp(escapeRegExp(uniqueCitationEnd), 'g'), '^');
if (sources.size) {
if (sources.length) {
text += '\n\n';
Array.from(sources.entries()).forEach(([source, index], arrayIndex) => {
text += `^${index}.^ ${source}${arrayIndex === sources.size - 1 ? '' : '\n'}`;
});
for (let i = 0; i < sources.length; i++) {
text += `^${i + 1}.^ ${sources[i]}${i === sources.length - 1 ? '' : '\n'}`;
}
}
return { messages: sorted, text, edited };

View File

@@ -1,983 +0,0 @@
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processMessages } = require('./manage');
jest.mock('~/server/services/Files/process', () => ({
retrieveAndProcessFile: jest.fn(),
}));
describe('processMessages', () => {
let openai, client;
beforeEach(() => {
openai = {};
client = {
processedFileIds: new Set(),
};
jest.clearAllMocks();
retrieveAndProcessFile.mockReset();
});
test('handles normal case with single source', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This is a test ^1^ and another^1^',
annotations: [
{
type: 'file_citation',
start_index: 15,
end_index: 18,
file_citation: { file_id: 'file1' },
},
{
type: 'file_citation',
start_index: 30,
end_index: 33,
file_citation: { file_id: 'file1' },
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockResolvedValue({ filename: 'test.txt' });
const result = await processMessages({ openai, client, messages });
expect(result.text).toBe('This is a test ^1^ and another^1^\n\n^1.^ test.txt');
expect(result.edited).toBe(true);
});
test('handles multiple different sources', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This is a test ^1^ and another^2^',
annotations: [
{
type: 'file_citation',
start_index: 15,
end_index: 18,
file_citation: { file_id: 'file1' },
},
{
type: 'file_citation',
start_index: 30,
end_index: 33,
file_citation: { file_id: 'file2' },
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile
.mockResolvedValueOnce({ filename: 'test1.txt' })
.mockResolvedValueOnce({ filename: 'test2.txt' });
const result = await processMessages({ openai, client, messages });
expect(result.text).toBe('This is a test ^1^ and another^2^\n\n^1.^ test1.txt\n^2.^ test2.txt');
expect(result.edited).toBe(true);
});
test('handles file retrieval failure', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This is a test ^1^',
annotations: [
{
type: 'file_citation',
start_index: 15,
end_index: 18,
file_citation: { file_id: 'file1' },
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockRejectedValue(new Error('File not found'));
const result = await processMessages({ openai, client, messages });
expect(result.text).toBe('This is a test ^1^');
expect(result.edited).toBe(false);
});
test('handles citations without file ids', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This is a test ^1^',
annotations: [{ type: 'file_citation', start_index: 15, end_index: 18 }],
},
},
],
created_at: 1,
},
];
const result = await processMessages({ openai, client, messages });
expect(result.text).toBe('This is a test ^1^');
expect(result.edited).toBe(false);
});
test('handles mixed valid and invalid citations', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This is a test ^1^ and ^2^ and ^3^',
annotations: [
{
type: 'file_citation',
start_index: 15,
end_index: 18,
file_citation: { file_id: 'file1' },
},
{ type: 'file_citation', start_index: 23, end_index: 26 },
{
type: 'file_citation',
start_index: 31,
end_index: 34,
file_citation: { file_id: 'file3' },
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile
.mockResolvedValueOnce({ filename: 'test1.txt' })
.mockResolvedValueOnce({ filename: 'test3.txt' });
const result = await processMessages({ openai, client, messages });
expect(result.text).toBe(
'This is a test ^1^ and ^2^ and ^2^\n\n^1.^ test1.txt\n^2.^ test3.txt',
);
expect(result.edited).toBe(true);
});
test('handles adjacent identical citations', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This is a test ^1^^1^ and ^1^ ^1^',
annotations: [
{
type: 'file_citation',
start_index: 15,
end_index: 18,
file_citation: { file_id: 'file1' },
},
{
type: 'file_citation',
start_index: 18,
end_index: 21,
file_citation: { file_id: 'file1' },
},
{
type: 'file_citation',
start_index: 26,
end_index: 29,
file_citation: { file_id: 'file1' },
},
{
type: 'file_citation',
start_index: 30,
end_index: 33,
file_citation: { file_id: 'file1' },
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockResolvedValue({ filename: 'test.txt' });
const result = await processMessages({ openai, client, messages });
expect(result.text).toBe('This is a test ^1^ and ^1^\n\n^1.^ test.txt');
expect(result.edited).toBe(true);
});
test('handles real data with multiple adjacent citations', async () => {
const messages = [
{
id: 'msg_XXXXXXXXXXXXXXXXXXXX',
object: 'thread.message',
created_at: 1722980324,
assistant_id: 'asst_XXXXXXXXXXXXXXXXXXXX',
thread_id: 'thread_XXXXXXXXXXXXXXXXXXXX',
run_id: 'run_XXXXXXXXXXXXXXXXXXXX',
status: 'completed',
incomplete_details: null,
incomplete_at: null,
completed_at: 1722980331,
role: 'assistant',
content: [
{
type: 'text',
text: {
value:
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
annotations: [
{
type: 'file_citation',
text: '【11:2†source】',
start_index: 420,
end_index: 433,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:4†source】',
start_index: 433,
end_index: 446,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:9†source】',
start_index: 578,
end_index: 591,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:14†source】',
start_index: 591,
end_index: 605,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:12†source】',
start_index: 767,
end_index: 781,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:18†source】',
start_index: 781,
end_index: 795,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:16†source】',
start_index: 935,
end_index: 949,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:1†source】',
start_index: 1114,
end_index: 1127,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:10†source】',
start_index: 1127,
end_index: 1141,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:7†source】',
start_index: 1141,
end_index: 1154,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
],
},
},
],
attachments: [],
metadata: {},
files: [
{
object: 'file',
id: 'file-XXXXXXXXXXXXXXXXXXXX',
purpose: 'assistants',
filename: 'hp1.txt',
bytes: 439742,
created_at: 1722962139,
status: 'processed',
status_details: null,
type: 'text/plain',
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-XXXXXXXXXXXXXXXXXXXX/hp1.txt',
usage: 1,
user: 'XXXXXXXXXXXXXXXXXXXX',
context: 'assistants',
source: 'openai',
model: 'gpt-4o',
},
],
},
];
retrieveAndProcessFile.mockResolvedValue({ filename: 'hp1.txt' });
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText = `The text you have uploaded is from the book "Harry Potter and the Philosopher's Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:
1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry^1^.
2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's^1^.
3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background^1^.
4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy^1^.
5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort^1^.
These points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.
^1.^ hp1.txt`;
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
test('handles real data with multiple adjacent citations with multiple sources', async () => {
const messages = [
{
id: 'msg_XXXXXXXXXXXXXXXXXXXX',
object: 'thread.message',
created_at: 1722980324,
assistant_id: 'asst_XXXXXXXXXXXXXXXXXXXX',
thread_id: 'thread_XXXXXXXXXXXXXXXXXXXX',
run_id: 'run_XXXXXXXXXXXXXXXXXXXX',
status: 'completed',
incomplete_details: null,
incomplete_at: null,
completed_at: 1722980331,
role: 'assistant',
content: [
{
type: 'text',
text: {
value:
'The text you have uploaded is from the book "Harry Potter and the Philosopher\'s Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:\n\n1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry【11:2†source】【11:4†source】.\n\n2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander\'s【11:9†source】【11:14†source】.\n\n3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background【11:12†source】【11:18†source】.\n\n4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy【11:16†source】.\n\n5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher\'s Stone and its connection to the dark wizard Voldemort【11:1†source】【11:10†source】【11:7†source】.\n\nThese points highlight Harry\'s initial experiences in the magical world and set the stage for his adventures at Hogwarts.',
annotations: [
{
type: 'file_citation',
text: '【11:2†source】',
start_index: 420,
end_index: 433,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:4†source】',
start_index: 433,
end_index: 446,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:9†source】',
start_index: 578,
end_index: 591,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:14†source】',
start_index: 591,
end_index: 605,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:12†source】',
start_index: 767,
end_index: 781,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:18†source】',
start_index: 781,
end_index: 795,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:16†source】',
start_index: 935,
end_index: 949,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:1†source】',
start_index: 1114,
end_index: 1127,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:10†source】',
start_index: 1127,
end_index: 1141,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_citation',
text: '【11:7†source】',
start_index: 1141,
end_index: 1154,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
],
},
},
],
attachments: [],
metadata: {},
files: [
{
object: 'file',
id: 'file-XXXXXXXXXXXXXXXXXXXX',
purpose: 'assistants',
filename: 'hp1.txt',
bytes: 439742,
created_at: 1722962139,
status: 'processed',
status_details: null,
type: 'text/plain',
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-XXXXXXXXXXXXXXXXXXXX/hp1.txt',
usage: 1,
user: 'XXXXXXXXXXXXXXXXXXXX',
context: 'assistants',
source: 'openai',
model: 'gpt-4o',
},
],
},
];
retrieveAndProcessFile.mockResolvedValue({ filename: 'hp1.txt' });
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText = `The text you have uploaded is from the book "Harry Potter and the Philosopher's Stone" by J.K. Rowling. It follows the story of a young boy named Harry Potter who discovers that he is a wizard on his eleventh birthday. Here are some key points of the narrative:
1. **Discovery and Invitation to Hogwarts**: Harry learns that he is a wizard and receives an invitation to attend Hogwarts School of Witchcraft and Wizardry^1^.
2. **Shopping for Supplies**: Hagrid takes Harry to Diagon Alley to buy his school supplies, including his wand from Ollivander's^1^.
3. **Introduction to Hogwarts**: Harry is introduced to Hogwarts, the magical school where he will learn about magic and discover more about his own background^1^.
4. **Meeting Friends and Enemies**: At Hogwarts, Harry makes friends like Ron Weasley and Hermione Granger, and enemies like Draco Malfoy^1^.
5. **Uncovering the Mystery**: Harry, along with Ron and Hermione, uncovers the mystery of the Philosopher's Stone and its connection to the dark wizard Voldemort^1^.
These points highlight Harry's initial experiences in the magical world and set the stage for his adventures at Hogwarts.
^1.^ hp1.txt`;
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
test('handles edge case with pre-existing citation-like text', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value:
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation【11:2†source】.',
annotations: [
{
type: 'file_citation',
text: '【11:2†source】',
start_index: 79,
end_index: 92,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockResolvedValue({ filename: 'test.txt' });
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText =
'This is a test ^1^ with pre-existing citation-like text. Here\'s a real citation^1^.\n\n^1.^ test.txt';
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
test('handles FILE_PATH annotation type', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'Here is a file path: [file_path]',
annotations: [
{
type: 'file_path',
text: '[file_path]',
start_index: 21,
end_index: 32,
file_path: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockResolvedValue({
filename: 'test.txt',
filepath: '/path/to/test.txt',
});
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText = 'Here is a file path: /path/to/test.txt';
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
test('handles FILE_CITATION annotation type', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'Here is a citation: [citation]',
annotations: [
{
type: 'file_citation',
text: '[citation]',
start_index: 20,
end_index: 30,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockResolvedValue({ filename: 'test.txt' });
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText = 'Here is a citation: ^1^\n\n^1.^ test.txt';
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
test('handles multiple annotation types in a single message', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value:
'File path: [file_path]. Citation: [citation1]. Another citation: [citation2].',
annotations: [
{
type: 'file_path',
text: '[file_path]',
start_index: 11,
end_index: 22,
file_path: {
file_id: 'file-XXXXXXXXXXXXXXXX1',
},
},
{
type: 'file_citation',
text: '[citation1]',
start_index: 34,
end_index: 45,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXX2',
},
},
{
type: 'file_citation',
text: '[citation2]',
start_index: 65,
end_index: 76,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXX3',
},
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockResolvedValueOnce({
filename: 'file1.txt',
filepath: '/path/to/file1.txt',
});
retrieveAndProcessFile.mockResolvedValueOnce({ filename: 'file2.txt' });
retrieveAndProcessFile.mockResolvedValueOnce({ filename: 'file3.txt' });
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText =
'File path: /path/to/file1.txt. Citation: ^1^. Another citation: ^2^.\n\n^1.^ file2.txt\n^2.^ file3.txt';
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
test('handles annotation processing failure', async () => {
const messages = [
{
content: [
{
type: 'text',
text: {
value: 'This citation will fail: [citation]',
annotations: [
{
type: 'file_citation',
text: '[citation]',
start_index: 25,
end_index: 35,
file_citation: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
],
},
},
],
created_at: 1,
},
];
retrieveAndProcessFile.mockRejectedValue(new Error('File not found'));
const result = await processMessages({
openai: {},
client: { processedFileIds: new Set() },
messages,
});
const expectedText = 'This citation will fail: [citation]';
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(false);
});
test('handles multiple FILE_PATH annotations with sandbox links', async () => {
const messages = [
{
id: 'msg_XXXXXXXXXXXXXXXXXXXX',
object: 'thread.message',
created_at: 1722983745,
assistant_id: 'asst_XXXXXXXXXXXXXXXXXXXX',
thread_id: 'thread_XXXXXXXXXXXXXXXXXXXX',
run_id: 'run_XXXXXXXXXXXXXXXXXXXX',
status: 'completed',
incomplete_details: null,
incomplete_at: null,
completed_at: 1722983747,
role: 'assistant',
content: [
{
type: 'text',
text: {
value:
'I have generated three dummy CSV files for you. You can download them using the links below:\n\n1. [Download Dummy Data 1](sandbox:/mnt/data/dummy_data1.csv)\n2. [Download Dummy Data 2](sandbox:/mnt/data/dummy_data2.csv)\n3. [Download Dummy Data 3](sandbox:/mnt/data/dummy_data3.csv)',
annotations: [
{
type: 'file_path',
text: 'sandbox:/mnt/data/dummy_data1.csv',
start_index: 121,
end_index: 154,
file_path: {
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
},
},
{
type: 'file_path',
text: 'sandbox:/mnt/data/dummy_data2.csv',
start_index: 183,
end_index: 216,
file_path: {
file_id: 'file-YYYYYYYYYYYYYYYYYYYY',
},
},
{
type: 'file_path',
text: 'sandbox:/mnt/data/dummy_data3.csv',
start_index: 245,
end_index: 278,
file_path: {
file_id: 'file-ZZZZZZZZZZZZZZZZZZZZ',
},
},
],
},
},
],
attachments: [
{
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
tools: [
{
type: 'code_interpreter',
},
],
},
{
file_id: 'file-YYYYYYYYYYYYYYYYYYYY',
tools: [
{
type: 'code_interpreter',
},
],
},
{
file_id: 'file-ZZZZZZZZZZZZZZZZZZZZ',
tools: [
{
type: 'code_interpreter',
},
],
},
],
metadata: {},
files: [
{
object: 'file',
id: 'file-XXXXXXXXXXXXXXXXXXXX',
purpose: 'assistants_output',
filename: 'dummy_data1.csv',
bytes: 1925,
created_at: 1722983746,
status: 'processed',
status_details: null,
type: 'text/csv',
file_id: 'file-XXXXXXXXXXXXXXXXXXXX',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-XXXXXXXXXXXXXXXXXXXX/dummy_data1.csv',
usage: 1,
user: 'XXXXXXXXXXXXXXXXXXXX',
context: 'assistants_output',
source: 'openai',
model: 'gpt-4o-mini',
},
{
object: 'file',
id: 'file-YYYYYYYYYYYYYYYYYYYY',
purpose: 'assistants_output',
filename: 'dummy_data2.csv',
bytes: 4221,
created_at: 1722983746,
status: 'processed',
status_details: null,
type: 'text/csv',
file_id: 'file-YYYYYYYYYYYYYYYYYYYY',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-YYYYYYYYYYYYYYYYYYYY/dummy_data2.csv',
usage: 1,
user: 'XXXXXXXXXXXXXXXXXXXX',
context: 'assistants_output',
source: 'openai',
model: 'gpt-4o-mini',
},
{
object: 'file',
id: 'file-ZZZZZZZZZZZZZZZZZZZZ',
purpose: 'assistants_output',
filename: 'dummy_data3.csv',
bytes: 3534,
created_at: 1722983747,
status: 'processed',
status_details: null,
type: 'text/csv',
file_id: 'file-ZZZZZZZZZZZZZZZZZZZZ',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-ZZZZZZZZZZZZZZZZZZZZ/dummy_data3.csv',
usage: 1,
user: 'XXXXXXXXXXXXXXXXXXXX',
context: 'assistants_output',
source: 'openai',
model: 'gpt-4o-mini',
},
],
},
];
const mockClient = {
processedFileIds: new Set(),
};
// Mock the retrieveAndProcessFile function for each file
retrieveAndProcessFile.mockImplementation(({ file_id }) => {
const fileMap = {
'file-XXXXXXXXXXXXXXXXXXXX': {
filename: 'dummy_data1.csv',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-XXXXXXXXXXXXXXXXXXXX/dummy_data1.csv',
},
'file-YYYYYYYYYYYYYYYYYYYY': {
filename: 'dummy_data2.csv',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-YYYYYYYYYYYYYYYYYYYY/dummy_data2.csv',
},
'file-ZZZZZZZZZZZZZZZZZZZZ': {
filename: 'dummy_data3.csv',
filepath:
'https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-ZZZZZZZZZZZZZZZZZZZZ/dummy_data3.csv',
},
};
return Promise.resolve(fileMap[file_id]);
});
const result = await processMessages({ openai: {}, client: mockClient, messages });
const expectedText =
'I have generated three dummy CSV files for you. You can download them using the links below:\n\n1. [Download Dummy Data 1](https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-XXXXXXXXXXXXXXXXXXXX/dummy_data1.csv)\n2. [Download Dummy Data 2](https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-YYYYYYYYYYYYYYYYYYYY/dummy_data2.csv)\n3. [Download Dummy Data 3](https://api.openai.com/v1/files/XXXXXXXXXXXXXXXXXXXX/file-ZZZZZZZZZZZZZZZZZZZZ/dummy_data3.csv)';
expect(result.text).toBe(expectedText);
expect(result.edited).toBe(true);
});
});

View File

@@ -73,6 +73,10 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
continue;
}
if (included.size > 0 && !included.has(file)) {
continue;
}
let toolInstance = null;
try {
toolInstance = new ToolClass({ override: true });
@@ -88,14 +92,6 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] })
continue;
}
if (filter.has(toolInstance.name) && included.size === 0) {
continue;
}
if (included.size > 0 && !included.has(file) && !included.has(toolInstance.name)) {
continue;
}
const formattedTool = formatToOpenAIAssistantTool(toolInstance);
tools.push(formattedTool);
}
@@ -335,7 +331,7 @@ async function processRequiredActions(client, requiredActions) {
continue;
}
tool = await createActionTool({ action: actionSet, requestBuilder });
tool = createActionTool({ action: actionSet, requestBuilder });
isActionTool = !!tool;
ActionToolMap[currentAction.tool] = tool;
}

Some files were not shown because too many files have changed in this diff Show More