Compare commits

..

1 Commits

Author SHA1 Message Date
Danny Avila
78283e1686 🤖 : Azure Assistants V2 2024-05-21 17:01:49 -04:00
563 changed files with 8172 additions and 26169 deletions

View File

@@ -1,2 +0,0 @@
# When running devcontainers, you can specify if docker & docker-compose should be installed in your environment
INSTALL_DOCKER=false

View File

@@ -1,33 +1,5 @@
# .devcontainer/Dockerfile
FROM node:18-bullseye
ARG INSTALL_DOCKER="false"
ENV INSTALL_DOCKER=${INSTALL_DOCKER}
# Install Docker and Docker Compose only if INSTALL_DOCKER is "true"
RUN if [ "$INSTALL_DOCKER" = "true" ]; then \
apt-get update && \
apt-get install -y apt-transport-https ca-certificates curl gnupg lsb-release && \
curl -fsSL https://download.docker.com/linux/debian/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg && \
echo "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null && \
apt-get update && \
apt-get install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin; \
fi
# Install sudo
RUN apt-get update && apt-get install -y sudo
# Set up non-root user
RUN useradd -m -s /bin/bash vscode
RUN if [ "$INSTALL_DOCKER" = "true" ]; then usermod -aG docker vscode; fi
# Add vscode user to sudoers
RUN echo "vscode ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers.d/vscode && \
chmod 0440 /etc/sudoers.d/vscode
USER vscode
WORKDIR /workspaces/LibreChat
# Set the default command
CMD ["/bin/bash"]
RUN mkdir -p /workspaces && chown -R vscode:vscode /workspaces
WORKDIR /workspaces

View File

@@ -1,23 +1,18 @@
{
"name": "LibreChat Development",
"dockerComposeFile": "docker-compose.yml",
"service": "app",
"workspaceFolder": "/workspaces/LibreChat",
"workspaceFolder": "/workspaces",
"customizations": {
"vscode": {
"extensions": ["ms-azuretools.vscode-docker"]
"extensions": [],
"settings": {
"terminal.integrated.profiles.linux": {
"bash": null
}
}
}
},
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {
"version": "latest",
"moby": false,
"dockerDashComposeVersion": "v2"
}
},
"remoteUser": "vscode",
"postCreateCommand": "sudo chown root:docker /var/run/docker.sock && sudo chmod 660 /var/run/docker.sock && npm run reinstall && npm run pull:rag && npm run copy-ex && MEILI_MASTER_KEY=$(docker-compose -f .devcontainer/docker-compose.yml exec -T meilisearch printenv MEILI_MASTER_KEY) && sed -i \"s/^MEILI_MASTER_KEY=.*/MEILI_MASTER_KEY=$MEILI_MASTER_KEY/\" .env",
"remoteEnv": {
"INSTALL_DOCKER": "${localEnv:INSTALL_DOCKER:false}"
}
"postCreateCommand": "",
"features": { "ghcr.io/devcontainers/features/git:1": {} },
"remoteUser": "vscode"
}

View File

@@ -1,16 +1,10 @@
# .devcontainer/docker-compose.yml
version: "3.8"
services:
app:
group_add:
- docker
build:
build:
context: ..
dockerfile: .devcontainer/Dockerfile
args:
- INSTALL_DOCKER=${INSTALL_DOCKER:-false}
# restart: always
links:
- mongodb
@@ -23,7 +17,6 @@ services:
volumes:
# This is where VS Code should expect to find your project's source code and the value of "workspaceFolder" in .devcontainer/devcontainer.json
- ..:/workspaces:cached
- /var/run/docker.sock:/var/run/docker.sock
# Uncomment the next line to use Docker from inside the container. See https://aka.ms/vscode-remote/samples/docker-from-docker-compose for details.
# - /var/run/docker.sock:/var/run/docker.sock
environment:
@@ -43,12 +36,14 @@ services:
user: vscode
# Overrides default command so things don't shut down after the process ends.
command: /bin/sh -c "while sleep 1000; do :; done"
command: /bin/sh -c "while sleep 1000; do :; done"
mongodb:
container_name: chat-mongodb
expose:
- 27017
# ports:
# - 27018:27017
image: mongo
# restart: always
volumes:
@@ -60,8 +55,11 @@ services:
# restart: always
expose:
- 7700
# Uncomment this to access meilisearch from outside docker
# ports:
# - 7700:7700 # if exposing these ports, make sure your master key is not the default value
environment:
- MEILI_NO_ANALYTICS=true
- MEILI_MASTER_KEY=${MEILI_MASTER_KEY:-$(openssl rand -hex 16)}
- MEILI_MASTER_KEY=5c71cf56d672d009e36070b5bc5e47b743535ae55c818ae3b735bb6ebfb4ba63
volumes:
- ./meili_data_v1.5:/meili_data

View File

@@ -64,8 +64,6 @@ PROXY=
# ANYSCALE_API_KEY=
# APIPIE_API_KEY=
# COHERE_API_KEY=
# DATABRICKS_API_KEY=
# FIREWORKS_API_KEY=
# GROQ_API_KEY=
# HUGGINGFACE_TOKEN=
@@ -80,7 +78,7 @@ PROXY=
#============#
ANTHROPIC_API_KEY=user_provided
# ANTHROPIC_MODELS=claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
# ANTHROPIC_MODELS=claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
# ANTHROPIC_REVERSE_PROXY=
#============#
@@ -121,9 +119,7 @@ GOOGLE_KEY=user_provided
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
# Vertex AI
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
# GOOGLE_TITLE_MODEL=gemini-pro
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0409,gemini-1.0-pro-vision-001,gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k
# Google Gemini Safety Settings
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
@@ -261,14 +257,6 @@ MEILI_NO_ANALYTICS=true
MEILI_HOST=http://0.0.0.0:7700
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
#==================================================#
# Speech to Text & Text to Speech #
#==================================================#
STT_API_KEY=
TTS_API_KEY=
#===================================================#
# User System #
#===================================================#
@@ -323,9 +311,6 @@ ALLOW_EMAIL_LOGIN=true
ALLOW_REGISTRATION=true
ALLOW_SOCIAL_LOGIN=false
ALLOW_SOCIAL_REGISTRATION=false
ALLOW_PASSWORD_RESET=false
# ALLOW_ACCOUNT_DELETION=true # note: enabled by default if omitted/commented out
ALLOW_UNVERIFIED_EMAIL_LOGIN=true
SESSION_EXPIRY=1000 * 60 * 15
REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7
@@ -367,17 +352,6 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
OPENID_BUTTON_LABEL=
OPENID_IMAGE_URL=
# LDAP
LDAP_URL=
LDAP_BIND_DN=
LDAP_BIND_CREDENTIALS=
LDAP_USER_SEARCH_BASE=
LDAP_SEARCH_FILTER=mail={{username}}
LDAP_CA_CERT_PATH=
# LDAP_ID=
# LDAP_USERNAME=
# LDAP_FULL_NAME=
#========================#
# Email Password Reset #
#========================#
@@ -404,13 +378,6 @@ FIREBASE_STORAGE_BUCKET=
FIREBASE_MESSAGING_SENDER_ID=
FIREBASE_APP_ID=
#========================#
# Shared Links #
#========================#
ALLOW_SHARED_LINKS=true
ALLOW_SHARED_LINKS_PUBLIC=true
#===================================================#
# UI #
#===================================================#
@@ -421,9 +388,6 @@ HELP_AND_FAQ_URL=https://librechat.ai
# SHOW_BIRTHDAY_ICON=true
# Google tag manager id
#ANALYTICS_GTM_ID=user provided google tag manager id
#==================================================#
# Others #
#==================================================#
@@ -436,6 +400,3 @@ HELP_AND_FAQ_URL=https://librechat.ai
# E2E_USER_EMAIL=
# E2E_USER_PASSWORD=
# RAG_PORT
RAG_PORT=8000

View File

@@ -126,18 +126,6 @@ Apply the following naming conventions to branches, labels, and other Git-relate
- **Current Stance**: At present, this backend transition is of lower priority and might not be pursued.
## 7. Module Import Conventions
- `npm` packages first,
- from shortest line (top) to longest (bottom)
- Followed by typescript types (pertains to data-provider and client workspaces)
- longest line (top) to shortest (bottom)
- types from package come first
- Lastly, local imports
- longest line (top) to shortest (bottom)
- imports with alias `~` treated the same as relative import with respect to line length
---

2
.gitignore vendored
View File

@@ -11,7 +11,6 @@ logs
pids
*.pid
*.seed
.git
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
@@ -46,7 +45,6 @@ api/node_modules/
client/node_modules/
bower_components/
*.d.ts
!vite-env.d.ts
# Floobits
.floo

View File

@@ -1,4 +1,4 @@
# v0.7.3
# v0.7.2
# Base node image
FROM node:20-alpine AS node

View File

@@ -1,4 +1,4 @@
# v0.7.3
# v0.7.2
# Build API, Client and Data Provider
FROM node:20-alpine AS base

View File

@@ -27,7 +27,7 @@
</p>
<p align="center">
<a href="https://railway.app/template/b5k2mn?referralCode=HI9hWz">
<a href="https://railway.app/template/b5k2mn?referralCode=myKrVZ">
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
</a>
<a href="https://zeabur.com/templates/0X2ZY8">
@@ -58,13 +58,9 @@
- 🌎 Multilingual UI:
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro,
- Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands, עברית
- 🎨 Customizable Dropdown & Interface: Adapts to both power users and newcomers
- 📧 Verify your email to ensure secure access
- 🗣️ Chat hands-free with Speech-to-Text and Text-to-Speech magic
- Automatically send and play Audio
- Supports OpenAI, Azure OpenAI, and Elevenlabs
- 🎨 Customizable Dropdown & Interface: Adapts to both power users and newcomers.
- 📥 Import Conversations from LibreChat, ChatGPT, Chatbot UI
- 📤 Export conversations as screenshots, markdown, text, json
- 📤 Export conversations as screenshots, markdown, text, json.
- 🔍 Search all messages/conversations
- 🔌 Plugins, including web access, image generation with DALL-E-3 and more
- 👥 Multi-User, Secure Authentication with Moderation and Token spend tools
@@ -81,7 +77,7 @@ LibreChat brings together the future of assistant AIs with the revolutionary tec
With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform.
[![Watch the video](https://img.youtube.com/vi/bSVHEbVPNl4/maxresdefault.jpg)](https://www.youtube.com/watch?v=bSVHEbVPNl4)
[![Watch the video](https://img.youtube.com/vi/YLVUW5UP9N0/maxresdefault.jpg)](https://www.youtube.com/watch?v=YLVUW5UP9N0)
Click on the thumbnail to open the video☝
---

View File

@@ -1,5 +1,4 @@
const Anthropic = require('@anthropic-ai/sdk');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
getResponseSender,
@@ -124,14 +123,9 @@ class AnthropicClient extends BaseClient {
getClient() {
/** @type {Anthropic.default.RequestOptions} */
const options = {
fetch: this.fetch,
apiKey: this.apiKey,
};
if (this.options.proxy) {
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
if (this.options.reverseProxyUrl) {
options.baseURL = this.options.reverseProxyUrl;
}

View File

@@ -1,7 +1,6 @@
const crypto = require('crypto');
const fetch = require('node-fetch');
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
@@ -18,15 +17,6 @@ class BaseClient {
month: 'long',
day: 'numeric',
});
this.fetch = this.fetch.bind(this);
/** @type {boolean} */
this.skipSaveConvo = false;
/** @type {boolean} */
this.skipSaveUserMessage = false;
/** @type {ClientDatabaseSavePromise} */
this.userMessagePromise;
/** @type {ClientDatabaseSavePromise} */
this.responsePromise;
}
setOptions() {
@@ -64,25 +54,6 @@ class BaseClient {
});
}
/**
* Makes an HTTP request and logs the process.
*
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
* @param {RequestInit} [init] - Optional init options for the request.
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
*/
async fetch(_url, init) {
let url = _url;
if (this.options.directEndpoint) {
url = this.options.reverseProxyUrl;
}
logger.debug(`Making request to ${url}`);
if (typeof Bun !== 'undefined') {
return await fetch(url, init);
}
return await fetch(url, init);
}
getBuildMessagesOptions() {
throw new Error('Subclasses must implement getBuildMessagesOptions');
}
@@ -92,45 +63,19 @@ class BaseClient {
await stream.processTextStream(onProgress);
}
/**
* @returns {[string|undefined, string|undefined]}
*/
processOverideIds() {
/** @type {Record<string, string | undefined>} */
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
if (overrideConvoId) {
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
overrideConvoId = conversationId;
if (index !== '0') {
this.skipSaveConvo = true;
}
}
if (overrideUserMessageId) {
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
overrideUserMessageId = userMessageId;
if (index !== '0') {
this.skipSaveUserMessage = true;
}
}
return [overrideConvoId, overrideUserMessageId];
}
async setMessageOptions(opts = {}) {
if (opts && opts.replaceOptions) {
this.setOptions(opts);
}
const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
const { isEdited, isContinued } = opts;
const user = opts.user ?? null;
this.user = user;
const saveOptions = this.getSaveOptions();
this.abortController = opts.abortController ?? new AbortController();
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
const conversationId = opts.conversationId ?? crypto.randomUUID();
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
const userMessageId =
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
let head = isEdited ? responseMessageId : parentMessageId;
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
@@ -194,7 +139,7 @@ class BaseClient {
}
if (typeof opts?.onStart === 'function') {
opts.onStart(userMessage, responseMessageId);
opts.onStart(userMessage);
}
return {
@@ -428,14 +373,6 @@ class BaseClient {
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
const { generation = '' } = opts;
// It's not necessary to push to currentMessages
@@ -484,13 +421,8 @@ class BaseClient {
this.handleTokenCountMap(tokenCountMap);
}
if (!isEdited && !this.skipSaveUserMessage) {
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise: this.userMessagePromise,
});
}
if (!isEdited) {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
}
if (
@@ -539,11 +471,15 @@ class BaseClient {
const completionTokens = this.getTokenCount(completion);
await this.recordTokenUsage({ promptTokens, completionTokens });
}
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return responseMessage;
}
async getConversation(conversationId, user = null) {
return await getConvo(user, conversationId);
}
async loadHistory(conversationId, parentMessageId = null) {
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
@@ -598,24 +534,18 @@ class BaseClient {
* @param {string | null} user
*/
async saveMessageToDatabase(message, endpointOptions, user = null) {
const savedMessage = await saveMessage({
await saveMessage({
...message,
endpoint: this.options.endpoint,
unfinished: false,
user,
});
if (this.skipSaveConvo) {
return { message: savedMessage };
}
const conversation = await saveConvo(user, {
await saveConvo(user, {
conversationId: message.conversationId,
endpoint: this.options.endpoint,
endpointType: this.options.endpointType,
...endpointOptions,
});
return { message: savedMessage, conversation };
}
async updateMessageInDatabase(message) {

View File

@@ -438,17 +438,9 @@ class ChatGPTClient extends BaseClient {
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
reply += message.text;
}
/*
Cohere API Chinese Unicode character replacement hotfix.
Should be un-commented when the following issue is resolved:
https://github.com/cohere-ai/cohere-typescript/issues/151
else if (message.eventType === 'stream-end' && message.response) {
} else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
*/
}
return reply;

View File

@@ -16,15 +16,10 @@ const {
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { formatMessage, createContextHandlers } = require('./prompts');
const { getModelMaxTokens } = require('~/utils');
const { logger } = require('~/config');
const {
formatMessage,
createContextHandlers,
titleInstruction,
truncateText,
} = require('./prompts');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const loc = 'us-central1';
const publisher = 'google';
@@ -596,16 +591,12 @@ class GoogleClient extends BaseClient {
createLLM(clientOptions) {
const model = clientOptions.modelName ?? clientOptions.model;
if (this.project_id && this.isTextModel) {
logger.debug('Creating Google VertexAI client');
return new GoogleVertexAI(clientOptions);
} else if (this.project_id && this.isChatModel) {
logger.debug('Creating Chat Google VertexAI client');
return new ChatGoogleVertexAI(clientOptions);
} else if (this.project_id) {
logger.debug('Creating VertexAI client');
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
logger.debug('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
@@ -615,7 +606,6 @@ class GoogleClient extends BaseClient {
);
}
logger.debug('Creating Chat Google Generative AI client');
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}
@@ -727,123 +717,6 @@ class GoogleClient extends BaseClient {
return reply;
}
/**
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
*/
async titleChatCompletion(_payload, options = {}) {
const { abortController } = options;
const { parameters, instances } = _payload;
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
let clientOptions = { ...parameters, maxRetries: 2 };
logger.debug('Initialized title client options');
if (this.project_id) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
},
projectId: this.project_id,
};
}
if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}
if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}
const model = this.createLLM(clientOptions);
let reply = '';
const messages = this.isTextModel ? _payload.trim() : _messages;
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
logger.debug('Identified titling model as 1.5 version');
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
};
if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = {
parts: [
{
text: this.options.promptPrefix,
},
],
};
}
const safetySettings = _payload.safetySettings;
requestOptions.safetySettings = safetySettings;
const result = await client.generateContent(requestOptions);
reply = result.response?.text();
return reply;
} else {
logger.debug('Beginning titling');
const safetySettings = _payload.safetySettings;
const titleResponse = await model.invoke(messages, {
signal: abortController.signal,
timeout: 7000,
safetySettings: safetySettings,
});
reply = titleResponse.content;
return reply;
}
}
async titleConvo({ text, responseText = '' }) {
let title = 'New Chat';
const convo = `||>User:
"${truncateText(text)}"
||>Response:
"${JSON.stringify(truncateText(responseText))}"`;
let { prompt: payload } = await this.buildMessages([
{
text: `Please generate ${titleInstruction}
${convo}
||>Title:`,
isCreatedByUser: true,
author: this.userLabel,
},
]);
if (this.isVisionModel) {
logger.warn(
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
);
payload.parameters = { ...payload.parameters, model: settings.model.default };
}
try {
title = await this.titleChatCompletion(payload, {
abortController: new AbortController(),
onProgress: () => {},
});
} catch (e) {
logger.error('[GoogleClient] There was an issue generating the title', e);
}
logger.debug(`Title response: ${title}`);
return title;
}
getSaveOptions() {
return {
promptPrefix: this.options.promptPrefix,

View File

@@ -588,7 +588,7 @@ class OpenAIClient extends BaseClient {
let streamResult = null;
this.modelOptions.user = this.user;
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
if (typeof opts.onProgress === 'function' && useOldMethod) {
const completionResult = await this.getCompletion(
payload,
@@ -827,7 +827,7 @@ class OpenAIClient extends BaseClient {
const instructionsPayload = [
{
role: this.options.titleMessageRole ?? 'system',
role: 'system',
content: `Please generate ${titleInstruction}
${convo}
@@ -1106,12 +1106,7 @@ ${convo}
}
if (this.azure || this.options.azure) {
/* Azure Bug, extremely short default `max_tokens` response */
if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
modelOptions.max_tokens = 4000;
}
/* Azure does not accept `model` in the body, so we need to remove it. */
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
opts.baseURL = this.langchainProxy
@@ -1132,7 +1127,6 @@ ${convo}
let chatCompletion;
/** @type {OpenAI} */
const openai = new OpenAI({
fetch: this.fetch,
apiKey: this.apiKey,
...opts,
});
@@ -1222,7 +1216,6 @@ ${convo}
});
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
for await (const chunk of stream) {
const token = chunk.choices[0]?.delta?.content || '';
intermediateReply += token;

View File

@@ -238,30 +238,18 @@ class PluginsClient extends OpenAIClient {
await this.recordTokenUsage(responseMessage);
}
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}
async sendMessage(message, opts = {}) {
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);
}
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const {
user,
@@ -276,14 +264,6 @@ class PluginsClient extends OpenAIClient {
onToolEnd,
} = await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
this.currentMessages.push(userMessage);
let {
@@ -312,15 +292,7 @@ class PluginsClient extends OpenAIClient {
if (payload) {
this.currentMessages = payload;
}
if (!this.skipSaveUserMessage) {
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise: this.userMessagePromise,
});
}
}
await this.saveMessageToDatabase(userMessage, saveOptions, user);
if (isEnabled(process.env.CHECK_BALANCE)) {
await checkBalance({

View File

@@ -1,3 +1,44 @@
/*
module.exports = `You are ChatGPT, a Large Language model with useful tools.
Talk to the human and provide meaningful answers when questions are asked.
Use the tools when you need them, but use your own knowledge if you are confident of the answer. Keep answers short and concise.
A tool is not usually needed for creative requests, so do your best to answer them without tools.
Avoid repeating identical answers if it appears before. Only fulfill the human's requests, do not create extra steps beyond what the human has asked for.
Your input for 'Action' should be the name of tool used only.
Be honest. If you can't answer something, or a tool is not appropriate, say you don't know or answer to the best of your ability.
Attempt to fulfill the human's requests in as few actions as possible`;
*/
// module.exports = `You are ChatGPT, a highly knowledgeable and versatile large language model.
// Engage with the Human conversationally, providing concise and meaningful answers to questions. Utilize built-in tools when necessary, except for creative requests, where relying on your own knowledge is preferred. Aim for variety and avoid repetitive answers.
// For your 'Action' input, state the name of the tool used only, and honor user requests without adding extra steps. Always be honest; if you cannot provide an appropriate answer or tool, admit that or do your best.
// Strive to meet the user's needs efficiently with minimal actions.`;
// import {
// BasePromptTemplate,
// BaseStringPromptTemplate,
// SerializedBasePromptTemplate,
// renderTemplate,
// } from "langchain/prompts";
// prefix: `You are ChatGPT, a highly knowledgeable and versatile large language model.
// Your objective is to help users by understanding their intent and choosing the best action. Prioritize direct, specific responses. Use concise, varied answers and rely on your knowledge for creative tasks. Utilize tools when needed, and structure results for machine compatibility.
// prefix: `Objective: to comprehend human intentions based on user input and available tools. Goal: identify the best action to directly address the human's query. In your subsequent steps, you will utilize the chosen action. You may select multiple actions and list them in a meaningful order. Prioritize actions that directly relate to the user's query over general ones. Ensure that the generated thought is highly specific and explicit to best match the user's expectations. Construct the result in a manner that an online open-API would most likely expect. Provide concise and meaningful answers to human queries. Utilize tools when necessary. Relying on your own knowledge is preferred for creative requests. Aim for variety and avoid repetitive answers.
// # Available Actions & Tools:
// N/A: no suitable action, use your own knowledge.`,
// suffix: `Remember, all your responses MUST adhere to the described format and only respond if the format is followed. Output exactly with the requested format, avoiding any other text as this will be parsed by a machine. Following 'Action:', provide only one of the actions listed above. If a tool is not necessary, deduce this quickly and finish your response. Honor the human's requests without adding extra steps. Carry out tasks in the sequence written by the human. Always be honest; if you cannot provide an appropriate answer or tool, do your best with your own knowledge. Strive to meet the user's needs efficiently with minimal actions.`;
module.exports = {
'gpt3-v1': {
prefix: `Objective: Understand human intentions using user input and available tools. Goal: Identify the most suitable actions to directly address user queries.

View File

@@ -8,6 +8,8 @@ In your response, remember to follow these guidelines:
- If you don't know the answer, simply say that you don't know.
- If you are unsure how to answer, ask for clarification.
- Avoid mentioning that you obtained the information from the context.
Answer appropriately in the user's language.
`;
function createContextHandlers(req, userMessageContent) {
@@ -92,40 +94,37 @@ function createContextHandlers(req, userMessageContent) {
const resolvedQueries = await Promise.all(queryPromises);
const context =
resolvedQueries.length === 0
? '\n\tThe semantic search did not return any results.'
: resolvedQueries
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
const context = resolvedQueries
.map((queryResult, index) => {
const file = processedFiles[index];
let contextItems = queryResult.data;
const generateContext = (currentContext) =>
`
const generateContext = (currentContext) =>
`
<file>
<filename>${file.filename}</filename>
<context>${currentContext}
</context>
</file>`;
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
if (useFullContext) {
return generateContext(`\n${contextItems}`);
}
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
contextItems = queryResult.data
.map((item) => {
const pageContent = item[0].page_content;
return `
<contextItem>
<![CDATA[${pageContent?.trim()}]]>
</contextItem>`;
})
.join('');
return generateContext(contextItems);
})
.join('');
return generateContext(contextItems);
})
.join('');
if (useFullContext) {
const prompt = `${header}
${context}

View File

@@ -28,7 +28,7 @@ ${convo}`,
};
const titleInstruction =
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. Never directly mention the language name or the word "title"';
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
You may call them like this:

View File

@@ -576,11 +576,7 @@ describe('BaseClient', () => {
const onStart = jest.fn();
const opts = { onStart };
await TestClient.sendMessage('Hello, world!', opts);
expect(onStart).toHaveBeenCalledWith(
expect.objectContaining({ text: 'Hello, world!' }),
expect.any(String),
);
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
});
test('saveMessageToDatabase is called with the correct arguments', async () => {

View File

@@ -194,7 +194,6 @@ describe('PluginsClient', () => {
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
});
});
describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins
// let client;
@@ -221,94 +220,4 @@ describe('PluginsClient', () => {
spy.mockRestore();
});
});
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
});

View File

@@ -7,7 +7,6 @@ const keyvMongo = require('./keyvMongo');
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
const THIRTY_MINUTES = 1800000;
const TEN_MINUTES = 600000;
const duration = math(BAN_DURATION, 7200000);
@@ -25,14 +24,6 @@ const config = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
const roles = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.ROLES });
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
@@ -50,7 +41,6 @@ const abortKeys = isEnabled(USE_REDIS)
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
const namespaces = {
[CacheKeys.ROLES]: roles,
[CacheKeys.CONFIG_STORE]: config,
pending_req,
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
@@ -65,13 +55,7 @@ const namespaces = {
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
registrations: createViolationInstance('registrations'),
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
ViolationTypes.RESET_PASSWORD_LIMIT,
),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
@@ -80,7 +64,6 @@ const namespaces = {
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
};
/**

View File

@@ -1,6 +1,6 @@
const { isEnabled } = require('~/server/utils');
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
const { isEnabled } = require('../server/utils');
/**
* Logs the violation.

View File

@@ -27,25 +27,26 @@ function getMatchingSensitivePatterns(valueStr) {
}
/**
* Redacts sensitive information from a console message and trims it to a specified length if provided.
* Redacts sensitive information from a console message.
*
* @param {string} str - The console message to be redacted.
* @param {number} [trimLength] - The optional length at which to trim the redacted message.
* @returns {string} - The redacted and optionally trimmed console message.
* @returns {string} - The redacted console message.
*/
function redactMessage(str, trimLength) {
function redactMessage(str) {
if (!str) {
return '';
}
const patterns = getMatchingSensitivePatterns(str);
if (patterns.length === 0) {
return str;
}
patterns.forEach((pattern) => {
str = str.replace(pattern, '$1[REDACTED]');
});
if (trimLength !== undefined && str.length > trimLength) {
return `${str.substring(0, trimLength)}...`;
}
return str;
}

View File

@@ -14,7 +14,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
*/
const updateAssistantDoc = async (searchParams, updateData, session = null) => {
const updateAssistant = async (searchParams, updateData, session = null) => {
const options = { new: true, upsert: true, session };
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
};
@@ -52,7 +52,7 @@ const deleteAssistant = async (searchParams) => {
};
module.exports = {
updateAssistantDoc,
updateAssistant,
deleteAssistant,
getAssistants,
getAssistant,

View File

@@ -1,61 +0,0 @@
const { logger } = require('~/config');
// const { Categories } = require('./schema/categories');
const options = [
{
label: '',
value: '',
},
{
label: 'idea',
value: 'idea',
},
{
label: 'travel',
value: 'travel',
},
{
label: 'teach_or_explain',
value: 'teach_or_explain',
},
{
label: 'write',
value: 'write',
},
{
label: 'shop',
value: 'shop',
},
{
label: 'code',
value: 'code',
},
{
label: 'misc',
value: 'misc',
},
{
label: 'roleplay',
value: 'roleplay',
},
{
label: 'finance',
value: 'finance',
},
];
module.exports = {
/**
* Retrieves the categories asynchronously.
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
* @throws {Error} If there is an error retrieving the categories.
*/
getCategories: async () => {
try {
// const categories = await Categories.find();
return options;
} catch (error) {
logger.error('Error getting categories', error);
return [];
}
},
};

View File

@@ -21,18 +21,16 @@ module.exports = {
Conversation,
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
try {
const messages = await getMessages({ conversationId }, '_id');
const messages = await getMessages({ conversationId });
const update = { ...convo, messages, user };
if (newConversationId) {
update.conversationId = newConversationId;
}
const conversation = await Conversation.findOneAndUpdate({ conversationId, user }, update, {
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: true,
});
return conversation.toObject();
} catch (error) {
logger.error('[saveConvo] Error saving conversation', error);
return { message: 'Error saving conversation' };

View File

@@ -97,12 +97,8 @@ const deleteFileByFilter = async (filter) => {
* @param {Array<string>} file_ids - The unique identifiers of the files to delete.
* @returns {Promise<Object>} A promise that resolves to the result of the deletion operation.
*/
const deleteFiles = async (file_ids, user) => {
let deleteQuery = { file_id: { $in: file_ids } };
if (user) {
deleteQuery = { user: user };
}
return await File.deleteMany(deleteQuery);
const deleteFiles = async (file_ids) => {
return await File.deleteMany({ file_id: { $in: file_ids } });
};
module.exports = {

View File

@@ -57,13 +57,18 @@ module.exports = {
if (files) {
update.files = files;
}
// may also need to update the conversation here
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
const message = await Message.findOneAndUpdate({ messageId }, update, {
upsert: true,
new: true,
});
return message.toObject();
return {
messageId,
conversationId,
parentMessageId,
sender,
text,
isCreatedByUser,
tokenCount,
};
} catch (err) {
logger.error('Error saving message:', err);
throw new Error('Failed to save message.');
@@ -124,14 +129,6 @@ module.exports = {
throw new Error('Failed to save message.');
}
},
async updateMessageText({ messageId, text }) {
try {
await Message.updateOne({ messageId }, { text });
} catch (err) {
logger.error('Error updating message text:', err);
throw new Error('Failed to update message text.');
}
},
async updateMessage(message) {
try {
const { messageId, ...update } = message;
@@ -174,18 +171,8 @@ module.exports = {
}
},
/**
* Retrieves messages from the database.
* @param {Record<string, unknown>} filter
* @param {string | undefined} [select]
* @returns
*/
async getMessages(filter, select) {
async getMessages(filter) {
try {
if (select) {
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
}
return await Message.find(filter).sort({ createdAt: 1 }).lean();
} catch (err) {
logger.error('Error getting messages:', err);

View File

@@ -1,90 +0,0 @@
const { model } = require('mongoose');
const projectSchema = require('~/models/schema/projectSchema');
const Project = model('Project', projectSchema);
/**
* Retrieve a project by ID and convert the found project document to a plain object.
*
* @param {string} projectId - The ID of the project to find and return as a plain object.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoProject>} A plain object representing the project document, or `null` if no project is found.
*/
const getProjectById = async function (projectId, fieldsToSelect = null) {
const query = Project.findById(projectId);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
/**
* Retrieve a project by name and convert the found project document to a plain object.
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
*
* @param {string} projectName - The name of the project to find or create.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoProject>} A plain object representing the project document.
*/
const getProjectByName = async function (projectName, fieldsToSelect = null) {
const query = { name: projectName };
const update = { $setOnInsert: { name: projectName } };
const options = {
new: true,
upsert: projectName === 'instance',
lean: true,
select: fieldsToSelect,
};
return await Project.findOneAndUpdate(query, update, options);
};
/**
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
*
* @param {string} projectId - The ID of the project to update.
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
* @returns {Promise<MongoProject>} The updated project document.
*/
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
return await Project.findByIdAndUpdate(
projectId,
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
{ new: true },
);
};
/**
* Remove an array of prompt group IDs from a project's promptGroupIds array.
*
* @param {string} projectId - The ID of the project to update.
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
* @returns {Promise<MongoProject>} The updated project document.
*/
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
return await Project.findByIdAndUpdate(
projectId,
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
{ new: true },
);
};
/**
* Remove a prompt group ID from all projects.
*
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
* @returns {Promise<void>}
*/
const removeGroupFromAllProjects = async (promptGroupId) => {
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
};
module.exports = {
getProjectById,
getProjectByName,
addGroupIdsToProject,
removeGroupIdsFromProject,
removeGroupFromAllProjects,
};

View File

@@ -1,528 +1,52 @@
const { ObjectId } = require('mongodb');
const { SystemRoles, SystemCategories } = require('librechat-data-provider');
const {
getProjectByName,
addGroupIdsToProject,
removeGroupIdsFromProject,
removeGroupFromAllProjects,
} = require('./Project');
const { Prompt, PromptGroup } = require('./schema/promptSchema');
const mongoose = require('mongoose');
const { logger } = require('~/config');
/**
* Create a pipeline for the aggregation to get prompt groups
* @param {Object} query
* @param {number} skip
* @param {number} limit
* @returns {[Object]} - The pipeline for the aggregation
*/
const createGroupPipeline = (query, skip, limit) => {
return [
{ $match: query },
{ $sort: { createdAt: -1 } },
{ $skip: skip },
{ $limit: limit },
{
$lookup: {
from: 'prompts',
localField: 'productionId',
foreignField: '_id',
as: 'productionPrompt',
},
const promptSchema = mongoose.Schema(
{
title: {
type: String,
required: true,
},
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
{
$project: {
name: 1,
numberOfGenerations: 1,
oneliner: 1,
category: 1,
projectIds: 1,
productionId: 1,
author: 1,
authorName: 1,
createdAt: 1,
updatedAt: 1,
'productionPrompt.prompt': 1,
// 'productionPrompt._id': 1,
// 'productionPrompt.type': 1,
},
prompt: {
type: String,
required: true,
},
category: {
type: String,
},
];
};
/**
* Create a pipeline for the aggregation to get all prompt groups
* @param {Object} query
* @param {Partial<MongoPromptGroup>} $project
* @returns {[Object]} - The pipeline for the aggregation
*/
const createAllGroupsPipeline = (
query,
$project = {
name: 1,
oneliner: 1,
category: 1,
author: 1,
authorName: 1,
createdAt: 1,
updatedAt: 1,
command: 1,
'productionPrompt.prompt': 1,
},
) => {
return [
{ $match: query },
{ $sort: { createdAt: -1 } },
{
$lookup: {
from: 'prompts',
localField: 'productionId',
foreignField: '_id',
as: 'productionPrompt',
},
},
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
{
$project,
},
];
};
{ timestamps: true },
);
/**
* Get all prompt groups with filters
* @param {Object} req
* @param {TPromptGroupsWithFilterRequest} filter
* @returns {Promise<PromptGroupListResponse>}
*/
const getAllPromptGroups = async (req, filter) => {
try {
const { name, ...query } = filter;
if (!query.author) {
throw new Error('Author is required');
}
let searchShared = true;
let searchSharedOnly = false;
if (name) {
query.name = new RegExp(name, 'i');
}
if (!query.category) {
delete query.category;
} else if (query.category === SystemCategories.MY_PROMPTS) {
searchShared = false;
delete query.category;
} else if (query.category === SystemCategories.NO_CATEGORY) {
query.category = '';
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
searchSharedOnly = true;
delete query.category;
}
let combinedQuery = query;
if (searchShared) {
const project = await getProjectByName('instance', 'promptGroupIds');
if (project && project.promptGroupIds.length > 0) {
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
delete projectQuery.author;
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
}
}
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
} catch (error) {
console.error('Error getting all prompt groups', error);
return { message: 'Error getting all prompt groups' };
}
};
/**
* Get prompt groups with filters
* @param {Object} req
* @param {TPromptGroupsWithFilterRequest} filter
* @returns {Promise<PromptGroupListResponse>}
*/
const getPromptGroups = async (req, filter) => {
try {
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
if (!query.author) {
throw new Error('Author is required');
}
let searchShared = true;
let searchSharedOnly = false;
if (name) {
query.name = new RegExp(name, 'i');
}
if (!query.category) {
delete query.category;
} else if (query.category === SystemCategories.MY_PROMPTS) {
searchShared = false;
delete query.category;
} else if (query.category === SystemCategories.NO_CATEGORY) {
query.category = '';
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
searchSharedOnly = true;
delete query.category;
}
let combinedQuery = query;
if (searchShared) {
// const projects = req.user.projects || []; // TODO: handle multiple projects
const project = await getProjectByName('instance', 'promptGroupIds');
if (project && project.promptGroupIds.length > 0) {
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
delete projectQuery.author;
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
}
}
const skip = (validatedPageNumber - 1) * validatedPageSize;
const limit = validatedPageSize;
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
PromptGroup.aggregate(promptGroupsPipeline).exec(),
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
]);
const promptGroups = promptGroupsResults;
const totalPromptGroups =
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
return {
promptGroups,
pageNumber: validatedPageNumber.toString(),
pageSize: validatedPageSize.toString(),
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
};
} catch (error) {
console.error('Error getting prompt groups', error);
return { message: 'Error getting prompt groups' };
}
};
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
module.exports = {
getPromptGroups,
getAllPromptGroups,
/**
* Create a prompt and its respective group
* @param {TCreatePromptRecord} saveData
* @returns {Promise<TCreatePromptResponse>}
*/
createPromptGroup: async (saveData) => {
savePrompt: async ({ title, prompt }) => {
try {
const { prompt, group, author, authorName } = saveData;
let newPromptGroup = await PromptGroup.findOneAndUpdate(
{ ...group, author, authorName, productionId: null },
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
{ new: true, upsert: true },
)
.lean()
.select('-__v')
.exec();
const newPrompt = await Prompt.findOneAndUpdate(
{ ...prompt, author, groupId: newPromptGroup._id },
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
{ new: true, upsert: true },
)
.lean()
.select('-__v')
.exec();
newPromptGroup = await PromptGroup.findByIdAndUpdate(
newPromptGroup._id,
{ productionId: newPrompt._id },
{ new: true },
)
.lean()
.select('-__v')
.exec();
return {
prompt: newPrompt,
group: {
...newPromptGroup,
productionPrompt: { prompt: newPrompt.prompt },
},
};
} catch (error) {
logger.error('Error saving prompt group', error);
throw new Error('Error saving prompt group');
}
},
/**
* Save a prompt
* @param {TCreatePromptRecord} saveData
* @returns {Promise<TCreatePromptResponse>}
*/
savePrompt: async (saveData) => {
try {
const { prompt, author } = saveData;
const newPromptData = {
...prompt,
author,
};
/** @type {TPrompt} */
let newPrompt;
try {
newPrompt = await Prompt.create(newPromptData);
} catch (error) {
if (error?.message?.includes('groupId_1_version_1')) {
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
} else {
throw error;
}
newPrompt = await Prompt.create(newPromptData);
}
return { prompt: newPrompt };
await Prompt.create({
title,
prompt,
});
return { title, prompt };
} catch (error) {
logger.error('Error saving prompt', error);
return { message: 'Error saving prompt' };
return { prompt: 'Error saving prompt' };
}
},
getPrompts: async (filter) => {
try {
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
return await Prompt.find(filter).lean();
} catch (error) {
logger.error('Error getting prompts', error);
return { message: 'Error getting prompts' };
return { prompt: 'Error getting prompts' };
}
},
getPrompt: async (filter) => {
deletePrompts: async (filter) => {
try {
if (filter.groupId) {
filter.groupId = new ObjectId(filter.groupId);
}
return await Prompt.findOne(filter).lean();
return await Prompt.deleteMany(filter);
} catch (error) {
logger.error('Error getting prompt', error);
return { message: 'Error getting prompt' };
}
},
/**
* Get prompt groups with filters
* @param {TGetRandomPromptsRequest} filter
* @returns {Promise<TGetRandomPromptsResponse>}
*/
getRandomPromptGroups: async (filter) => {
try {
const result = await PromptGroup.aggregate([
{
$match: {
category: { $ne: '' },
},
},
{
$group: {
_id: '$category',
promptGroup: { $first: '$$ROOT' },
},
},
{
$replaceRoot: { newRoot: '$promptGroup' },
},
{
$sample: { size: +filter.limit + +filter.skip },
},
{
$skip: +filter.skip,
},
{
$limit: +filter.limit,
},
]);
return { prompts: result };
} catch (error) {
logger.error('Error getting prompt groups', error);
return { message: 'Error getting prompt groups' };
}
},
getPromptGroupsWithPrompts: async (filter) => {
try {
return await PromptGroup.findOne(filter)
.populate({
path: 'prompts',
select: '-_id -__v -user',
})
.select('-_id -__v -user')
.lean();
} catch (error) {
logger.error('Error getting prompt groups', error);
return { message: 'Error getting prompt groups' };
}
},
getPromptGroup: async (filter) => {
try {
return await PromptGroup.findOne(filter).lean();
} catch (error) {
logger.error('Error getting prompt group', error);
return { message: 'Error getting prompt group' };
}
},
/**
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
*
* @param {Object} options - The options for deleting the prompt.
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
* @param {ObjectId|string} options.author - The ID of the prompt's author.
* @param {string} options.role - The role of the prompt's author.
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
*/
deletePrompt: async ({ promptId, groupId, author, role }) => {
const query = { _id: promptId, groupId, author };
if (role === SystemRoles.ADMIN) {
delete query.author;
}
const { deletedCount } = await Prompt.deleteOne(query);
if (deletedCount === 0) {
throw new Error('Failed to delete the prompt');
}
const remainingPrompts = await Prompt.find({ groupId })
.select('_id')
.sort({ createdAt: 1 })
.lean();
if (remainingPrompts.length === 0) {
await PromptGroup.deleteOne({ _id: groupId });
await removeGroupFromAllProjects(groupId);
return {
prompt: 'Prompt deleted successfully',
promptGroup: {
message: 'Prompt group deleted successfully',
id: groupId,
},
};
} else {
const promptGroup = await PromptGroup.findById(groupId).lean();
if (promptGroup.productionId.toString() === promptId.toString()) {
await PromptGroup.updateOne(
{ _id: groupId },
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
);
}
return { prompt: 'Prompt deleted successfully' };
}
},
/**
* Update prompt group
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
* @param {Partial<MongoPromptGroup>} data - Data to update
* @returns {Promise<TUpdatePromptGroupResponse>}
*/
updatePromptGroup: async (filter, data) => {
try {
const updateOps = {};
if (data.removeProjectIds) {
for (const projectId of data.removeProjectIds) {
await removeGroupIdsFromProject(projectId, [filter._id]);
}
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
delete data.removeProjectIds;
}
if (data.projectIds) {
for (const projectId of data.projectIds) {
await addGroupIdsToProject(projectId, [filter._id]);
}
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
delete data.projectIds;
}
const updateData = { ...data, ...updateOps };
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
new: true,
upsert: false,
});
if (!updatedDoc) {
throw new Error('Prompt group not found');
}
return updatedDoc;
} catch (error) {
logger.error('Error updating prompt group', error);
return { message: 'Error updating prompt group' };
}
},
/**
* Function to make a prompt production based on its ID.
* @param {String} promptId - The ID of the prompt to make production.
* @returns {Object} The result of the production operation.
*/
makePromptProduction: async (promptId) => {
try {
const prompt = await Prompt.findById(promptId).lean();
if (!prompt) {
throw new Error('Prompt not found');
}
await PromptGroup.findByIdAndUpdate(
prompt.groupId,
{ productionId: prompt._id },
{ new: true },
)
.lean()
.exec();
return {
message: 'Prompt production made successfully',
};
} catch (error) {
logger.error('Error making prompt production', error);
return { message: 'Error making prompt production' };
}
},
updatePromptLabels: async (_id, labels) => {
try {
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
if (response.matchedCount === 0) {
return { message: 'Prompt not found' };
}
return { message: 'Prompt labels updated successfully' };
} catch (error) {
logger.error('Error updating prompt labels', error);
return { message: 'Error updating prompt labels' };
}
},
deletePromptGroup: async (_id) => {
try {
const response = await PromptGroup.deleteOne({ _id });
if (response.deletedCount === 0) {
return { promptGroup: 'Prompt group not found' };
}
await Prompt.deleteMany({ groupId: new ObjectId(_id) });
await removeGroupFromAllProjects(_id);
return { promptGroup: 'Prompt group deleted successfully' };
} catch (error) {
logger.error('Error deleting prompt group', error);
return { message: 'Error deleting prompt group' };
logger.error('Error deleting prompts', error);
return { prompt: 'Error deleting prompts' };
}
},
};

View File

@@ -1,86 +0,0 @@
const { SystemRoles, CacheKeys, roleDefaults } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const Role = require('~/models/schema/roleSchema');
/**
* Retrieve a role by name and convert the found role document to a plain object.
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
*
* @param {string} roleName - The name of the role to find or create.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<Object>} A plain object representing the role document.
*/
const getRoleByName = async function (roleName, fieldsToSelect = null) {
try {
const cache = getLogStores(CacheKeys.ROLES);
const cachedRole = await cache.get(roleName);
if (cachedRole) {
return cachedRole;
}
let query = Role.findOne({ name: roleName });
if (fieldsToSelect) {
query = query.select(fieldsToSelect);
}
let role = await query.lean().exec();
if (!role && SystemRoles[roleName]) {
role = roleDefaults[roleName];
role = await new Role(role).save();
await cache.set(roleName, role);
return role.toObject();
}
await cache.set(roleName, role);
return role;
} catch (error) {
throw new Error(`Failed to retrieve or create role: ${error.message}`);
}
};
/**
* Update role values by name.
*
* @param {string} roleName - The name of the role to update.
* @param {Partial<TRole>} updates - The fields to update.
* @returns {Promise<TRole>} Updated role document.
*/
const updateRoleByName = async function (roleName, updates) {
try {
const cache = getLogStores(CacheKeys.ROLES);
const role = await Role.findOneAndUpdate(
{ name: roleName },
{ $set: updates },
{ new: true, lean: true },
)
.select('-__v')
.lean()
.exec();
await cache.set(roleName, role);
return role;
} catch (error) {
throw new Error(`Failed to update role: ${error.message}`);
}
};
/**
* Initialize default roles in the system.
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
*
* @returns {Promise<void>}
*/
const initializeRoles = async function () {
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
for (const roleName of defaultRoles) {
let role = await Role.findOne({ name: roleName }).select('name').lean();
if (!role) {
role = new Role(roleDefaults[roleName]);
await role.save();
}
}
};
module.exports = {
getRoleByName,
initializeRoles,
updateRoleByName,
};

View File

@@ -22,7 +22,7 @@ module.exports = {
return share;
} catch (error) {
logger.error('[getShare] Error getting share link', error);
throw new Error('Error getting share link');
return { message: 'Error getting share link' };
}
},
@@ -41,17 +41,17 @@ module.exports = {
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
} catch (error) {
logger.error('[getShareByPage] Error getting shares', error);
throw new Error('Error getting shares');
return { message: 'Error getting shares' };
}
},
createSharedLink: async (user, { conversationId, ...shareData }) => {
try {
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (share) {
return share;
}
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (share) {
return share;
}
try {
const shareId = crypto.randomUUID();
const messages = await getMessages({ conversationId });
const update = { ...shareData, shareId, messages, user };
@@ -60,58 +60,30 @@ module.exports = {
upsert: true,
});
} catch (error) {
logger.error('[createSharedLink] Error creating shared link', error);
throw new Error('Error creating shared link');
logger.error('[saveShareMessage] Error saving conversation', error);
return { message: 'Error saving conversation' };
}
},
updateSharedLink: async (user, { conversationId, ...shareData }) => {
try {
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (!share) {
return { message: 'Share not found' };
}
// update messages to the latest
const messages = await getMessages({ conversationId });
const update = { ...shareData, messages, user };
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: false,
});
} catch (error) {
logger.error('[updateSharedLink] Error updating shared link', error);
throw new Error('Error updating shared link');
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
if (!share) {
return { message: 'Share not found' };
}
// update messages to the latest
const messages = await getMessages({ conversationId });
const update = { ...shareData, messages, user };
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: false,
});
},
deleteSharedLink: async (user, { shareId }) => {
try {
const share = await SharedLink.findOne({ shareId, user });
if (!share) {
return { message: 'Share not found' };
}
return await SharedLink.findOneAndDelete({ shareId, user });
} catch (error) {
logger.error('[deleteSharedLink] Error deleting shared link', error);
throw new Error('Error deleting shared link');
}
},
/**
* Deletes all shared links for a specific user.
* @param {string} user - The user ID.
* @returns {Promise<{ message: string, deletedCount?: number }>} A result object indicating success or error message.
*/
deleteAllSharedLinks: async (user) => {
try {
const result = await SharedLink.deleteMany({ user });
return {
message: 'All shared links have been deleted successfully',
deletedCount: result.deletedCount,
};
} catch (error) {
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
throw new Error('Error deleting shared links');
const share = await SharedLink.findOne({ shareId, user });
if (!share) {
return { message: 'Share not found' };
}
return await SharedLink.findOneAndDelete({ shareId, user });
},
};

View File

@@ -1,5 +1,61 @@
const mongoose = require('mongoose');
const userSchema = require('~/models/schema/userSchema');
const bcrypt = require('bcryptjs');
const signPayload = require('../server/services/signPayload');
const userSchema = require('./schema/userSchema.js');
const { SESSION_EXPIRY } = process.env ?? {};
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
userSchema.methods.toJSON = function () {
return {
id: this._id,
provider: this.provider,
email: this.email,
name: this.name,
username: this.username,
avatar: this.avatar,
role: this.role,
emailVerified: this.emailVerified,
plugins: this.plugins,
createdAt: this.createdAt,
updatedAt: this.updatedAt,
};
};
userSchema.methods.generateToken = async function () {
return await signPayload({
payload: {
id: this._id,
username: this.username,
provider: this.provider,
email: this.email,
},
secret: process.env.JWT_SECRET,
expirationTime: expires / 1000,
});
};
userSchema.methods.comparePassword = function (candidatePassword, callback) {
bcrypt.compare(candidatePassword, this.password, (err, isMatch) => {
if (err) {
return callback(err);
}
callback(null, isMatch);
});
};
module.exports.hashPassword = async (password) => {
const hashedPassword = await new Promise((resolve, reject) => {
bcrypt.hash(password, 10, function (err, hash) {
if (err) {
reject(err);
} else {
resolve(hash);
}
});
});
return hashedPassword;
};
const User = mongoose.model('User', userSchema);

View File

@@ -6,18 +6,9 @@ const {
deleteMessagesSince,
deleteMessages,
} = require('./Message');
const {
comparePassword,
deleteUserById,
generateToken,
getUserById,
updateUser,
createUser,
countUsers,
findUser,
} = require('./userMethods');
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
const { hashPassword, getUser, updateUser } = require('./userMethods');
const {
findFileById,
createFile,
@@ -38,14 +29,9 @@ module.exports = {
Session,
Balance,
comparePassword,
deleteUserById,
generateToken,
getUserById,
countUsers,
createUser,
hashPassword,
updateUser,
findUser,
getUser,
getMessages,
saveMessage,

View File

@@ -1,19 +0,0 @@
const mongoose = require('mongoose');
const Schema = mongoose.Schema;
const categoriesSchema = new Schema({
label: {
type: String,
required: true,
unique: true,
},
value: {
type: String,
required: true,
unique: true,
},
});
const categories = mongoose.model('categories', categoriesSchema);
module.exports = { Categories: categories };

View File

@@ -3,9 +3,9 @@ const mongoose = require('mongoose');
/**
* @typedef {Object} MongoFile
* @property {ObjectId} [_id] - MongoDB Document ID
* @property {mongoose.Schema.Types.ObjectId} [_id] - MongoDB Document ID
* @property {number} [__v] - MongoDB Version Key
* @property {ObjectId} user - User ID
* @property {mongoose.Schema.Types.ObjectId} user - User ID
* @property {string} [conversationId] - Optional conversation ID
* @property {string} file_id - File identifier
* @property {string} [temp_file_id] - Temporary File identifier
@@ -14,19 +14,17 @@ const mongoose = require('mongoose');
* @property {string} filepath - Location of the file
* @property {'file'} object - Type of object, always 'file'
* @property {string} type - Type of file
* @property {number} [usage=0] - Number of uses of the file
* @property {number} usage - Number of uses of the file
* @property {string} [context] - Context of the file origin
* @property {boolean} [embedded=false] - Whether or not the file is embedded in vector db
* @property {boolean} [embedded] - Whether or not the file is embedded in vector db
* @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting)
* @property {string} [source] - The source of the file (e.g., from FileSources)
* @property {string} [source] - The source of the file
* @property {number} [width] - Optional width of the file
* @property {number} [height] - Optional height of the file
* @property {Date} [expiresAt] - Optional expiration date of the file
* @property {Date} [expiresAt] - Optional height of the file
* @property {Date} [createdAt] - Date when the file was created
* @property {Date} [updatedAt] - Date when the file was updated
*/
/** @type {MongooseSchema<MongoFile>} */
const fileSchema = mongoose.Schema(
{
user: {
@@ -93,7 +91,7 @@ const fileSchema = mongoose.Schema(
height: Number,
expiresAt: {
type: Date,
expires: 3600, // 1 hour in seconds
expires: 3600,
},
},
{

View File

@@ -11,7 +11,6 @@ const messageSchema = mongoose.Schema(
},
conversationId: {
type: String,
index: true,
required: true,
meiliIndex: true,
},

View File

@@ -1,30 +0,0 @@
const { Schema } = require('mongoose');
/**
* @typedef {Object} MongoProject
* @property {ObjectId} [_id] - MongoDB Document ID
* @property {string} name - The name of the project
* @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project
* @property {Date} [createdAt] - Date when the project was created (added by timestamps)
* @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps)
*/
const projectSchema = new Schema(
{
name: {
type: String,
required: true,
index: true,
},
promptGroupIds: {
type: [Schema.Types.ObjectId],
ref: 'PromptGroup',
default: [],
},
},
{
timestamps: true,
},
);
module.exports = projectSchema;

View File

@@ -1,118 +0,0 @@
const mongoose = require('mongoose');
const { Constants } = require('librechat-data-provider');
const Schema = mongoose.Schema;
/**
* @typedef {Object} MongoPromptGroup
* @property {ObjectId} [_id] - MongoDB Document ID
* @property {string} name - The name of the prompt group
* @property {ObjectId} author - The author of the prompt group
* @property {ObjectId} [projectId=null] - The project ID of the prompt group
* @property {ObjectId} [productionId=null] - The project ID of the prompt group
* @property {string} authorName - The name of the author of the prompt group
* @property {number} [numberOfGenerations=0] - Number of generations the prompt group has
* @property {string} [oneliner=''] - Oneliner description of the prompt group
* @property {string} [category=''] - Category of the prompt group
* @property {string} [command] - Command for the prompt group
* @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps)
* @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps)
*/
const promptGroupSchema = new Schema(
{
name: {
type: String,
required: true,
index: true,
},
numberOfGenerations: {
type: Number,
default: 0,
},
oneliner: {
type: String,
default: '',
},
category: {
type: String,
default: '',
index: true,
},
projectIds: {
type: [Schema.Types.ObjectId],
ref: 'Project',
index: true,
},
productionId: {
type: Schema.Types.ObjectId,
ref: 'Prompt',
required: true,
index: true,
},
author: {
type: Schema.Types.ObjectId,
ref: 'User',
required: true,
index: true,
},
authorName: {
type: String,
required: true,
},
command: {
type: String,
index: true,
validate: {
validator: function (v) {
return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v);
},
message: (props) =>
`${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`,
},
maxlength: [
Constants.COMMANDS_MAX_LENGTH,
`Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`,
],
},
},
{
timestamps: true,
},
);
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
const promptSchema = new Schema(
{
groupId: {
type: Schema.Types.ObjectId,
ref: 'PromptGroup',
required: true,
index: true,
},
author: {
type: Schema.Types.ObjectId,
ref: 'User',
required: true,
},
prompt: {
type: String,
required: true,
},
type: {
type: String,
enum: ['text', 'chat'],
required: true,
},
},
{
timestamps: true,
},
);
const Prompt = mongoose.model('Prompt', promptSchema);
promptSchema.index({ createdAt: 1, updatedAt: 1 });
promptGroupSchema.index({ createdAt: 1, updatedAt: 1 });
module.exports = { Prompt, PromptGroup };

View File

@@ -1,29 +0,0 @@
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const mongoose = require('mongoose');
const roleSchema = new mongoose.Schema({
name: {
type: String,
required: true,
unique: true,
index: true,
},
[PermissionTypes.PROMPTS]: {
[Permissions.SHARED_GLOBAL]: {
type: Boolean,
default: false,
},
[Permissions.USE]: {
type: Boolean,
default: true,
},
[Permissions.CREATE]: {
type: Boolean,
default: true,
},
},
});
const Role = mongoose.model('Role', roleSchema);
module.exports = Role;

View File

@@ -7,9 +7,6 @@ const tokenSchema = new Schema({
required: true,
ref: 'user',
},
email: {
type: String,
},
token: {
type: String,
required: true,

View File

@@ -1,36 +1,5 @@
const mongoose = require('mongoose');
const { SystemRoles } = require('librechat-data-provider');
/**
* @typedef {Object} MongoSession
* @property {string} [refreshToken] - The refresh token
*/
/**
* @typedef {Object} MongoUser
* @property {ObjectId} [_id] - MongoDB Document ID
* @property {string} [name] - The user's name
* @property {string} [username] - The user's username, in lowercase
* @property {string} email - The user's email address
* @property {boolean} emailVerified - Whether the user's email is verified
* @property {string} [password] - The user's password, trimmed with 8-128 characters
* @property {string} [avatar] - The URL of the user's avatar
* @property {string} provider - The provider of the user's account (e.g., 'local', 'google')
* @property {string} [role='USER'] - The role of the user
* @property {string} [googleId] - Optional Google ID for the user
* @property {string} [facebookId] - Optional Facebook ID for the user
* @property {string} [openidId] - Optional OpenID ID for the user
* @property {string} [ldapId] - Optional LDAP ID for the user
* @property {string} [githubId] - Optional GitHub ID for the user
* @property {string} [discordId] - Optional Discord ID for the user
* @property {Array} [plugins=[]] - List of plugins used by the user
* @property {Array.<MongoSession>} [refreshToken] - List of sessions with refresh tokens
* @property {Date} [expiresAt] - Optional expiration date of the file
* @property {Date} [createdAt] - Date when the user was created (added by timestamps)
* @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps)
*/
/** @type {MongooseSchema<MongoSession>} */
const Session = mongoose.Schema({
refreshToken: {
type: String,
@@ -38,7 +7,6 @@ const Session = mongoose.Schema({
},
});
/** @type {MongooseSchema<MongoUser>} */
const userSchema = mongoose.Schema(
{
name: {
@@ -79,7 +47,7 @@ const userSchema = mongoose.Schema(
},
role: {
type: String,
default: SystemRoles.USER,
default: 'USER',
},
googleId: {
type: String,
@@ -96,11 +64,6 @@ const userSchema = mongoose.Schema(
unique: true,
sparse: true,
},
ldapId: {
type: String,
unique: true,
sparse: true,
},
githubId: {
type: String,
unique: true,
@@ -118,10 +81,6 @@ const userSchema = mongoose.Schema(
refreshToken: {
type: [Session],
},
expiresAt: {
type: Date,
expires: 604800, // 7 days in seconds
},
},
{ timestamps: true },
);

View File

@@ -17,7 +17,6 @@ const tokenValues = {
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
'claude-3-opus': { prompt: 15, completion: 75 },
'claude-3-sonnet': { prompt: 3, completion: 15 },
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
'claude-2.1': { prompt: 8, completion: 24 },
'claude-2': { prompt: 8, completion: 24 },

View File

@@ -48,13 +48,6 @@ describe('getValueKey', () => {
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
});
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
expect(getValueKey('claude-3-5-sonnet-turbo')).toBe('claude-3-5-sonnet');
expect(getValueKey('claude-3-5-sonnet-0125')).toBe('claude-3-5-sonnet');
});
});
describe('getMultiplier', () => {

View File

@@ -1,37 +1,28 @@
const bcrypt = require('bcryptjs');
const signPayload = require('~/server/services/signPayload');
const User = require('./User');
const hashPassword = async (password) => {
const hashedPassword = await new Promise((resolve, reject) => {
bcrypt.hash(password, 10, function (err, hash) {
if (err) {
reject(err);
} else {
resolve(hash);
}
});
});
return hashedPassword;
};
/**
* Retrieve a user by ID and convert the found user document to a plain object.
*
* @param {string} userId - The ID of the user to find and return as a plain object.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
* @returns {Promise<Object>} A plain object representing the user document, or `null` if no user is found.
*/
const getUserById = async function (userId, fieldsToSelect = null) {
const query = User.findById(userId);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
/**
* Search for a single user based on partial data and return matching user document as plain object.
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
*/
const findUser = async function (searchCriteria, fieldsToSelect = null) {
const query = User.findOne(searchCriteria);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
const getUser = async function (userId) {
return await User.findById(userId).lean();
};
/**
@@ -39,127 +30,17 @@ const findUser = async function (searchCriteria, fieldsToSelect = null) {
*
* @param {string} userId - The ID of the user to update.
* @param {Object} updateData - An object containing the properties to update.
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
* @returns {Promise<Object>} The updated user document as a plain object, or `null` if no user is found.
*/
const updateUser = async function (userId, updateData) {
const updateOperation = {
$set: updateData,
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
};
return await User.findByIdAndUpdate(userId, updateOperation, {
return await User.findByIdAndUpdate(userId, updateData, {
new: true,
runValidators: true,
}).lean();
};
/**
* Creates a new user, optionally with a TTL of 1 week.
* @param {MongoUser} data - The user data to be created, must contain user_id.
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
* @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`.
* @returns {Promise<ObjectId>} A promise that resolves to the created user document ID.
* @throws {Error} If a user with the same user_id already exists.
*/
const createUser = async (data, disableTTL = true, returnUser = false) => {
const userData = {
...data,
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
};
if (disableTTL) {
delete userData.expiresAt;
}
const user = await User.create(userData);
if (returnUser) {
return user.toObject();
}
return user._id;
};
/**
* Count the number of user documents in the collection based on the provided filter.
*
* @param {Object} [filter={}] - The filter to apply when counting the documents.
* @returns {Promise<number>} The count of documents that match the filter.
*/
const countUsers = async function (filter = {}) {
return await User.countDocuments(filter);
};
/**
* Delete a user by their unique ID.
*
* @param {string} userId - The ID of the user to delete.
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
*/
const deleteUserById = async function (userId) {
try {
const result = await User.deleteOne({ _id: userId });
if (result.deletedCount === 0) {
return { deletedCount: 0, message: 'No user found with that ID.' };
}
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
} catch (error) {
throw new Error('Error deleting user: ' + error.message);
}
};
const { SESSION_EXPIRY } = process.env ?? {};
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
/**
* Generates a JWT token for a given user.
*
* @param {MongoUser} user - ID of the user for whom the token is being generated.
* @returns {Promise<string>} A promise that resolves to a JWT token.
*/
const generateToken = async (user) => {
if (!user) {
throw new Error('No user provided');
}
return await signPayload({
payload: {
id: user._id,
username: user.username,
provider: user.provider,
email: user.email,
},
secret: process.env.JWT_SECRET,
expirationTime: expires / 1000,
});
};
/**
* Compares the provided password with the user's password.
*
* @param {MongoUser} user - the user to compare password for.
* @param {string} candidatePassword - The password to test against the user's password.
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
*/
const comparePassword = async (user, candidatePassword) => {
if (!user) {
throw new Error('No user provided');
}
return new Promise((resolve, reject) => {
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
if (err) {
reject(err);
}
resolve(isMatch);
});
});
};
module.exports = {
comparePassword,
deleteUserById,
generateToken,
getUserById,
countUsers,
createUser,
hashPassword,
updateUser,
findUser,
getUser,
};

View File

@@ -1,6 +1,6 @@
{
"name": "@librechat/backend",
"version": "0.7.4-rc1",
"version": "0.7.2",
"description": "",
"scripts": {
"start": "echo 'please run this from the root directory'",
@@ -40,7 +40,8 @@
"@keyv/redis": "^2.8.1",
"@langchain/community": "^0.0.46",
"@langchain/google-genai": "^0.0.11",
"@langchain/google-vertexai": "^0.0.17",
"@langchain/google-vertexai": "^0.0.5",
"agenda": "^5.0.0",
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
@@ -85,7 +86,6 @@
"passport-github2": "^0.1.12",
"passport-google-oauth20": "^2.0.0",
"passport-jwt": "^4.0.1",
"passport-ldapauth": "^3.0.1",
"passport-local": "^1.0.0",
"pino": "^8.12.1",
"sharp": "^0.32.6",
@@ -94,7 +94,6 @@
"ua-parser-js": "^1.0.36",
"winston": "^3.11.0",
"winston-daily-rotate-file": "^4.7.1",
"ws": "^8.17.0",
"zod": "^3.22.4"
},
"devDependencies": {

View File

@@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { saveMessage, getConvo } = require('~/models');
const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => {
@@ -18,7 +18,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
@@ -35,8 +34,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@@ -77,7 +74,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@@ -85,7 +81,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
const { abortController, onStart } = createAbortController(req, res, getAbortData);
res.on('close', () => {
logger.debug('[AskController] Request closed');
@@ -109,12 +105,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
getReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
onProgress: progressCallback.call(null, {
res,
text,
// parentMessageId: overrideParentMessageId || userMessageId,
},
parentMessageId: overrideParentMessageId || userMessageId,
}),
};
let response = await client.sendMessage(text, messageOptions);
@@ -125,7 +120,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
response.endpoint = endpointOption.endpoint;
const { conversation = {} } = await client.responsePromise;
const conversation = await getConvo(user, conversationId);
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
@@ -148,9 +143,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
await saveMessage({ ...response, user });
}
if (!client.skipSaveUserMessage) {
await saveMessage(userMessage);
}
await saveMessage(userMessage);
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {

View File

@@ -1,29 +1,45 @@
const crypto = require('crypto');
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const { Session, User } = require('~/models');
const {
registerUser,
resetPassword,
setAuthTokens,
requestPasswordReset,
} = require('~/server/services/AuthService');
const { Session, getUserById } = require('~/models');
const { logger } = require('~/config');
const registrationController = async (req, res) => {
try {
const response = await registerUser(req.body);
const { status, message } = response;
res.status(status).send({ message });
if (response.status === 200) {
const { status, user } = response;
let newUser = await User.findOne({ _id: user._id });
if (!newUser) {
newUser = new User(user);
await newUser.save();
}
const token = await setAuthTokens(user._id, res);
res.setHeader('Authorization', `Bearer ${token}`);
res.status(status).send({ user });
} else {
const { status, message } = response;
res.status(status).send({ message });
}
} catch (err) {
logger.error('[registrationController]', err);
return res.status(500).json({ message: err.message });
}
};
const getUserController = async (req, res) => {
return res.status(200).send(req.user);
};
const resetPasswordRequestController = async (req, res) => {
try {
const resetService = await requestPasswordReset(req);
const resetService = await requestPasswordReset(req.body.email);
if (resetService instanceof Error) {
return res.status(400).json(resetService);
} else {
@@ -61,7 +77,7 @@ const refreshController = async (req, res) => {
try {
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
const user = await getUserById(payload.id, '-password -__v');
const user = await User.findOne({ _id: payload.id });
if (!user) {
return res.status(401).redirect('/login');
}
@@ -70,7 +86,8 @@ const refreshController = async (req, res) => {
if (process.env.NODE_ENV === 'CI') {
const token = await setAuthTokens(userId, res);
return res.status(200).send({ token, user });
const userObj = user.toJSON();
return res.status(200).send({ token, user: userObj });
}
// Hash the refresh token
@@ -81,7 +98,8 @@ const refreshController = async (req, res) => {
const session = await Session.findOne({ user: userId, refreshTokenHash: hashedToken });
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session._id);
res.status(200).send({ token, user });
const userObj = user.toJSON();
res.status(200).send({ token, user: userObj });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)
res.status(403).send('No session found');
@@ -97,6 +115,7 @@ const refreshController = async (req, res) => {
};
module.exports = {
getUserController,
refreshController,
registrationController,
resetPasswordController,

View File

@@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { saveMessage, getConvo } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
@@ -27,7 +27,6 @@ const EditController = async (req, res, next, initializeClient) => {
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
@@ -41,8 +40,6 @@ const EditController = async (req, res, next, initializeClient) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@@ -76,7 +73,6 @@ const EditController = async (req, res, next, initializeClient) => {
const getAbortData = () => ({
conversationId,
userMessagePromise,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
@@ -85,7 +81,7 @@ const EditController = async (req, res, next, initializeClient) => {
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
const { abortController, onStart } = createAbortController(req, res, getAbortData);
res.on('close', () => {
logger.debug('[EditController] Request closed');
@@ -116,15 +112,14 @@ const EditController = async (req, res, next, initializeClient) => {
getReqData,
onStart,
abortController,
progressCallback,
progressOptions: {
onProgress: progressCallback.call(null, {
res,
text,
// parentMessageId: overrideParentMessageId || userMessageId,
},
parentMessageId: overrideParentMessageId || userMessageId,
}),
});
const { conversation = {} } = await client.responsePromise;
const conversation = await getConvo(user, conversationId);
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View File

@@ -1,37 +1,11 @@
const {
Session,
Balance,
getFiles,
deleteFiles,
deleteConvos,
deletePresets,
deleteMessages,
deleteUserById,
} = require('~/models');
const { updateUserPluginsService } = require('~/server/services/UserService');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { deleteAllSharedLinks } = require('~/models/Share');
const { Transaction } = require('~/models/Transaction');
const { logger } = require('~/config');
const getUserController = async (req, res) => {
res.status(200).send(req.user);
};
const deleteUserFiles = async (req) => {
try {
const userFiles = await getFiles({ user: req.user.id });
await processDeleteRequest({
req,
files: userFiles,
});
} catch (error) {
logger.error('[deleteUserFiles]', error);
}
};
const updateUserPluginsController = async (req, res) => {
const { user } = req;
const { pluginKey, action, auth, isAssistantTool } = req.body;
@@ -75,68 +49,11 @@ const updateUserPluginsController = async (req, res) => {
res.status(200).send();
} catch (err) {
logger.error('[updateUserPluginsController]', err);
return res.status(500).json({ message: 'Something went wrong.' });
}
};
const deleteUserController = async (req, res) => {
const { user } = req;
try {
await deleteMessages({ user: user.id }); // delete user messages
await Session.deleteMany({ user: user.id }); // delete user sessions
await Transaction.deleteMany({ user: user.id }); // delete user transactions
await deleteUserKey({ userId: user.id, all: true }); // delete user keys
await Balance.deleteMany({ user: user._id }); // delete user balances
await deletePresets(user.id); // delete user presets
/* TODO: Delete Assistant Threads */
await deleteConvos(user.id); // delete user convos
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
await deleteUserById(user.id); // delete user
await deleteAllSharedLinks(user.id); // delete user shared links
await deleteUserFiles(req); // delete user files
await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps
/* TODO: queue job for cleaning actions and assistants of non-existant users */
logger.info(`User deleted account. Email: ${user.email} ID: ${user.id}`);
res.status(200).send({ message: 'User deleted' });
} catch (err) {
logger.error('[deleteUserController]', err);
return res.status(500).json({ message: 'Something went wrong.' });
}
};
const verifyEmailController = async (req, res) => {
try {
const verifyEmailService = await verifyEmail(req);
if (verifyEmailService instanceof Error) {
return res.status(400).json(verifyEmailService);
} else {
return res.status(200).json(verifyEmailService);
}
} catch (e) {
logger.error('[verifyEmailController]', e);
return res.status(500).json({ message: 'Something went wrong.' });
}
};
const resendVerificationController = async (req, res) => {
try {
const result = await resendVerificationEmail(req);
if (result instanceof Error) {
return res.status(400).json(result);
} else {
return res.status(200).json(result);
}
} catch (e) {
logger.error('[verifyEmailController]', e);
return res.status(500).json({ message: 'Something went wrong.' });
res.status(500).json({ message: err.message });
}
};
module.exports = {
getUserController,
deleteUserController,
verifyEmailController,
updateUserPluginsController,
resendVerificationController,
};

View File

@@ -20,7 +20,6 @@ const {
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
@@ -32,14 +31,15 @@ const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
const { handleAbortError } = require('~/server/middleware');
const ten_minutes = 1000 * 60 * 10;
/**
* @route POST /
* @desc Chat with an assistant
* @access Public
* @param {object} req - The request object, containing the request data.
* @param {object} req.body - The request payload.
* @param {Express.Request} req - The request object, containing the request data.
* @param {Express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
@@ -60,6 +60,30 @@ const chatV1 = async (req, res) => {
parentMessageId: _parentId = Constants.NO_PARENT,
} = req.body;
/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals?.[endpoint];
if (assistantsConfig) {
const { supportedIds, excludedIds } = assistantsConfig;
const error = { message: 'Assistant not supported' };
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
error,
});
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
});
}
}
/** @type {OpenAIClient} */
let openai;
/** @type {string|undefined} - the current thread id */
@@ -287,7 +311,6 @@ const chatV1 = async (req, res) => {
});
openai = _openai;
await validateAuthor({ req, openai });
if (previousMessages.length) {
parentMessageId = previousMessages[previousMessages.length - 1].messageId;

View File

@@ -19,10 +19,9 @@ const {
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { createOnTextProgress } = require('~/server/services/AssistantService');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { StreamRunManager } = require('~/server/services/Runs');
const { getTransactions } = require('~/models/Transaction');
const checkBalance = require('~/models/checkBalance');
const { getConvo } = require('~/models/Conversation');
@@ -31,6 +30,8 @@ const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
const { handleAbortError } = require('~/server/middleware');
const ten_minutes = 1000 * 60 * 10;
/**
@@ -59,6 +60,30 @@ const chatV2 = async (req, res) => {
parentMessageId: _parentId = Constants.NO_PARENT,
} = req.body;
/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals?.[endpoint];
if (assistantsConfig) {
const { supportedIds, excludedIds } = assistantsConfig;
const error = { message: 'Assistant not supported' };
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
error,
});
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId: convoId,
messageId: v4(),
parentMessageId: _messageId,
});
}
}
/** @type {OpenAIClient} */
let openai;
/** @type {string|undefined} - the current thread id */
@@ -284,7 +309,6 @@ const chatV2 = async (req, res) => {
});
openai = _openai;
await validateAuthor({ req, openai });
if (previousMessages.length) {
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
@@ -447,39 +471,7 @@ const chatV2 = async (req, res) => {
/** @type {RunResponse | typeof StreamRunManager | undefined} */
let response;
const processRun = async (retry = false) => {
if (endpoint === EModelEndpoint.azureAssistants) {
body.model = openai._options.model;
openai.attachedFileIds = attachedFileIds;
if (retry) {
response = await runAssistant({
openai,
thread_id,
run_id,
in_progress: openai.in_progress,
});
return;
}
/* NOTE:
* By default, a Run will use the model and tools configuration specified in Assistant object,
* but you can override most of these when creating the Run for added flexibility:
*/
const run = await createRun({
openai,
thread_id,
body,
});
run_id = run.id;
await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes);
sendInitialResponse();
// todo: retry logic
response = await runAssistant({ openai, thread_id, run_id });
return;
}
const processRun = async () => {
/** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise<void>}} */
const handlers = {
[AssistantStreamEvents.ThreadRunCreated]: async (event) => {
@@ -496,7 +488,6 @@ const chatV2 = async (req, res) => {
handlers,
thread_id,
attachedFileIds,
parentMessageId: userMessageId,
responseMessage: openai.responseMessage,
// streamOptions: {
@@ -509,7 +500,6 @@ const chatV2 = async (req, res) => {
});
response = streamRunManager;
response.text = streamRunManager.intermediateText;
};
await processRun();
@@ -532,7 +522,6 @@ const chatV2 = async (req, res) => {
/** @type {ResponseMessage} */
const responseMessage = {
...(response.responseMessage ?? response.finalMessage),
text: response.text,
parentMessageId: userMessageId,
conversationId,
user: req.user.id,

View File

@@ -1,10 +1,4 @@
const {
CacheKeys,
SystemRoles,
EModelEndpoint,
defaultOrderQuery,
defaultAssistantsVersion,
} = require('librechat-data-provider');
const { EModelEndpoint, CacheKeys, defaultAssistantsVersion } = require('librechat-data-provider');
const {
initializeClient: initAzureClient,
} = require('~/server/services/Endpoints/azureAssistants');
@@ -41,7 +35,6 @@ const getCurrentVersion = async (req, endpoint) => {
* Initializes the client with the current request and response objects and lists assistants
* according to the query parameters. This function abstracts the logic for non-Azure paths.
*
* @deprecated
* @async
* @param {object} params - The parameters object.
* @param {object} params.req - The request object, used for initializing the client.
@@ -50,65 +43,11 @@ const getCurrentVersion = async (req, endpoint) => {
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
*/
const _listAssistants = async ({ req, res, version, query }) => {
const listAssistants = async ({ req, res, version, query }) => {
const { openai } = await getOpenAIClient({ req, res, version });
return openai.beta.assistants.list(query);
};
/**
* Fetches all assistants based on provided query params, until `has_more` is `false`.
*
* @async
* @param {object} params - The parameters object.
* @param {object} params.req - The request object, used for initializing the client.
* @param {object} params.res - The response object, used for initializing the client.
* @param {string} params.version - The API version to use.
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
*/
const listAllAssistants = async ({ req, res, version, query }) => {
/** @type {{ openai: OpenAIClient }} */
const { openai } = await getOpenAIClient({ req, res, version });
const allAssistants = [];
let first_id;
let last_id;
let afterToken = query.after;
let hasMore = true;
while (hasMore) {
const response = await openai.beta.assistants.list({
...query,
after: afterToken,
});
const { body } = response;
allAssistants.push(...body.data);
hasMore = body.has_more;
if (!first_id) {
first_id = body.first_id;
}
if (hasMore) {
afterToken = body.last_id;
} else {
last_id = body.last_id;
}
}
return {
data: allAssistants,
body: {
data: allAssistants,
has_more: false,
first_id,
last_id,
},
};
};
/**
* Asynchronously lists assistants for Azure configured groups.
*
@@ -143,7 +82,7 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
/* The specified model is only necessary to
fetch assistants for the shared instance */
req.body.model = currentModelTuples[0][0];
promises.push(listAllAssistants({ req, res, version, query }));
promises.push(listAssistants({ req, res, version, query }));
}
const resolvedQueries = await Promise.all(promises);
@@ -194,27 +133,8 @@ async function getOpenAIClient({ req, res, endpointOption, initAppClient, overri
return result;
}
/**
* Returns a list of assistants.
* @param {object} params
* @param {object} params.req - Express Request
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
* @param {object} params.res - Express Response
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
*/
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
const {
limit = 100,
order = 'desc',
after,
before,
endpoint,
} = req.query ?? {
endpoint: overrideEndpoint,
...defaultOrderQuery,
};
const fetchAssistants = async (req, res) => {
const { limit = 100, order = 'desc', after, before, endpoint } = req.query;
const version = await getCurrentVersion(req, endpoint);
const query = { limit, order, after, before };
@@ -222,47 +142,15 @@ const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
let body;
if (endpoint === EModelEndpoint.assistants) {
({ body } = await listAllAssistants({ req, res, version, query }));
({ body } = await listAssistants({ req, res, version, query }));
} else if (endpoint === EModelEndpoint.azureAssistants) {
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
}
if (req.user.role === SystemRoles.ADMIN) {
return body;
} else if (!req.app.locals[endpoint]) {
return body;
}
body.data = filterAssistants({
userId: req.user.id,
assistants: body.data,
assistantsConfig: req.app.locals[endpoint],
});
return body;
};
/**
* Filter assistants based on configuration.
*
* @param {object} params - The parameters object.
* @param {string} params.userId - The user ID to filter private assistants.
* @param {Assistant[]} params.assistants - The list of assistants to filter.
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
* @returns {Assistant[]} - The filtered list of assistants.
*/
function filterAssistants({ assistants, userId, assistantsConfig }) {
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
if (privateAssistants) {
return assistants.filter((assistant) => userId === assistant.metadata?.author);
} else if (supportedIds?.length) {
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
} else if (excludedIds?.length) {
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
}
return assistants;
}
module.exports = {
getOpenAIClient,
fetchAssistants,

View File

@@ -1,9 +1,8 @@
const { FileContext } = require('librechat-data-provider');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { deleteAssistantActions } = require('~/server/services/ActionService');
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
const { uploadImageBuffer } = require('~/server/services/Files/process');
const { updateAssistant, getAssistants } = require('~/models/Assistant');
const { getOpenAIClient, fetchAssistants } = require('./helpers');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
@@ -41,11 +40,9 @@ const createAssistant = async (req, res) => {
};
const assistant = await openai.beta.assistants.create(assistantData);
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
if (azureModelIdentifier) {
assistant.model = azureModelIdentifier;
}
await promise;
logger.debug('/assistants/', assistant);
res.status(201).json(assistant);
} catch (error) {
@@ -64,6 +61,7 @@ const retrieveAssistant = async (req, res) => {
try {
/* NOTE: not actually being used right now */
const { openai } = await getOpenAIClient({ req, res });
const assistant_id = req.params.id;
const assistant = await openai.beta.assistants.retrieve(assistant_id);
res.json(assistant);
@@ -85,7 +83,6 @@ const retrieveAssistant = async (req, res) => {
const patchAssistant = async (req, res) => {
try {
const { openai } = await getOpenAIClient({ req, res });
await validateAuthor({ req, openai });
const assistant_id = req.params.id;
const { endpoint: _e, ...updateData } = req.body;
@@ -122,7 +119,6 @@ const patchAssistant = async (req, res) => {
const deleteAssistant = async (req, res) => {
try {
const { openai } = await getOpenAIClient({ req, res });
await validateAuthor({ req, openai });
const assistant_id = req.params.id;
const deletionStatus = await openai.beta.assistants.del(assistant_id);
@@ -145,7 +141,19 @@ const deleteAssistant = async (req, res) => {
*/
const listAssistants = async (req, res) => {
try {
const body = await fetchAssistants({ req, res });
const body = await fetchAssistants(req, res);
if (req.app.locals?.[req.query.endpoint]) {
/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals[req.query.endpoint];
const { supportedIds, excludedIds } = assistantsConfig;
if (supportedIds?.length) {
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
} else if (excludedIds?.length) {
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
}
}
res.json(body);
} catch (error) {
logger.error('[/assistants] Error listing assistants', error);
@@ -187,7 +195,6 @@ const uploadAssistantAvatar = async (req, res) => {
let { metadata: _metadata = '{}' } = req.body;
const { openai } = await getOpenAIClient({ req, res });
await validateAuthor({ req, openai });
const image = await uploadImageBuffer({
req,
@@ -222,7 +229,7 @@ const uploadAssistantAvatar = async (req, res) => {
const promises = [];
promises.push(
updateAssistantDoc(
updateAssistant(
{ assistant_id },
{
avatar: {

View File

@@ -1,7 +1,5 @@
const { ToolCallTypes } = require('librechat-data-provider');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { validateAndUpdateTool } = require('~/server/services/ActionService');
const { updateAssistantDoc } = require('~/models/Assistant');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
@@ -39,11 +37,9 @@ const createAssistant = async (req, res) => {
};
const assistant = await openai.beta.assistants.create(assistantData);
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
if (azureModelIdentifier) {
assistant.model = azureModelIdentifier;
}
await promise;
logger.debug('/assistants/', assistant);
res.status(201).json(assistant);
} catch (error) {
@@ -62,7 +58,6 @@ const createAssistant = async (req, res) => {
* @returns {Promise<Assistant>} The updated assistant.
*/
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
await validateAuthor({ req, openai });
const tools = [];
let hasFileSearch = false;

View File

@@ -1,22 +1,26 @@
const User = require('~/models/User');
const { setAuthTokens } = require('~/server/services/AuthService');
const { logger } = require('~/config');
const loginController = async (req, res) => {
try {
if (!req.user) {
const user = await User.findById(req.user._id);
// If user doesn't exist, return error
if (!user) {
// typeof user !== User) { // this doesn't seem to resolve the User type ??
return res.status(400).json({ message: 'Invalid credentials' });
}
const { password: _, __v, ...user } = req.user;
user.id = user._id.toString();
const token = await setAuthTokens(req.user._id, res);
const token = await setAuthTokens(user._id, res);
return res.status(200).send({ token, user });
} catch (err) {
logger.error('[loginController]', err);
return res.status(500).json({ message: 'Something went wrong' });
}
// Generic error messages are safer
return res.status(500).json({ message: 'Something went wrong' });
};
module.exports = {

View File

@@ -6,16 +6,16 @@ const axios = require('axios');
const express = require('express');
const passport = require('passport');
const mongoSanitize = require('express-mongo-sanitize');
const { jwtLogin, passportLogin } = require('~/strategies');
const { connectDb, indexSync } = require('~/lib/db');
const { isEnabled } = require('~/server/utils');
const { ldapLogin } = require('~/strategies');
const { logger } = require('~/config');
const validateImageRequest = require('./middleware/validateImageRequest');
const errorController = require('./controllers/ErrorController');
const { jwtLogin, passportLogin } = require('~/strategies');
const configureSocialLogins = require('./socialLogins');
const { connectDb, indexSync } = require('~/lib/db');
const AppService = require('./services/AppService');
const noIndex = require('./middleware/noIndex');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const routes = require('./routes');
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
@@ -60,11 +60,6 @@ const startServer = async () => {
passport.use(await jwtLogin());
passport.use(passportLogin());
// LDAP Auth
if (process.env.LDAP_URL && process.env.LDAP_USER_SEARCH_BASE) {
passport.use(ldapLogin);
}
if (isEnabled(ALLOW_SOCIAL_LOGIN)) {
configureSocialLogins(app);
}
@@ -81,7 +76,6 @@ const startServer = async () => {
app.use('/api/convos', routes.convos);
app.use('/api/presets', routes.presets);
app.use('/api/prompts', routes.prompts);
app.use('/api/categories', routes.categories);
app.use('/api/tokenizer', routes.tokenizer);
app.use('/api/endpoints', routes.endpoints);
app.use('/api/balance', routes.balance);
@@ -92,10 +86,9 @@ const startServer = async () => {
app.use('/api/files', await routes.files.initialize());
app.use('/images/', validateImageRequest, routes.staticRoute);
app.use('/api/share', routes.share);
app.use('/api/roles', routes.roles);
app.use((req, res) => {
res.sendFile(path.join(app.locals.paths.dist, 'index.html'));
res.status(404).sendFile(path.join(app.locals.paths.dist, 'index.html'));
});
app.listen(port, host, () => {

View File

@@ -1,36 +1,31 @@
const { isAssistantsEndpoint } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
const clearPendingReq = require('~/cache/clearPendingReq');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const spendTokens = require('~/models/spendTokens');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
async function abortMessage(req, res) {
let { abortKey, endpoint } = req.body;
let { abortKey, conversationId, endpoint } = req.body;
if (!abortKey && conversationId) {
abortKey = conversationId;
}
if (isAssistantsEndpoint(endpoint)) {
return await abortRun(req, res);
}
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
abortKey = conversationId;
}
if (!abortControllers.has(abortKey) && !res.headersSent) {
return res.status(204).send({ message: 'Request not found' });
}
const { abortController } = abortControllers.get(abortKey) ?? {};
if (!abortController) {
return res.status(204).send({ message: 'Request not found' });
}
const { abortController } = abortControllers.get(abortKey);
const finalEvent = await abortController.abortCompletion();
logger.info('[abortMessage] Aborted request', { abortKey });
logger.debug('[abortMessage] Aborted request', { abortKey });
abortControllers.delete(abortKey);
if (res.headersSent && finalEvent) {
@@ -55,32 +50,12 @@ const handleAbort = () => {
};
};
const createAbortController = (req, res, getAbortData, getReqData) => {
const createAbortController = (req, res, getAbortData) => {
const abortController = new AbortController();
const { endpointOption } = req.body;
abortController.getAbortData = function () {
return getAbortData();
};
/**
* @param {TMessage} userMessage
* @param {string} responseMessageId
*/
const onStart = (userMessage, responseMessageId) => {
const onStart = (userMessage) => {
sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
const prevRequest = abortControllers.get(abortKey);
if (prevRequest && prevRequest?.abortController) {
const data = prevRequest.abortController.getAbortData();
getReqData({ userMessage: data?.userMessage });
const addedAbortKey = `${abortKey}:${responseMessageId}`;
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
res.on('finish', function () {
abortControllers.delete(addedAbortKey);
});
return;
}
abortControllers.set(abortKey, { abortController, ...endpointOption });
res.on('finish', function () {
@@ -90,8 +65,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
abortController.abortCompletion = async function () {
abortController.abort();
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
getAbortData();
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
const completionTokens = await countTokens(responseData?.text ?? '');
const user = req.user.id;
@@ -115,20 +89,10 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
saveMessage({ ...responseMessage, user });
let conversation;
if (userMessagePromise) {
const resolved = await userMessagePromise;
conversation = resolved?.conversation;
}
if (!conversation) {
conversation = await getConvo(req.user.id, conversationId);
}
return {
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
title: await getConvoTitle(user, conversationId),
final: true,
conversation,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage,
};

View File

@@ -1,7 +1,6 @@
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { deleteMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils');
@@ -67,19 +66,13 @@ async function abortRun(req, res) {
logger.error('[abortRun] Error fetching or processing run', error);
}
/* TODO: a reconciling strategy between the existing intermediate message would be more optimal than deleting it */
await deleteMessages({
user: req.user.id,
unfinished: true,
conversationId,
});
runMessages = await checkMessageGaps({
openai,
run_id,
endpoint,
thread_id,
conversationId,
run_id,
latestMessageId,
conversationId,
});
const finalEvent = {

View File

@@ -1,43 +0,0 @@
const { v4 } = require('uuid');
const { handleAbortError } = require('~/server/middleware/abortMiddleware');
/**
* Checks if the assistant is supported or excluded
* @param {object} req - Express Request
* @param {object} req.body - The request payload.
* @param {object} res - Express Response
* @param {function} next - Express next middleware function.
* @returns {Promise<void>}
*/
const validateAssistant = async (req, res, next) => {
const { endpoint, conversationId, assistant_id, messageId } = req.body;
/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals?.[endpoint];
if (!assistantsConfig) {
return next();
}
const { supportedIds, excludedIds } = assistantsConfig;
const error = { message: 'Assistant not supported' };
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId,
messageId: v4(),
parentMessageId: messageId,
error,
});
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
return await handleAbortError(res, req, error, {
sender: 'System',
conversationId,
messageId: v4(),
parentMessageId: messageId,
});
}
return next();
};
module.exports = validateAssistant;

View File

@@ -1,43 +0,0 @@
const { SystemRoles } = require('librechat-data-provider');
const { getAssistant } = require('~/models/Assistant');
/**
* Checks if the assistant is supported or excluded
* @param {object} params
* @param {object} params.req - Express Request
* @param {object} params.req.body - The request payload.
* @param {string} params.overrideEndpoint - The override endpoint
* @param {string} params.overrideAssistantId - The override assistant ID
* @param {OpenAIClient} params.openai - OpenAI API Client
* @returns {Promise<void>}
*/
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
if (req.user.role === SystemRoles.ADMIN) {
return;
}
const endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
const assistant_id =
overrideAssistantId ?? req.params.id ?? req.body.assistant_id ?? req.query.assistant_id;
/** @type {Partial<TAssistantEndpoint>} */
const assistantsConfig = req.app.locals?.[endpoint];
if (!assistantsConfig) {
return;
}
if (!assistantsConfig.privateAssistants) {
return;
}
const assistantDoc = await getAssistant({ assistant_id, user: req.user.id });
if (assistantDoc) {
return;
}
const assistant = await openai.beta.assistants.retrieve(assistant_id);
if (req.user.id !== assistant?.metadata?.author) {
throw new Error(`Assistant ${assistant_id} is not authored by the user.`);
}
};
module.exports = validateAuthor;

View File

@@ -1,28 +0,0 @@
const { SystemRoles } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
/**
* Checks if the user can delete their account
*
* @async
* @function
* @param {Object} req - Express request object
* @param {Object} res - Express response object
* @param {Function} next - Next middleware function
*
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the user can delete their account
*/
const canDeleteAccount = async (req, res, next = () => {}) => {
const { user } = req;
const { ALLOW_ACCOUNT_DELETION = true } = process.env;
if (user?.role === SystemRoles.ADMIN || isEnabled(ALLOW_ACCOUNT_DELETION)) {
return next();
} else {
logger.error(`[User] [Delete Account] [User cannot delete account] [User: ${user?.id}]`);
return res.status(403).send({ message: 'You do not have permission to delete this account' });
}
};
module.exports = canDeleteAccount;

View File

@@ -1,13 +1,15 @@
const Keyv = require('keyv');
const uap = require('ua-parser-js');
const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, removePorts } = require('~/server/utils');
const keyvMongo = require('~/cache/keyvMongo');
const { isEnabled, removePorts } = require('../utils');
const keyvRedis = require('~/cache/keyvRedis');
const denyRequest = require('./denyRequest');
const { getLogStores } = require('~/cache');
const { findUser } = require('~/models');
const User = require('~/models/User');
const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 });
const banCache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 });
const message = 'Your account has been temporarily banned due to violations of our service.';
/**
@@ -55,7 +57,7 @@ const checkBan = async (req, res, next = () => {}) => {
let userId = req.user?.id ?? req.user?._id ?? null;
if (!userId && req?.body?.email) {
const user = await findUser({ email: req.body.email }, '_id');
const user = await User.findOne({ email: req.body.email }, '_id').lean();
userId = user?._id ? user._id.toString() : userId;
}

View File

@@ -1,45 +1,45 @@
const validatePasswordReset = require('./validatePasswordReset');
const abortMiddleware = require('./abortMiddleware');
const checkBan = require('./checkBan');
const checkDomainAllowed = require('./checkDomainAllowed');
const uaParser = require('./uaParser');
const setHeaders = require('./setHeaders');
const loginLimiter = require('./loginLimiter');
const validateModel = require('./validateModel');
const requireJwtAuth = require('./requireJwtAuth');
const uploadLimiters = require('./uploadLimiters');
const registerLimiter = require('./registerLimiter');
const messageLimiters = require('./messageLimiters');
const requireLocalAuth = require('./requireLocalAuth');
const validateEndpoint = require('./validateEndpoint');
const concurrentLimiter = require('./concurrentLimiter');
const validateMessageReq = require('./validateMessageReq');
const buildEndpointOption = require('./buildEndpointOption');
const validateRegistration = require('./validateRegistration');
const validateImageRequest = require('./validateImageRequest');
const buildEndpointOption = require('./buildEndpointOption');
const validateMessageReq = require('./validateMessageReq');
const checkDomainAllowed = require('./checkDomainAllowed');
const concurrentLimiter = require('./concurrentLimiter');
const validateEndpoint = require('./validateEndpoint');
const requireLocalAuth = require('./requireLocalAuth');
const canDeleteAccount = require('./canDeleteAccount');
const requireLdapAuth = require('./requireLdapAuth');
const abortMiddleware = require('./abortMiddleware');
const requireJwtAuth = require('./requireJwtAuth');
const validateModel = require('./validateModel');
const moderateText = require('./moderateText');
const setHeaders = require('./setHeaders');
const limiters = require('./limiters');
const uaParser = require('./uaParser');
const checkBan = require('./checkBan');
const noIndex = require('./noIndex');
const roles = require('./roles');
const importLimiters = require('./importLimiters');
module.exports = {
...uploadLimiters,
...abortMiddleware,
...limiters,
...roles,
noIndex,
...messageLimiters,
checkBan,
uaParser,
setHeaders,
moderateText,
validateModel,
loginLimiter,
requireJwtAuth,
requireLdapAuth,
registerLimiter,
requireLocalAuth,
canDeleteAccount,
validateEndpoint,
concurrentLimiter,
checkDomainAllowed,
validateMessageReq,
buildEndpointOption,
validateRegistration,
validateImageRequest,
validatePasswordReset,
validateModel,
moderateText,
noIndex,
...importLimiters,
checkDomainAllowed,
};

View File

@@ -1,22 +0,0 @@
const createTTSLimiters = require('./ttsLimiters');
const createSTTLimiters = require('./sttLimiters');
const loginLimiter = require('./loginLimiter');
const importLimiters = require('./importLimiters');
const uploadLimiters = require('./uploadLimiters');
const registerLimiter = require('./registerLimiter');
const messageLimiters = require('./messageLimiters');
const verifyEmailLimiter = require('./verifyEmailLimiter');
const resetPasswordLimiter = require('./resetPasswordLimiter');
module.exports = {
...uploadLimiters,
...importLimiters,
...messageLimiters,
loginLimiter,
registerLimiter,
createTTSLimiters,
createSTTLimiters,
verifyEmailLimiter,
resetPasswordLimiter,
};

View File

@@ -1,35 +0,0 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const {
RESET_PASSWORD_WINDOW = 2,
RESET_PASSWORD_MAX = 2,
RESET_PASSWORD_VIOLATION_SCORE: score,
} = process.env;
const windowMs = RESET_PASSWORD_WINDOW * 60 * 1000;
const max = RESET_PASSWORD_MAX;
const windowInMinutes = windowMs / 60000;
const message = `Too many attempts, please try again after ${windowInMinutes} minute(s)`;
const handler = async (req, res) => {
const type = ViolationTypes.RESET_PASSWORD_LIMIT;
const errorMessage = {
type,
max,
windowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
return res.status(429).json({ message });
};
const resetPasswordLimiter = rateLimit({
windowMs,
max,
handler,
keyGenerator: removePorts,
});
module.exports = resetPasswordLimiter;

View File

@@ -1,68 +0,0 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
const sttIpMax = STT_IP_MAX;
const sttIpWindowInMinutes = sttIpWindowMs / 60000;
const sttUserWindowMs = STT_USER_WINDOW * 60 * 1000;
const sttUserMax = STT_USER_MAX;
const sttUserWindowInMinutes = sttUserWindowMs / 60000;
return {
sttIpWindowMs,
sttIpMax,
sttIpWindowInMinutes,
sttUserWindowMs,
sttUserMax,
sttUserWindowInMinutes,
};
};
const createSTTHandler = (ip = true) => {
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.STT_LIMIT;
const errorMessage = {
type,
max: ip ? sttIpMax : sttUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many STT requests. Try again later' });
};
};
const createSTTLimiters = () => {
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
const sttIpLimiter = rateLimit({
windowMs: sttIpWindowMs,
max: sttIpMax,
handler: createSTTHandler(),
});
const sttUserLimiter = rateLimit({
windowMs: sttUserWindowMs,
max: sttUserMax,
handler: createSTTHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
});
return { sttIpLimiter, sttUserLimiter };
};
module.exports = createSTTLimiters;

View File

@@ -1,68 +0,0 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const getEnvironmentVariables = () => {
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
const ttsIpMax = TTS_IP_MAX;
const ttsIpWindowInMinutes = ttsIpWindowMs / 60000;
const ttsUserWindowMs = TTS_USER_WINDOW * 60 * 1000;
const ttsUserMax = TTS_USER_MAX;
const ttsUserWindowInMinutes = ttsUserWindowMs / 60000;
return {
ttsIpWindowMs,
ttsIpMax,
ttsIpWindowInMinutes,
ttsUserWindowMs,
ttsUserMax,
ttsUserWindowInMinutes,
};
};
const createTTSHandler = (ip = true) => {
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
getEnvironmentVariables();
return async (req, res) => {
const type = ViolationTypes.TTS_LIMIT;
const errorMessage = {
type,
max: ip ? ttsIpMax : ttsUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
};
await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
};
};
const createTTSLimiters = () => {
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
const ttsIpLimiter = rateLimit({
windowMs: ttsIpWindowMs,
max: ttsIpMax,
handler: createTTSHandler(),
});
const ttsUserLimiter = rateLimit({
windowMs: ttsUserWindowMs,
max: ttsUserMax,
handler: createTTSHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
});
return { ttsIpLimiter, ttsUserLimiter };
};
module.exports = createTTSLimiters;

View File

@@ -1,35 +0,0 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const {
VERIFY_EMAIL_WINDOW = 2,
VERIFY_EMAIL_MAX = 2,
VERIFY_EMAIL_VIOLATION_SCORE: score,
} = process.env;
const windowMs = VERIFY_EMAIL_WINDOW * 60 * 1000;
const max = VERIFY_EMAIL_MAX;
const windowInMinutes = windowMs / 60000;
const message = `Too many attempts, please try again after ${windowInMinutes} minute(s)`;
const handler = async (req, res) => {
const type = ViolationTypes.VERIFY_EMAIL_LIMIT;
const errorMessage = {
type,
max,
windowInMinutes,
};
await logViolation(req, res, type, errorMessage, score);
return res.status(429).json({ message });
};
const verifyEmailLimiter = rateLimit({
windowMs,
max,
handler,
keyGenerator: removePorts,
});
module.exports = verifyEmailLimiter;

View File

@@ -1,6 +1,6 @@
const rateLimit = require('express-rate-limit');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logViolation } = require('../../cache');
const { removePorts } = require('../utils');
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
const windowMs = LOGIN_WINDOW * 60 * 1000;

View File

@@ -1,6 +1,6 @@
const rateLimit = require('express-rate-limit');
const denyRequest = require('~/server/middleware/denyRequest');
const { logViolation } = require('~/cache');
const { logViolation } = require('../../cache');
const denyRequest = require('./denyRequest');
const {
MESSAGE_IP_MAX = 40,

View File

@@ -1,6 +1,6 @@
const rateLimit = require('express-rate-limit');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logViolation } = require('../../cache');
const { removePorts } = require('../utils');
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
const windowMs = REGISTER_WINDOW * 60 * 1000;

View File

@@ -1,22 +0,0 @@
const passport = require('passport');
const requireLdapAuth = (req, res, next) => {
passport.authenticate('ldapauth', (err, user, info) => {
if (err) {
console.log({
title: '(requireLdapAuth) Error at passport.authenticate',
parameters: [{ name: 'error', value: err }],
});
return next(err);
}
if (!user) {
console.log({
title: '(requireLdapAuth) Error: No user',
});
return res.status(404).send(info);
}
req.user = user;
next();
})(req, res, next);
};
module.exports = requireLdapAuth;

View File

@@ -21,13 +21,7 @@ const requireLocalAuth = (req, res, next) => {
log({
title: '(requireLocalAuth) Error: No user',
});
return res.status(404).send(info);
}
if (info && info.message) {
log({
title: '(requireLocalAuth) Error: ' + info.message,
});
return res.status(422).send({ message: info.message });
return res.status(422).send(info);
}
req.user = user;
next();

View File

@@ -1,14 +0,0 @@
const { SystemRoles } = require('librechat-data-provider');
function checkAdmin(req, res, next) {
try {
if (req.user.role !== SystemRoles.ADMIN) {
return res.status(403).json({ message: 'Forbidden' });
}
next();
} catch (error) {
res.status(500).json({ message: 'Internal Server Error' });
}
}
module.exports = checkAdmin;

View File

@@ -1,52 +0,0 @@
const { SystemRoles } = require('librechat-data-provider');
const { getRoleByName } = require('~/models/Role');
/**
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
*
* @param {PermissionTypes} permissionType - The type of permission to check.
* @param {Permissions[]} permissions - The list of specific permissions to check.
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
* @returns {Function} Express middleware function.
*/
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
return async (req, res, next) => {
try {
const { user } = req;
if (!user) {
return res.status(401).json({ message: 'Authorization required' });
}
if (user.role === SystemRoles.ADMIN) {
return next();
}
const role = await getRoleByName(user.role);
if (role && role[permissionType]) {
const hasAnyPermission = permissions.some((permission) => {
if (role[permissionType][permission]) {
return true;
}
if (bodyProps[permission] && req.body) {
return bodyProps[permission].some((prop) =>
Object.prototype.hasOwnProperty.call(req.body, prop),
);
}
return false;
});
if (hasAnyPermission) {
return next();
}
}
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
} catch (error) {
return res.status(500).json({ message: `Server error: ${error.message}` });
}
};
};
module.exports = generateCheckAccess;

View File

@@ -1,7 +0,0 @@
const checkAdmin = require('./checkAdmin');
const generateCheckAccess = require('./generateCheckAccess');
module.exports = {
checkAdmin,
generateCheckAccess,
};

View File

@@ -1,4 +1,4 @@
const { getConvo } = require('~/models');
const { getConvo } = require('../../models');
// Middleware to validate conversationId and user relationship
const validateMessageReq = async (req, res, next) => {

View File

@@ -1,13 +0,0 @@
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
function validatePasswordReset(req, res, next) {
if (isEnabled(process.env.ALLOW_PASSWORD_RESET)) {
next();
} else {
logger.warn(`Password reset attempt while not allowed. IP: ${req.ip}`);
res.status(403).send('Password reset is not allowed.');
}
}
module.exports = validatePasswordReset;

View File

@@ -1,7 +1,6 @@
const { isEnabled } = require('~/server/utils');
function validateRegistration(req, res, next) {
if (isEnabled(process.env.ALLOW_REGISTRATION)) {
const setting = process.env.ALLOW_REGISTRATION?.toLowerCase();
if (setting === 'true') {
next();
} else {
res.status(403).send('Registration is not allowed.');

View File

@@ -25,12 +25,6 @@ afterEach(() => {
delete process.env.DOMAIN_SERVER;
delete process.env.ALLOW_REGISTRATION;
delete process.env.ALLOW_SOCIAL_LOGIN;
delete process.env.ALLOW_PASSWORD_RESET;
delete process.env.LDAP_URL;
delete process.env.LDAP_BIND_DN;
delete process.env.LDAP_BIND_CREDENTIALS;
delete process.env.LDAP_USER_SEARCH_BASE;
delete process.env.LDAP_SEARCH_FILTER;
});
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
@@ -56,12 +50,6 @@ describe.skip('GET /', () => {
process.env.DOMAIN_SERVER = 'http://test-server.com';
process.env.ALLOW_REGISTRATION = 'true';
process.env.ALLOW_SOCIAL_LOGIN = 'true';
process.env.ALLOW_PASSWORD_RESET = 'true';
process.env.LDAP_URL = 'Test LDAP URL';
process.env.LDAP_BIND_DN = 'Test LDAP Bind DN';
process.env.LDAP_BIND_CREDENTIALS = 'Test LDAP Bind Credentials';
process.env.LDAP_USER_SEARCH_BASE = 'Test LDAP User Search Base';
process.env.LDAP_SEARCH_FILTER = 'Test LDAP Search Filter';
const response = await request(app).get('/');
@@ -76,11 +64,9 @@ describe.skip('GET /', () => {
openidLoginEnabled: true,
openidLabel: 'Test OpenID',
openidImageUrl: 'http://test-server.com',
ldapLoginEnabled: true,
serverDomain: 'http://test-server.com',
emailLoginEnabled: 'true',
registrationEnabled: 'true',
passwordResetEnabled: 'true',
socialLoginEnabled: 'true',
});
});

View File

@@ -1,6 +1,6 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
@@ -20,7 +20,7 @@ router.post(
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
await AskController(req, res, next, initializeClient);
},
);

View File

@@ -2,9 +2,9 @@ const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage } = require('~/models');
const {
handleAbort,
createAbortController,
@@ -41,7 +41,6 @@ router.post(
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
@@ -59,8 +58,6 @@ router.post(
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@@ -109,11 +106,7 @@ router.post(
const pluginMap = new Map();
const onAgentAction = async (action, runId) => {
pluginMap.set(runId, action.tool);
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
sendIntermediateMessage(res, { plugins });
};
const onToolStart = async (tool, input, runId, parentRunId) => {
@@ -131,11 +124,7 @@ router.post(
}
const extraTokens = ':::plugin:::\n';
plugins.push(latestPlugin);
sendIntermediateMessage(
res,
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
extraTokens,
);
sendIntermediateMessage(res, { plugins }, extraTokens);
};
const onToolEnd = async (output, runId) => {
@@ -151,10 +140,14 @@ router.post(
}
};
const onChainEnd = () => {
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, { plugins });
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@@ -162,23 +155,12 @@ router.post(
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onChainEnd = () => {
if (!client.skipSaveUserMessage) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};
let response = await client.sendMessage(text, {
user,
conversationId,
@@ -192,13 +174,12 @@ router.post(
onStart,
getPartialText,
...endpointOption,
progressCallback,
progressOptions: {
onProgress: progressCallback.call(null, {
res,
text,
// parentMessageId: overrideParentMessageId || userMessageId,
parentMessageId: overrideParentMessageId || userMessageId,
plugins,
},
}),
abortController,
});
@@ -211,14 +192,10 @@ router.post(
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
title: await getConvoTitle(user, conversationId),
final: true,
conversation,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});

View File

@@ -4,7 +4,7 @@ const { encryptMetadata, domainParser } = require('~/server/services/ActionServi
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
const { updateAssistant, getAssistant } = require('~/models/Assistant');
const { logger } = require('~/config');
const router = express.Router();
@@ -109,7 +109,7 @@ router.post('/:assistant_id', async (req, res) => {
let updatedAssistant = await openai.beta.assistants.update(assistant_id, { tools });
const promises = [];
promises.push(
updateAssistantDoc(
updateAssistant(
{ assistant_id },
{
actions,
@@ -186,7 +186,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
const promises = [];
promises.push(
updateAssistantDoc(
updateAssistant(
{ assistant_id },
{
actions: updatedActions,

View File

@@ -8,7 +8,6 @@ const {
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const validateAssistant = require('~/server/middleware/assistants/validate');
const chatController = require('~/server/controllers/assistants/chatV1');
router.post('/abort', handleAbort());
@@ -21,6 +20,6 @@ router.post('/abort', handleAbort());
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
router.post('/', validateModel, buildEndpointOption, setHeaders, chatController);
module.exports = router;

View File

@@ -8,7 +8,6 @@ const {
// validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
const validateAssistant = require('~/server/middleware/assistants/validate');
const chatController = require('~/server/controllers/assistants/chatV2');
router.post('/abort', handleAbort());
@@ -21,6 +20,6 @@ router.post('/abort', handleAbort());
* @param {express.Response} res - The response object, used to send back a response.
* @returns {void}
*/
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
router.post('/', validateModel, buildEndpointOption, setHeaders, chatController);
module.exports = router;

View File

@@ -1,45 +1,29 @@
const express = require('express');
const {
resetPasswordRequestController,
resetPasswordController,
refreshController,
registrationController,
resetPasswordController,
resetPasswordRequestController,
} = require('~/server/controllers/AuthController');
const { loginController } = require('~/server/controllers/auth/LoginController');
const { logoutController } = require('~/server/controllers/auth/LogoutController');
} = require('../controllers/AuthController');
const { loginController } = require('../controllers/auth/LoginController');
const { logoutController } = require('../controllers/auth/LogoutController');
const {
checkBan,
loginLimiter,
requireJwtAuth,
registerLimiter,
requireLdapAuth,
requireJwtAuth,
requireLocalAuth,
resetPasswordLimiter,
validateRegistration,
validatePasswordReset,
} = require('~/server/middleware');
} = require('../middleware');
const router = express.Router();
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
//Local
router.post('/logout', requireJwtAuth, logoutController);
router.post(
'/login',
loginLimiter,
checkBan,
ldapAuth ? requireLdapAuth : requireLocalAuth,
loginController,
);
router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController);
router.post('/refresh', refreshController);
router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController);
router.post(
'/requestPasswordReset',
resetPasswordLimiter,
checkBan,
validatePasswordReset,
resetPasswordRequestController,
);
router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController);
router.post('/requestPasswordReset', resetPasswordRequestController);
router.post('/resetPassword', resetPasswordController);
module.exports = router;

View File

@@ -1,15 +0,0 @@
const express = require('express');
const router = express.Router();
const { requireJwtAuth } = require('~/server/middleware');
const { getCategories } = require('~/models/Categories');
router.get('/', requireJwtAuth, async (req, res) => {
try {
const categories = await getCategories();
res.status(200).send(categories);
} catch (error) {
res.status(500).send({ message: 'Failed to retrieve categories', error: error.message });
}
});
module.exports = router;

View File

@@ -1,39 +1,18 @@
const express = require('express');
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
const { getProjectByName } = require('~/models/Project');
const { defaultSocialLogins } = require('librechat-data-provider');
const { isEnabled } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const router = express.Router();
const emailLoginEnabled =
process.env.ALLOW_EMAIL_LOGIN === undefined || isEnabled(process.env.ALLOW_EMAIL_LOGIN);
const passwordResetEnabled = isEnabled(process.env.ALLOW_PASSWORD_RESET);
const sharedLinksEnabled =
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
const publicSharedLinksEnabled =
sharedLinksEnabled &&
(process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
router.get('/', async function (req, res) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
if (cachedStartupConfig) {
res.send(cachedStartupConfig);
return;
}
const isBirthday = () => {
const today = new Date();
return today.getMonth() === 1 && today.getDate() === 11;
};
const instanceProject = await getProjectByName('instance', '_id');
const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
try {
/** @type {TStartupConfig} */
const payload = {
@@ -51,17 +30,15 @@ router.get('/', async function (req, res) {
!!process.env.OPENID_SESSION_SECRET,
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
openidImageUrl: process.env.OPENID_IMAGE_URL,
ldapLoginEnabled,
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
emailLoginEnabled,
registrationEnabled: !ldapLoginEnabled && isEnabled(process.env.ALLOW_REGISTRATION),
registrationEnabled: isEnabled(process.env.ALLOW_REGISTRATION),
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
emailEnabled:
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
!!process.env.EMAIL_USERNAME &&
!!process.env.EMAIL_PASSWORD &&
!!process.env.EMAIL_FROM,
passwordResetEnabled,
checkBalance: isEnabled(process.env.CHECK_BALANCE),
showBirthdayIcon:
isBirthday() ||
@@ -70,17 +47,12 @@ router.get('/', async function (req, res) {
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
interface: req.app.locals.interfaceConfig,
modelSpecs: req.app.locals.modelSpecs,
sharedLinksEnabled,
publicSharedLinksEnabled,
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
instanceProjectId: instanceProject._id.toString(),
};
if (typeof process.env.CUSTOM_FOOTER === 'string') {
payload.customFooter = process.env.CUSTOM_FOOTER;
}
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
return res.status(200).send(payload);
} catch (err) {
logger.error('Error in startup config', err);

View File

@@ -3,11 +3,12 @@ const express = require('express');
const { CacheKeys } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
const { IMPORT_CONVERSATION_JOB_NAME } = require('~/server/utils/import/jobDefinition');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { forkConversation } = require('~/server/utils/import/fork');
const { importConversations } = require('~/server/utils/import');
const { createImportLimiters } = require('~/server/middleware');
const jobScheduler = require('~/server/utils/jobScheduler');
const getLogStores = require('~/cache/getLogStores');
const { sleep } = require('~/server/utils');
const { logger } = require('~/config');
@@ -128,9 +129,10 @@ router.post(
upload.single('file'),
async (req, res) => {
try {
/* TODO: optimize to return imported conversations and add manually */
await importConversations({ filepath: req.file.path, requestUserId: req.user.id });
res.status(201).json({ message: 'Conversation(s) imported successfully' });
const filepath = req.file.path;
const job = await jobScheduler.now(IMPORT_CONVERSATION_JOB_NAME, filepath, req.user.id);
res.status(201).json({ message: 'Import started', jobId: job.id });
} catch (error) {
logger.error('Error processing file', error);
res.status(500).send('Error processing file');
@@ -167,4 +169,24 @@ router.post('/fork', async (req, res) => {
}
});
// Get the status of an import job for polling
router.get('/import/jobs/:jobId', async (req, res) => {
try {
const { jobId } = req.params;
const { userId, ...jobStatus } = await jobScheduler.getJobStatus(jobId);
if (!jobStatus) {
return res.status(404).json({ message: 'Job not found.' });
}
if (userId !== req.user.id) {
return res.status(403).json({ message: 'Unauthorized' });
}
res.json(jobStatus);
} catch (error) {
logger.error('Error getting job details', error);
res.status(500).send('Error getting job details');
}
});
module.exports = router;

View File

@@ -13,7 +13,7 @@ const {
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage } = require('~/models');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
@@ -49,7 +49,6 @@ router.post(
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
@@ -69,8 +68,6 @@ router.post(
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@@ -106,23 +103,29 @@ router.post(
},
});
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, { plugin });
// logger.debug('PLUGIN ACTION', formattedAction);
};
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
sendIntermediateMessage(res, { plugin });
// logger.debug('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@@ -130,27 +133,12 @@ router.post(
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
const { abortController, onStart } = createAbortController(req, res, getAbortData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction);
};
let response = await client.sendMessage(text, {
user,
generation,
@@ -165,13 +153,12 @@ router.post(
onChainEnd,
onStart,
...endpointOption,
progressCallback,
progressOptions: {
onProgress: progressCallback.call(null, {
res,
text,
plugin,
// parentMessageId: overrideParentMessageId || userMessageId,
},
parentMessageId: overrideParentMessageId || userMessageId,
}),
abortController,
});
@@ -183,14 +170,10 @@ router.post(
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
title: await getConvoTitle(user, conversationId),
final: true,
conversation,
conversation: await getConvo(user, conversationId),
requestMessage: userMessage,
responseMessage: response,
});

View File

@@ -5,7 +5,6 @@ const { createMulterInstance } = require('./multer');
const files = require('./files');
const images = require('./images');
const avatar = require('./avatar');
const speech = require('./speech');
const initialize = async () => {
const router = express.Router();
@@ -13,9 +12,6 @@ const initialize = async () => {
router.use(checkBan);
router.use(uaParser);
/* Important: speech route must be added before the upload limiters */
router.use('/speech', speech);
const upload = await createMulterInstance();
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);

View File

@@ -1,10 +0,0 @@
const express = require('express');
const router = express.Router();
const { getCustomConfigSpeech } = require('~/server/services/Files/Audio');
router.get('/get', async (req, res) => {
await getCustomConfigSpeech(req, res);
});
module.exports = router;

View File

@@ -1,17 +0,0 @@
const express = require('express');
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware');
const stt = require('./stt');
const tts = require('./tts');
const customConfigSpeech = require('./customConfigSpeech');
const router = express.Router();
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
router.use('/config', customConfigSpeech);
module.exports = router;

View File

@@ -1,13 +0,0 @@
const express = require('express');
const router = express.Router();
const multer = require('multer');
const { requireJwtAuth } = require('~/server/middleware/');
const { speechToText } = require('~/server/services/Files/Audio');
const upload = multer();
router.post('/', requireJwtAuth, upload.single('audio'), async (req, res) => {
await speechToText(req, res);
});
module.exports = router;

View File

@@ -1,42 +0,0 @@
const multer = require('multer');
const express = require('express');
const { CacheKeys } = require('librechat-data-provider');
const { getVoices, streamAudio, textToSpeech } = require('~/server/services/Files/Audio');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');
const router = express.Router();
const upload = multer();
router.post('/manual', upload.none(), async (req, res) => {
await textToSpeech(req, res);
});
const logDebugMessage = (req, message) =>
logger.debug(`[streamAudio] user: ${req?.user?.id ?? 'UNDEFINED_USER'} | ${message}`);
// TODO: test caching
router.post('/', async (req, res) => {
try {
const audioRunsCache = getLogStores(CacheKeys.AUDIO_RUNS);
const audioRun = await audioRunsCache.get(req.body.runId);
logDebugMessage(req, 'start stream audio');
if (audioRun) {
logDebugMessage(req, 'stream audio already running');
return res.status(401).json({ error: 'Audio stream already running' });
}
audioRunsCache.set(req.body.runId, true);
await streamAudio(req, res);
logDebugMessage(req, 'end stream audio');
res.status(200).end();
} catch (error) {
logger.error(`[streamAudio] user: ${req.user.id} | Failed to stream audio: ${error}`);
res.status(500).json({ error: 'Failed to stream audio' });
}
});
router.get('/voices', async (req, res) => {
await getVoices(req, res);
});
module.exports = router;

View File

@@ -19,8 +19,6 @@ const assistants = require('./assistants');
const files = require('./files');
const staticRoute = require('./static');
const share = require('./share');
const categories = require('./categories');
const roles = require('./roles');
module.exports = {
search,
@@ -44,6 +42,4 @@ module.exports = {
files,
staticRoute,
share,
categories,
roles,
};

Some files were not shown because too many files have changed in this diff Show More