Compare commits
119 Commits
update-tit
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d69465ea3d | ||
|
|
5ef71a7a36 | ||
|
|
326069d7a6 | ||
|
|
785430daf5 | ||
|
|
03fe361917 | ||
|
|
b34a4ddac1 | ||
|
|
7d5b03dd98 | ||
|
|
f959ee302c | ||
|
|
cd00df69bb | ||
|
|
a05e2c1dcc | ||
|
|
87bdbda10a | ||
|
|
605a8ae8c9 | ||
|
|
a724635998 | ||
|
|
6c306a662c | ||
|
|
55f8d9910e | ||
|
|
7edb54889b | ||
|
|
71d9e841b1 | ||
|
|
e76777d298 | ||
|
|
1edbfdbce2 | ||
|
|
1aad315de6 | ||
|
|
5d985746cb | ||
|
|
04654014b2 | ||
|
|
456793772b | ||
|
|
a87d4e0b75 | ||
|
|
a2fd975cd5 | ||
|
|
83619de158 | ||
|
|
b8f2bee3fc | ||
|
|
81292bb4dd | ||
|
|
ed5ee1f86f | ||
|
|
791b0139bc | ||
|
|
156c52e293 | ||
|
|
eef894e608 | ||
|
|
e2867eecc9 | ||
|
|
dd563e0796 | ||
|
|
c99cf1b4b1 | ||
|
|
b5081bfe86 | ||
|
|
aac01df80c | ||
|
|
24467dd626 | ||
|
|
b2b469bd3d | ||
|
|
cec2e57ee9 | ||
|
|
a8c874267f | ||
|
|
a53312bbd4 | ||
|
|
ab74685476 | ||
|
|
015215b790 | ||
|
|
4e4de88faa | ||
|
|
3172381bad | ||
|
|
54b1095239 | ||
|
|
0424f8fe55 | ||
|
|
4319c62e66 | ||
|
|
d3a0b862db | ||
|
|
5d8793c5d1 | ||
|
|
54db67449a | ||
|
|
0cd3c83328 | ||
|
|
302b28fc9b | ||
|
|
dad25bd297 | ||
|
|
a338decf90 | ||
|
|
2cf5228021 | ||
|
|
0294cfc881 | ||
|
|
8d8b17e7ed | ||
|
|
04502e9525 | ||
|
|
bcaa7d5d29 | ||
|
|
c288b458b6 | ||
|
|
447bbcb8ca | ||
|
|
68bf7ac7c0 | ||
|
|
97d12d03d1 | ||
|
|
4416f69a9b | ||
|
|
29e71e98ad | ||
|
|
e9bbf39618 | ||
|
|
08b8ae120e | ||
|
|
803fd63121 | ||
|
|
ef76cc195e | ||
|
|
2e559137ae | ||
|
|
92232afaca | ||
|
|
084cf266a2 | ||
|
|
baf0848021 | ||
|
|
1da92111aa | ||
|
|
35f8053f45 | ||
|
|
ee673d682e | ||
|
|
b7fef6958b | ||
|
|
5452d4c20c | ||
|
|
a7f5b57272 | ||
|
|
f69b317171 | ||
|
|
4469ba72fc | ||
|
|
0e3e45e77d | ||
|
|
9f0c1914a5 | ||
|
|
37ae484fbc | ||
|
|
8939d8af37 | ||
|
|
f9a0166352 | ||
|
|
248dfb8b5b | ||
|
|
b8e35002f4 | ||
|
|
8318f26d66 | ||
|
|
08d6bea359 | ||
|
|
a6058c5669 | ||
|
|
e0402b71f0 | ||
|
|
a618266905 | ||
|
|
d5a7806e32 | ||
|
|
e2cb2905e7 | ||
|
|
3f600f0d3f | ||
|
|
c9e7d4ac18 | ||
|
|
40685f6eb4 | ||
|
|
0ee060d730 | ||
|
|
5dc5d875ba | ||
|
|
9f2538fcd9 | ||
|
|
2b7a973a33 | ||
|
|
c704a23749 | ||
|
|
eb5733083e | ||
|
|
b80f38e49e | ||
|
|
4369e75ca7 | ||
|
|
35ba4ba1a4 | ||
|
|
dcd2e3e62d | ||
|
|
514a502b9c | ||
|
|
8e66683577 | ||
|
|
dc1778b11f | ||
|
|
795bb9c568 | ||
|
|
a937650df6 | ||
|
|
6cf1c85363 | ||
|
|
b3e03b75d0 | ||
|
|
9d8fd92dd3 | ||
|
|
f00a8f87f7 |
40
.env.example
40
.env.example
@@ -64,6 +64,8 @@ PROXY=
|
|||||||
|
|
||||||
# ANYSCALE_API_KEY=
|
# ANYSCALE_API_KEY=
|
||||||
# APIPIE_API_KEY=
|
# APIPIE_API_KEY=
|
||||||
|
# COHERE_API_KEY=
|
||||||
|
# DATABRICKS_API_KEY=
|
||||||
# FIREWORKS_API_KEY=
|
# FIREWORKS_API_KEY=
|
||||||
# GROQ_API_KEY=
|
# GROQ_API_KEY=
|
||||||
# HUGGINGFACE_TOKEN=
|
# HUGGINGFACE_TOKEN=
|
||||||
@@ -78,7 +80,7 @@ PROXY=
|
|||||||
#============#
|
#============#
|
||||||
|
|
||||||
ANTHROPIC_API_KEY=user_provided
|
ANTHROPIC_API_KEY=user_provided
|
||||||
# ANTHROPIC_MODELS=claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
# ANTHROPIC_MODELS=claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||||
# ANTHROPIC_REVERSE_PROXY=
|
# ANTHROPIC_REVERSE_PROXY=
|
||||||
|
|
||||||
#============#
|
#============#
|
||||||
@@ -119,7 +121,9 @@ GOOGLE_KEY=user_provided
|
|||||||
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||||
|
|
||||||
# Vertex AI
|
# Vertex AI
|
||||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0409,gemini-1.0-pro-vision-001,gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k
|
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||||
|
|
||||||
|
# GOOGLE_TITLE_MODEL=gemini-pro
|
||||||
|
|
||||||
# Google Gemini Safety Settings
|
# Google Gemini Safety Settings
|
||||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
||||||
@@ -257,6 +261,14 @@ MEILI_NO_ANALYTICS=true
|
|||||||
MEILI_HOST=http://0.0.0.0:7700
|
MEILI_HOST=http://0.0.0.0:7700
|
||||||
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
|
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
|
||||||
|
|
||||||
|
|
||||||
|
#==================================================#
|
||||||
|
# Speech to Text & Text to Speech #
|
||||||
|
#==================================================#
|
||||||
|
|
||||||
|
STT_API_KEY=
|
||||||
|
TTS_API_KEY=
|
||||||
|
|
||||||
#===================================================#
|
#===================================================#
|
||||||
# User System #
|
# User System #
|
||||||
#===================================================#
|
#===================================================#
|
||||||
@@ -311,6 +323,9 @@ ALLOW_EMAIL_LOGIN=true
|
|||||||
ALLOW_REGISTRATION=true
|
ALLOW_REGISTRATION=true
|
||||||
ALLOW_SOCIAL_LOGIN=false
|
ALLOW_SOCIAL_LOGIN=false
|
||||||
ALLOW_SOCIAL_REGISTRATION=false
|
ALLOW_SOCIAL_REGISTRATION=false
|
||||||
|
ALLOW_PASSWORD_RESET=false
|
||||||
|
# ALLOW_ACCOUNT_DELETION=true # note: enabled by default if omitted/commented out
|
||||||
|
ALLOW_UNVERIFIED_EMAIL_LOGIN=true
|
||||||
|
|
||||||
SESSION_EXPIRY=1000 * 60 * 15
|
SESSION_EXPIRY=1000 * 60 * 15
|
||||||
REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7
|
REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7
|
||||||
@@ -352,6 +367,17 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
|
|||||||
OPENID_BUTTON_LABEL=
|
OPENID_BUTTON_LABEL=
|
||||||
OPENID_IMAGE_URL=
|
OPENID_IMAGE_URL=
|
||||||
|
|
||||||
|
# LDAP
|
||||||
|
LDAP_URL=
|
||||||
|
LDAP_BIND_DN=
|
||||||
|
LDAP_BIND_CREDENTIALS=
|
||||||
|
LDAP_USER_SEARCH_BASE=
|
||||||
|
LDAP_SEARCH_FILTER=mail={{username}}
|
||||||
|
LDAP_CA_CERT_PATH=
|
||||||
|
# LDAP_ID=
|
||||||
|
# LDAP_USERNAME=
|
||||||
|
# LDAP_FULL_NAME=
|
||||||
|
|
||||||
#========================#
|
#========================#
|
||||||
# Email Password Reset #
|
# Email Password Reset #
|
||||||
#========================#
|
#========================#
|
||||||
@@ -378,6 +404,13 @@ FIREBASE_STORAGE_BUCKET=
|
|||||||
FIREBASE_MESSAGING_SENDER_ID=
|
FIREBASE_MESSAGING_SENDER_ID=
|
||||||
FIREBASE_APP_ID=
|
FIREBASE_APP_ID=
|
||||||
|
|
||||||
|
#========================#
|
||||||
|
# Shared Links #
|
||||||
|
#========================#
|
||||||
|
|
||||||
|
ALLOW_SHARED_LINKS=true
|
||||||
|
ALLOW_SHARED_LINKS_PUBLIC=true
|
||||||
|
|
||||||
#===================================================#
|
#===================================================#
|
||||||
# UI #
|
# UI #
|
||||||
#===================================================#
|
#===================================================#
|
||||||
@@ -388,6 +421,9 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
|||||||
|
|
||||||
# SHOW_BIRTHDAY_ICON=true
|
# SHOW_BIRTHDAY_ICON=true
|
||||||
|
|
||||||
|
# Google tag manager id
|
||||||
|
#ANALYTICS_GTM_ID=user provided google tag manager id
|
||||||
|
|
||||||
#==================================================#
|
#==================================================#
|
||||||
# Others #
|
# Others #
|
||||||
#==================================================#
|
#==================================================#
|
||||||
|
|||||||
12
.github/CONTRIBUTING.md
vendored
12
.github/CONTRIBUTING.md
vendored
@@ -126,6 +126,18 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
|||||||
|
|
||||||
- **Current Stance**: At present, this backend transition is of lower priority and might not be pursued.
|
- **Current Stance**: At present, this backend transition is of lower priority and might not be pursued.
|
||||||
|
|
||||||
|
## 7. Module Import Conventions
|
||||||
|
|
||||||
|
- `npm` packages first,
|
||||||
|
- from shortest line (top) to longest (bottom)
|
||||||
|
|
||||||
|
- Followed by typescript types (pertains to data-provider and client workspaces)
|
||||||
|
- longest line (top) to shortest (bottom)
|
||||||
|
- types from package come first
|
||||||
|
|
||||||
|
- Lastly, local imports
|
||||||
|
- longest line (top) to shortest (bottom)
|
||||||
|
- imports with alias `~` treated the same as relative import with respect to line length
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,6 +11,7 @@ logs
|
|||||||
pids
|
pids
|
||||||
*.pid
|
*.pid
|
||||||
*.seed
|
*.seed
|
||||||
|
.git
|
||||||
|
|
||||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||||
lib-cov
|
lib-cov
|
||||||
@@ -45,6 +46,7 @@ api/node_modules/
|
|||||||
client/node_modules/
|
client/node_modules/
|
||||||
bower_components/
|
bower_components/
|
||||||
*.d.ts
|
*.d.ts
|
||||||
|
!vite-env.d.ts
|
||||||
|
|
||||||
# Floobits
|
# Floobits
|
||||||
.floo
|
.floo
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# v0.7.2
|
# v0.7.3
|
||||||
|
|
||||||
# Base node image
|
# Base node image
|
||||||
FROM node:20-alpine AS node
|
FROM node:20-alpine AS node
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# v0.7.2
|
# v0.7.3
|
||||||
|
|
||||||
# Build API, Client and Data Provider
|
# Build API, Client and Data Provider
|
||||||
FROM node:20-alpine AS base
|
FROM node:20-alpine AS base
|
||||||
|
|||||||
12
README.md
12
README.md
@@ -27,7 +27,7 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://railway.app/template/b5k2mn?referralCode=myKrVZ">
|
<a href="https://railway.app/template/b5k2mn?referralCode=HI9hWz">
|
||||||
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
|
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
|
||||||
</a>
|
</a>
|
||||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||||
@@ -58,9 +58,13 @@
|
|||||||
- 🌎 Multilingual UI:
|
- 🌎 Multilingual UI:
|
||||||
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro,
|
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro,
|
||||||
- Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands, עברית
|
- Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands, עברית
|
||||||
- 🎨 Customizable Dropdown & Interface: Adapts to both power users and newcomers.
|
- 🎨 Customizable Dropdown & Interface: Adapts to both power users and newcomers
|
||||||
|
- 📧 Verify your email to ensure secure access
|
||||||
|
- 🗣️ Chat hands-free with Speech-to-Text and Text-to-Speech magic
|
||||||
|
- Automatically send and play Audio
|
||||||
|
- Supports OpenAI, Azure OpenAI, and Elevenlabs
|
||||||
- 📥 Import Conversations from LibreChat, ChatGPT, Chatbot UI
|
- 📥 Import Conversations from LibreChat, ChatGPT, Chatbot UI
|
||||||
- 📤 Export conversations as screenshots, markdown, text, json.
|
- 📤 Export conversations as screenshots, markdown, text, json
|
||||||
- 🔍 Search all messages/conversations
|
- 🔍 Search all messages/conversations
|
||||||
- 🔌 Plugins, including web access, image generation with DALL-E-3 and more
|
- 🔌 Plugins, including web access, image generation with DALL-E-3 and more
|
||||||
- 👥 Multi-User, Secure Authentication with Moderation and Token spend tools
|
- 👥 Multi-User, Secure Authentication with Moderation and Token spend tools
|
||||||
@@ -77,7 +81,7 @@ LibreChat brings together the future of assistant AIs with the revolutionary tec
|
|||||||
|
|
||||||
With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform.
|
With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform.
|
||||||
|
|
||||||
[](https://www.youtube.com/watch?v=YLVUW5UP9N0)
|
[](https://www.youtube.com/watch?v=bSVHEbVPNl4)
|
||||||
Click on the thumbnail to open the video☝️
|
Click on the thumbnail to open the video☝️
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
const Anthropic = require('@anthropic-ai/sdk');
|
const Anthropic = require('@anthropic-ai/sdk');
|
||||||
|
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||||
const {
|
const {
|
||||||
getResponseSender,
|
getResponseSender,
|
||||||
@@ -123,9 +124,14 @@ class AnthropicClient extends BaseClient {
|
|||||||
getClient() {
|
getClient() {
|
||||||
/** @type {Anthropic.default.RequestOptions} */
|
/** @type {Anthropic.default.RequestOptions} */
|
||||||
const options = {
|
const options = {
|
||||||
|
fetch: this.fetch,
|
||||||
apiKey: this.apiKey,
|
apiKey: this.apiKey,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (this.options.proxy) {
|
||||||
|
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||||
|
}
|
||||||
|
|
||||||
if (this.options.reverseProxyUrl) {
|
if (this.options.reverseProxyUrl) {
|
||||||
options.baseURL = this.options.reverseProxyUrl;
|
options.baseURL = this.options.reverseProxyUrl;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
|
const fetch = require('node-fetch');
|
||||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
||||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||||
const checkBalance = require('~/models/checkBalance');
|
const checkBalance = require('~/models/checkBalance');
|
||||||
const { getFiles } = require('~/models/File');
|
const { getFiles } = require('~/models/File');
|
||||||
@@ -17,6 +18,15 @@ class BaseClient {
|
|||||||
month: 'long',
|
month: 'long',
|
||||||
day: 'numeric',
|
day: 'numeric',
|
||||||
});
|
});
|
||||||
|
this.fetch = this.fetch.bind(this);
|
||||||
|
/** @type {boolean} */
|
||||||
|
this.skipSaveConvo = false;
|
||||||
|
/** @type {boolean} */
|
||||||
|
this.skipSaveUserMessage = false;
|
||||||
|
/** @type {ClientDatabaseSavePromise} */
|
||||||
|
this.userMessagePromise;
|
||||||
|
/** @type {ClientDatabaseSavePromise} */
|
||||||
|
this.responsePromise;
|
||||||
}
|
}
|
||||||
|
|
||||||
setOptions() {
|
setOptions() {
|
||||||
@@ -54,6 +64,25 @@ class BaseClient {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Makes an HTTP request and logs the process.
|
||||||
|
*
|
||||||
|
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
|
||||||
|
* @param {RequestInit} [init] - Optional init options for the request.
|
||||||
|
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
|
||||||
|
*/
|
||||||
|
async fetch(_url, init) {
|
||||||
|
let url = _url;
|
||||||
|
if (this.options.directEndpoint) {
|
||||||
|
url = this.options.reverseProxyUrl;
|
||||||
|
}
|
||||||
|
logger.debug(`Making request to ${url}`);
|
||||||
|
if (typeof Bun !== 'undefined') {
|
||||||
|
return await fetch(url, init);
|
||||||
|
}
|
||||||
|
return await fetch(url, init);
|
||||||
|
}
|
||||||
|
|
||||||
getBuildMessagesOptions() {
|
getBuildMessagesOptions() {
|
||||||
throw new Error('Subclasses must implement getBuildMessagesOptions');
|
throw new Error('Subclasses must implement getBuildMessagesOptions');
|
||||||
}
|
}
|
||||||
@@ -63,19 +92,45 @@ class BaseClient {
|
|||||||
await stream.processTextStream(onProgress);
|
await stream.processTextStream(onProgress);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns {[string|undefined, string|undefined]}
|
||||||
|
*/
|
||||||
|
processOverideIds() {
|
||||||
|
/** @type {Record<string, string | undefined>} */
|
||||||
|
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
|
||||||
|
if (overrideConvoId) {
|
||||||
|
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
|
||||||
|
overrideConvoId = conversationId;
|
||||||
|
if (index !== '0') {
|
||||||
|
this.skipSaveConvo = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (overrideUserMessageId) {
|
||||||
|
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
|
||||||
|
overrideUserMessageId = userMessageId;
|
||||||
|
if (index !== '0') {
|
||||||
|
this.skipSaveUserMessage = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [overrideConvoId, overrideUserMessageId];
|
||||||
|
}
|
||||||
|
|
||||||
async setMessageOptions(opts = {}) {
|
async setMessageOptions(opts = {}) {
|
||||||
if (opts && opts.replaceOptions) {
|
if (opts && opts.replaceOptions) {
|
||||||
this.setOptions(opts);
|
this.setOptions(opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
|
||||||
const { isEdited, isContinued } = opts;
|
const { isEdited, isContinued } = opts;
|
||||||
const user = opts.user ?? null;
|
const user = opts.user ?? null;
|
||||||
this.user = user;
|
this.user = user;
|
||||||
const saveOptions = this.getSaveOptions();
|
const saveOptions = this.getSaveOptions();
|
||||||
this.abortController = opts.abortController ?? new AbortController();
|
this.abortController = opts.abortController ?? new AbortController();
|
||||||
const conversationId = opts.conversationId ?? crypto.randomUUID();
|
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||||
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
|
const userMessageId =
|
||||||
|
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||||
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||||
let head = isEdited ? responseMessageId : parentMessageId;
|
let head = isEdited ? responseMessageId : parentMessageId;
|
||||||
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||||
@@ -139,7 +194,7 @@ class BaseClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (typeof opts?.onStart === 'function') {
|
if (typeof opts?.onStart === 'function') {
|
||||||
opts.onStart(userMessage);
|
opts.onStart(userMessage, responseMessageId);
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -373,6 +428,14 @@ class BaseClient {
|
|||||||
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
||||||
await this.handleStartMethods(message, opts);
|
await this.handleStartMethods(message, opts);
|
||||||
|
|
||||||
|
if (opts.progressCallback) {
|
||||||
|
opts.onProgress = opts.progressCallback.call(null, {
|
||||||
|
...(opts.progressOptions ?? {}),
|
||||||
|
parentMessageId: userMessage.messageId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const { generation = '' } = opts;
|
const { generation = '' } = opts;
|
||||||
|
|
||||||
// It's not necessary to push to currentMessages
|
// It's not necessary to push to currentMessages
|
||||||
@@ -421,8 +484,13 @@ class BaseClient {
|
|||||||
this.handleTokenCountMap(tokenCountMap);
|
this.handleTokenCountMap(tokenCountMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isEdited) {
|
if (!isEdited && !this.skipSaveUserMessage) {
|
||||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||||
|
if (typeof opts?.getReqData === 'function') {
|
||||||
|
opts.getReqData({
|
||||||
|
userMessagePromise: this.userMessagePromise,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -471,15 +539,11 @@ class BaseClient {
|
|||||||
const completionTokens = this.getTokenCount(completion);
|
const completionTokens = this.getTokenCount(completion);
|
||||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||||
}
|
}
|
||||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||||
delete responseMessage.tokenCount;
|
delete responseMessage.tokenCount;
|
||||||
return responseMessage;
|
return responseMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getConversation(conversationId, user = null) {
|
|
||||||
return await getConvo(user, conversationId);
|
|
||||||
}
|
|
||||||
|
|
||||||
async loadHistory(conversationId, parentMessageId = null) {
|
async loadHistory(conversationId, parentMessageId = null) {
|
||||||
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
||||||
|
|
||||||
@@ -534,18 +598,24 @@ class BaseClient {
|
|||||||
* @param {string | null} user
|
* @param {string | null} user
|
||||||
*/
|
*/
|
||||||
async saveMessageToDatabase(message, endpointOptions, user = null) {
|
async saveMessageToDatabase(message, endpointOptions, user = null) {
|
||||||
await saveMessage({
|
const savedMessage = await saveMessage({
|
||||||
...message,
|
...message,
|
||||||
endpoint: this.options.endpoint,
|
endpoint: this.options.endpoint,
|
||||||
unfinished: false,
|
unfinished: false,
|
||||||
user,
|
user,
|
||||||
});
|
});
|
||||||
await saveConvo(user, {
|
|
||||||
|
if (this.skipSaveConvo) {
|
||||||
|
return { message: savedMessage };
|
||||||
|
}
|
||||||
|
const conversation = await saveConvo(user, {
|
||||||
conversationId: message.conversationId,
|
conversationId: message.conversationId,
|
||||||
endpoint: this.options.endpoint,
|
endpoint: this.options.endpoint,
|
||||||
endpointType: this.options.endpointType,
|
endpointType: this.options.endpointType,
|
||||||
...endpointOptions,
|
...endpointOptions,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return { message: savedMessage, conversation };
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateMessageInDatabase(message) {
|
async updateMessageInDatabase(message) {
|
||||||
|
|||||||
@@ -438,9 +438,17 @@ class ChatGPTClient extends BaseClient {
|
|||||||
|
|
||||||
if (message.eventType === 'text-generation' && message.text) {
|
if (message.eventType === 'text-generation' && message.text) {
|
||||||
onTokenProgress(message.text);
|
onTokenProgress(message.text);
|
||||||
} else if (message.eventType === 'stream-end' && message.response) {
|
reply += message.text;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
Cohere API Chinese Unicode character replacement hotfix.
|
||||||
|
Should be un-commented when the following issue is resolved:
|
||||||
|
https://github.com/cohere-ai/cohere-typescript/issues/151
|
||||||
|
|
||||||
|
else if (message.eventType === 'stream-end' && message.response) {
|
||||||
reply = message.response.text;
|
reply = message.response.text;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
return reply;
|
return reply;
|
||||||
|
|||||||
@@ -16,10 +16,15 @@ const {
|
|||||||
AuthKeys,
|
AuthKeys,
|
||||||
} = require('librechat-data-provider');
|
} = require('librechat-data-provider');
|
||||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||||
const { formatMessage, createContextHandlers } = require('./prompts');
|
|
||||||
const { getModelMaxTokens } = require('~/utils');
|
const { getModelMaxTokens } = require('~/utils');
|
||||||
const BaseClient = require('./BaseClient');
|
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
const {
|
||||||
|
formatMessage,
|
||||||
|
createContextHandlers,
|
||||||
|
titleInstruction,
|
||||||
|
truncateText,
|
||||||
|
} = require('./prompts');
|
||||||
|
const BaseClient = require('./BaseClient');
|
||||||
|
|
||||||
const loc = 'us-central1';
|
const loc = 'us-central1';
|
||||||
const publisher = 'google';
|
const publisher = 'google';
|
||||||
@@ -591,12 +596,16 @@ class GoogleClient extends BaseClient {
|
|||||||
createLLM(clientOptions) {
|
createLLM(clientOptions) {
|
||||||
const model = clientOptions.modelName ?? clientOptions.model;
|
const model = clientOptions.modelName ?? clientOptions.model;
|
||||||
if (this.project_id && this.isTextModel) {
|
if (this.project_id && this.isTextModel) {
|
||||||
|
logger.debug('Creating Google VertexAI client');
|
||||||
return new GoogleVertexAI(clientOptions);
|
return new GoogleVertexAI(clientOptions);
|
||||||
} else if (this.project_id && this.isChatModel) {
|
} else if (this.project_id && this.isChatModel) {
|
||||||
|
logger.debug('Creating Chat Google VertexAI client');
|
||||||
return new ChatGoogleVertexAI(clientOptions);
|
return new ChatGoogleVertexAI(clientOptions);
|
||||||
} else if (this.project_id) {
|
} else if (this.project_id) {
|
||||||
|
logger.debug('Creating VertexAI client');
|
||||||
return new ChatVertexAI(clientOptions);
|
return new ChatVertexAI(clientOptions);
|
||||||
} else if (model.includes('1.5')) {
|
} else if (model.includes('1.5')) {
|
||||||
|
logger.debug('Creating GenAI client');
|
||||||
return new GenAI(this.apiKey).getGenerativeModel(
|
return new GenAI(this.apiKey).getGenerativeModel(
|
||||||
{
|
{
|
||||||
...clientOptions,
|
...clientOptions,
|
||||||
@@ -606,6 +615,7 @@ class GoogleClient extends BaseClient {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug('Creating Chat Google Generative AI client');
|
||||||
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,6 +727,123 @@ class GoogleClient extends BaseClient {
|
|||||||
return reply;
|
return reply;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
|
||||||
|
*/
|
||||||
|
async titleChatCompletion(_payload, options = {}) {
|
||||||
|
const { abortController } = options;
|
||||||
|
const { parameters, instances } = _payload;
|
||||||
|
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
|
||||||
|
|
||||||
|
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||||
|
|
||||||
|
logger.debug('Initialized title client options');
|
||||||
|
|
||||||
|
if (this.project_id) {
|
||||||
|
clientOptions['authOptions'] = {
|
||||||
|
credentials: {
|
||||||
|
...this.serviceKey,
|
||||||
|
},
|
||||||
|
projectId: this.project_id,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!parameters) {
|
||||||
|
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.isGenerativeModel && !this.project_id) {
|
||||||
|
clientOptions.modelName = clientOptions.model;
|
||||||
|
delete clientOptions.model;
|
||||||
|
}
|
||||||
|
|
||||||
|
const model = this.createLLM(clientOptions);
|
||||||
|
|
||||||
|
let reply = '';
|
||||||
|
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||||
|
|
||||||
|
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||||
|
if (modelName?.includes('1.5') && !this.project_id) {
|
||||||
|
logger.debug('Identified titling model as 1.5 version');
|
||||||
|
/** @type {GenerativeModel} */
|
||||||
|
const client = model;
|
||||||
|
const requestOptions = {
|
||||||
|
contents: _payload,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (this.options?.promptPrefix?.length) {
|
||||||
|
requestOptions.systemInstruction = {
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
text: this.options.promptPrefix,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const safetySettings = _payload.safetySettings;
|
||||||
|
requestOptions.safetySettings = safetySettings;
|
||||||
|
|
||||||
|
const result = await client.generateContent(requestOptions);
|
||||||
|
|
||||||
|
reply = result.response?.text();
|
||||||
|
|
||||||
|
return reply;
|
||||||
|
} else {
|
||||||
|
logger.debug('Beginning titling');
|
||||||
|
const safetySettings = _payload.safetySettings;
|
||||||
|
|
||||||
|
const titleResponse = await model.invoke(messages, {
|
||||||
|
signal: abortController.signal,
|
||||||
|
timeout: 7000,
|
||||||
|
safetySettings: safetySettings,
|
||||||
|
});
|
||||||
|
|
||||||
|
reply = titleResponse.content;
|
||||||
|
|
||||||
|
return reply;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async titleConvo({ text, responseText = '' }) {
|
||||||
|
let title = 'New Chat';
|
||||||
|
const convo = `||>User:
|
||||||
|
"${truncateText(text)}"
|
||||||
|
||>Response:
|
||||||
|
"${JSON.stringify(truncateText(responseText))}"`;
|
||||||
|
|
||||||
|
let { prompt: payload } = await this.buildMessages([
|
||||||
|
{
|
||||||
|
text: `Please generate ${titleInstruction}
|
||||||
|
|
||||||
|
${convo}
|
||||||
|
|
||||||
|
||>Title:`,
|
||||||
|
isCreatedByUser: true,
|
||||||
|
author: this.userLabel,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
if (this.isVisionModel) {
|
||||||
|
logger.warn(
|
||||||
|
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
payload.parameters = { ...payload.parameters, model: settings.model.default };
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
title = await this.titleChatCompletion(payload, {
|
||||||
|
abortController: new AbortController(),
|
||||||
|
onProgress: () => {},
|
||||||
|
});
|
||||||
|
} catch (e) {
|
||||||
|
logger.error('[GoogleClient] There was an issue generating the title', e);
|
||||||
|
}
|
||||||
|
logger.debug(`Title response: ${title}`);
|
||||||
|
return title;
|
||||||
|
}
|
||||||
|
|
||||||
getSaveOptions() {
|
getSaveOptions() {
|
||||||
return {
|
return {
|
||||||
promptPrefix: this.options.promptPrefix,
|
promptPrefix: this.options.promptPrefix,
|
||||||
|
|||||||
@@ -588,7 +588,7 @@ class OpenAIClient extends BaseClient {
|
|||||||
let streamResult = null;
|
let streamResult = null;
|
||||||
this.modelOptions.user = this.user;
|
this.modelOptions.user = this.user;
|
||||||
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
||||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
|
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
|
||||||
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
||||||
const completionResult = await this.getCompletion(
|
const completionResult = await this.getCompletion(
|
||||||
payload,
|
payload,
|
||||||
@@ -827,7 +827,7 @@ class OpenAIClient extends BaseClient {
|
|||||||
|
|
||||||
const instructionsPayload = [
|
const instructionsPayload = [
|
||||||
{
|
{
|
||||||
role: 'system',
|
role: this.options.titleMessageRole ?? 'system',
|
||||||
content: `Please generate ${titleInstruction}
|
content: `Please generate ${titleInstruction}
|
||||||
|
|
||||||
${convo}
|
${convo}
|
||||||
@@ -1106,7 +1106,12 @@ ${convo}
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (this.azure || this.options.azure) {
|
if (this.azure || this.options.azure) {
|
||||||
// Azure does not accept `model` in the body, so we need to remove it.
|
/* Azure Bug, extremely short default `max_tokens` response */
|
||||||
|
if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
|
||||||
|
modelOptions.max_tokens = 4000;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Azure does not accept `model` in the body, so we need to remove it. */
|
||||||
delete modelOptions.model;
|
delete modelOptions.model;
|
||||||
|
|
||||||
opts.baseURL = this.langchainProxy
|
opts.baseURL = this.langchainProxy
|
||||||
@@ -1127,6 +1132,7 @@ ${convo}
|
|||||||
let chatCompletion;
|
let chatCompletion;
|
||||||
/** @type {OpenAI} */
|
/** @type {OpenAI} */
|
||||||
const openai = new OpenAI({
|
const openai = new OpenAI({
|
||||||
|
fetch: this.fetch,
|
||||||
apiKey: this.apiKey,
|
apiKey: this.apiKey,
|
||||||
...opts,
|
...opts,
|
||||||
});
|
});
|
||||||
@@ -1216,6 +1222,7 @@ ${convo}
|
|||||||
});
|
});
|
||||||
|
|
||||||
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
const token = chunk.choices[0]?.delta?.content || '';
|
const token = chunk.choices[0]?.delta?.content || '';
|
||||||
intermediateReply += token;
|
intermediateReply += token;
|
||||||
|
|||||||
@@ -238,18 +238,30 @@ class PluginsClient extends OpenAIClient {
|
|||||||
await this.recordTokenUsage(responseMessage);
|
await this.recordTokenUsage(responseMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||||
delete responseMessage.tokenCount;
|
delete responseMessage.tokenCount;
|
||||||
return { ...responseMessage, ...result };
|
return { ...responseMessage, ...result };
|
||||||
}
|
}
|
||||||
|
|
||||||
async sendMessage(message, opts = {}) {
|
async sendMessage(message, opts = {}) {
|
||||||
|
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||||
|
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
|
||||||
|
|
||||||
|
if (includedTools.length > 0) {
|
||||||
|
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
|
||||||
|
this.options.tools = tools;
|
||||||
|
} else {
|
||||||
|
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
|
||||||
|
this.options.tools = tools;
|
||||||
|
}
|
||||||
|
|
||||||
// If a message is edited, no tools can be used.
|
// If a message is edited, no tools can be used.
|
||||||
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
||||||
if (completionMode) {
|
if (completionMode) {
|
||||||
this.setOptions(opts);
|
this.setOptions(opts);
|
||||||
return super.sendMessage(message, opts);
|
return super.sendMessage(message, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
||||||
const {
|
const {
|
||||||
user,
|
user,
|
||||||
@@ -264,6 +276,14 @@ class PluginsClient extends OpenAIClient {
|
|||||||
onToolEnd,
|
onToolEnd,
|
||||||
} = await this.handleStartMethods(message, opts);
|
} = await this.handleStartMethods(message, opts);
|
||||||
|
|
||||||
|
if (opts.progressCallback) {
|
||||||
|
opts.onProgress = opts.progressCallback.call(null, {
|
||||||
|
...(opts.progressOptions ?? {}),
|
||||||
|
parentMessageId: userMessage.messageId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
this.currentMessages.push(userMessage);
|
this.currentMessages.push(userMessage);
|
||||||
|
|
||||||
let {
|
let {
|
||||||
@@ -292,7 +312,15 @@ class PluginsClient extends OpenAIClient {
|
|||||||
if (payload) {
|
if (payload) {
|
||||||
this.currentMessages = payload;
|
this.currentMessages = payload;
|
||||||
}
|
}
|
||||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
|
||||||
|
if (!this.skipSaveUserMessage) {
|
||||||
|
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||||
|
if (typeof opts?.getReqData === 'function') {
|
||||||
|
opts.getReqData({
|
||||||
|
userMessagePromise: this.userMessagePromise,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||||
await checkBalance({
|
await checkBalance({
|
||||||
|
|||||||
@@ -1,44 +1,3 @@
|
|||||||
/*
|
|
||||||
module.exports = `You are ChatGPT, a Large Language model with useful tools.
|
|
||||||
|
|
||||||
Talk to the human and provide meaningful answers when questions are asked.
|
|
||||||
|
|
||||||
Use the tools when you need them, but use your own knowledge if you are confident of the answer. Keep answers short and concise.
|
|
||||||
|
|
||||||
A tool is not usually needed for creative requests, so do your best to answer them without tools.
|
|
||||||
|
|
||||||
Avoid repeating identical answers if it appears before. Only fulfill the human's requests, do not create extra steps beyond what the human has asked for.
|
|
||||||
|
|
||||||
Your input for 'Action' should be the name of tool used only.
|
|
||||||
|
|
||||||
Be honest. If you can't answer something, or a tool is not appropriate, say you don't know or answer to the best of your ability.
|
|
||||||
|
|
||||||
Attempt to fulfill the human's requests in as few actions as possible`;
|
|
||||||
*/
|
|
||||||
|
|
||||||
// module.exports = `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
|
||||||
|
|
||||||
// Engage with the Human conversationally, providing concise and meaningful answers to questions. Utilize built-in tools when necessary, except for creative requests, where relying on your own knowledge is preferred. Aim for variety and avoid repetitive answers.
|
|
||||||
|
|
||||||
// For your 'Action' input, state the name of the tool used only, and honor user requests without adding extra steps. Always be honest; if you cannot provide an appropriate answer or tool, admit that or do your best.
|
|
||||||
|
|
||||||
// Strive to meet the user's needs efficiently with minimal actions.`;
|
|
||||||
|
|
||||||
// import {
|
|
||||||
// BasePromptTemplate,
|
|
||||||
// BaseStringPromptTemplate,
|
|
||||||
// SerializedBasePromptTemplate,
|
|
||||||
// renderTemplate,
|
|
||||||
// } from "langchain/prompts";
|
|
||||||
|
|
||||||
// prefix: `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
|
||||||
// Your objective is to help users by understanding their intent and choosing the best action. Prioritize direct, specific responses. Use concise, varied answers and rely on your knowledge for creative tasks. Utilize tools when needed, and structure results for machine compatibility.
|
|
||||||
// prefix: `Objective: to comprehend human intentions based on user input and available tools. Goal: identify the best action to directly address the human's query. In your subsequent steps, you will utilize the chosen action. You may select multiple actions and list them in a meaningful order. Prioritize actions that directly relate to the user's query over general ones. Ensure that the generated thought is highly specific and explicit to best match the user's expectations. Construct the result in a manner that an online open-API would most likely expect. Provide concise and meaningful answers to human queries. Utilize tools when necessary. Relying on your own knowledge is preferred for creative requests. Aim for variety and avoid repetitive answers.
|
|
||||||
|
|
||||||
// # Available Actions & Tools:
|
|
||||||
// N/A: no suitable action, use your own knowledge.`,
|
|
||||||
// suffix: `Remember, all your responses MUST adhere to the described format and only respond if the format is followed. Output exactly with the requested format, avoiding any other text as this will be parsed by a machine. Following 'Action:', provide only one of the actions listed above. If a tool is not necessary, deduce this quickly and finish your response. Honor the human's requests without adding extra steps. Carry out tasks in the sequence written by the human. Always be honest; if you cannot provide an appropriate answer or tool, do your best with your own knowledge. Strive to meet the user's needs efficiently with minimal actions.`;
|
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
'gpt3-v1': {
|
'gpt3-v1': {
|
||||||
prefix: `Objective: Understand human intentions using user input and available tools. Goal: Identify the most suitable actions to directly address user queries.
|
prefix: `Objective: Understand human intentions using user input and available tools. Goal: Identify the most suitable actions to directly address user queries.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ In your response, remember to follow these guidelines:
|
|||||||
- If you don't know the answer, simply say that you don't know.
|
- If you don't know the answer, simply say that you don't know.
|
||||||
- If you are unsure how to answer, ask for clarification.
|
- If you are unsure how to answer, ask for clarification.
|
||||||
- Avoid mentioning that you obtained the information from the context.
|
- Avoid mentioning that you obtained the information from the context.
|
||||||
|
|
||||||
Answer appropriately in the user's language.
|
|
||||||
`;
|
`;
|
||||||
|
|
||||||
function createContextHandlers(req, userMessageContent) {
|
function createContextHandlers(req, userMessageContent) {
|
||||||
@@ -94,37 +92,40 @@ function createContextHandlers(req, userMessageContent) {
|
|||||||
|
|
||||||
const resolvedQueries = await Promise.all(queryPromises);
|
const resolvedQueries = await Promise.all(queryPromises);
|
||||||
|
|
||||||
const context = resolvedQueries
|
const context =
|
||||||
.map((queryResult, index) => {
|
resolvedQueries.length === 0
|
||||||
const file = processedFiles[index];
|
? '\n\tThe semantic search did not return any results.'
|
||||||
let contextItems = queryResult.data;
|
: resolvedQueries
|
||||||
|
.map((queryResult, index) => {
|
||||||
|
const file = processedFiles[index];
|
||||||
|
let contextItems = queryResult.data;
|
||||||
|
|
||||||
const generateContext = (currentContext) =>
|
const generateContext = (currentContext) =>
|
||||||
`
|
`
|
||||||
<file>
|
<file>
|
||||||
<filename>${file.filename}</filename>
|
<filename>${file.filename}</filename>
|
||||||
<context>${currentContext}
|
<context>${currentContext}
|
||||||
</context>
|
</context>
|
||||||
</file>`;
|
</file>`;
|
||||||
|
|
||||||
if (useFullContext) {
|
if (useFullContext) {
|
||||||
return generateContext(`\n${contextItems}`);
|
return generateContext(`\n${contextItems}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
contextItems = queryResult.data
|
contextItems = queryResult.data
|
||||||
.map((item) => {
|
.map((item) => {
|
||||||
const pageContent = item[0].page_content;
|
const pageContent = item[0].page_content;
|
||||||
return `
|
return `
|
||||||
<contextItem>
|
<contextItem>
|
||||||
<![CDATA[${pageContent?.trim()}]]>
|
<![CDATA[${pageContent?.trim()}]]>
|
||||||
</contextItem>`;
|
</contextItem>`;
|
||||||
|
})
|
||||||
|
.join('');
|
||||||
|
|
||||||
|
return generateContext(contextItems);
|
||||||
})
|
})
|
||||||
.join('');
|
.join('');
|
||||||
|
|
||||||
return generateContext(contextItems);
|
|
||||||
})
|
|
||||||
.join('');
|
|
||||||
|
|
||||||
if (useFullContext) {
|
if (useFullContext) {
|
||||||
const prompt = `${header}
|
const prompt = `${header}
|
||||||
${context}
|
${context}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ ${convo}`,
|
|||||||
};
|
};
|
||||||
|
|
||||||
const titleInstruction =
|
const titleInstruction =
|
||||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
|
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. Never directly mention the language name or the word "title"';
|
||||||
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
|
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
|
||||||
|
|
||||||
You may call them like this:
|
You may call them like this:
|
||||||
|
|||||||
@@ -576,7 +576,11 @@ describe('BaseClient', () => {
|
|||||||
const onStart = jest.fn();
|
const onStart = jest.fn();
|
||||||
const opts = { onStart };
|
const opts = { onStart };
|
||||||
await TestClient.sendMessage('Hello, world!', opts);
|
await TestClient.sendMessage('Hello, world!', opts);
|
||||||
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
|
|
||||||
|
expect(onStart).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ text: 'Hello, world!' }),
|
||||||
|
expect.any(String),
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ describe('PluginsClient', () => {
|
|||||||
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('Azure OpenAI tests specific to Plugins', () => {
|
describe('Azure OpenAI tests specific to Plugins', () => {
|
||||||
// TODO: add more tests for Azure OpenAI integration with Plugins
|
// TODO: add more tests for Azure OpenAI integration with Plugins
|
||||||
// let client;
|
// let client;
|
||||||
@@ -220,4 +221,94 @@ describe('PluginsClient', () => {
|
|||||||
spy.mockRestore();
|
spy.mockRestore();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('sendMessage with filtered tools', () => {
|
||||||
|
let TestAgent;
|
||||||
|
const apiKey = 'fake-api-key';
|
||||||
|
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
TestAgent = new PluginsClient(apiKey, {
|
||||||
|
tools: mockTools,
|
||||||
|
modelOptions: {
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
temperature: 0,
|
||||||
|
max_tokens: 2,
|
||||||
|
},
|
||||||
|
agentOptions: {
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
TestAgent.options.req = {
|
||||||
|
app: {
|
||||||
|
locals: {},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
|
||||||
|
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
|
||||||
|
|
||||||
|
if (includedTools.length > 0) {
|
||||||
|
const tools = TestAgent.options.tools.filter((plugin) =>
|
||||||
|
includedTools.includes(plugin.name),
|
||||||
|
);
|
||||||
|
TestAgent.options.tools = tools;
|
||||||
|
} else {
|
||||||
|
const tools = TestAgent.options.tools.filter(
|
||||||
|
(plugin) => !filteredTools.includes(plugin.name),
|
||||||
|
);
|
||||||
|
TestAgent.options.tools = tools;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
text: 'Mocked response',
|
||||||
|
tools: TestAgent.options.tools,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should filter out tools when filteredTools is provided', async () => {
|
||||||
|
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||||
|
const response = await TestAgent.sendMessage('Test message');
|
||||||
|
expect(response.tools).toHaveLength(2);
|
||||||
|
expect(response.tools).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({ name: 'tool2' }),
|
||||||
|
expect.objectContaining({ name: 'tool4' }),
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should only include specified tools when includedTools is provided', async () => {
|
||||||
|
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
|
||||||
|
const response = await TestAgent.sendMessage('Test message');
|
||||||
|
expect(response.tools).toHaveLength(2);
|
||||||
|
expect(response.tools).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({ name: 'tool2' }),
|
||||||
|
expect.objectContaining({ name: 'tool4' }),
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should prioritize includedTools over filteredTools', async () => {
|
||||||
|
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||||
|
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
|
||||||
|
const response = await TestAgent.sendMessage('Test message');
|
||||||
|
expect(response.tools).toHaveLength(2);
|
||||||
|
expect(response.tools).toEqual(
|
||||||
|
expect.arrayContaining([
|
||||||
|
expect.objectContaining({ name: 'tool1' }),
|
||||||
|
expect.objectContaining({ name: 'tool2' }),
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('should not modify tools when no filters are provided', async () => {
|
||||||
|
const response = await TestAgent.sendMessage('Test message');
|
||||||
|
expect(response.tools).toHaveLength(4);
|
||||||
|
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -80,13 +80,18 @@ class StableDiffusionAPI extends StructuredTool {
|
|||||||
const payload = {
|
const payload = {
|
||||||
prompt,
|
prompt,
|
||||||
negative_prompt,
|
negative_prompt,
|
||||||
sampler_index: 'DPM++ 2M Karras',
|
|
||||||
cfg_scale: 4.5,
|
cfg_scale: 4.5,
|
||||||
steps: 22,
|
steps: 22,
|
||||||
width: 1024,
|
width: 1024,
|
||||||
height: 1024,
|
height: 1024,
|
||||||
};
|
};
|
||||||
const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
let generationResponse;
|
||||||
|
try {
|
||||||
|
generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[StableDiffusion] Error while generating image:', error);
|
||||||
|
return 'Error making API request.';
|
||||||
|
}
|
||||||
const image = generationResponse.data.images[0];
|
const image = generationResponse.data.images[0];
|
||||||
|
|
||||||
/** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */
|
/** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */
|
||||||
|
|||||||
17
api/cache/getLogStores.js
vendored
17
api/cache/getLogStores.js
vendored
@@ -7,6 +7,7 @@ const keyvMongo = require('./keyvMongo');
|
|||||||
|
|
||||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||||
const THIRTY_MINUTES = 1800000;
|
const THIRTY_MINUTES = 1800000;
|
||||||
|
const TEN_MINUTES = 600000;
|
||||||
|
|
||||||
const duration = math(BAN_DURATION, 7200000);
|
const duration = math(BAN_DURATION, 7200000);
|
||||||
|
|
||||||
@@ -24,6 +25,14 @@ const config = isEnabled(USE_REDIS)
|
|||||||
? new Keyv({ store: keyvRedis })
|
? new Keyv({ store: keyvRedis })
|
||||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||||
|
|
||||||
|
const roles = isEnabled(USE_REDIS)
|
||||||
|
? new Keyv({ store: keyvRedis })
|
||||||
|
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||||
|
|
||||||
|
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||||
|
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
|
||||||
|
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
|
||||||
|
|
||||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||||
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
||||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
||||||
@@ -41,6 +50,7 @@ const abortKeys = isEnabled(USE_REDIS)
|
|||||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
|
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
|
||||||
|
|
||||||
const namespaces = {
|
const namespaces = {
|
||||||
|
[CacheKeys.ROLES]: roles,
|
||||||
[CacheKeys.CONFIG_STORE]: config,
|
[CacheKeys.CONFIG_STORE]: config,
|
||||||
pending_req,
|
pending_req,
|
||||||
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
||||||
@@ -55,7 +65,13 @@ const namespaces = {
|
|||||||
message_limit: createViolationInstance('message_limit'),
|
message_limit: createViolationInstance('message_limit'),
|
||||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||||
registrations: createViolationInstance('registrations'),
|
registrations: createViolationInstance('registrations'),
|
||||||
|
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
|
||||||
|
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
|
||||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||||
|
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
|
||||||
|
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
|
||||||
|
ViolationTypes.RESET_PASSWORD_LIMIT,
|
||||||
|
),
|
||||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||||
),
|
),
|
||||||
@@ -64,6 +80,7 @@ const namespaces = {
|
|||||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||||
[CacheKeys.GEN_TITLE]: genTitle,
|
[CacheKeys.GEN_TITLE]: genTitle,
|
||||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||||
|
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
2
api/cache/logViolation.js
vendored
2
api/cache/logViolation.js
vendored
@@ -1,6 +1,6 @@
|
|||||||
|
const { isEnabled } = require('~/server/utils');
|
||||||
const getLogStores = require('./getLogStores');
|
const getLogStores = require('./getLogStores');
|
||||||
const banViolation = require('./banViolation');
|
const banViolation = require('./banViolation');
|
||||||
const { isEnabled } = require('../server/utils');
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Logs the violation.
|
* Logs the violation.
|
||||||
|
|||||||
@@ -27,26 +27,25 @@ function getMatchingSensitivePatterns(valueStr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Redacts sensitive information from a console message.
|
* Redacts sensitive information from a console message and trims it to a specified length if provided.
|
||||||
*
|
|
||||||
* @param {string} str - The console message to be redacted.
|
* @param {string} str - The console message to be redacted.
|
||||||
* @returns {string} - The redacted console message.
|
* @param {number} [trimLength] - The optional length at which to trim the redacted message.
|
||||||
|
* @returns {string} - The redacted and optionally trimmed console message.
|
||||||
*/
|
*/
|
||||||
function redactMessage(str) {
|
function redactMessage(str, trimLength) {
|
||||||
if (!str) {
|
if (!str) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
const patterns = getMatchingSensitivePatterns(str);
|
const patterns = getMatchingSensitivePatterns(str);
|
||||||
|
|
||||||
if (patterns.length === 0) {
|
|
||||||
return str;
|
|
||||||
}
|
|
||||||
|
|
||||||
patterns.forEach((pattern) => {
|
patterns.forEach((pattern) => {
|
||||||
str = str.replace(pattern, '$1[REDACTED]');
|
str = str.replace(pattern, '$1[REDACTED]');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (trimLength !== undefined && str.length > trimLength) {
|
||||||
|
return `${str.substring(0, trimLength)}...`;
|
||||||
|
}
|
||||||
|
|
||||||
return str;
|
return str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
|
|||||||
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
||||||
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
|
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
|
||||||
*/
|
*/
|
||||||
const updateAssistant = async (searchParams, updateData, session = null) => {
|
const updateAssistantDoc = async (searchParams, updateData, session = null) => {
|
||||||
const options = { new: true, upsert: true, session };
|
const options = { new: true, upsert: true, session };
|
||||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||||
};
|
};
|
||||||
@@ -52,7 +52,7 @@ const deleteAssistant = async (searchParams) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
updateAssistant,
|
updateAssistantDoc,
|
||||||
deleteAssistant,
|
deleteAssistant,
|
||||||
getAssistants,
|
getAssistants,
|
||||||
getAssistant,
|
getAssistant,
|
||||||
|
|||||||
61
api/models/Categories.js
Normal file
61
api/models/Categories.js
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
const { logger } = require('~/config');
|
||||||
|
// const { Categories } = require('./schema/categories');
|
||||||
|
const options = [
|
||||||
|
{
|
||||||
|
label: '',
|
||||||
|
value: '',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'idea',
|
||||||
|
value: 'idea',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'travel',
|
||||||
|
value: 'travel',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'teach_or_explain',
|
||||||
|
value: 'teach_or_explain',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'write',
|
||||||
|
value: 'write',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'shop',
|
||||||
|
value: 'shop',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'code',
|
||||||
|
value: 'code',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'misc',
|
||||||
|
value: 'misc',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'roleplay',
|
||||||
|
value: 'roleplay',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'finance',
|
||||||
|
value: 'finance',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
/**
|
||||||
|
* Retrieves the categories asynchronously.
|
||||||
|
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
|
||||||
|
* @throws {Error} If there is an error retrieving the categories.
|
||||||
|
*/
|
||||||
|
getCategories: async () => {
|
||||||
|
try {
|
||||||
|
// const categories = await Categories.find();
|
||||||
|
return options;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error getting categories', error);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
@@ -21,16 +21,18 @@ module.exports = {
|
|||||||
Conversation,
|
Conversation,
|
||||||
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
||||||
try {
|
try {
|
||||||
const messages = await getMessages({ conversationId });
|
const messages = await getMessages({ conversationId }, '_id');
|
||||||
const update = { ...convo, messages, user };
|
const update = { ...convo, messages, user };
|
||||||
if (newConversationId) {
|
if (newConversationId) {
|
||||||
update.conversationId = newConversationId;
|
update.conversationId = newConversationId;
|
||||||
}
|
}
|
||||||
|
|
||||||
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
const conversation = await Conversation.findOneAndUpdate({ conversationId, user }, update, {
|
||||||
new: true,
|
new: true,
|
||||||
upsert: true,
|
upsert: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return conversation.toObject();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[saveConvo] Error saving conversation', error);
|
logger.error('[saveConvo] Error saving conversation', error);
|
||||||
return { message: 'Error saving conversation' };
|
return { message: 'Error saving conversation' };
|
||||||
|
|||||||
@@ -97,8 +97,12 @@ const deleteFileByFilter = async (filter) => {
|
|||||||
* @param {Array<string>} file_ids - The unique identifiers of the files to delete.
|
* @param {Array<string>} file_ids - The unique identifiers of the files to delete.
|
||||||
* @returns {Promise<Object>} A promise that resolves to the result of the deletion operation.
|
* @returns {Promise<Object>} A promise that resolves to the result of the deletion operation.
|
||||||
*/
|
*/
|
||||||
const deleteFiles = async (file_ids) => {
|
const deleteFiles = async (file_ids, user) => {
|
||||||
return await File.deleteMany({ file_id: { $in: file_ids } });
|
let deleteQuery = { file_id: { $in: file_ids } };
|
||||||
|
if (user) {
|
||||||
|
deleteQuery = { user: user };
|
||||||
|
}
|
||||||
|
return await File.deleteMany(deleteQuery);
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
|
|||||||
@@ -57,18 +57,13 @@ module.exports = {
|
|||||||
if (files) {
|
if (files) {
|
||||||
update.files = files;
|
update.files = files;
|
||||||
}
|
}
|
||||||
// may also need to update the conversation here
|
|
||||||
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
|
|
||||||
|
|
||||||
return {
|
const message = await Message.findOneAndUpdate({ messageId }, update, {
|
||||||
messageId,
|
upsert: true,
|
||||||
conversationId,
|
new: true,
|
||||||
parentMessageId,
|
});
|
||||||
sender,
|
|
||||||
text,
|
return message.toObject();
|
||||||
isCreatedByUser,
|
|
||||||
tokenCount,
|
|
||||||
};
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('Error saving message:', err);
|
logger.error('Error saving message:', err);
|
||||||
throw new Error('Failed to save message.');
|
throw new Error('Failed to save message.');
|
||||||
@@ -129,6 +124,14 @@ module.exports = {
|
|||||||
throw new Error('Failed to save message.');
|
throw new Error('Failed to save message.');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
async updateMessageText({ messageId, text }) {
|
||||||
|
try {
|
||||||
|
await Message.updateOne({ messageId }, { text });
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Error updating message text:', err);
|
||||||
|
throw new Error('Failed to update message text.');
|
||||||
|
}
|
||||||
|
},
|
||||||
async updateMessage(message) {
|
async updateMessage(message) {
|
||||||
try {
|
try {
|
||||||
const { messageId, ...update } = message;
|
const { messageId, ...update } = message;
|
||||||
@@ -171,8 +174,18 @@ module.exports = {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
async getMessages(filter) {
|
/**
|
||||||
|
* Retrieves messages from the database.
|
||||||
|
* @param {Record<string, unknown>} filter
|
||||||
|
* @param {string | undefined} [select]
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
async getMessages(filter, select) {
|
||||||
try {
|
try {
|
||||||
|
if (select) {
|
||||||
|
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||||
|
}
|
||||||
|
|
||||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('Error getting messages:', err);
|
logger.error('Error getting messages:', err);
|
||||||
|
|||||||
90
api/models/Project.js
Normal file
90
api/models/Project.js
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
const { model } = require('mongoose');
|
||||||
|
const projectSchema = require('~/models/schema/projectSchema');
|
||||||
|
|
||||||
|
const Project = model('Project', projectSchema);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||||
|
*
|
||||||
|
* @param {string} projectId - The ID of the project to find and return as a plain object.
|
||||||
|
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||||
|
* @returns {Promise<MongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||||
|
*/
|
||||||
|
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||||
|
const query = Project.findById(projectId);
|
||||||
|
|
||||||
|
if (fieldsToSelect) {
|
||||||
|
query.select(fieldsToSelect);
|
||||||
|
}
|
||||||
|
|
||||||
|
return await query.lean();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieve a project by name and convert the found project document to a plain object.
|
||||||
|
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
|
||||||
|
*
|
||||||
|
* @param {string} projectName - The name of the project to find or create.
|
||||||
|
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||||
|
* @returns {Promise<MongoProject>} A plain object representing the project document.
|
||||||
|
*/
|
||||||
|
const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
||||||
|
const query = { name: projectName };
|
||||||
|
const update = { $setOnInsert: { name: projectName } };
|
||||||
|
const options = {
|
||||||
|
new: true,
|
||||||
|
upsert: projectName === 'instance',
|
||||||
|
lean: true,
|
||||||
|
select: fieldsToSelect,
|
||||||
|
};
|
||||||
|
|
||||||
|
return await Project.findOneAndUpdate(query, update, options);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
|
||||||
|
*
|
||||||
|
* @param {string} projectId - The ID of the project to update.
|
||||||
|
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
|
||||||
|
* @returns {Promise<MongoProject>} The updated project document.
|
||||||
|
*/
|
||||||
|
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||||
|
return await Project.findByIdAndUpdate(
|
||||||
|
projectId,
|
||||||
|
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||||
|
{ new: true },
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove an array of prompt group IDs from a project's promptGroupIds array.
|
||||||
|
*
|
||||||
|
* @param {string} projectId - The ID of the project to update.
|
||||||
|
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
|
||||||
|
* @returns {Promise<MongoProject>} The updated project document.
|
||||||
|
*/
|
||||||
|
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||||
|
return await Project.findByIdAndUpdate(
|
||||||
|
projectId,
|
||||||
|
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||||
|
{ new: true },
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a prompt group ID from all projects.
|
||||||
|
*
|
||||||
|
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||||
|
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
getProjectById,
|
||||||
|
getProjectByName,
|
||||||
|
addGroupIdsToProject,
|
||||||
|
removeGroupIdsFromProject,
|
||||||
|
removeGroupFromAllProjects,
|
||||||
|
};
|
||||||
@@ -1,52 +1,528 @@
|
|||||||
const mongoose = require('mongoose');
|
const { ObjectId } = require('mongodb');
|
||||||
|
const { SystemRoles, SystemCategories } = require('librechat-data-provider');
|
||||||
|
const {
|
||||||
|
getProjectByName,
|
||||||
|
addGroupIdsToProject,
|
||||||
|
removeGroupIdsFromProject,
|
||||||
|
removeGroupFromAllProjects,
|
||||||
|
} = require('./Project');
|
||||||
|
const { Prompt, PromptGroup } = require('./schema/promptSchema');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const promptSchema = mongoose.Schema(
|
/**
|
||||||
{
|
* Create a pipeline for the aggregation to get prompt groups
|
||||||
title: {
|
* @param {Object} query
|
||||||
type: String,
|
* @param {number} skip
|
||||||
required: true,
|
* @param {number} limit
|
||||||
|
* @returns {[Object]} - The pipeline for the aggregation
|
||||||
|
*/
|
||||||
|
const createGroupPipeline = (query, skip, limit) => {
|
||||||
|
return [
|
||||||
|
{ $match: query },
|
||||||
|
{ $sort: { createdAt: -1 } },
|
||||||
|
{ $skip: skip },
|
||||||
|
{ $limit: limit },
|
||||||
|
{
|
||||||
|
$lookup: {
|
||||||
|
from: 'prompts',
|
||||||
|
localField: 'productionId',
|
||||||
|
foreignField: '_id',
|
||||||
|
as: 'productionPrompt',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
prompt: {
|
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||||
type: String,
|
{
|
||||||
required: true,
|
$project: {
|
||||||
|
name: 1,
|
||||||
|
numberOfGenerations: 1,
|
||||||
|
oneliner: 1,
|
||||||
|
category: 1,
|
||||||
|
projectIds: 1,
|
||||||
|
productionId: 1,
|
||||||
|
author: 1,
|
||||||
|
authorName: 1,
|
||||||
|
createdAt: 1,
|
||||||
|
updatedAt: 1,
|
||||||
|
'productionPrompt.prompt': 1,
|
||||||
|
// 'productionPrompt._id': 1,
|
||||||
|
// 'productionPrompt.type': 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
category: {
|
];
|
||||||
type: String,
|
};
|
||||||
},
|
|
||||||
},
|
|
||||||
{ timestamps: true },
|
|
||||||
);
|
|
||||||
|
|
||||||
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
|
/**
|
||||||
|
* Create a pipeline for the aggregation to get all prompt groups
|
||||||
|
* @param {Object} query
|
||||||
|
* @param {Partial<MongoPromptGroup>} $project
|
||||||
|
* @returns {[Object]} - The pipeline for the aggregation
|
||||||
|
*/
|
||||||
|
const createAllGroupsPipeline = (
|
||||||
|
query,
|
||||||
|
$project = {
|
||||||
|
name: 1,
|
||||||
|
oneliner: 1,
|
||||||
|
category: 1,
|
||||||
|
author: 1,
|
||||||
|
authorName: 1,
|
||||||
|
createdAt: 1,
|
||||||
|
updatedAt: 1,
|
||||||
|
command: 1,
|
||||||
|
'productionPrompt.prompt': 1,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
return [
|
||||||
|
{ $match: query },
|
||||||
|
{ $sort: { createdAt: -1 } },
|
||||||
|
{
|
||||||
|
$lookup: {
|
||||||
|
from: 'prompts',
|
||||||
|
localField: 'productionId',
|
||||||
|
foreignField: '_id',
|
||||||
|
as: 'productionPrompt',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||||
|
{
|
||||||
|
$project,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all prompt groups with filters
|
||||||
|
* @param {Object} req
|
||||||
|
* @param {TPromptGroupsWithFilterRequest} filter
|
||||||
|
* @returns {Promise<PromptGroupListResponse>}
|
||||||
|
*/
|
||||||
|
const getAllPromptGroups = async (req, filter) => {
|
||||||
|
try {
|
||||||
|
const { name, ...query } = filter;
|
||||||
|
|
||||||
|
if (!query.author) {
|
||||||
|
throw new Error('Author is required');
|
||||||
|
}
|
||||||
|
|
||||||
|
let searchShared = true;
|
||||||
|
let searchSharedOnly = false;
|
||||||
|
if (name) {
|
||||||
|
query.name = new RegExp(name, 'i');
|
||||||
|
}
|
||||||
|
if (!query.category) {
|
||||||
|
delete query.category;
|
||||||
|
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||||
|
searchShared = false;
|
||||||
|
delete query.category;
|
||||||
|
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||||
|
query.category = '';
|
||||||
|
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||||
|
searchSharedOnly = true;
|
||||||
|
delete query.category;
|
||||||
|
}
|
||||||
|
|
||||||
|
let combinedQuery = query;
|
||||||
|
|
||||||
|
if (searchShared) {
|
||||||
|
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||||
|
if (project && project.promptGroupIds.length > 0) {
|
||||||
|
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||||
|
delete projectQuery.author;
|
||||||
|
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||||
|
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error getting all prompt groups', error);
|
||||||
|
return { message: 'Error getting all prompt groups' };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get prompt groups with filters
|
||||||
|
* @param {Object} req
|
||||||
|
* @param {TPromptGroupsWithFilterRequest} filter
|
||||||
|
* @returns {Promise<PromptGroupListResponse>}
|
||||||
|
*/
|
||||||
|
const getPromptGroups = async (req, filter) => {
|
||||||
|
try {
|
||||||
|
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
|
||||||
|
|
||||||
|
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||||
|
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||||
|
|
||||||
|
if (!query.author) {
|
||||||
|
throw new Error('Author is required');
|
||||||
|
}
|
||||||
|
|
||||||
|
let searchShared = true;
|
||||||
|
let searchSharedOnly = false;
|
||||||
|
if (name) {
|
||||||
|
query.name = new RegExp(name, 'i');
|
||||||
|
}
|
||||||
|
if (!query.category) {
|
||||||
|
delete query.category;
|
||||||
|
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||||
|
searchShared = false;
|
||||||
|
delete query.category;
|
||||||
|
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||||
|
query.category = '';
|
||||||
|
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||||
|
searchSharedOnly = true;
|
||||||
|
delete query.category;
|
||||||
|
}
|
||||||
|
|
||||||
|
let combinedQuery = query;
|
||||||
|
|
||||||
|
if (searchShared) {
|
||||||
|
// const projects = req.user.projects || []; // TODO: handle multiple projects
|
||||||
|
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||||
|
if (project && project.promptGroupIds.length > 0) {
|
||||||
|
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||||
|
delete projectQuery.author;
|
||||||
|
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||||
|
const limit = validatedPageSize;
|
||||||
|
|
||||||
|
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
|
||||||
|
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
|
||||||
|
|
||||||
|
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
|
||||||
|
PromptGroup.aggregate(promptGroupsPipeline).exec(),
|
||||||
|
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
|
||||||
|
]);
|
||||||
|
|
||||||
|
const promptGroups = promptGroupsResults;
|
||||||
|
const totalPromptGroups =
|
||||||
|
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
|
||||||
|
|
||||||
|
return {
|
||||||
|
promptGroups,
|
||||||
|
pageNumber: validatedPageNumber.toString(),
|
||||||
|
pageSize: validatedPageSize.toString(),
|
||||||
|
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error getting prompt groups', error);
|
||||||
|
return { message: 'Error getting prompt groups' };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
savePrompt: async ({ title, prompt }) => {
|
getPromptGroups,
|
||||||
|
getAllPromptGroups,
|
||||||
|
/**
|
||||||
|
* Create a prompt and its respective group
|
||||||
|
* @param {TCreatePromptRecord} saveData
|
||||||
|
* @returns {Promise<TCreatePromptResponse>}
|
||||||
|
*/
|
||||||
|
createPromptGroup: async (saveData) => {
|
||||||
try {
|
try {
|
||||||
await Prompt.create({
|
const { prompt, group, author, authorName } = saveData;
|
||||||
title,
|
|
||||||
prompt,
|
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||||
});
|
{ ...group, author, authorName, productionId: null },
|
||||||
return { title, prompt };
|
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||||
|
{ new: true, upsert: true },
|
||||||
|
)
|
||||||
|
.lean()
|
||||||
|
.select('-__v')
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
const newPrompt = await Prompt.findOneAndUpdate(
|
||||||
|
{ ...prompt, author, groupId: newPromptGroup._id },
|
||||||
|
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
|
||||||
|
{ new: true, upsert: true },
|
||||||
|
)
|
||||||
|
.lean()
|
||||||
|
.select('-__v')
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
newPromptGroup = await PromptGroup.findByIdAndUpdate(
|
||||||
|
newPromptGroup._id,
|
||||||
|
{ productionId: newPrompt._id },
|
||||||
|
{ new: true },
|
||||||
|
)
|
||||||
|
.lean()
|
||||||
|
.select('-__v')
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
return {
|
||||||
|
prompt: newPrompt,
|
||||||
|
group: {
|
||||||
|
...newPromptGroup,
|
||||||
|
productionPrompt: { prompt: newPrompt.prompt },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error saving prompt group', error);
|
||||||
|
throw new Error('Error saving prompt group');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Save a prompt
|
||||||
|
* @param {TCreatePromptRecord} saveData
|
||||||
|
* @returns {Promise<TCreatePromptResponse>}
|
||||||
|
*/
|
||||||
|
savePrompt: async (saveData) => {
|
||||||
|
try {
|
||||||
|
const { prompt, author } = saveData;
|
||||||
|
const newPromptData = {
|
||||||
|
...prompt,
|
||||||
|
author,
|
||||||
|
};
|
||||||
|
|
||||||
|
/** @type {TPrompt} */
|
||||||
|
let newPrompt;
|
||||||
|
try {
|
||||||
|
newPrompt = await Prompt.create(newPromptData);
|
||||||
|
} catch (error) {
|
||||||
|
if (error?.message?.includes('groupId_1_version_1')) {
|
||||||
|
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||||
|
} else {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
newPrompt = await Prompt.create(newPromptData);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { prompt: newPrompt };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error saving prompt', error);
|
logger.error('Error saving prompt', error);
|
||||||
return { prompt: 'Error saving prompt' };
|
return { message: 'Error saving prompt' };
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
getPrompts: async (filter) => {
|
getPrompts: async (filter) => {
|
||||||
try {
|
try {
|
||||||
return await Prompt.find(filter).lean();
|
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error getting prompts', error);
|
logger.error('Error getting prompts', error);
|
||||||
return { prompt: 'Error getting prompts' };
|
return { message: 'Error getting prompts' };
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
deletePrompts: async (filter) => {
|
getPrompt: async (filter) => {
|
||||||
try {
|
try {
|
||||||
return await Prompt.deleteMany(filter);
|
if (filter.groupId) {
|
||||||
|
filter.groupId = new ObjectId(filter.groupId);
|
||||||
|
}
|
||||||
|
return await Prompt.findOne(filter).lean();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error deleting prompts', error);
|
logger.error('Error getting prompt', error);
|
||||||
return { prompt: 'Error deleting prompts' };
|
return { message: 'Error getting prompt' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Get prompt groups with filters
|
||||||
|
* @param {TGetRandomPromptsRequest} filter
|
||||||
|
* @returns {Promise<TGetRandomPromptsResponse>}
|
||||||
|
*/
|
||||||
|
getRandomPromptGroups: async (filter) => {
|
||||||
|
try {
|
||||||
|
const result = await PromptGroup.aggregate([
|
||||||
|
{
|
||||||
|
$match: {
|
||||||
|
category: { $ne: '' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$group: {
|
||||||
|
_id: '$category',
|
||||||
|
promptGroup: { $first: '$$ROOT' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$replaceRoot: { newRoot: '$promptGroup' },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$sample: { size: +filter.limit + +filter.skip },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$skip: +filter.skip,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
$limit: +filter.limit,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
return { prompts: result };
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error getting prompt groups', error);
|
||||||
|
return { message: 'Error getting prompt groups' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
getPromptGroupsWithPrompts: async (filter) => {
|
||||||
|
try {
|
||||||
|
return await PromptGroup.findOne(filter)
|
||||||
|
.populate({
|
||||||
|
path: 'prompts',
|
||||||
|
select: '-_id -__v -user',
|
||||||
|
})
|
||||||
|
.select('-_id -__v -user')
|
||||||
|
.lean();
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error getting prompt groups', error);
|
||||||
|
return { message: 'Error getting prompt groups' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
getPromptGroup: async (filter) => {
|
||||||
|
try {
|
||||||
|
return await PromptGroup.findOne(filter).lean();
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error getting prompt group', error);
|
||||||
|
return { message: 'Error getting prompt group' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
|
||||||
|
*
|
||||||
|
* @param {Object} options - The options for deleting the prompt.
|
||||||
|
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
|
||||||
|
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
|
||||||
|
* @param {ObjectId|string} options.author - The ID of the prompt's author.
|
||||||
|
* @param {string} options.role - The role of the prompt's author.
|
||||||
|
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
|
||||||
|
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
|
||||||
|
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
|
||||||
|
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
|
||||||
|
*/
|
||||||
|
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||||
|
const query = { _id: promptId, groupId, author };
|
||||||
|
if (role === SystemRoles.ADMIN) {
|
||||||
|
delete query.author;
|
||||||
|
}
|
||||||
|
const { deletedCount } = await Prompt.deleteOne(query);
|
||||||
|
if (deletedCount === 0) {
|
||||||
|
throw new Error('Failed to delete the prompt');
|
||||||
|
}
|
||||||
|
|
||||||
|
const remainingPrompts = await Prompt.find({ groupId })
|
||||||
|
.select('_id')
|
||||||
|
.sort({ createdAt: 1 })
|
||||||
|
.lean();
|
||||||
|
|
||||||
|
if (remainingPrompts.length === 0) {
|
||||||
|
await PromptGroup.deleteOne({ _id: groupId });
|
||||||
|
await removeGroupFromAllProjects(groupId);
|
||||||
|
|
||||||
|
return {
|
||||||
|
prompt: 'Prompt deleted successfully',
|
||||||
|
promptGroup: {
|
||||||
|
message: 'Prompt group deleted successfully',
|
||||||
|
id: groupId,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
const promptGroup = await PromptGroup.findById(groupId).lean();
|
||||||
|
if (promptGroup.productionId.toString() === promptId.toString()) {
|
||||||
|
await PromptGroup.updateOne(
|
||||||
|
{ _id: groupId },
|
||||||
|
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { prompt: 'Prompt deleted successfully' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Update prompt group
|
||||||
|
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
|
||||||
|
* @param {Partial<MongoPromptGroup>} data - Data to update
|
||||||
|
* @returns {Promise<TUpdatePromptGroupResponse>}
|
||||||
|
*/
|
||||||
|
updatePromptGroup: async (filter, data) => {
|
||||||
|
try {
|
||||||
|
const updateOps = {};
|
||||||
|
if (data.removeProjectIds) {
|
||||||
|
for (const projectId of data.removeProjectIds) {
|
||||||
|
await removeGroupIdsFromProject(projectId, [filter._id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
|
||||||
|
delete data.removeProjectIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.projectIds) {
|
||||||
|
for (const projectId of data.projectIds) {
|
||||||
|
await addGroupIdsToProject(projectId, [filter._id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
|
||||||
|
delete data.projectIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateData = { ...data, ...updateOps };
|
||||||
|
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||||
|
new: true,
|
||||||
|
upsert: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!updatedDoc) {
|
||||||
|
throw new Error('Prompt group not found');
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedDoc;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error updating prompt group', error);
|
||||||
|
return { message: 'Error updating prompt group' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Function to make a prompt production based on its ID.
|
||||||
|
* @param {String} promptId - The ID of the prompt to make production.
|
||||||
|
* @returns {Object} The result of the production operation.
|
||||||
|
*/
|
||||||
|
makePromptProduction: async (promptId) => {
|
||||||
|
try {
|
||||||
|
const prompt = await Prompt.findById(promptId).lean();
|
||||||
|
|
||||||
|
if (!prompt) {
|
||||||
|
throw new Error('Prompt not found');
|
||||||
|
}
|
||||||
|
|
||||||
|
await PromptGroup.findByIdAndUpdate(
|
||||||
|
prompt.groupId,
|
||||||
|
{ productionId: prompt._id },
|
||||||
|
{ new: true },
|
||||||
|
)
|
||||||
|
.lean()
|
||||||
|
.exec();
|
||||||
|
|
||||||
|
return {
|
||||||
|
message: 'Prompt production made successfully',
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error making prompt production', error);
|
||||||
|
return { message: 'Error making prompt production' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
updatePromptLabels: async (_id, labels) => {
|
||||||
|
try {
|
||||||
|
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||||
|
if (response.matchedCount === 0) {
|
||||||
|
return { message: 'Prompt not found' };
|
||||||
|
}
|
||||||
|
return { message: 'Prompt labels updated successfully' };
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error updating prompt labels', error);
|
||||||
|
return { message: 'Error updating prompt labels' };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
deletePromptGroup: async (_id) => {
|
||||||
|
try {
|
||||||
|
const response = await PromptGroup.deleteOne({ _id });
|
||||||
|
|
||||||
|
if (response.deletedCount === 0) {
|
||||||
|
return { promptGroup: 'Prompt group not found' };
|
||||||
|
}
|
||||||
|
|
||||||
|
await Prompt.deleteMany({ groupId: new ObjectId(_id) });
|
||||||
|
await removeGroupFromAllProjects(_id);
|
||||||
|
return { promptGroup: 'Prompt group deleted successfully' };
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error deleting prompt group', error);
|
||||||
|
return { message: 'Error deleting prompt group' };
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
86
api/models/Role.js
Normal file
86
api/models/Role.js
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
const { SystemRoles, CacheKeys, roleDefaults } = require('librechat-data-provider');
|
||||||
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
|
const Role = require('~/models/schema/roleSchema');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieve a role by name and convert the found role document to a plain object.
|
||||||
|
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||||
|
*
|
||||||
|
* @param {string} roleName - The name of the role to find or create.
|
||||||
|
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||||
|
* @returns {Promise<Object>} A plain object representing the role document.
|
||||||
|
*/
|
||||||
|
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||||
|
try {
|
||||||
|
const cache = getLogStores(CacheKeys.ROLES);
|
||||||
|
const cachedRole = await cache.get(roleName);
|
||||||
|
if (cachedRole) {
|
||||||
|
return cachedRole;
|
||||||
|
}
|
||||||
|
let query = Role.findOne({ name: roleName });
|
||||||
|
if (fieldsToSelect) {
|
||||||
|
query = query.select(fieldsToSelect);
|
||||||
|
}
|
||||||
|
let role = await query.lean().exec();
|
||||||
|
|
||||||
|
if (!role && SystemRoles[roleName]) {
|
||||||
|
role = roleDefaults[roleName];
|
||||||
|
role = await new Role(role).save();
|
||||||
|
await cache.set(roleName, role);
|
||||||
|
return role.toObject();
|
||||||
|
}
|
||||||
|
await cache.set(roleName, role);
|
||||||
|
return role;
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to retrieve or create role: ${error.message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update role values by name.
|
||||||
|
*
|
||||||
|
* @param {string} roleName - The name of the role to update.
|
||||||
|
* @param {Partial<TRole>} updates - The fields to update.
|
||||||
|
* @returns {Promise<TRole>} Updated role document.
|
||||||
|
*/
|
||||||
|
const updateRoleByName = async function (roleName, updates) {
|
||||||
|
try {
|
||||||
|
const cache = getLogStores(CacheKeys.ROLES);
|
||||||
|
const role = await Role.findOneAndUpdate(
|
||||||
|
{ name: roleName },
|
||||||
|
{ $set: updates },
|
||||||
|
{ new: true, lean: true },
|
||||||
|
)
|
||||||
|
.select('-__v')
|
||||||
|
.lean()
|
||||||
|
.exec();
|
||||||
|
await cache.set(roleName, role);
|
||||||
|
return role;
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Failed to update role: ${error.message}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize default roles in the system.
|
||||||
|
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
|
||||||
|
*
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
const initializeRoles = async function () {
|
||||||
|
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||||
|
|
||||||
|
for (const roleName of defaultRoles) {
|
||||||
|
let role = await Role.findOne({ name: roleName }).select('name').lean();
|
||||||
|
if (!role) {
|
||||||
|
role = new Role(roleDefaults[roleName]);
|
||||||
|
await role.save();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
getRoleByName,
|
||||||
|
initializeRoles,
|
||||||
|
updateRoleByName,
|
||||||
|
};
|
||||||
@@ -22,7 +22,7 @@ module.exports = {
|
|||||||
return share;
|
return share;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[getShare] Error getting share link', error);
|
logger.error('[getShare] Error getting share link', error);
|
||||||
return { message: 'Error getting share link' };
|
throw new Error('Error getting share link');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -41,17 +41,17 @@ module.exports = {
|
|||||||
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
|
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[getShareByPage] Error getting shares', error);
|
logger.error('[getShareByPage] Error getting shares', error);
|
||||||
return { message: 'Error getting shares' };
|
throw new Error('Error getting shares');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
createSharedLink: async (user, { conversationId, ...shareData }) => {
|
createSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
|
||||||
if (share) {
|
|
||||||
return share;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||||
|
if (share) {
|
||||||
|
return share;
|
||||||
|
}
|
||||||
|
|
||||||
const shareId = crypto.randomUUID();
|
const shareId = crypto.randomUUID();
|
||||||
const messages = await getMessages({ conversationId });
|
const messages = await getMessages({ conversationId });
|
||||||
const update = { ...shareData, shareId, messages, user };
|
const update = { ...shareData, shareId, messages, user };
|
||||||
@@ -60,30 +60,58 @@ module.exports = {
|
|||||||
upsert: true,
|
upsert: true,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[saveShareMessage] Error saving conversation', error);
|
logger.error('[createSharedLink] Error creating shared link', error);
|
||||||
return { message: 'Error saving conversation' };
|
throw new Error('Error creating shared link');
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
updateSharedLink: async (user, { conversationId, ...shareData }) => {
|
updateSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
try {
|
||||||
if (!share) {
|
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||||
return { message: 'Share not found' };
|
if (!share) {
|
||||||
|
return { message: 'Share not found' };
|
||||||
|
}
|
||||||
|
|
||||||
|
// update messages to the latest
|
||||||
|
const messages = await getMessages({ conversationId });
|
||||||
|
const update = { ...shareData, messages, user };
|
||||||
|
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||||
|
new: true,
|
||||||
|
upsert: false,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[updateSharedLink] Error updating shared link', error);
|
||||||
|
throw new Error('Error updating shared link');
|
||||||
}
|
}
|
||||||
// update messages to the latest
|
|
||||||
const messages = await getMessages({ conversationId });
|
|
||||||
const update = { ...shareData, messages, user };
|
|
||||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
|
||||||
new: true,
|
|
||||||
upsert: false,
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
|
|
||||||
deleteSharedLink: async (user, { shareId }) => {
|
deleteSharedLink: async (user, { shareId }) => {
|
||||||
const share = await SharedLink.findOne({ shareId, user });
|
try {
|
||||||
if (!share) {
|
const share = await SharedLink.findOne({ shareId, user });
|
||||||
return { message: 'Share not found' };
|
if (!share) {
|
||||||
|
return { message: 'Share not found' };
|
||||||
|
}
|
||||||
|
return await SharedLink.findOneAndDelete({ shareId, user });
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[deleteSharedLink] Error deleting shared link', error);
|
||||||
|
throw new Error('Error deleting shared link');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Deletes all shared links for a specific user.
|
||||||
|
* @param {string} user - The user ID.
|
||||||
|
* @returns {Promise<{ message: string, deletedCount?: number }>} A result object indicating success or error message.
|
||||||
|
*/
|
||||||
|
deleteAllSharedLinks: async (user) => {
|
||||||
|
try {
|
||||||
|
const result = await SharedLink.deleteMany({ user });
|
||||||
|
return {
|
||||||
|
message: 'All shared links have been deleted successfully',
|
||||||
|
deletedCount: result.deletedCount,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
|
||||||
|
throw new Error('Error deleting shared links');
|
||||||
}
|
}
|
||||||
return await SharedLink.findOneAndDelete({ shareId, user });
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,61 +1,5 @@
|
|||||||
const mongoose = require('mongoose');
|
const mongoose = require('mongoose');
|
||||||
const bcrypt = require('bcryptjs');
|
const userSchema = require('~/models/schema/userSchema');
|
||||||
const signPayload = require('../server/services/signPayload');
|
|
||||||
const userSchema = require('./schema/userSchema.js');
|
|
||||||
const { SESSION_EXPIRY } = process.env ?? {};
|
|
||||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
|
||||||
|
|
||||||
userSchema.methods.toJSON = function () {
|
|
||||||
return {
|
|
||||||
id: this._id,
|
|
||||||
provider: this.provider,
|
|
||||||
email: this.email,
|
|
||||||
name: this.name,
|
|
||||||
username: this.username,
|
|
||||||
avatar: this.avatar,
|
|
||||||
role: this.role,
|
|
||||||
emailVerified: this.emailVerified,
|
|
||||||
plugins: this.plugins,
|
|
||||||
createdAt: this.createdAt,
|
|
||||||
updatedAt: this.updatedAt,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
userSchema.methods.generateToken = async function () {
|
|
||||||
return await signPayload({
|
|
||||||
payload: {
|
|
||||||
id: this._id,
|
|
||||||
username: this.username,
|
|
||||||
provider: this.provider,
|
|
||||||
email: this.email,
|
|
||||||
},
|
|
||||||
secret: process.env.JWT_SECRET,
|
|
||||||
expirationTime: expires / 1000,
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
userSchema.methods.comparePassword = function (candidatePassword, callback) {
|
|
||||||
bcrypt.compare(candidatePassword, this.password, (err, isMatch) => {
|
|
||||||
if (err) {
|
|
||||||
return callback(err);
|
|
||||||
}
|
|
||||||
callback(null, isMatch);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
module.exports.hashPassword = async (password) => {
|
|
||||||
const hashedPassword = await new Promise((resolve, reject) => {
|
|
||||||
bcrypt.hash(password, 10, function (err, hash) {
|
|
||||||
if (err) {
|
|
||||||
reject(err);
|
|
||||||
} else {
|
|
||||||
resolve(hash);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
return hashedPassword;
|
|
||||||
};
|
|
||||||
|
|
||||||
const User = mongoose.model('User', userSchema);
|
const User = mongoose.model('User', userSchema);
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,18 @@ const {
|
|||||||
deleteMessagesSince,
|
deleteMessagesSince,
|
||||||
deleteMessages,
|
deleteMessages,
|
||||||
} = require('./Message');
|
} = require('./Message');
|
||||||
|
const {
|
||||||
|
comparePassword,
|
||||||
|
deleteUserById,
|
||||||
|
generateToken,
|
||||||
|
getUserById,
|
||||||
|
updateUser,
|
||||||
|
createUser,
|
||||||
|
countUsers,
|
||||||
|
findUser,
|
||||||
|
} = require('./userMethods');
|
||||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||||
const { hashPassword, getUser, updateUser } = require('./userMethods');
|
|
||||||
const {
|
const {
|
||||||
findFileById,
|
findFileById,
|
||||||
createFile,
|
createFile,
|
||||||
@@ -29,9 +38,14 @@ module.exports = {
|
|||||||
Session,
|
Session,
|
||||||
Balance,
|
Balance,
|
||||||
|
|
||||||
hashPassword,
|
comparePassword,
|
||||||
|
deleteUserById,
|
||||||
|
generateToken,
|
||||||
|
getUserById,
|
||||||
|
countUsers,
|
||||||
|
createUser,
|
||||||
updateUser,
|
updateUser,
|
||||||
getUser,
|
findUser,
|
||||||
|
|
||||||
getMessages,
|
getMessages,
|
||||||
saveMessage,
|
saveMessage,
|
||||||
|
|||||||
19
api/models/schema/categories.js
Normal file
19
api/models/schema/categories.js
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
const mongoose = require('mongoose');
|
||||||
|
const Schema = mongoose.Schema;
|
||||||
|
|
||||||
|
const categoriesSchema = new Schema({
|
||||||
|
label: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
unique: true,
|
||||||
|
},
|
||||||
|
value: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
unique: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const categories = mongoose.model('categories', categoriesSchema);
|
||||||
|
|
||||||
|
module.exports = { Categories: categories };
|
||||||
@@ -3,9 +3,9 @@ const mongoose = require('mongoose');
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef {Object} MongoFile
|
* @typedef {Object} MongoFile
|
||||||
* @property {mongoose.Schema.Types.ObjectId} [_id] - MongoDB Document ID
|
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||||
* @property {number} [__v] - MongoDB Version Key
|
* @property {number} [__v] - MongoDB Version Key
|
||||||
* @property {mongoose.Schema.Types.ObjectId} user - User ID
|
* @property {ObjectId} user - User ID
|
||||||
* @property {string} [conversationId] - Optional conversation ID
|
* @property {string} [conversationId] - Optional conversation ID
|
||||||
* @property {string} file_id - File identifier
|
* @property {string} file_id - File identifier
|
||||||
* @property {string} [temp_file_id] - Temporary File identifier
|
* @property {string} [temp_file_id] - Temporary File identifier
|
||||||
@@ -14,17 +14,19 @@ const mongoose = require('mongoose');
|
|||||||
* @property {string} filepath - Location of the file
|
* @property {string} filepath - Location of the file
|
||||||
* @property {'file'} object - Type of object, always 'file'
|
* @property {'file'} object - Type of object, always 'file'
|
||||||
* @property {string} type - Type of file
|
* @property {string} type - Type of file
|
||||||
* @property {number} usage - Number of uses of the file
|
* @property {number} [usage=0] - Number of uses of the file
|
||||||
* @property {string} [context] - Context of the file origin
|
* @property {string} [context] - Context of the file origin
|
||||||
* @property {boolean} [embedded] - Whether or not the file is embedded in vector db
|
* @property {boolean} [embedded=false] - Whether or not the file is embedded in vector db
|
||||||
* @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting)
|
* @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting)
|
||||||
* @property {string} [source] - The source of the file
|
* @property {string} [source] - The source of the file (e.g., from FileSources)
|
||||||
* @property {number} [width] - Optional width of the file
|
* @property {number} [width] - Optional width of the file
|
||||||
* @property {number} [height] - Optional height of the file
|
* @property {number} [height] - Optional height of the file
|
||||||
* @property {Date} [expiresAt] - Optional height of the file
|
* @property {Date} [expiresAt] - Optional expiration date of the file
|
||||||
* @property {Date} [createdAt] - Date when the file was created
|
* @property {Date} [createdAt] - Date when the file was created
|
||||||
* @property {Date} [updatedAt] - Date when the file was updated
|
* @property {Date} [updatedAt] - Date when the file was updated
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/** @type {MongooseSchema<MongoFile>} */
|
||||||
const fileSchema = mongoose.Schema(
|
const fileSchema = mongoose.Schema(
|
||||||
{
|
{
|
||||||
user: {
|
user: {
|
||||||
@@ -91,7 +93,7 @@ const fileSchema = mongoose.Schema(
|
|||||||
height: Number,
|
height: Number,
|
||||||
expiresAt: {
|
expiresAt: {
|
||||||
type: Date,
|
type: Date,
|
||||||
expires: 3600,
|
expires: 3600, // 1 hour in seconds
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ const messageSchema = mongoose.Schema(
|
|||||||
},
|
},
|
||||||
conversationId: {
|
conversationId: {
|
||||||
type: String,
|
type: String,
|
||||||
|
index: true,
|
||||||
required: true,
|
required: true,
|
||||||
meiliIndex: true,
|
meiliIndex: true,
|
||||||
},
|
},
|
||||||
|
|||||||
30
api/models/schema/projectSchema.js
Normal file
30
api/models/schema/projectSchema.js
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
const { Schema } = require('mongoose');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} MongoProject
|
||||||
|
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||||
|
* @property {string} name - The name of the project
|
||||||
|
* @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project
|
||||||
|
* @property {Date} [createdAt] - Date when the project was created (added by timestamps)
|
||||||
|
* @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps)
|
||||||
|
*/
|
||||||
|
|
||||||
|
const projectSchema = new Schema(
|
||||||
|
{
|
||||||
|
name: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
promptGroupIds: {
|
||||||
|
type: [Schema.Types.ObjectId],
|
||||||
|
ref: 'PromptGroup',
|
||||||
|
default: [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
timestamps: true,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
module.exports = projectSchema;
|
||||||
118
api/models/schema/promptSchema.js
Normal file
118
api/models/schema/promptSchema.js
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
const mongoose = require('mongoose');
|
||||||
|
const { Constants } = require('librechat-data-provider');
|
||||||
|
const Schema = mongoose.Schema;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} MongoPromptGroup
|
||||||
|
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||||
|
* @property {string} name - The name of the prompt group
|
||||||
|
* @property {ObjectId} author - The author of the prompt group
|
||||||
|
* @property {ObjectId} [projectId=null] - The project ID of the prompt group
|
||||||
|
* @property {ObjectId} [productionId=null] - The project ID of the prompt group
|
||||||
|
* @property {string} authorName - The name of the author of the prompt group
|
||||||
|
* @property {number} [numberOfGenerations=0] - Number of generations the prompt group has
|
||||||
|
* @property {string} [oneliner=''] - Oneliner description of the prompt group
|
||||||
|
* @property {string} [category=''] - Category of the prompt group
|
||||||
|
* @property {string} [command] - Command for the prompt group
|
||||||
|
* @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps)
|
||||||
|
* @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps)
|
||||||
|
*/
|
||||||
|
|
||||||
|
const promptGroupSchema = new Schema(
|
||||||
|
{
|
||||||
|
name: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
numberOfGenerations: {
|
||||||
|
type: Number,
|
||||||
|
default: 0,
|
||||||
|
},
|
||||||
|
oneliner: {
|
||||||
|
type: String,
|
||||||
|
default: '',
|
||||||
|
},
|
||||||
|
category: {
|
||||||
|
type: String,
|
||||||
|
default: '',
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
projectIds: {
|
||||||
|
type: [Schema.Types.ObjectId],
|
||||||
|
ref: 'Project',
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
productionId: {
|
||||||
|
type: Schema.Types.ObjectId,
|
||||||
|
ref: 'Prompt',
|
||||||
|
required: true,
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
author: {
|
||||||
|
type: Schema.Types.ObjectId,
|
||||||
|
ref: 'User',
|
||||||
|
required: true,
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
authorName: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
command: {
|
||||||
|
type: String,
|
||||||
|
index: true,
|
||||||
|
validate: {
|
||||||
|
validator: function (v) {
|
||||||
|
return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v);
|
||||||
|
},
|
||||||
|
message: (props) =>
|
||||||
|
`${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`,
|
||||||
|
},
|
||||||
|
maxlength: [
|
||||||
|
Constants.COMMANDS_MAX_LENGTH,
|
||||||
|
`Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
timestamps: true,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
|
||||||
|
|
||||||
|
const promptSchema = new Schema(
|
||||||
|
{
|
||||||
|
groupId: {
|
||||||
|
type: Schema.Types.ObjectId,
|
||||||
|
ref: 'PromptGroup',
|
||||||
|
required: true,
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
author: {
|
||||||
|
type: Schema.Types.ObjectId,
|
||||||
|
ref: 'User',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
type: {
|
||||||
|
type: String,
|
||||||
|
enum: ['text', 'chat'],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
timestamps: true,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const Prompt = mongoose.model('Prompt', promptSchema);
|
||||||
|
|
||||||
|
promptSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||||
|
promptGroupSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||||
|
|
||||||
|
module.exports = { Prompt, PromptGroup };
|
||||||
29
api/models/schema/roleSchema.js
Normal file
29
api/models/schema/roleSchema.js
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||||
|
const mongoose = require('mongoose');
|
||||||
|
|
||||||
|
const roleSchema = new mongoose.Schema({
|
||||||
|
name: {
|
||||||
|
type: String,
|
||||||
|
required: true,
|
||||||
|
unique: true,
|
||||||
|
index: true,
|
||||||
|
},
|
||||||
|
[PermissionTypes.PROMPTS]: {
|
||||||
|
[Permissions.SHARED_GLOBAL]: {
|
||||||
|
type: Boolean,
|
||||||
|
default: false,
|
||||||
|
},
|
||||||
|
[Permissions.USE]: {
|
||||||
|
type: Boolean,
|
||||||
|
default: true,
|
||||||
|
},
|
||||||
|
[Permissions.CREATE]: {
|
||||||
|
type: Boolean,
|
||||||
|
default: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const Role = mongoose.model('Role', roleSchema);
|
||||||
|
|
||||||
|
module.exports = Role;
|
||||||
@@ -7,6 +7,9 @@ const tokenSchema = new Schema({
|
|||||||
required: true,
|
required: true,
|
||||||
ref: 'user',
|
ref: 'user',
|
||||||
},
|
},
|
||||||
|
email: {
|
||||||
|
type: String,
|
||||||
|
},
|
||||||
token: {
|
token: {
|
||||||
type: String,
|
type: String,
|
||||||
required: true,
|
required: true,
|
||||||
|
|||||||
@@ -1,5 +1,36 @@
|
|||||||
const mongoose = require('mongoose');
|
const mongoose = require('mongoose');
|
||||||
|
const { SystemRoles } = require('librechat-data-provider');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} MongoSession
|
||||||
|
* @property {string} [refreshToken] - The refresh token
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} MongoUser
|
||||||
|
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||||
|
* @property {string} [name] - The user's name
|
||||||
|
* @property {string} [username] - The user's username, in lowercase
|
||||||
|
* @property {string} email - The user's email address
|
||||||
|
* @property {boolean} emailVerified - Whether the user's email is verified
|
||||||
|
* @property {string} [password] - The user's password, trimmed with 8-128 characters
|
||||||
|
* @property {string} [avatar] - The URL of the user's avatar
|
||||||
|
* @property {string} provider - The provider of the user's account (e.g., 'local', 'google')
|
||||||
|
* @property {string} [role='USER'] - The role of the user
|
||||||
|
* @property {string} [googleId] - Optional Google ID for the user
|
||||||
|
* @property {string} [facebookId] - Optional Facebook ID for the user
|
||||||
|
* @property {string} [openidId] - Optional OpenID ID for the user
|
||||||
|
* @property {string} [ldapId] - Optional LDAP ID for the user
|
||||||
|
* @property {string} [githubId] - Optional GitHub ID for the user
|
||||||
|
* @property {string} [discordId] - Optional Discord ID for the user
|
||||||
|
* @property {Array} [plugins=[]] - List of plugins used by the user
|
||||||
|
* @property {Array.<MongoSession>} [refreshToken] - List of sessions with refresh tokens
|
||||||
|
* @property {Date} [expiresAt] - Optional expiration date of the file
|
||||||
|
* @property {Date} [createdAt] - Date when the user was created (added by timestamps)
|
||||||
|
* @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps)
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** @type {MongooseSchema<MongoSession>} */
|
||||||
const Session = mongoose.Schema({
|
const Session = mongoose.Schema({
|
||||||
refreshToken: {
|
refreshToken: {
|
||||||
type: String,
|
type: String,
|
||||||
@@ -7,6 +38,7 @@ const Session = mongoose.Schema({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/** @type {MongooseSchema<MongoUser>} */
|
||||||
const userSchema = mongoose.Schema(
|
const userSchema = mongoose.Schema(
|
||||||
{
|
{
|
||||||
name: {
|
name: {
|
||||||
@@ -47,7 +79,7 @@ const userSchema = mongoose.Schema(
|
|||||||
},
|
},
|
||||||
role: {
|
role: {
|
||||||
type: String,
|
type: String,
|
||||||
default: 'USER',
|
default: SystemRoles.USER,
|
||||||
},
|
},
|
||||||
googleId: {
|
googleId: {
|
||||||
type: String,
|
type: String,
|
||||||
@@ -64,6 +96,11 @@ const userSchema = mongoose.Schema(
|
|||||||
unique: true,
|
unique: true,
|
||||||
sparse: true,
|
sparse: true,
|
||||||
},
|
},
|
||||||
|
ldapId: {
|
||||||
|
type: String,
|
||||||
|
unique: true,
|
||||||
|
sparse: true,
|
||||||
|
},
|
||||||
githubId: {
|
githubId: {
|
||||||
type: String,
|
type: String,
|
||||||
unique: true,
|
unique: true,
|
||||||
@@ -81,6 +118,10 @@ const userSchema = mongoose.Schema(
|
|||||||
refreshToken: {
|
refreshToken: {
|
||||||
type: [Session],
|
type: [Session],
|
||||||
},
|
},
|
||||||
|
expiresAt: {
|
||||||
|
type: Date,
|
||||||
|
expires: 604800, // 7 days in seconds
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{ timestamps: true },
|
{ timestamps: true },
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ const tokenValues = {
|
|||||||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||||
|
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
|
||||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||||
'claude-2.1': { prompt: 8, completion: 24 },
|
'claude-2.1': { prompt: 8, completion: 24 },
|
||||||
'claude-2': { prompt: 8, completion: 24 },
|
'claude-2': { prompt: 8, completion: 24 },
|
||||||
|
|||||||
@@ -48,6 +48,13 @@ describe('getValueKey', () => {
|
|||||||
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
|
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
|
||||||
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
|
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
|
||||||
|
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
|
||||||
|
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
|
||||||
|
expect(getValueKey('claude-3-5-sonnet-turbo')).toBe('claude-3-5-sonnet');
|
||||||
|
expect(getValueKey('claude-3-5-sonnet-0125')).toBe('claude-3-5-sonnet');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getMultiplier', () => {
|
describe('getMultiplier', () => {
|
||||||
|
|||||||
@@ -1,28 +1,37 @@
|
|||||||
const bcrypt = require('bcryptjs');
|
const bcrypt = require('bcryptjs');
|
||||||
|
const signPayload = require('~/server/services/signPayload');
|
||||||
const User = require('./User');
|
const User = require('./User');
|
||||||
|
|
||||||
const hashPassword = async (password) => {
|
|
||||||
const hashedPassword = await new Promise((resolve, reject) => {
|
|
||||||
bcrypt.hash(password, 10, function (err, hash) {
|
|
||||||
if (err) {
|
|
||||||
reject(err);
|
|
||||||
} else {
|
|
||||||
resolve(hash);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
return hashedPassword;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieve a user by ID and convert the found user document to a plain object.
|
* Retrieve a user by ID and convert the found user document to a plain object.
|
||||||
*
|
*
|
||||||
* @param {string} userId - The ID of the user to find and return as a plain object.
|
* @param {string} userId - The ID of the user to find and return as a plain object.
|
||||||
* @returns {Promise<Object>} A plain object representing the user document, or `null` if no user is found.
|
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||||
|
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||||
*/
|
*/
|
||||||
const getUser = async function (userId) {
|
const getUserById = async function (userId, fieldsToSelect = null) {
|
||||||
return await User.findById(userId).lean();
|
const query = User.findById(userId);
|
||||||
|
|
||||||
|
if (fieldsToSelect) {
|
||||||
|
query.select(fieldsToSelect);
|
||||||
|
}
|
||||||
|
|
||||||
|
return await query.lean();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search for a single user based on partial data and return matching user document as plain object.
|
||||||
|
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
|
||||||
|
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||||
|
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||||
|
*/
|
||||||
|
const findUser = async function (searchCriteria, fieldsToSelect = null) {
|
||||||
|
const query = User.findOne(searchCriteria);
|
||||||
|
if (fieldsToSelect) {
|
||||||
|
query.select(fieldsToSelect);
|
||||||
|
}
|
||||||
|
|
||||||
|
return await query.lean();
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -30,17 +39,127 @@ const getUser = async function (userId) {
|
|||||||
*
|
*
|
||||||
* @param {string} userId - The ID of the user to update.
|
* @param {string} userId - The ID of the user to update.
|
||||||
* @param {Object} updateData - An object containing the properties to update.
|
* @param {Object} updateData - An object containing the properties to update.
|
||||||
* @returns {Promise<Object>} The updated user document as a plain object, or `null` if no user is found.
|
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
|
||||||
*/
|
*/
|
||||||
const updateUser = async function (userId, updateData) {
|
const updateUser = async function (userId, updateData) {
|
||||||
return await User.findByIdAndUpdate(userId, updateData, {
|
const updateOperation = {
|
||||||
|
$set: updateData,
|
||||||
|
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
|
||||||
|
};
|
||||||
|
return await User.findByIdAndUpdate(userId, updateOperation, {
|
||||||
new: true,
|
new: true,
|
||||||
runValidators: true,
|
runValidators: true,
|
||||||
}).lean();
|
}).lean();
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
/**
|
||||||
hashPassword,
|
* Creates a new user, optionally with a TTL of 1 week.
|
||||||
updateUser,
|
* @param {MongoUser} data - The user data to be created, must contain user_id.
|
||||||
getUser,
|
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
|
||||||
|
* @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`.
|
||||||
|
* @returns {Promise<ObjectId>} A promise that resolves to the created user document ID.
|
||||||
|
* @throws {Error} If a user with the same user_id already exists.
|
||||||
|
*/
|
||||||
|
const createUser = async (data, disableTTL = true, returnUser = false) => {
|
||||||
|
const userData = {
|
||||||
|
...data,
|
||||||
|
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
|
||||||
|
};
|
||||||
|
|
||||||
|
if (disableTTL) {
|
||||||
|
delete userData.expiresAt;
|
||||||
|
}
|
||||||
|
|
||||||
|
const user = await User.create(userData);
|
||||||
|
if (returnUser) {
|
||||||
|
return user.toObject();
|
||||||
|
}
|
||||||
|
return user._id;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Count the number of user documents in the collection based on the provided filter.
|
||||||
|
*
|
||||||
|
* @param {Object} [filter={}] - The filter to apply when counting the documents.
|
||||||
|
* @returns {Promise<number>} The count of documents that match the filter.
|
||||||
|
*/
|
||||||
|
const countUsers = async function (filter = {}) {
|
||||||
|
return await User.countDocuments(filter);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete a user by their unique ID.
|
||||||
|
*
|
||||||
|
* @param {string} userId - The ID of the user to delete.
|
||||||
|
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
|
||||||
|
*/
|
||||||
|
const deleteUserById = async function (userId) {
|
||||||
|
try {
|
||||||
|
const result = await User.deleteOne({ _id: userId });
|
||||||
|
if (result.deletedCount === 0) {
|
||||||
|
return { deletedCount: 0, message: 'No user found with that ID.' };
|
||||||
|
}
|
||||||
|
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error('Error deleting user: ' + error.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const { SESSION_EXPIRY } = process.env ?? {};
|
||||||
|
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates a JWT token for a given user.
|
||||||
|
*
|
||||||
|
* @param {MongoUser} user - ID of the user for whom the token is being generated.
|
||||||
|
* @returns {Promise<string>} A promise that resolves to a JWT token.
|
||||||
|
*/
|
||||||
|
const generateToken = async (user) => {
|
||||||
|
if (!user) {
|
||||||
|
throw new Error('No user provided');
|
||||||
|
}
|
||||||
|
|
||||||
|
return await signPayload({
|
||||||
|
payload: {
|
||||||
|
id: user._id,
|
||||||
|
username: user.username,
|
||||||
|
provider: user.provider,
|
||||||
|
email: user.email,
|
||||||
|
},
|
||||||
|
secret: process.env.JWT_SECRET,
|
||||||
|
expirationTime: expires / 1000,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compares the provided password with the user's password.
|
||||||
|
*
|
||||||
|
* @param {MongoUser} user - the user to compare password for.
|
||||||
|
* @param {string} candidatePassword - The password to test against the user's password.
|
||||||
|
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
|
||||||
|
*/
|
||||||
|
const comparePassword = async (user, candidatePassword) => {
|
||||||
|
if (!user) {
|
||||||
|
throw new Error('No user provided');
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
|
||||||
|
if (err) {
|
||||||
|
reject(err);
|
||||||
|
}
|
||||||
|
resolve(isMatch);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
comparePassword,
|
||||||
|
deleteUserById,
|
||||||
|
generateToken,
|
||||||
|
getUserById,
|
||||||
|
countUsers,
|
||||||
|
createUser,
|
||||||
|
updateUser,
|
||||||
|
findUser,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@librechat/backend",
|
"name": "@librechat/backend",
|
||||||
"version": "0.7.2",
|
"version": "0.7.4-rc1",
|
||||||
"description": "",
|
"description": "",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "echo 'please run this from the root directory'",
|
"start": "echo 'please run this from the root directory'",
|
||||||
@@ -40,8 +40,7 @@
|
|||||||
"@keyv/redis": "^2.8.1",
|
"@keyv/redis": "^2.8.1",
|
||||||
"@langchain/community": "^0.0.46",
|
"@langchain/community": "^0.0.46",
|
||||||
"@langchain/google-genai": "^0.0.11",
|
"@langchain/google-genai": "^0.0.11",
|
||||||
"@langchain/google-vertexai": "^0.0.5",
|
"@langchain/google-vertexai": "^0.0.17",
|
||||||
"agenda": "^5.0.0",
|
|
||||||
"axios": "^1.3.4",
|
"axios": "^1.3.4",
|
||||||
"bcryptjs": "^2.4.3",
|
"bcryptjs": "^2.4.3",
|
||||||
"cheerio": "^1.0.0-rc.12",
|
"cheerio": "^1.0.0-rc.12",
|
||||||
@@ -86,6 +85,7 @@
|
|||||||
"passport-github2": "^0.1.12",
|
"passport-github2": "^0.1.12",
|
||||||
"passport-google-oauth20": "^2.0.0",
|
"passport-google-oauth20": "^2.0.0",
|
||||||
"passport-jwt": "^4.0.1",
|
"passport-jwt": "^4.0.1",
|
||||||
|
"passport-ldapauth": "^3.0.1",
|
||||||
"passport-local": "^1.0.0",
|
"passport-local": "^1.0.0",
|
||||||
"pino": "^8.12.1",
|
"pino": "^8.12.1",
|
||||||
"sharp": "^0.32.6",
|
"sharp": "^0.32.6",
|
||||||
@@ -94,6 +94,7 @@
|
|||||||
"ua-parser-js": "^1.0.36",
|
"ua-parser-js": "^1.0.36",
|
||||||
"winston": "^3.11.0",
|
"winston": "^3.11.0",
|
||||||
"winston-daily-rotate-file": "^4.7.1",
|
"winston-daily-rotate-file": "^4.7.1",
|
||||||
|
"ws": "^8.17.0",
|
||||||
"zod": "^3.22.4"
|
"zod": "^3.22.4"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
|
|||||||
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
|
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||||
const { saveMessage, getConvo } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||||
@@ -18,6 +18,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
|
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
|
||||||
|
|
||||||
let userMessage;
|
let userMessage;
|
||||||
|
let userMessagePromise;
|
||||||
let promptTokens;
|
let promptTokens;
|
||||||
let userMessageId;
|
let userMessageId;
|
||||||
let responseMessageId;
|
let responseMessageId;
|
||||||
@@ -34,6 +35,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
if (key === 'userMessage') {
|
if (key === 'userMessage') {
|
||||||
userMessage = data[key];
|
userMessage = data[key];
|
||||||
userMessageId = data[key].messageId;
|
userMessageId = data[key].messageId;
|
||||||
|
} else if (key === 'userMessagePromise') {
|
||||||
|
userMessagePromise = data[key];
|
||||||
} else if (key === 'responseMessageId') {
|
} else if (key === 'responseMessageId') {
|
||||||
responseMessageId = data[key];
|
responseMessageId = data[key];
|
||||||
} else if (key === 'promptTokens') {
|
} else if (key === 'promptTokens') {
|
||||||
@@ -74,6 +77,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
const getAbortData = () => ({
|
const getAbortData = () => ({
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
userMessagePromise,
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
text: getPartialText(),
|
text: getPartialText(),
|
||||||
@@ -81,7 +85,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
promptTokens,
|
promptTokens,
|
||||||
});
|
});
|
||||||
|
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||||
|
|
||||||
res.on('close', () => {
|
res.on('close', () => {
|
||||||
logger.debug('[AskController] Request closed');
|
logger.debug('[AskController] Request closed');
|
||||||
@@ -105,11 +109,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
getReqData,
|
getReqData,
|
||||||
onStart,
|
onStart,
|
||||||
abortController,
|
abortController,
|
||||||
onProgress: progressCallback.call(null, {
|
progressCallback,
|
||||||
|
progressOptions: {
|
||||||
res,
|
res,
|
||||||
text,
|
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
},
|
||||||
}),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = await client.sendMessage(text, messageOptions);
|
let response = await client.sendMessage(text, messageOptions);
|
||||||
@@ -120,7 +124,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
|
|
||||||
response.endpoint = endpointOption.endpoint;
|
response.endpoint = endpointOption.endpoint;
|
||||||
|
|
||||||
const conversation = await getConvo(user, conversationId);
|
const { conversation = {} } = await client.responsePromise;
|
||||||
conversation.title =
|
conversation.title =
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
@@ -143,7 +147,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||||||
await saveMessage({ ...response, user });
|
await saveMessage({ ...response, user });
|
||||||
}
|
}
|
||||||
|
|
||||||
await saveMessage(userMessage);
|
if (!client.skipSaveUserMessage) {
|
||||||
|
await saveMessage(userMessage);
|
||||||
|
}
|
||||||
|
|
||||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||||
addTitle(req, {
|
addTitle(req, {
|
||||||
|
|||||||
@@ -1,45 +1,29 @@
|
|||||||
const crypto = require('crypto');
|
const crypto = require('crypto');
|
||||||
const cookies = require('cookie');
|
const cookies = require('cookie');
|
||||||
const jwt = require('jsonwebtoken');
|
const jwt = require('jsonwebtoken');
|
||||||
const { Session, User } = require('~/models');
|
|
||||||
const {
|
const {
|
||||||
registerUser,
|
registerUser,
|
||||||
resetPassword,
|
resetPassword,
|
||||||
setAuthTokens,
|
setAuthTokens,
|
||||||
requestPasswordReset,
|
requestPasswordReset,
|
||||||
} = require('~/server/services/AuthService');
|
} = require('~/server/services/AuthService');
|
||||||
|
const { Session, getUserById } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const registrationController = async (req, res) => {
|
const registrationController = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const response = await registerUser(req.body);
|
const response = await registerUser(req.body);
|
||||||
if (response.status === 200) {
|
const { status, message } = response;
|
||||||
const { status, user } = response;
|
res.status(status).send({ message });
|
||||||
let newUser = await User.findOne({ _id: user._id });
|
|
||||||
if (!newUser) {
|
|
||||||
newUser = new User(user);
|
|
||||||
await newUser.save();
|
|
||||||
}
|
|
||||||
const token = await setAuthTokens(user._id, res);
|
|
||||||
res.setHeader('Authorization', `Bearer ${token}`);
|
|
||||||
res.status(status).send({ user });
|
|
||||||
} else {
|
|
||||||
const { status, message } = response;
|
|
||||||
res.status(status).send({ message });
|
|
||||||
}
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('[registrationController]', err);
|
logger.error('[registrationController]', err);
|
||||||
return res.status(500).json({ message: err.message });
|
return res.status(500).json({ message: err.message });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const getUserController = async (req, res) => {
|
|
||||||
return res.status(200).send(req.user);
|
|
||||||
};
|
|
||||||
|
|
||||||
const resetPasswordRequestController = async (req, res) => {
|
const resetPasswordRequestController = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const resetService = await requestPasswordReset(req.body.email);
|
const resetService = await requestPasswordReset(req);
|
||||||
if (resetService instanceof Error) {
|
if (resetService instanceof Error) {
|
||||||
return res.status(400).json(resetService);
|
return res.status(400).json(resetService);
|
||||||
} else {
|
} else {
|
||||||
@@ -77,7 +61,7 @@ const refreshController = async (req, res) => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||||
const user = await User.findOne({ _id: payload.id });
|
const user = await getUserById(payload.id, '-password -__v');
|
||||||
if (!user) {
|
if (!user) {
|
||||||
return res.status(401).redirect('/login');
|
return res.status(401).redirect('/login');
|
||||||
}
|
}
|
||||||
@@ -86,8 +70,7 @@ const refreshController = async (req, res) => {
|
|||||||
|
|
||||||
if (process.env.NODE_ENV === 'CI') {
|
if (process.env.NODE_ENV === 'CI') {
|
||||||
const token = await setAuthTokens(userId, res);
|
const token = await setAuthTokens(userId, res);
|
||||||
const userObj = user.toJSON();
|
return res.status(200).send({ token, user });
|
||||||
return res.status(200).send({ token, user: userObj });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash the refresh token
|
// Hash the refresh token
|
||||||
@@ -98,8 +81,7 @@ const refreshController = async (req, res) => {
|
|||||||
const session = await Session.findOne({ user: userId, refreshTokenHash: hashedToken });
|
const session = await Session.findOne({ user: userId, refreshTokenHash: hashedToken });
|
||||||
if (session && session.expiration > new Date()) {
|
if (session && session.expiration > new Date()) {
|
||||||
const token = await setAuthTokens(userId, res, session._id);
|
const token = await setAuthTokens(userId, res, session._id);
|
||||||
const userObj = user.toJSON();
|
res.status(200).send({ token, user });
|
||||||
res.status(200).send({ token, user: userObj });
|
|
||||||
} else if (req?.query?.retry) {
|
} else if (req?.query?.retry) {
|
||||||
// Retrying from a refresh token request that failed (401)
|
// Retrying from a refresh token request that failed (401)
|
||||||
res.status(403).send('No session found');
|
res.status(403).send('No session found');
|
||||||
@@ -115,7 +97,6 @@ const refreshController = async (req, res) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
getUserController,
|
|
||||||
refreshController,
|
refreshController,
|
||||||
registrationController,
|
registrationController,
|
||||||
resetPasswordController,
|
resetPasswordController,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
|
|||||||
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||||
const { saveMessage, getConvo } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const EditController = async (req, res, next, initializeClient) => {
|
const EditController = async (req, res, next, initializeClient) => {
|
||||||
@@ -27,6 +27,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let userMessage;
|
let userMessage;
|
||||||
|
let userMessagePromise;
|
||||||
let promptTokens;
|
let promptTokens;
|
||||||
const sender = getResponseSender({
|
const sender = getResponseSender({
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
@@ -40,6 +41,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||||||
for (let key in data) {
|
for (let key in data) {
|
||||||
if (key === 'userMessage') {
|
if (key === 'userMessage') {
|
||||||
userMessage = data[key];
|
userMessage = data[key];
|
||||||
|
} else if (key === 'userMessagePromise') {
|
||||||
|
userMessagePromise = data[key];
|
||||||
} else if (key === 'responseMessageId') {
|
} else if (key === 'responseMessageId') {
|
||||||
responseMessageId = data[key];
|
responseMessageId = data[key];
|
||||||
} else if (key === 'promptTokens') {
|
} else if (key === 'promptTokens') {
|
||||||
@@ -73,6 +76,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||||||
|
|
||||||
const getAbortData = () => ({
|
const getAbortData = () => ({
|
||||||
conversationId,
|
conversationId,
|
||||||
|
userMessagePromise,
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
sender,
|
sender,
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
@@ -81,7 +85,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||||||
promptTokens,
|
promptTokens,
|
||||||
});
|
});
|
||||||
|
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||||
|
|
||||||
res.on('close', () => {
|
res.on('close', () => {
|
||||||
logger.debug('[EditController] Request closed');
|
logger.debug('[EditController] Request closed');
|
||||||
@@ -112,14 +116,14 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||||||
getReqData,
|
getReqData,
|
||||||
onStart,
|
onStart,
|
||||||
abortController,
|
abortController,
|
||||||
onProgress: progressCallback.call(null, {
|
progressCallback,
|
||||||
|
progressOptions: {
|
||||||
res,
|
res,
|
||||||
text,
|
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
},
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const conversation = await getConvo(user, conversationId);
|
const { conversation = {} } = await client.responsePromise;
|
||||||
conversation.title =
|
conversation.title =
|
||||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,37 @@
|
|||||||
const { updateUserPluginsService } = require('~/server/services/UserService');
|
const {
|
||||||
|
Session,
|
||||||
|
Balance,
|
||||||
|
getFiles,
|
||||||
|
deleteFiles,
|
||||||
|
deleteConvos,
|
||||||
|
deletePresets,
|
||||||
|
deleteMessages,
|
||||||
|
deleteUserById,
|
||||||
|
} = require('~/models');
|
||||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||||
|
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||||
|
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||||
|
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||||
|
const { deleteAllSharedLinks } = require('~/models/Share');
|
||||||
|
const { Transaction } = require('~/models/Transaction');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const getUserController = async (req, res) => {
|
const getUserController = async (req, res) => {
|
||||||
res.status(200).send(req.user);
|
res.status(200).send(req.user);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const deleteUserFiles = async (req) => {
|
||||||
|
try {
|
||||||
|
const userFiles = await getFiles({ user: req.user.id });
|
||||||
|
await processDeleteRequest({
|
||||||
|
req,
|
||||||
|
files: userFiles,
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[deleteUserFiles]', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const updateUserPluginsController = async (req, res) => {
|
const updateUserPluginsController = async (req, res) => {
|
||||||
const { user } = req;
|
const { user } = req;
|
||||||
const { pluginKey, action, auth, isAssistantTool } = req.body;
|
const { pluginKey, action, auth, isAssistantTool } = req.body;
|
||||||
@@ -49,11 +75,68 @@ const updateUserPluginsController = async (req, res) => {
|
|||||||
res.status(200).send();
|
res.status(200).send();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('[updateUserPluginsController]', err);
|
logger.error('[updateUserPluginsController]', err);
|
||||||
res.status(500).json({ message: err.message });
|
return res.status(500).json({ message: 'Something went wrong.' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const deleteUserController = async (req, res) => {
|
||||||
|
const { user } = req;
|
||||||
|
|
||||||
|
try {
|
||||||
|
await deleteMessages({ user: user.id }); // delete user messages
|
||||||
|
await Session.deleteMany({ user: user.id }); // delete user sessions
|
||||||
|
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
||||||
|
await deleteUserKey({ userId: user.id, all: true }); // delete user keys
|
||||||
|
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||||
|
await deletePresets(user.id); // delete user presets
|
||||||
|
/* TODO: Delete Assistant Threads */
|
||||||
|
await deleteConvos(user.id); // delete user convos
|
||||||
|
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
|
||||||
|
await deleteUserById(user.id); // delete user
|
||||||
|
await deleteAllSharedLinks(user.id); // delete user shared links
|
||||||
|
await deleteUserFiles(req); // delete user files
|
||||||
|
await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps
|
||||||
|
/* TODO: queue job for cleaning actions and assistants of non-existant users */
|
||||||
|
logger.info(`User deleted account. Email: ${user.email} ID: ${user.id}`);
|
||||||
|
res.status(200).send({ message: 'User deleted' });
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('[deleteUserController]', err);
|
||||||
|
return res.status(500).json({ message: 'Something went wrong.' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const verifyEmailController = async (req, res) => {
|
||||||
|
try {
|
||||||
|
const verifyEmailService = await verifyEmail(req);
|
||||||
|
if (verifyEmailService instanceof Error) {
|
||||||
|
return res.status(400).json(verifyEmailService);
|
||||||
|
} else {
|
||||||
|
return res.status(200).json(verifyEmailService);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
logger.error('[verifyEmailController]', e);
|
||||||
|
return res.status(500).json({ message: 'Something went wrong.' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const resendVerificationController = async (req, res) => {
|
||||||
|
try {
|
||||||
|
const result = await resendVerificationEmail(req);
|
||||||
|
if (result instanceof Error) {
|
||||||
|
return res.status(400).json(result);
|
||||||
|
} else {
|
||||||
|
return res.status(200).json(result);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
logger.error('[verifyEmailController]', e);
|
||||||
|
return res.status(500).json({ message: 'Something went wrong.' });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
getUserController,
|
getUserController,
|
||||||
|
deleteUserController,
|
||||||
|
verifyEmailController,
|
||||||
updateUserPluginsController,
|
updateUserPluginsController,
|
||||||
|
resendVerificationController,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ const {
|
|||||||
} = require('~/server/services/Threads');
|
} = require('~/server/services/Threads');
|
||||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||||
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||||
@@ -31,15 +32,14 @@ const { getModelMaxTokens } = require('~/utils');
|
|||||||
const { getOpenAIClient } = require('./helpers');
|
const { getOpenAIClient } = require('./helpers');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const { handleAbortError } = require('~/server/middleware');
|
|
||||||
|
|
||||||
const ten_minutes = 1000 * 60 * 10;
|
const ten_minutes = 1000 * 60 * 10;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @route POST /
|
* @route POST /
|
||||||
* @desc Chat with an assistant
|
* @desc Chat with an assistant
|
||||||
* @access Public
|
* @access Public
|
||||||
* @param {Express.Request} req - The request object, containing the request data.
|
* @param {object} req - The request object, containing the request data.
|
||||||
|
* @param {object} req.body - The request payload.
|
||||||
* @param {Express.Response} res - The response object, used to send back a response.
|
* @param {Express.Response} res - The response object, used to send back a response.
|
||||||
* @returns {void}
|
* @returns {void}
|
||||||
*/
|
*/
|
||||||
@@ -60,30 +60,6 @@ const chatV1 = async (req, res) => {
|
|||||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||||
} = req.body;
|
} = req.body;
|
||||||
|
|
||||||
/** @type {Partial<TAssistantEndpoint>} */
|
|
||||||
const assistantsConfig = req.app.locals?.[endpoint];
|
|
||||||
|
|
||||||
if (assistantsConfig) {
|
|
||||||
const { supportedIds, excludedIds } = assistantsConfig;
|
|
||||||
const error = { message: 'Assistant not supported' };
|
|
||||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
|
||||||
return await handleAbortError(res, req, error, {
|
|
||||||
sender: 'System',
|
|
||||||
conversationId: convoId,
|
|
||||||
messageId: v4(),
|
|
||||||
parentMessageId: _messageId,
|
|
||||||
error,
|
|
||||||
});
|
|
||||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
|
||||||
return await handleAbortError(res, req, error, {
|
|
||||||
sender: 'System',
|
|
||||||
conversationId: convoId,
|
|
||||||
messageId: v4(),
|
|
||||||
parentMessageId: _messageId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {OpenAIClient} */
|
/** @type {OpenAIClient} */
|
||||||
let openai;
|
let openai;
|
||||||
/** @type {string|undefined} - the current thread id */
|
/** @type {string|undefined} - the current thread id */
|
||||||
@@ -311,6 +287,7 @@ const chatV1 = async (req, res) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
openai = _openai;
|
openai = _openai;
|
||||||
|
await validateAuthor({ req, openai });
|
||||||
|
|
||||||
if (previousMessages.length) {
|
if (previousMessages.length) {
|
||||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ const {
|
|||||||
} = require('~/server/services/Threads');
|
} = require('~/server/services/Threads');
|
||||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||||
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||||
const { getTransactions } = require('~/models/Transaction');
|
const { getTransactions } = require('~/models/Transaction');
|
||||||
@@ -30,8 +31,6 @@ const { getModelMaxTokens } = require('~/utils');
|
|||||||
const { getOpenAIClient } = require('./helpers');
|
const { getOpenAIClient } = require('./helpers');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const { handleAbortError } = require('~/server/middleware');
|
|
||||||
|
|
||||||
const ten_minutes = 1000 * 60 * 10;
|
const ten_minutes = 1000 * 60 * 10;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -60,30 +59,6 @@ const chatV2 = async (req, res) => {
|
|||||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||||
} = req.body;
|
} = req.body;
|
||||||
|
|
||||||
/** @type {Partial<TAssistantEndpoint>} */
|
|
||||||
const assistantsConfig = req.app.locals?.[endpoint];
|
|
||||||
|
|
||||||
if (assistantsConfig) {
|
|
||||||
const { supportedIds, excludedIds } = assistantsConfig;
|
|
||||||
const error = { message: 'Assistant not supported' };
|
|
||||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
|
||||||
return await handleAbortError(res, req, error, {
|
|
||||||
sender: 'System',
|
|
||||||
conversationId: convoId,
|
|
||||||
messageId: v4(),
|
|
||||||
parentMessageId: _messageId,
|
|
||||||
error,
|
|
||||||
});
|
|
||||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
|
||||||
return await handleAbortError(res, req, error, {
|
|
||||||
sender: 'System',
|
|
||||||
conversationId: convoId,
|
|
||||||
messageId: v4(),
|
|
||||||
parentMessageId: _messageId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @type {OpenAIClient} */
|
/** @type {OpenAIClient} */
|
||||||
let openai;
|
let openai;
|
||||||
/** @type {string|undefined} - the current thread id */
|
/** @type {string|undefined} - the current thread id */
|
||||||
@@ -309,6 +284,7 @@ const chatV2 = async (req, res) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
openai = _openai;
|
openai = _openai;
|
||||||
|
await validateAuthor({ req, openai });
|
||||||
|
|
||||||
if (previousMessages.length) {
|
if (previousMessages.length) {
|
||||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||||
@@ -520,6 +496,7 @@ const chatV2 = async (req, res) => {
|
|||||||
handlers,
|
handlers,
|
||||||
thread_id,
|
thread_id,
|
||||||
attachedFileIds,
|
attachedFileIds,
|
||||||
|
parentMessageId: userMessageId,
|
||||||
responseMessage: openai.responseMessage,
|
responseMessage: openai.responseMessage,
|
||||||
// streamOptions: {
|
// streamOptions: {
|
||||||
|
|
||||||
@@ -532,6 +509,7 @@ const chatV2 = async (req, res) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
response = streamRunManager;
|
response = streamRunManager;
|
||||||
|
response.text = streamRunManager.intermediateText;
|
||||||
};
|
};
|
||||||
|
|
||||||
await processRun();
|
await processRun();
|
||||||
@@ -554,6 +532,7 @@ const chatV2 = async (req, res) => {
|
|||||||
/** @type {ResponseMessage} */
|
/** @type {ResponseMessage} */
|
||||||
const responseMessage = {
|
const responseMessage = {
|
||||||
...(response.responseMessage ?? response.finalMessage),
|
...(response.responseMessage ?? response.finalMessage),
|
||||||
|
text: response.text,
|
||||||
parentMessageId: userMessageId,
|
parentMessageId: userMessageId,
|
||||||
conversationId,
|
conversationId,
|
||||||
user: req.user.id,
|
user: req.user.id,
|
||||||
|
|||||||
@@ -1,4 +1,10 @@
|
|||||||
const { EModelEndpoint, CacheKeys, defaultAssistantsVersion } = require('librechat-data-provider');
|
const {
|
||||||
|
CacheKeys,
|
||||||
|
SystemRoles,
|
||||||
|
EModelEndpoint,
|
||||||
|
defaultOrderQuery,
|
||||||
|
defaultAssistantsVersion,
|
||||||
|
} = require('librechat-data-provider');
|
||||||
const {
|
const {
|
||||||
initializeClient: initAzureClient,
|
initializeClient: initAzureClient,
|
||||||
} = require('~/server/services/Endpoints/azureAssistants');
|
} = require('~/server/services/Endpoints/azureAssistants');
|
||||||
@@ -35,6 +41,7 @@ const getCurrentVersion = async (req, endpoint) => {
|
|||||||
* Initializes the client with the current request and response objects and lists assistants
|
* Initializes the client with the current request and response objects and lists assistants
|
||||||
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
||||||
*
|
*
|
||||||
|
* @deprecated
|
||||||
* @async
|
* @async
|
||||||
* @param {object} params - The parameters object.
|
* @param {object} params - The parameters object.
|
||||||
* @param {object} params.req - The request object, used for initializing the client.
|
* @param {object} params.req - The request object, used for initializing the client.
|
||||||
@@ -43,11 +50,65 @@ const getCurrentVersion = async (req, endpoint) => {
|
|||||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||||
*/
|
*/
|
||||||
const listAssistants = async ({ req, res, version, query }) => {
|
const _listAssistants = async ({ req, res, version, query }) => {
|
||||||
const { openai } = await getOpenAIClient({ req, res, version });
|
const { openai } = await getOpenAIClient({ req, res, version });
|
||||||
return openai.beta.assistants.list(query);
|
return openai.beta.assistants.list(query);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetches all assistants based on provided query params, until `has_more` is `false`.
|
||||||
|
*
|
||||||
|
* @async
|
||||||
|
* @param {object} params - The parameters object.
|
||||||
|
* @param {object} params.req - The request object, used for initializing the client.
|
||||||
|
* @param {object} params.res - The response object, used for initializing the client.
|
||||||
|
* @param {string} params.version - The API version to use.
|
||||||
|
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||||
|
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||||
|
*/
|
||||||
|
const listAllAssistants = async ({ req, res, version, query }) => {
|
||||||
|
/** @type {{ openai: OpenAIClient }} */
|
||||||
|
const { openai } = await getOpenAIClient({ req, res, version });
|
||||||
|
const allAssistants = [];
|
||||||
|
|
||||||
|
let first_id;
|
||||||
|
let last_id;
|
||||||
|
let afterToken = query.after;
|
||||||
|
let hasMore = true;
|
||||||
|
|
||||||
|
while (hasMore) {
|
||||||
|
const response = await openai.beta.assistants.list({
|
||||||
|
...query,
|
||||||
|
after: afterToken,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { body } = response;
|
||||||
|
|
||||||
|
allAssistants.push(...body.data);
|
||||||
|
hasMore = body.has_more;
|
||||||
|
|
||||||
|
if (!first_id) {
|
||||||
|
first_id = body.first_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasMore) {
|
||||||
|
afterToken = body.last_id;
|
||||||
|
} else {
|
||||||
|
last_id = body.last_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
data: allAssistants,
|
||||||
|
body: {
|
||||||
|
data: allAssistants,
|
||||||
|
has_more: false,
|
||||||
|
first_id,
|
||||||
|
last_id,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Asynchronously lists assistants for Azure configured groups.
|
* Asynchronously lists assistants for Azure configured groups.
|
||||||
*
|
*
|
||||||
@@ -82,7 +143,7 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
|
|||||||
/* The specified model is only necessary to
|
/* The specified model is only necessary to
|
||||||
fetch assistants for the shared instance */
|
fetch assistants for the shared instance */
|
||||||
req.body.model = currentModelTuples[0][0];
|
req.body.model = currentModelTuples[0][0];
|
||||||
promises.push(listAssistants({ req, res, version, query }));
|
promises.push(listAllAssistants({ req, res, version, query }));
|
||||||
}
|
}
|
||||||
|
|
||||||
const resolvedQueries = await Promise.all(promises);
|
const resolvedQueries = await Promise.all(promises);
|
||||||
@@ -133,8 +194,27 @@ async function getOpenAIClient({ req, res, endpointOption, initAppClient, overri
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
const fetchAssistants = async (req, res) => {
|
/**
|
||||||
const { limit = 100, order = 'desc', after, before, endpoint } = req.query;
|
* Returns a list of assistants.
|
||||||
|
* @param {object} params
|
||||||
|
* @param {object} params.req - Express Request
|
||||||
|
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
|
||||||
|
* @param {object} params.res - Express Response
|
||||||
|
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
|
||||||
|
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
|
||||||
|
*/
|
||||||
|
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||||
|
const {
|
||||||
|
limit = 100,
|
||||||
|
order = 'desc',
|
||||||
|
after,
|
||||||
|
before,
|
||||||
|
endpoint,
|
||||||
|
} = req.query ?? {
|
||||||
|
endpoint: overrideEndpoint,
|
||||||
|
...defaultOrderQuery,
|
||||||
|
};
|
||||||
|
|
||||||
const version = await getCurrentVersion(req, endpoint);
|
const version = await getCurrentVersion(req, endpoint);
|
||||||
const query = { limit, order, after, before };
|
const query = { limit, order, after, before };
|
||||||
|
|
||||||
@@ -142,15 +222,47 @@ const fetchAssistants = async (req, res) => {
|
|||||||
let body;
|
let body;
|
||||||
|
|
||||||
if (endpoint === EModelEndpoint.assistants) {
|
if (endpoint === EModelEndpoint.assistants) {
|
||||||
({ body } = await listAssistants({ req, res, version, query }));
|
({ body } = await listAllAssistants({ req, res, version, query }));
|
||||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
return body;
|
||||||
|
} else if (!req.app.locals[endpoint]) {
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
body.data = filterAssistants({
|
||||||
|
userId: req.user.id,
|
||||||
|
assistants: body.data,
|
||||||
|
assistantsConfig: req.app.locals[endpoint],
|
||||||
|
});
|
||||||
return body;
|
return body;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Filter assistants based on configuration.
|
||||||
|
*
|
||||||
|
* @param {object} params - The parameters object.
|
||||||
|
* @param {string} params.userId - The user ID to filter private assistants.
|
||||||
|
* @param {Assistant[]} params.assistants - The list of assistants to filter.
|
||||||
|
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
|
||||||
|
* @returns {Assistant[]} - The filtered list of assistants.
|
||||||
|
*/
|
||||||
|
function filterAssistants({ assistants, userId, assistantsConfig }) {
|
||||||
|
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
|
||||||
|
if (privateAssistants) {
|
||||||
|
return assistants.filter((assistant) => userId === assistant.metadata?.author);
|
||||||
|
} else if (supportedIds?.length) {
|
||||||
|
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
|
||||||
|
} else if (excludedIds?.length) {
|
||||||
|
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||||
|
}
|
||||||
|
return assistants;
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
getOpenAIClient,
|
getOpenAIClient,
|
||||||
fetchAssistants,
|
fetchAssistants,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
const { FileContext } = require('librechat-data-provider');
|
const { FileContext } = require('librechat-data-provider');
|
||||||
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||||
|
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||||
const { uploadImageBuffer } = require('~/server/services/Files/process');
|
const { uploadImageBuffer } = require('~/server/services/Files/process');
|
||||||
const { updateAssistant, getAssistants } = require('~/models/Assistant');
|
|
||||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||||
const { deleteFileByFilter } = require('~/models/File');
|
const { deleteFileByFilter } = require('~/models/File');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
@@ -40,9 +41,11 @@ const createAssistant = async (req, res) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const assistant = await openai.beta.assistants.create(assistantData);
|
const assistant = await openai.beta.assistants.create(assistantData);
|
||||||
|
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||||
if (azureModelIdentifier) {
|
if (azureModelIdentifier) {
|
||||||
assistant.model = azureModelIdentifier;
|
assistant.model = azureModelIdentifier;
|
||||||
}
|
}
|
||||||
|
await promise;
|
||||||
logger.debug('/assistants/', assistant);
|
logger.debug('/assistants/', assistant);
|
||||||
res.status(201).json(assistant);
|
res.status(201).json(assistant);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -61,7 +64,6 @@ const retrieveAssistant = async (req, res) => {
|
|||||||
try {
|
try {
|
||||||
/* NOTE: not actually being used right now */
|
/* NOTE: not actually being used right now */
|
||||||
const { openai } = await getOpenAIClient({ req, res });
|
const { openai } = await getOpenAIClient({ req, res });
|
||||||
|
|
||||||
const assistant_id = req.params.id;
|
const assistant_id = req.params.id;
|
||||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||||
res.json(assistant);
|
res.json(assistant);
|
||||||
@@ -83,6 +85,7 @@ const retrieveAssistant = async (req, res) => {
|
|||||||
const patchAssistant = async (req, res) => {
|
const patchAssistant = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { openai } = await getOpenAIClient({ req, res });
|
const { openai } = await getOpenAIClient({ req, res });
|
||||||
|
await validateAuthor({ req, openai });
|
||||||
|
|
||||||
const assistant_id = req.params.id;
|
const assistant_id = req.params.id;
|
||||||
const { endpoint: _e, ...updateData } = req.body;
|
const { endpoint: _e, ...updateData } = req.body;
|
||||||
@@ -119,6 +122,7 @@ const patchAssistant = async (req, res) => {
|
|||||||
const deleteAssistant = async (req, res) => {
|
const deleteAssistant = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const { openai } = await getOpenAIClient({ req, res });
|
const { openai } = await getOpenAIClient({ req, res });
|
||||||
|
await validateAuthor({ req, openai });
|
||||||
|
|
||||||
const assistant_id = req.params.id;
|
const assistant_id = req.params.id;
|
||||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||||
@@ -141,19 +145,7 @@ const deleteAssistant = async (req, res) => {
|
|||||||
*/
|
*/
|
||||||
const listAssistants = async (req, res) => {
|
const listAssistants = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const body = await fetchAssistants(req, res);
|
const body = await fetchAssistants({ req, res });
|
||||||
|
|
||||||
if (req.app.locals?.[req.query.endpoint]) {
|
|
||||||
/** @type {Partial<TAssistantEndpoint>} */
|
|
||||||
const assistantsConfig = req.app.locals[req.query.endpoint];
|
|
||||||
const { supportedIds, excludedIds } = assistantsConfig;
|
|
||||||
if (supportedIds?.length) {
|
|
||||||
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
|
|
||||||
} else if (excludedIds?.length) {
|
|
||||||
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res.json(body);
|
res.json(body);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('[/assistants] Error listing assistants', error);
|
logger.error('[/assistants] Error listing assistants', error);
|
||||||
@@ -195,6 +187,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
|||||||
|
|
||||||
let { metadata: _metadata = '{}' } = req.body;
|
let { metadata: _metadata = '{}' } = req.body;
|
||||||
const { openai } = await getOpenAIClient({ req, res });
|
const { openai } = await getOpenAIClient({ req, res });
|
||||||
|
await validateAuthor({ req, openai });
|
||||||
|
|
||||||
const image = await uploadImageBuffer({
|
const image = await uploadImageBuffer({
|
||||||
req,
|
req,
|
||||||
@@ -229,7 +222,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
|||||||
|
|
||||||
const promises = [];
|
const promises = [];
|
||||||
promises.push(
|
promises.push(
|
||||||
updateAssistant(
|
updateAssistantDoc(
|
||||||
{ assistant_id },
|
{ assistant_id },
|
||||||
{
|
{
|
||||||
avatar: {
|
avatar: {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
const { ToolCallTypes } = require('librechat-data-provider');
|
const { ToolCallTypes } = require('librechat-data-provider');
|
||||||
|
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||||
|
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||||
const { getOpenAIClient } = require('./helpers');
|
const { getOpenAIClient } = require('./helpers');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
@@ -37,9 +39,11 @@ const createAssistant = async (req, res) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const assistant = await openai.beta.assistants.create(assistantData);
|
const assistant = await openai.beta.assistants.create(assistantData);
|
||||||
|
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||||
if (azureModelIdentifier) {
|
if (azureModelIdentifier) {
|
||||||
assistant.model = azureModelIdentifier;
|
assistant.model = azureModelIdentifier;
|
||||||
}
|
}
|
||||||
|
await promise;
|
||||||
logger.debug('/assistants/', assistant);
|
logger.debug('/assistants/', assistant);
|
||||||
res.status(201).json(assistant);
|
res.status(201).json(assistant);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -58,6 +62,7 @@ const createAssistant = async (req, res) => {
|
|||||||
* @returns {Promise<Assistant>} The updated assistant.
|
* @returns {Promise<Assistant>} The updated assistant.
|
||||||
*/
|
*/
|
||||||
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||||
|
await validateAuthor({ req, openai });
|
||||||
const tools = [];
|
const tools = [];
|
||||||
|
|
||||||
let hasFileSearch = false;
|
let hasFileSearch = false;
|
||||||
|
|||||||
@@ -1,26 +1,22 @@
|
|||||||
const User = require('~/models/User');
|
|
||||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const loginController = async (req, res) => {
|
const loginController = async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const user = await User.findById(req.user._id);
|
if (!req.user) {
|
||||||
|
|
||||||
// If user doesn't exist, return error
|
|
||||||
if (!user) {
|
|
||||||
// typeof user !== User) { // this doesn't seem to resolve the User type ??
|
|
||||||
return res.status(400).json({ message: 'Invalid credentials' });
|
return res.status(400).json({ message: 'Invalid credentials' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const token = await setAuthTokens(user._id, res);
|
const { password: _, __v, ...user } = req.user;
|
||||||
|
user.id = user._id.toString();
|
||||||
|
|
||||||
|
const token = await setAuthTokens(req.user._id, res);
|
||||||
|
|
||||||
return res.status(200).send({ token, user });
|
return res.status(200).send({ token, user });
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('[loginController]', err);
|
logger.error('[loginController]', err);
|
||||||
|
return res.status(500).json({ message: 'Something went wrong' });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generic error messages are safer
|
|
||||||
return res.status(500).json({ message: 'Something went wrong' });
|
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
|
|||||||
@@ -6,16 +6,16 @@ const axios = require('axios');
|
|||||||
const express = require('express');
|
const express = require('express');
|
||||||
const passport = require('passport');
|
const passport = require('passport');
|
||||||
const mongoSanitize = require('express-mongo-sanitize');
|
const mongoSanitize = require('express-mongo-sanitize');
|
||||||
|
const { jwtLogin, passportLogin } = require('~/strategies');
|
||||||
|
const { connectDb, indexSync } = require('~/lib/db');
|
||||||
|
const { isEnabled } = require('~/server/utils');
|
||||||
|
const { ldapLogin } = require('~/strategies');
|
||||||
|
const { logger } = require('~/config');
|
||||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||||
const errorController = require('./controllers/ErrorController');
|
const errorController = require('./controllers/ErrorController');
|
||||||
const { jwtLogin, passportLogin } = require('~/strategies');
|
|
||||||
const configureSocialLogins = require('./socialLogins');
|
const configureSocialLogins = require('./socialLogins');
|
||||||
const { connectDb, indexSync } = require('~/lib/db');
|
|
||||||
const AppService = require('./services/AppService');
|
const AppService = require('./services/AppService');
|
||||||
const noIndex = require('./middleware/noIndex');
|
const noIndex = require('./middleware/noIndex');
|
||||||
const { isEnabled } = require('~/server/utils');
|
|
||||||
const { logger } = require('~/config');
|
|
||||||
|
|
||||||
const routes = require('./routes');
|
const routes = require('./routes');
|
||||||
|
|
||||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
|
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
|
||||||
@@ -60,6 +60,11 @@ const startServer = async () => {
|
|||||||
passport.use(await jwtLogin());
|
passport.use(await jwtLogin());
|
||||||
passport.use(passportLogin());
|
passport.use(passportLogin());
|
||||||
|
|
||||||
|
// LDAP Auth
|
||||||
|
if (process.env.LDAP_URL && process.env.LDAP_USER_SEARCH_BASE) {
|
||||||
|
passport.use(ldapLogin);
|
||||||
|
}
|
||||||
|
|
||||||
if (isEnabled(ALLOW_SOCIAL_LOGIN)) {
|
if (isEnabled(ALLOW_SOCIAL_LOGIN)) {
|
||||||
configureSocialLogins(app);
|
configureSocialLogins(app);
|
||||||
}
|
}
|
||||||
@@ -76,6 +81,7 @@ const startServer = async () => {
|
|||||||
app.use('/api/convos', routes.convos);
|
app.use('/api/convos', routes.convos);
|
||||||
app.use('/api/presets', routes.presets);
|
app.use('/api/presets', routes.presets);
|
||||||
app.use('/api/prompts', routes.prompts);
|
app.use('/api/prompts', routes.prompts);
|
||||||
|
app.use('/api/categories', routes.categories);
|
||||||
app.use('/api/tokenizer', routes.tokenizer);
|
app.use('/api/tokenizer', routes.tokenizer);
|
||||||
app.use('/api/endpoints', routes.endpoints);
|
app.use('/api/endpoints', routes.endpoints);
|
||||||
app.use('/api/balance', routes.balance);
|
app.use('/api/balance', routes.balance);
|
||||||
@@ -86,9 +92,10 @@ const startServer = async () => {
|
|||||||
app.use('/api/files', await routes.files.initialize());
|
app.use('/api/files', await routes.files.initialize());
|
||||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||||
app.use('/api/share', routes.share);
|
app.use('/api/share', routes.share);
|
||||||
|
app.use('/api/roles', routes.roles);
|
||||||
|
|
||||||
app.use((req, res) => {
|
app.use((req, res) => {
|
||||||
res.status(404).sendFile(path.join(app.locals.paths.dist, 'index.html'));
|
res.sendFile(path.join(app.locals.paths.dist, 'index.html'));
|
||||||
});
|
});
|
||||||
|
|
||||||
app.listen(port, host, () => {
|
app.listen(port, host, () => {
|
||||||
|
|||||||
@@ -1,31 +1,36 @@
|
|||||||
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
||||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||||
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
|
|
||||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||||
const abortControllers = require('./abortControllers');
|
const abortControllers = require('./abortControllers');
|
||||||
|
const { saveMessage, getConvo } = require('~/models');
|
||||||
const spendTokens = require('~/models/spendTokens');
|
const spendTokens = require('~/models/spendTokens');
|
||||||
const { abortRun } = require('./abortRun');
|
const { abortRun } = require('./abortRun');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
async function abortMessage(req, res) {
|
async function abortMessage(req, res) {
|
||||||
let { abortKey, conversationId, endpoint } = req.body;
|
let { abortKey, endpoint } = req.body;
|
||||||
|
|
||||||
if (!abortKey && conversationId) {
|
|
||||||
abortKey = conversationId;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isAssistantsEndpoint(endpoint)) {
|
if (isAssistantsEndpoint(endpoint)) {
|
||||||
return await abortRun(req, res);
|
return await abortRun(req, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||||
|
|
||||||
|
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||||
|
abortKey = conversationId;
|
||||||
|
}
|
||||||
|
|
||||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||||
return res.status(204).send({ message: 'Request not found' });
|
return res.status(204).send({ message: 'Request not found' });
|
||||||
}
|
}
|
||||||
|
|
||||||
const { abortController } = abortControllers.get(abortKey);
|
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||||
|
if (!abortController) {
|
||||||
|
return res.status(204).send({ message: 'Request not found' });
|
||||||
|
}
|
||||||
const finalEvent = await abortController.abortCompletion();
|
const finalEvent = await abortController.abortCompletion();
|
||||||
logger.debug('[abortMessage] Aborted request', { abortKey });
|
logger.info('[abortMessage] Aborted request', { abortKey });
|
||||||
abortControllers.delete(abortKey);
|
abortControllers.delete(abortKey);
|
||||||
|
|
||||||
if (res.headersSent && finalEvent) {
|
if (res.headersSent && finalEvent) {
|
||||||
@@ -50,12 +55,35 @@ const handleAbort = () => {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const createAbortController = (req, res, getAbortData) => {
|
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
const { endpointOption } = req.body;
|
const { endpointOption } = req.body;
|
||||||
const onStart = (userMessage) => {
|
|
||||||
|
abortController.getAbortData = function () {
|
||||||
|
return getAbortData();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {TMessage} userMessage
|
||||||
|
* @param {string} responseMessageId
|
||||||
|
*/
|
||||||
|
const onStart = (userMessage, responseMessageId) => {
|
||||||
sendMessage(res, { message: userMessage, created: true });
|
sendMessage(res, { message: userMessage, created: true });
|
||||||
|
|
||||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||||
|
const prevRequest = abortControllers.get(abortKey);
|
||||||
|
|
||||||
|
if (prevRequest && prevRequest?.abortController) {
|
||||||
|
const data = prevRequest.abortController.getAbortData();
|
||||||
|
getReqData({ userMessage: data?.userMessage });
|
||||||
|
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||||
|
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
|
||||||
|
res.on('finish', function () {
|
||||||
|
abortControllers.delete(addedAbortKey);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
abortControllers.set(abortKey, { abortController, ...endpointOption });
|
abortControllers.set(abortKey, { abortController, ...endpointOption });
|
||||||
|
|
||||||
res.on('finish', function () {
|
res.on('finish', function () {
|
||||||
@@ -65,7 +93,8 @@ const createAbortController = (req, res, getAbortData) => {
|
|||||||
|
|
||||||
abortController.abortCompletion = async function () {
|
abortController.abortCompletion = async function () {
|
||||||
abortController.abort();
|
abortController.abort();
|
||||||
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
|
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||||
|
getAbortData();
|
||||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||||
const user = req.user.id;
|
const user = req.user.id;
|
||||||
|
|
||||||
@@ -89,10 +118,20 @@ const createAbortController = (req, res, getAbortData) => {
|
|||||||
|
|
||||||
saveMessage({ ...responseMessage, user });
|
saveMessage({ ...responseMessage, user });
|
||||||
|
|
||||||
|
let conversation;
|
||||||
|
if (userMessagePromise) {
|
||||||
|
const resolved = await userMessagePromise;
|
||||||
|
conversation = resolved?.conversation;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!conversation) {
|
||||||
|
conversation = await getConvo(req.user.id, conversationId);
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
title: await getConvoTitle(user, conversationId),
|
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||||
final: true,
|
final: true,
|
||||||
conversation: await getConvo(user, conversationId),
|
conversation,
|
||||||
requestMessage: userMessage,
|
requestMessage: userMessage,
|
||||||
responseMessage: responseMessage,
|
responseMessage: responseMessage,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||||
|
const { deleteMessages } = require('~/models/Message');
|
||||||
const { getConvo } = require('~/models/Conversation');
|
const { getConvo } = require('~/models/Conversation');
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
const { sendMessage } = require('~/server/utils');
|
const { sendMessage } = require('~/server/utils');
|
||||||
@@ -66,13 +67,19 @@ async function abortRun(req, res) {
|
|||||||
logger.error('[abortRun] Error fetching or processing run', error);
|
logger.error('[abortRun] Error fetching or processing run', error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* TODO: a reconciling strategy between the existing intermediate message would be more optimal than deleting it */
|
||||||
|
await deleteMessages({
|
||||||
|
user: req.user.id,
|
||||||
|
unfinished: true,
|
||||||
|
conversationId,
|
||||||
|
});
|
||||||
runMessages = await checkMessageGaps({
|
runMessages = await checkMessageGaps({
|
||||||
openai,
|
openai,
|
||||||
|
run_id,
|
||||||
endpoint,
|
endpoint,
|
||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
|
||||||
latestMessageId,
|
|
||||||
conversationId,
|
conversationId,
|
||||||
|
latestMessageId,
|
||||||
});
|
});
|
||||||
|
|
||||||
const finalEvent = {
|
const finalEvent = {
|
||||||
|
|||||||
43
api/server/middleware/assistants/validate.js
Normal file
43
api/server/middleware/assistants/validate.js
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
const { v4 } = require('uuid');
|
||||||
|
const { handleAbortError } = require('~/server/middleware/abortMiddleware');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the assistant is supported or excluded
|
||||||
|
* @param {object} req - Express Request
|
||||||
|
* @param {object} req.body - The request payload.
|
||||||
|
* @param {object} res - Express Response
|
||||||
|
* @param {function} next - Express next middleware function.
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
const validateAssistant = async (req, res, next) => {
|
||||||
|
const { endpoint, conversationId, assistant_id, messageId } = req.body;
|
||||||
|
|
||||||
|
/** @type {Partial<TAssistantEndpoint>} */
|
||||||
|
const assistantsConfig = req.app.locals?.[endpoint];
|
||||||
|
if (!assistantsConfig) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
const { supportedIds, excludedIds } = assistantsConfig;
|
||||||
|
const error = { message: 'Assistant not supported' };
|
||||||
|
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||||
|
return await handleAbortError(res, req, error, {
|
||||||
|
sender: 'System',
|
||||||
|
conversationId,
|
||||||
|
messageId: v4(),
|
||||||
|
parentMessageId: messageId,
|
||||||
|
error,
|
||||||
|
});
|
||||||
|
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||||
|
return await handleAbortError(res, req, error, {
|
||||||
|
sender: 'System',
|
||||||
|
conversationId,
|
||||||
|
messageId: v4(),
|
||||||
|
parentMessageId: messageId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return next();
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = validateAssistant;
|
||||||
43
api/server/middleware/assistants/validateAuthor.js
Normal file
43
api/server/middleware/assistants/validateAuthor.js
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
const { SystemRoles } = require('librechat-data-provider');
|
||||||
|
const { getAssistant } = require('~/models/Assistant');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the assistant is supported or excluded
|
||||||
|
* @param {object} params
|
||||||
|
* @param {object} params.req - Express Request
|
||||||
|
* @param {object} params.req.body - The request payload.
|
||||||
|
* @param {string} params.overrideEndpoint - The override endpoint
|
||||||
|
* @param {string} params.overrideAssistantId - The override assistant ID
|
||||||
|
* @param {OpenAIClient} params.openai - OpenAI API Client
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||||
|
const assistant_id =
|
||||||
|
overrideAssistantId ?? req.params.id ?? req.body.assistant_id ?? req.query.assistant_id;
|
||||||
|
|
||||||
|
/** @type {Partial<TAssistantEndpoint>} */
|
||||||
|
const assistantsConfig = req.app.locals?.[endpoint];
|
||||||
|
if (!assistantsConfig) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!assistantsConfig.privateAssistants) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const assistantDoc = await getAssistant({ assistant_id, user: req.user.id });
|
||||||
|
if (assistantDoc) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||||
|
if (req.user.id !== assistant?.metadata?.author) {
|
||||||
|
throw new Error(`Assistant ${assistant_id} is not authored by the user.`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = validateAuthor;
|
||||||
28
api/server/middleware/canDeleteAccount.js
Normal file
28
api/server/middleware/canDeleteAccount.js
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
const { SystemRoles } = require('librechat-data-provider');
|
||||||
|
const { isEnabled } = require('~/server/utils');
|
||||||
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the user can delete their account
|
||||||
|
*
|
||||||
|
* @async
|
||||||
|
* @function
|
||||||
|
* @param {Object} req - Express request object
|
||||||
|
* @param {Object} res - Express response object
|
||||||
|
* @param {Function} next - Next middleware function
|
||||||
|
*
|
||||||
|
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the user can delete their account
|
||||||
|
*/
|
||||||
|
|
||||||
|
const canDeleteAccount = async (req, res, next = () => {}) => {
|
||||||
|
const { user } = req;
|
||||||
|
const { ALLOW_ACCOUNT_DELETION = true } = process.env;
|
||||||
|
if (user?.role === SystemRoles.ADMIN || isEnabled(ALLOW_ACCOUNT_DELETION)) {
|
||||||
|
return next();
|
||||||
|
} else {
|
||||||
|
logger.error(`[User] [Delete Account] [User cannot delete account] [User: ${user?.id}]`);
|
||||||
|
return res.status(403).send({ message: 'You do not have permission to delete this account' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = canDeleteAccount;
|
||||||
@@ -1,15 +1,13 @@
|
|||||||
const Keyv = require('keyv');
|
const Keyv = require('keyv');
|
||||||
const uap = require('ua-parser-js');
|
const uap = require('ua-parser-js');
|
||||||
const { ViolationTypes } = require('librechat-data-provider');
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
const { isEnabled, removePorts } = require('../utils');
|
const { isEnabled, removePorts } = require('~/server/utils');
|
||||||
const keyvRedis = require('~/cache/keyvRedis');
|
const keyvMongo = require('~/cache/keyvMongo');
|
||||||
const denyRequest = require('./denyRequest');
|
const denyRequest = require('./denyRequest');
|
||||||
const { getLogStores } = require('~/cache');
|
const { getLogStores } = require('~/cache');
|
||||||
const User = require('~/models/User');
|
const { findUser } = require('~/models');
|
||||||
|
|
||||||
const banCache = isEnabled(process.env.USE_REDIS)
|
const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 });
|
||||||
? new Keyv({ store: keyvRedis })
|
|
||||||
: new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 });
|
|
||||||
const message = 'Your account has been temporarily banned due to violations of our service.';
|
const message = 'Your account has been temporarily banned due to violations of our service.';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -57,7 +55,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
|||||||
let userId = req.user?.id ?? req.user?._id ?? null;
|
let userId = req.user?.id ?? req.user?._id ?? null;
|
||||||
|
|
||||||
if (!userId && req?.body?.email) {
|
if (!userId && req?.body?.email) {
|
||||||
const user = await User.findOne({ email: req.body.email }, '_id').lean();
|
const user = await findUser({ email: req.body.email }, '_id');
|
||||||
userId = user?._id ? user._id.toString() : userId;
|
userId = user?._id ? user._id.toString() : userId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,45 +1,45 @@
|
|||||||
const abortMiddleware = require('./abortMiddleware');
|
const validatePasswordReset = require('./validatePasswordReset');
|
||||||
const checkBan = require('./checkBan');
|
|
||||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
|
||||||
const uaParser = require('./uaParser');
|
|
||||||
const setHeaders = require('./setHeaders');
|
|
||||||
const loginLimiter = require('./loginLimiter');
|
|
||||||
const validateModel = require('./validateModel');
|
|
||||||
const requireJwtAuth = require('./requireJwtAuth');
|
|
||||||
const uploadLimiters = require('./uploadLimiters');
|
|
||||||
const registerLimiter = require('./registerLimiter');
|
|
||||||
const messageLimiters = require('./messageLimiters');
|
|
||||||
const requireLocalAuth = require('./requireLocalAuth');
|
|
||||||
const validateEndpoint = require('./validateEndpoint');
|
|
||||||
const concurrentLimiter = require('./concurrentLimiter');
|
|
||||||
const validateMessageReq = require('./validateMessageReq');
|
|
||||||
const buildEndpointOption = require('./buildEndpointOption');
|
|
||||||
const validateRegistration = require('./validateRegistration');
|
const validateRegistration = require('./validateRegistration');
|
||||||
const validateImageRequest = require('./validateImageRequest');
|
const validateImageRequest = require('./validateImageRequest');
|
||||||
|
const buildEndpointOption = require('./buildEndpointOption');
|
||||||
|
const validateMessageReq = require('./validateMessageReq');
|
||||||
|
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||||
|
const concurrentLimiter = require('./concurrentLimiter');
|
||||||
|
const validateEndpoint = require('./validateEndpoint');
|
||||||
|
const requireLocalAuth = require('./requireLocalAuth');
|
||||||
|
const canDeleteAccount = require('./canDeleteAccount');
|
||||||
|
const requireLdapAuth = require('./requireLdapAuth');
|
||||||
|
const abortMiddleware = require('./abortMiddleware');
|
||||||
|
const requireJwtAuth = require('./requireJwtAuth');
|
||||||
|
const validateModel = require('./validateModel');
|
||||||
const moderateText = require('./moderateText');
|
const moderateText = require('./moderateText');
|
||||||
|
const setHeaders = require('./setHeaders');
|
||||||
|
const limiters = require('./limiters');
|
||||||
|
const uaParser = require('./uaParser');
|
||||||
|
const checkBan = require('./checkBan');
|
||||||
const noIndex = require('./noIndex');
|
const noIndex = require('./noIndex');
|
||||||
const importLimiters = require('./importLimiters');
|
const roles = require('./roles');
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
...uploadLimiters,
|
|
||||||
...abortMiddleware,
|
...abortMiddleware,
|
||||||
...messageLimiters,
|
...limiters,
|
||||||
|
...roles,
|
||||||
|
noIndex,
|
||||||
checkBan,
|
checkBan,
|
||||||
uaParser,
|
uaParser,
|
||||||
setHeaders,
|
setHeaders,
|
||||||
loginLimiter,
|
moderateText,
|
||||||
|
validateModel,
|
||||||
requireJwtAuth,
|
requireJwtAuth,
|
||||||
registerLimiter,
|
requireLdapAuth,
|
||||||
requireLocalAuth,
|
requireLocalAuth,
|
||||||
|
canDeleteAccount,
|
||||||
validateEndpoint,
|
validateEndpoint,
|
||||||
concurrentLimiter,
|
concurrentLimiter,
|
||||||
|
checkDomainAllowed,
|
||||||
validateMessageReq,
|
validateMessageReq,
|
||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
validateRegistration,
|
validateRegistration,
|
||||||
validateImageRequest,
|
validateImageRequest,
|
||||||
validateModel,
|
validatePasswordReset,
|
||||||
moderateText,
|
|
||||||
noIndex,
|
|
||||||
...importLimiters,
|
|
||||||
checkDomainAllowed,
|
|
||||||
};
|
};
|
||||||
|
|||||||
22
api/server/middleware/limiters/index.js
Normal file
22
api/server/middleware/limiters/index.js
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
const createTTSLimiters = require('./ttsLimiters');
|
||||||
|
const createSTTLimiters = require('./sttLimiters');
|
||||||
|
|
||||||
|
const loginLimiter = require('./loginLimiter');
|
||||||
|
const importLimiters = require('./importLimiters');
|
||||||
|
const uploadLimiters = require('./uploadLimiters');
|
||||||
|
const registerLimiter = require('./registerLimiter');
|
||||||
|
const messageLimiters = require('./messageLimiters');
|
||||||
|
const verifyEmailLimiter = require('./verifyEmailLimiter');
|
||||||
|
const resetPasswordLimiter = require('./resetPasswordLimiter');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
...uploadLimiters,
|
||||||
|
...importLimiters,
|
||||||
|
...messageLimiters,
|
||||||
|
loginLimiter,
|
||||||
|
registerLimiter,
|
||||||
|
createTTSLimiters,
|
||||||
|
createSTTLimiters,
|
||||||
|
verifyEmailLimiter,
|
||||||
|
resetPasswordLimiter,
|
||||||
|
};
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
const rateLimit = require('express-rate-limit');
|
const rateLimit = require('express-rate-limit');
|
||||||
const { logViolation } = require('../../cache');
|
const { removePorts } = require('~/server/utils');
|
||||||
const { removePorts } = require('../utils');
|
const { logViolation } = require('~/cache');
|
||||||
|
|
||||||
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
|
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
|
||||||
const windowMs = LOGIN_WINDOW * 60 * 1000;
|
const windowMs = LOGIN_WINDOW * 60 * 1000;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
const rateLimit = require('express-rate-limit');
|
const rateLimit = require('express-rate-limit');
|
||||||
const { logViolation } = require('../../cache');
|
const denyRequest = require('~/server/middleware/denyRequest');
|
||||||
const denyRequest = require('./denyRequest');
|
const { logViolation } = require('~/cache');
|
||||||
|
|
||||||
const {
|
const {
|
||||||
MESSAGE_IP_MAX = 40,
|
MESSAGE_IP_MAX = 40,
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
const rateLimit = require('express-rate-limit');
|
const rateLimit = require('express-rate-limit');
|
||||||
const { logViolation } = require('../../cache');
|
const { removePorts } = require('~/server/utils');
|
||||||
const { removePorts } = require('../utils');
|
const { logViolation } = require('~/cache');
|
||||||
|
|
||||||
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
|
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
|
||||||
const windowMs = REGISTER_WINDOW * 60 * 1000;
|
const windowMs = REGISTER_WINDOW * 60 * 1000;
|
||||||
35
api/server/middleware/limiters/resetPasswordLimiter.js
Normal file
35
api/server/middleware/limiters/resetPasswordLimiter.js
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
const rateLimit = require('express-rate-limit');
|
||||||
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const { removePorts } = require('~/server/utils');
|
||||||
|
const { logViolation } = require('~/cache');
|
||||||
|
|
||||||
|
const {
|
||||||
|
RESET_PASSWORD_WINDOW = 2,
|
||||||
|
RESET_PASSWORD_MAX = 2,
|
||||||
|
RESET_PASSWORD_VIOLATION_SCORE: score,
|
||||||
|
} = process.env;
|
||||||
|
const windowMs = RESET_PASSWORD_WINDOW * 60 * 1000;
|
||||||
|
const max = RESET_PASSWORD_MAX;
|
||||||
|
const windowInMinutes = windowMs / 60000;
|
||||||
|
const message = `Too many attempts, please try again after ${windowInMinutes} minute(s)`;
|
||||||
|
|
||||||
|
const handler = async (req, res) => {
|
||||||
|
const type = ViolationTypes.RESET_PASSWORD_LIMIT;
|
||||||
|
const errorMessage = {
|
||||||
|
type,
|
||||||
|
max,
|
||||||
|
windowInMinutes,
|
||||||
|
};
|
||||||
|
|
||||||
|
await logViolation(req, res, type, errorMessage, score);
|
||||||
|
return res.status(429).json({ message });
|
||||||
|
};
|
||||||
|
|
||||||
|
const resetPasswordLimiter = rateLimit({
|
||||||
|
windowMs,
|
||||||
|
max,
|
||||||
|
handler,
|
||||||
|
keyGenerator: removePorts,
|
||||||
|
});
|
||||||
|
|
||||||
|
module.exports = resetPasswordLimiter;
|
||||||
68
api/server/middleware/limiters/sttLimiters.js
Normal file
68
api/server/middleware/limiters/sttLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
const rateLimit = require('express-rate-limit');
|
||||||
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const logViolation = require('~/cache/logViolation');
|
||||||
|
|
||||||
|
const getEnvironmentVariables = () => {
|
||||||
|
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||||
|
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||||
|
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||||
|
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||||
|
|
||||||
|
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||||
|
const sttIpMax = STT_IP_MAX;
|
||||||
|
const sttIpWindowInMinutes = sttIpWindowMs / 60000;
|
||||||
|
|
||||||
|
const sttUserWindowMs = STT_USER_WINDOW * 60 * 1000;
|
||||||
|
const sttUserMax = STT_USER_MAX;
|
||||||
|
const sttUserWindowInMinutes = sttUserWindowMs / 60000;
|
||||||
|
|
||||||
|
return {
|
||||||
|
sttIpWindowMs,
|
||||||
|
sttIpMax,
|
||||||
|
sttIpWindowInMinutes,
|
||||||
|
sttUserWindowMs,
|
||||||
|
sttUserMax,
|
||||||
|
sttUserWindowInMinutes,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createSTTHandler = (ip = true) => {
|
||||||
|
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||||
|
getEnvironmentVariables();
|
||||||
|
|
||||||
|
return async (req, res) => {
|
||||||
|
const type = ViolationTypes.STT_LIMIT;
|
||||||
|
const errorMessage = {
|
||||||
|
type,
|
||||||
|
max: ip ? sttIpMax : sttUserMax,
|
||||||
|
limiter: ip ? 'ip' : 'user',
|
||||||
|
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||||
|
};
|
||||||
|
|
||||||
|
await logViolation(req, res, type, errorMessage);
|
||||||
|
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createSTTLimiters = () => {
|
||||||
|
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
|
||||||
|
|
||||||
|
const sttIpLimiter = rateLimit({
|
||||||
|
windowMs: sttIpWindowMs,
|
||||||
|
max: sttIpMax,
|
||||||
|
handler: createSTTHandler(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const sttUserLimiter = rateLimit({
|
||||||
|
windowMs: sttUserWindowMs,
|
||||||
|
max: sttUserMax,
|
||||||
|
handler: createSTTHandler(false),
|
||||||
|
keyGenerator: function (req) {
|
||||||
|
return req.user?.id; // Use the user ID or NULL if not available
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return { sttIpLimiter, sttUserLimiter };
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = createSTTLimiters;
|
||||||
68
api/server/middleware/limiters/ttsLimiters.js
Normal file
68
api/server/middleware/limiters/ttsLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
const rateLimit = require('express-rate-limit');
|
||||||
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const logViolation = require('~/cache/logViolation');
|
||||||
|
|
||||||
|
const getEnvironmentVariables = () => {
|
||||||
|
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||||
|
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||||
|
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||||
|
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||||
|
|
||||||
|
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||||
|
const ttsIpMax = TTS_IP_MAX;
|
||||||
|
const ttsIpWindowInMinutes = ttsIpWindowMs / 60000;
|
||||||
|
|
||||||
|
const ttsUserWindowMs = TTS_USER_WINDOW * 60 * 1000;
|
||||||
|
const ttsUserMax = TTS_USER_MAX;
|
||||||
|
const ttsUserWindowInMinutes = ttsUserWindowMs / 60000;
|
||||||
|
|
||||||
|
return {
|
||||||
|
ttsIpWindowMs,
|
||||||
|
ttsIpMax,
|
||||||
|
ttsIpWindowInMinutes,
|
||||||
|
ttsUserWindowMs,
|
||||||
|
ttsUserMax,
|
||||||
|
ttsUserWindowInMinutes,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createTTSHandler = (ip = true) => {
|
||||||
|
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||||
|
getEnvironmentVariables();
|
||||||
|
|
||||||
|
return async (req, res) => {
|
||||||
|
const type = ViolationTypes.TTS_LIMIT;
|
||||||
|
const errorMessage = {
|
||||||
|
type,
|
||||||
|
max: ip ? ttsIpMax : ttsUserMax,
|
||||||
|
limiter: ip ? 'ip' : 'user',
|
||||||
|
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||||
|
};
|
||||||
|
|
||||||
|
await logViolation(req, res, type, errorMessage);
|
||||||
|
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const createTTSLimiters = () => {
|
||||||
|
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
|
||||||
|
|
||||||
|
const ttsIpLimiter = rateLimit({
|
||||||
|
windowMs: ttsIpWindowMs,
|
||||||
|
max: ttsIpMax,
|
||||||
|
handler: createTTSHandler(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const ttsUserLimiter = rateLimit({
|
||||||
|
windowMs: ttsUserWindowMs,
|
||||||
|
max: ttsUserMax,
|
||||||
|
handler: createTTSHandler(false),
|
||||||
|
keyGenerator: function (req) {
|
||||||
|
return req.user?.id; // Use the user ID or NULL if not available
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return { ttsIpLimiter, ttsUserLimiter };
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = createTTSLimiters;
|
||||||
35
api/server/middleware/limiters/verifyEmailLimiter.js
Normal file
35
api/server/middleware/limiters/verifyEmailLimiter.js
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
const rateLimit = require('express-rate-limit');
|
||||||
|
const { ViolationTypes } = require('librechat-data-provider');
|
||||||
|
const { removePorts } = require('~/server/utils');
|
||||||
|
const { logViolation } = require('~/cache');
|
||||||
|
|
||||||
|
const {
|
||||||
|
VERIFY_EMAIL_WINDOW = 2,
|
||||||
|
VERIFY_EMAIL_MAX = 2,
|
||||||
|
VERIFY_EMAIL_VIOLATION_SCORE: score,
|
||||||
|
} = process.env;
|
||||||
|
const windowMs = VERIFY_EMAIL_WINDOW * 60 * 1000;
|
||||||
|
const max = VERIFY_EMAIL_MAX;
|
||||||
|
const windowInMinutes = windowMs / 60000;
|
||||||
|
const message = `Too many attempts, please try again after ${windowInMinutes} minute(s)`;
|
||||||
|
|
||||||
|
const handler = async (req, res) => {
|
||||||
|
const type = ViolationTypes.VERIFY_EMAIL_LIMIT;
|
||||||
|
const errorMessage = {
|
||||||
|
type,
|
||||||
|
max,
|
||||||
|
windowInMinutes,
|
||||||
|
};
|
||||||
|
|
||||||
|
await logViolation(req, res, type, errorMessage, score);
|
||||||
|
return res.status(429).json({ message });
|
||||||
|
};
|
||||||
|
|
||||||
|
const verifyEmailLimiter = rateLimit({
|
||||||
|
windowMs,
|
||||||
|
max,
|
||||||
|
handler,
|
||||||
|
keyGenerator: removePorts,
|
||||||
|
});
|
||||||
|
|
||||||
|
module.exports = verifyEmailLimiter;
|
||||||
22
api/server/middleware/requireLdapAuth.js
Normal file
22
api/server/middleware/requireLdapAuth.js
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
const passport = require('passport');
|
||||||
|
|
||||||
|
const requireLdapAuth = (req, res, next) => {
|
||||||
|
passport.authenticate('ldapauth', (err, user, info) => {
|
||||||
|
if (err) {
|
||||||
|
console.log({
|
||||||
|
title: '(requireLdapAuth) Error at passport.authenticate',
|
||||||
|
parameters: [{ name: 'error', value: err }],
|
||||||
|
});
|
||||||
|
return next(err);
|
||||||
|
}
|
||||||
|
if (!user) {
|
||||||
|
console.log({
|
||||||
|
title: '(requireLdapAuth) Error: No user',
|
||||||
|
});
|
||||||
|
return res.status(404).send(info);
|
||||||
|
}
|
||||||
|
req.user = user;
|
||||||
|
next();
|
||||||
|
})(req, res, next);
|
||||||
|
};
|
||||||
|
module.exports = requireLdapAuth;
|
||||||
@@ -21,7 +21,13 @@ const requireLocalAuth = (req, res, next) => {
|
|||||||
log({
|
log({
|
||||||
title: '(requireLocalAuth) Error: No user',
|
title: '(requireLocalAuth) Error: No user',
|
||||||
});
|
});
|
||||||
return res.status(422).send(info);
|
return res.status(404).send(info);
|
||||||
|
}
|
||||||
|
if (info && info.message) {
|
||||||
|
log({
|
||||||
|
title: '(requireLocalAuth) Error: ' + info.message,
|
||||||
|
});
|
||||||
|
return res.status(422).send({ message: info.message });
|
||||||
}
|
}
|
||||||
req.user = user;
|
req.user = user;
|
||||||
next();
|
next();
|
||||||
|
|||||||
14
api/server/middleware/roles/checkAdmin.js
Normal file
14
api/server/middleware/roles/checkAdmin.js
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
const { SystemRoles } = require('librechat-data-provider');
|
||||||
|
|
||||||
|
function checkAdmin(req, res, next) {
|
||||||
|
try {
|
||||||
|
if (req.user.role !== SystemRoles.ADMIN) {
|
||||||
|
return res.status(403).json({ message: 'Forbidden' });
|
||||||
|
}
|
||||||
|
next();
|
||||||
|
} catch (error) {
|
||||||
|
res.status(500).json({ message: 'Internal Server Error' });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = checkAdmin;
|
||||||
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
const { SystemRoles } = require('librechat-data-provider');
|
||||||
|
const { getRoleByName } = require('~/models/Role');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||||
|
*
|
||||||
|
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||||
|
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||||
|
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||||
|
* @returns {Function} Express middleware function.
|
||||||
|
*/
|
||||||
|
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||||
|
return async (req, res, next) => {
|
||||||
|
try {
|
||||||
|
const { user } = req;
|
||||||
|
if (!user) {
|
||||||
|
return res.status(401).json({ message: 'Authorization required' });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (user.role === SystemRoles.ADMIN) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
const role = await getRoleByName(user.role);
|
||||||
|
if (role && role[permissionType]) {
|
||||||
|
const hasAnyPermission = permissions.some((permission) => {
|
||||||
|
if (role[permissionType][permission]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bodyProps[permission] && req.body) {
|
||||||
|
return bodyProps[permission].some((prop) =>
|
||||||
|
Object.prototype.hasOwnProperty.call(req.body, prop),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (hasAnyPermission) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||||
|
} catch (error) {
|
||||||
|
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
module.exports = generateCheckAccess;
|
||||||
7
api/server/middleware/roles/index.js
Normal file
7
api/server/middleware/roles/index.js
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
const checkAdmin = require('./checkAdmin');
|
||||||
|
const generateCheckAccess = require('./generateCheckAccess');
|
||||||
|
|
||||||
|
module.exports = {
|
||||||
|
checkAdmin,
|
||||||
|
generateCheckAccess,
|
||||||
|
};
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
const { getConvo } = require('../../models');
|
const { getConvo } = require('~/models');
|
||||||
|
|
||||||
// Middleware to validate conversationId and user relationship
|
// Middleware to validate conversationId and user relationship
|
||||||
const validateMessageReq = async (req, res, next) => {
|
const validateMessageReq = async (req, res, next) => {
|
||||||
|
|||||||
13
api/server/middleware/validatePasswordReset.js
Normal file
13
api/server/middleware/validatePasswordReset.js
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
const { isEnabled } = require('~/server/utils');
|
||||||
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
function validatePasswordReset(req, res, next) {
|
||||||
|
if (isEnabled(process.env.ALLOW_PASSWORD_RESET)) {
|
||||||
|
next();
|
||||||
|
} else {
|
||||||
|
logger.warn(`Password reset attempt while not allowed. IP: ${req.ip}`);
|
||||||
|
res.status(403).send('Password reset is not allowed.');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.exports = validatePasswordReset;
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
|
const { isEnabled } = require('~/server/utils');
|
||||||
|
|
||||||
function validateRegistration(req, res, next) {
|
function validateRegistration(req, res, next) {
|
||||||
const setting = process.env.ALLOW_REGISTRATION?.toLowerCase();
|
if (isEnabled(process.env.ALLOW_REGISTRATION)) {
|
||||||
if (setting === 'true') {
|
|
||||||
next();
|
next();
|
||||||
} else {
|
} else {
|
||||||
res.status(403).send('Registration is not allowed.');
|
res.status(403).send('Registration is not allowed.');
|
||||||
|
|||||||
@@ -25,6 +25,12 @@ afterEach(() => {
|
|||||||
delete process.env.DOMAIN_SERVER;
|
delete process.env.DOMAIN_SERVER;
|
||||||
delete process.env.ALLOW_REGISTRATION;
|
delete process.env.ALLOW_REGISTRATION;
|
||||||
delete process.env.ALLOW_SOCIAL_LOGIN;
|
delete process.env.ALLOW_SOCIAL_LOGIN;
|
||||||
|
delete process.env.ALLOW_PASSWORD_RESET;
|
||||||
|
delete process.env.LDAP_URL;
|
||||||
|
delete process.env.LDAP_BIND_DN;
|
||||||
|
delete process.env.LDAP_BIND_CREDENTIALS;
|
||||||
|
delete process.env.LDAP_USER_SEARCH_BASE;
|
||||||
|
delete process.env.LDAP_SEARCH_FILTER;
|
||||||
});
|
});
|
||||||
|
|
||||||
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
||||||
@@ -50,6 +56,12 @@ describe.skip('GET /', () => {
|
|||||||
process.env.DOMAIN_SERVER = 'http://test-server.com';
|
process.env.DOMAIN_SERVER = 'http://test-server.com';
|
||||||
process.env.ALLOW_REGISTRATION = 'true';
|
process.env.ALLOW_REGISTRATION = 'true';
|
||||||
process.env.ALLOW_SOCIAL_LOGIN = 'true';
|
process.env.ALLOW_SOCIAL_LOGIN = 'true';
|
||||||
|
process.env.ALLOW_PASSWORD_RESET = 'true';
|
||||||
|
process.env.LDAP_URL = 'Test LDAP URL';
|
||||||
|
process.env.LDAP_BIND_DN = 'Test LDAP Bind DN';
|
||||||
|
process.env.LDAP_BIND_CREDENTIALS = 'Test LDAP Bind Credentials';
|
||||||
|
process.env.LDAP_USER_SEARCH_BASE = 'Test LDAP User Search Base';
|
||||||
|
process.env.LDAP_SEARCH_FILTER = 'Test LDAP Search Filter';
|
||||||
|
|
||||||
const response = await request(app).get('/');
|
const response = await request(app).get('/');
|
||||||
|
|
||||||
@@ -64,9 +76,11 @@ describe.skip('GET /', () => {
|
|||||||
openidLoginEnabled: true,
|
openidLoginEnabled: true,
|
||||||
openidLabel: 'Test OpenID',
|
openidLabel: 'Test OpenID',
|
||||||
openidImageUrl: 'http://test-server.com',
|
openidImageUrl: 'http://test-server.com',
|
||||||
|
ldapLoginEnabled: true,
|
||||||
serverDomain: 'http://test-server.com',
|
serverDomain: 'http://test-server.com',
|
||||||
emailLoginEnabled: 'true',
|
emailLoginEnabled: 'true',
|
||||||
registrationEnabled: 'true',
|
registrationEnabled: 'true',
|
||||||
|
passwordResetEnabled: 'true',
|
||||||
socialLoginEnabled: 'true',
|
socialLoginEnabled: 'true',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
const express = require('express');
|
const express = require('express');
|
||||||
const AskController = require('~/server/controllers/AskController');
|
const AskController = require('~/server/controllers/AskController');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/google');
|
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
|
||||||
const {
|
const {
|
||||||
setHeaders,
|
setHeaders,
|
||||||
handleAbort,
|
handleAbort,
|
||||||
@@ -20,7 +20,7 @@ router.post(
|
|||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
setHeaders,
|
setHeaders,
|
||||||
async (req, res, next) => {
|
async (req, res, next) => {
|
||||||
await AskController(req, res, next, initializeClient);
|
await AskController(req, res, next, initializeClient, addTitle);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ const express = require('express');
|
|||||||
const throttle = require('lodash/throttle');
|
const throttle = require('lodash/throttle');
|
||||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
|
||||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||||
|
const { saveMessage } = require('~/models');
|
||||||
const {
|
const {
|
||||||
handleAbort,
|
handleAbort,
|
||||||
createAbortController,
|
createAbortController,
|
||||||
@@ -41,6 +41,7 @@ router.post(
|
|||||||
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
|
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
|
||||||
|
|
||||||
let userMessage;
|
let userMessage;
|
||||||
|
let userMessagePromise;
|
||||||
let promptTokens;
|
let promptTokens;
|
||||||
let userMessageId;
|
let userMessageId;
|
||||||
let responseMessageId;
|
let responseMessageId;
|
||||||
@@ -58,6 +59,8 @@ router.post(
|
|||||||
if (key === 'userMessage') {
|
if (key === 'userMessage') {
|
||||||
userMessage = data[key];
|
userMessage = data[key];
|
||||||
userMessageId = data[key].messageId;
|
userMessageId = data[key].messageId;
|
||||||
|
} else if (key === 'userMessagePromise') {
|
||||||
|
userMessagePromise = data[key];
|
||||||
} else if (key === 'responseMessageId') {
|
} else if (key === 'responseMessageId') {
|
||||||
responseMessageId = data[key];
|
responseMessageId = data[key];
|
||||||
} else if (key === 'promptTokens') {
|
} else if (key === 'promptTokens') {
|
||||||
@@ -106,7 +109,11 @@ router.post(
|
|||||||
const pluginMap = new Map();
|
const pluginMap = new Map();
|
||||||
const onAgentAction = async (action, runId) => {
|
const onAgentAction = async (action, runId) => {
|
||||||
pluginMap.set(runId, action.tool);
|
pluginMap.set(runId, action.tool);
|
||||||
sendIntermediateMessage(res, { plugins });
|
sendIntermediateMessage(res, {
|
||||||
|
plugins,
|
||||||
|
parentMessageId: userMessage.messageId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const onToolStart = async (tool, input, runId, parentRunId) => {
|
const onToolStart = async (tool, input, runId, parentRunId) => {
|
||||||
@@ -124,7 +131,11 @@ router.post(
|
|||||||
}
|
}
|
||||||
const extraTokens = ':::plugin:::\n';
|
const extraTokens = ':::plugin:::\n';
|
||||||
plugins.push(latestPlugin);
|
plugins.push(latestPlugin);
|
||||||
sendIntermediateMessage(res, { plugins }, extraTokens);
|
sendIntermediateMessage(
|
||||||
|
res,
|
||||||
|
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
|
||||||
|
extraTokens,
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const onToolEnd = async (output, runId) => {
|
const onToolEnd = async (output, runId) => {
|
||||||
@@ -140,14 +151,10 @@ router.post(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const onChainEnd = () => {
|
|
||||||
saveMessage({ ...userMessage, user });
|
|
||||||
sendIntermediateMessage(res, { plugins });
|
|
||||||
};
|
|
||||||
|
|
||||||
const getAbortData = () => ({
|
const getAbortData = () => ({
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
userMessagePromise,
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
text: getPartialText(),
|
text: getPartialText(),
|
||||||
@@ -155,12 +162,23 @@ router.post(
|
|||||||
userMessage,
|
userMessage,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
});
|
});
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||||
const { client } = await initializeClient({ req, res, endpointOption });
|
const { client } = await initializeClient({ req, res, endpointOption });
|
||||||
|
|
||||||
|
const onChainEnd = () => {
|
||||||
|
if (!client.skipSaveUserMessage) {
|
||||||
|
saveMessage({ ...userMessage, user });
|
||||||
|
}
|
||||||
|
sendIntermediateMessage(res, {
|
||||||
|
plugins,
|
||||||
|
parentMessageId: userMessage.messageId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
let response = await client.sendMessage(text, {
|
let response = await client.sendMessage(text, {
|
||||||
user,
|
user,
|
||||||
conversationId,
|
conversationId,
|
||||||
@@ -174,12 +192,12 @@ router.post(
|
|||||||
onStart,
|
onStart,
|
||||||
getPartialText,
|
getPartialText,
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
onProgress: progressCallback.call(null, {
|
progressCallback,
|
||||||
|
progressOptions: {
|
||||||
res,
|
res,
|
||||||
text,
|
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
|
||||||
plugins,
|
plugins,
|
||||||
}),
|
},
|
||||||
abortController,
|
abortController,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -192,10 +210,14 @@ router.post(
|
|||||||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
||||||
await saveMessage({ ...response, user });
|
await saveMessage({ ...response, user });
|
||||||
|
|
||||||
|
const { conversation = {} } = await client.responsePromise;
|
||||||
|
conversation.title =
|
||||||
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
sendMessage(res, {
|
sendMessage(res, {
|
||||||
title: await getConvoTitle(user, conversationId),
|
title: conversation.title,
|
||||||
final: true,
|
final: true,
|
||||||
conversation: await getConvo(user, conversationId),
|
conversation,
|
||||||
requestMessage: userMessage,
|
requestMessage: userMessage,
|
||||||
responseMessage: response,
|
responseMessage: response,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ const { encryptMetadata, domainParser } = require('~/server/services/ActionServi
|
|||||||
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
|
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
|
||||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||||
const { updateAssistant, getAssistant } = require('~/models/Assistant');
|
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
@@ -109,7 +109,7 @@ router.post('/:assistant_id', async (req, res) => {
|
|||||||
let updatedAssistant = await openai.beta.assistants.update(assistant_id, { tools });
|
let updatedAssistant = await openai.beta.assistants.update(assistant_id, { tools });
|
||||||
const promises = [];
|
const promises = [];
|
||||||
promises.push(
|
promises.push(
|
||||||
updateAssistant(
|
updateAssistantDoc(
|
||||||
{ assistant_id },
|
{ assistant_id },
|
||||||
{
|
{
|
||||||
actions,
|
actions,
|
||||||
@@ -186,7 +186,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
|
|||||||
|
|
||||||
const promises = [];
|
const promises = [];
|
||||||
promises.push(
|
promises.push(
|
||||||
updateAssistant(
|
updateAssistantDoc(
|
||||||
{ assistant_id },
|
{ assistant_id },
|
||||||
{
|
{
|
||||||
actions: updatedActions,
|
actions: updatedActions,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ const {
|
|||||||
// validateEndpoint,
|
// validateEndpoint,
|
||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
|
const validateAssistant = require('~/server/middleware/assistants/validate');
|
||||||
const chatController = require('~/server/controllers/assistants/chatV1');
|
const chatController = require('~/server/controllers/assistants/chatV1');
|
||||||
|
|
||||||
router.post('/abort', handleAbort());
|
router.post('/abort', handleAbort());
|
||||||
@@ -20,6 +21,6 @@ router.post('/abort', handleAbort());
|
|||||||
* @param {express.Response} res - The response object, used to send back a response.
|
* @param {express.Response} res - The response object, used to send back a response.
|
||||||
* @returns {void}
|
* @returns {void}
|
||||||
*/
|
*/
|
||||||
router.post('/', validateModel, buildEndpointOption, setHeaders, chatController);
|
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ const {
|
|||||||
// validateEndpoint,
|
// validateEndpoint,
|
||||||
buildEndpointOption,
|
buildEndpointOption,
|
||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
|
const validateAssistant = require('~/server/middleware/assistants/validate');
|
||||||
const chatController = require('~/server/controllers/assistants/chatV2');
|
const chatController = require('~/server/controllers/assistants/chatV2');
|
||||||
|
|
||||||
router.post('/abort', handleAbort());
|
router.post('/abort', handleAbort());
|
||||||
@@ -20,6 +21,6 @@ router.post('/abort', handleAbort());
|
|||||||
* @param {express.Response} res - The response object, used to send back a response.
|
* @param {express.Response} res - The response object, used to send back a response.
|
||||||
* @returns {void}
|
* @returns {void}
|
||||||
*/
|
*/
|
||||||
router.post('/', validateModel, buildEndpointOption, setHeaders, chatController);
|
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|||||||
@@ -1,29 +1,45 @@
|
|||||||
const express = require('express');
|
const express = require('express');
|
||||||
const {
|
const {
|
||||||
resetPasswordRequestController,
|
|
||||||
resetPasswordController,
|
|
||||||
refreshController,
|
refreshController,
|
||||||
registrationController,
|
registrationController,
|
||||||
} = require('../controllers/AuthController');
|
resetPasswordController,
|
||||||
const { loginController } = require('../controllers/auth/LoginController');
|
resetPasswordRequestController,
|
||||||
const { logoutController } = require('../controllers/auth/LogoutController');
|
} = require('~/server/controllers/AuthController');
|
||||||
|
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||||
|
const { logoutController } = require('~/server/controllers/auth/LogoutController');
|
||||||
const {
|
const {
|
||||||
checkBan,
|
checkBan,
|
||||||
loginLimiter,
|
loginLimiter,
|
||||||
registerLimiter,
|
|
||||||
requireJwtAuth,
|
requireJwtAuth,
|
||||||
|
registerLimiter,
|
||||||
|
requireLdapAuth,
|
||||||
requireLocalAuth,
|
requireLocalAuth,
|
||||||
|
resetPasswordLimiter,
|
||||||
validateRegistration,
|
validateRegistration,
|
||||||
} = require('../middleware');
|
validatePasswordReset,
|
||||||
|
} = require('~/server/middleware');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
|
|
||||||
|
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||||
//Local
|
//Local
|
||||||
router.post('/logout', requireJwtAuth, logoutController);
|
router.post('/logout', requireJwtAuth, logoutController);
|
||||||
router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController);
|
router.post(
|
||||||
|
'/login',
|
||||||
|
loginLimiter,
|
||||||
|
checkBan,
|
||||||
|
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||||
|
loginController,
|
||||||
|
);
|
||||||
router.post('/refresh', refreshController);
|
router.post('/refresh', refreshController);
|
||||||
router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController);
|
router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController);
|
||||||
router.post('/requestPasswordReset', resetPasswordRequestController);
|
router.post(
|
||||||
router.post('/resetPassword', resetPasswordController);
|
'/requestPasswordReset',
|
||||||
|
resetPasswordLimiter,
|
||||||
|
checkBan,
|
||||||
|
validatePasswordReset,
|
||||||
|
resetPasswordRequestController,
|
||||||
|
);
|
||||||
|
router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController);
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|||||||
15
api/server/routes/categories.js
Normal file
15
api/server/routes/categories.js
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
const { requireJwtAuth } = require('~/server/middleware');
|
||||||
|
const { getCategories } = require('~/models/Categories');
|
||||||
|
|
||||||
|
router.get('/', requireJwtAuth, async (req, res) => {
|
||||||
|
try {
|
||||||
|
const categories = await getCategories();
|
||||||
|
res.status(200).send(categories);
|
||||||
|
} catch (error) {
|
||||||
|
res.status(500).send({ message: 'Failed to retrieve categories', error: error.message });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
@@ -1,18 +1,39 @@
|
|||||||
const express = require('express');
|
const express = require('express');
|
||||||
const { defaultSocialLogins } = require('librechat-data-provider');
|
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
|
||||||
|
const { getProjectByName } = require('~/models/Project');
|
||||||
const { isEnabled } = require('~/server/utils');
|
const { isEnabled } = require('~/server/utils');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const emailLoginEnabled =
|
const emailLoginEnabled =
|
||||||
process.env.ALLOW_EMAIL_LOGIN === undefined || isEnabled(process.env.ALLOW_EMAIL_LOGIN);
|
process.env.ALLOW_EMAIL_LOGIN === undefined || isEnabled(process.env.ALLOW_EMAIL_LOGIN);
|
||||||
|
const passwordResetEnabled = isEnabled(process.env.ALLOW_PASSWORD_RESET);
|
||||||
|
|
||||||
|
const sharedLinksEnabled =
|
||||||
|
process.env.ALLOW_SHARED_LINKS === undefined || isEnabled(process.env.ALLOW_SHARED_LINKS);
|
||||||
|
|
||||||
|
const publicSharedLinksEnabled =
|
||||||
|
sharedLinksEnabled &&
|
||||||
|
(process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
|
||||||
|
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
|
||||||
|
|
||||||
router.get('/', async function (req, res) {
|
router.get('/', async function (req, res) {
|
||||||
|
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||||
|
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
|
||||||
|
if (cachedStartupConfig) {
|
||||||
|
res.send(cachedStartupConfig);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const isBirthday = () => {
|
const isBirthday = () => {
|
||||||
const today = new Date();
|
const today = new Date();
|
||||||
return today.getMonth() === 1 && today.getDate() === 11;
|
return today.getMonth() === 1 && today.getDate() === 11;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const instanceProject = await getProjectByName('instance', '_id');
|
||||||
|
|
||||||
|
const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||||
try {
|
try {
|
||||||
/** @type {TStartupConfig} */
|
/** @type {TStartupConfig} */
|
||||||
const payload = {
|
const payload = {
|
||||||
@@ -30,15 +51,17 @@ router.get('/', async function (req, res) {
|
|||||||
!!process.env.OPENID_SESSION_SECRET,
|
!!process.env.OPENID_SESSION_SECRET,
|
||||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||||
|
ldapLoginEnabled,
|
||||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||||
emailLoginEnabled,
|
emailLoginEnabled,
|
||||||
registrationEnabled: isEnabled(process.env.ALLOW_REGISTRATION),
|
registrationEnabled: !ldapLoginEnabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||||
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
||||||
emailEnabled:
|
emailEnabled:
|
||||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||||
!!process.env.EMAIL_USERNAME &&
|
!!process.env.EMAIL_USERNAME &&
|
||||||
!!process.env.EMAIL_PASSWORD &&
|
!!process.env.EMAIL_PASSWORD &&
|
||||||
!!process.env.EMAIL_FROM,
|
!!process.env.EMAIL_FROM,
|
||||||
|
passwordResetEnabled,
|
||||||
checkBalance: isEnabled(process.env.CHECK_BALANCE),
|
checkBalance: isEnabled(process.env.CHECK_BALANCE),
|
||||||
showBirthdayIcon:
|
showBirthdayIcon:
|
||||||
isBirthday() ||
|
isBirthday() ||
|
||||||
@@ -47,12 +70,17 @@ router.get('/', async function (req, res) {
|
|||||||
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
|
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
|
||||||
interface: req.app.locals.interfaceConfig,
|
interface: req.app.locals.interfaceConfig,
|
||||||
modelSpecs: req.app.locals.modelSpecs,
|
modelSpecs: req.app.locals.modelSpecs,
|
||||||
|
sharedLinksEnabled,
|
||||||
|
publicSharedLinksEnabled,
|
||||||
|
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
|
||||||
|
instanceProjectId: instanceProject._id.toString(),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (typeof process.env.CUSTOM_FOOTER === 'string') {
|
if (typeof process.env.CUSTOM_FOOTER === 'string') {
|
||||||
payload.customFooter = process.env.CUSTOM_FOOTER;
|
payload.customFooter = process.env.CUSTOM_FOOTER;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
|
||||||
return res.status(200).send(payload);
|
return res.status(200).send(payload);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
logger.error('Error in startup config', err);
|
logger.error('Error in startup config', err);
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ const express = require('express');
|
|||||||
const { CacheKeys } = require('librechat-data-provider');
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||||
const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation');
|
||||||
const { IMPORT_CONVERSATION_JOB_NAME } = require('~/server/utils/import/jobDefinition');
|
|
||||||
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
const { storage, importFileFilter } = require('~/server/routes/files/multer');
|
||||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||||
const { forkConversation } = require('~/server/utils/import/fork');
|
const { forkConversation } = require('~/server/utils/import/fork');
|
||||||
|
const { importConversations } = require('~/server/utils/import');
|
||||||
const { createImportLimiters } = require('~/server/middleware');
|
const { createImportLimiters } = require('~/server/middleware');
|
||||||
const jobScheduler = require('~/server/utils/jobScheduler');
|
|
||||||
const getLogStores = require('~/cache/getLogStores');
|
const getLogStores = require('~/cache/getLogStores');
|
||||||
const { sleep } = require('~/server/utils');
|
const { sleep } = require('~/server/utils');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
@@ -129,10 +128,9 @@ router.post(
|
|||||||
upload.single('file'),
|
upload.single('file'),
|
||||||
async (req, res) => {
|
async (req, res) => {
|
||||||
try {
|
try {
|
||||||
const filepath = req.file.path;
|
/* TODO: optimize to return imported conversations and add manually */
|
||||||
const job = await jobScheduler.now(IMPORT_CONVERSATION_JOB_NAME, filepath, req.user.id);
|
await importConversations({ filepath: req.file.path, requestUserId: req.user.id });
|
||||||
|
res.status(201).json({ message: 'Conversation(s) imported successfully' });
|
||||||
res.status(201).json({ message: 'Import started', jobId: job.id });
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error processing file', error);
|
logger.error('Error processing file', error);
|
||||||
res.status(500).send('Error processing file');
|
res.status(500).send('Error processing file');
|
||||||
@@ -169,24 +167,4 @@ router.post('/fork', async (req, res) => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Get the status of an import job for polling
|
|
||||||
router.get('/import/jobs/:jobId', async (req, res) => {
|
|
||||||
try {
|
|
||||||
const { jobId } = req.params;
|
|
||||||
const { userId, ...jobStatus } = await jobScheduler.getJobStatus(jobId);
|
|
||||||
if (!jobStatus) {
|
|
||||||
return res.status(404).json({ message: 'Job not found.' });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (userId !== req.user.id) {
|
|
||||||
return res.status(403).json({ message: 'Unauthorized' });
|
|
||||||
}
|
|
||||||
|
|
||||||
res.json(jobStatus);
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Error getting job details', error);
|
|
||||||
res.status(500).send('Error getting job details');
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ const {
|
|||||||
} = require('~/server/middleware');
|
} = require('~/server/middleware');
|
||||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
||||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
const { saveMessage } = require('~/models');
|
||||||
const { validateTools } = require('~/app');
|
const { validateTools } = require('~/app');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
@@ -49,6 +49,7 @@ router.post(
|
|||||||
});
|
});
|
||||||
|
|
||||||
let userMessage;
|
let userMessage;
|
||||||
|
let userMessagePromise;
|
||||||
let promptTokens;
|
let promptTokens;
|
||||||
const sender = getResponseSender({
|
const sender = getResponseSender({
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
@@ -68,6 +69,8 @@ router.post(
|
|||||||
for (let key in data) {
|
for (let key in data) {
|
||||||
if (key === 'userMessage') {
|
if (key === 'userMessage') {
|
||||||
userMessage = data[key];
|
userMessage = data[key];
|
||||||
|
} else if (key === 'userMessagePromise') {
|
||||||
|
userMessagePromise = data[key];
|
||||||
} else if (key === 'responseMessageId') {
|
} else if (key === 'responseMessageId') {
|
||||||
responseMessageId = data[key];
|
responseMessageId = data[key];
|
||||||
} else if (key === 'promptTokens') {
|
} else if (key === 'promptTokens') {
|
||||||
@@ -103,29 +106,23 @@ router.post(
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const onAgentAction = (action, start = false) => {
|
|
||||||
const formattedAction = formatAction(action);
|
|
||||||
plugin.inputs.push(formattedAction);
|
|
||||||
plugin.latest = formattedAction.plugin;
|
|
||||||
if (!start) {
|
|
||||||
saveMessage({ ...userMessage, user });
|
|
||||||
}
|
|
||||||
sendIntermediateMessage(res, { plugin });
|
|
||||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onChainEnd = (data) => {
|
const onChainEnd = (data) => {
|
||||||
let { intermediateSteps: steps } = data;
|
let { intermediateSteps: steps } = data;
|
||||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||||
plugin.loading = false;
|
plugin.loading = false;
|
||||||
saveMessage({ ...userMessage, user });
|
saveMessage({ ...userMessage, user });
|
||||||
sendIntermediateMessage(res, { plugin });
|
sendIntermediateMessage(res, {
|
||||||
|
plugin,
|
||||||
|
parentMessageId: userMessage.messageId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
});
|
||||||
// logger.debug('CHAIN END', plugin.outputs);
|
// logger.debug('CHAIN END', plugin.outputs);
|
||||||
};
|
};
|
||||||
|
|
||||||
const getAbortData = () => ({
|
const getAbortData = () => ({
|
||||||
sender,
|
sender,
|
||||||
conversationId,
|
conversationId,
|
||||||
|
userMessagePromise,
|
||||||
messageId: responseMessageId,
|
messageId: responseMessageId,
|
||||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||||
text: getPartialText(),
|
text: getPartialText(),
|
||||||
@@ -133,12 +130,27 @@ router.post(
|
|||||||
userMessage,
|
userMessage,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
});
|
});
|
||||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||||
const { client } = await initializeClient({ req, res, endpointOption });
|
const { client } = await initializeClient({ req, res, endpointOption });
|
||||||
|
|
||||||
|
const onAgentAction = (action, start = false) => {
|
||||||
|
const formattedAction = formatAction(action);
|
||||||
|
plugin.inputs.push(formattedAction);
|
||||||
|
plugin.latest = formattedAction.plugin;
|
||||||
|
if (!start && !client.skipSaveUserMessage) {
|
||||||
|
saveMessage({ ...userMessage, user });
|
||||||
|
}
|
||||||
|
sendIntermediateMessage(res, {
|
||||||
|
plugin,
|
||||||
|
parentMessageId: userMessage.messageId,
|
||||||
|
messageId: responseMessageId,
|
||||||
|
});
|
||||||
|
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||||
|
};
|
||||||
|
|
||||||
let response = await client.sendMessage(text, {
|
let response = await client.sendMessage(text, {
|
||||||
user,
|
user,
|
||||||
generation,
|
generation,
|
||||||
@@ -153,12 +165,12 @@ router.post(
|
|||||||
onChainEnd,
|
onChainEnd,
|
||||||
onStart,
|
onStart,
|
||||||
...endpointOption,
|
...endpointOption,
|
||||||
onProgress: progressCallback.call(null, {
|
progressCallback,
|
||||||
|
progressOptions: {
|
||||||
res,
|
res,
|
||||||
text,
|
|
||||||
plugin,
|
plugin,
|
||||||
parentMessageId: overrideParentMessageId || userMessageId,
|
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||||
}),
|
},
|
||||||
abortController,
|
abortController,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -170,10 +182,14 @@ router.post(
|
|||||||
response.plugin = { ...plugin, loading: false };
|
response.plugin = { ...plugin, loading: false };
|
||||||
await saveMessage({ ...response, user });
|
await saveMessage({ ...response, user });
|
||||||
|
|
||||||
|
const { conversation = {} } = await client.responsePromise;
|
||||||
|
conversation.title =
|
||||||
|
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||||
|
|
||||||
sendMessage(res, {
|
sendMessage(res, {
|
||||||
title: await getConvoTitle(user, conversationId),
|
title: conversation.title,
|
||||||
final: true,
|
final: true,
|
||||||
conversation: await getConvo(user, conversationId),
|
conversation,
|
||||||
requestMessage: userMessage,
|
requestMessage: userMessage,
|
||||||
responseMessage: response,
|
responseMessage: response,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ const { createMulterInstance } = require('./multer');
|
|||||||
const files = require('./files');
|
const files = require('./files');
|
||||||
const images = require('./images');
|
const images = require('./images');
|
||||||
const avatar = require('./avatar');
|
const avatar = require('./avatar');
|
||||||
|
const speech = require('./speech');
|
||||||
|
|
||||||
const initialize = async () => {
|
const initialize = async () => {
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
@@ -12,6 +13,9 @@ const initialize = async () => {
|
|||||||
router.use(checkBan);
|
router.use(checkBan);
|
||||||
router.use(uaParser);
|
router.use(uaParser);
|
||||||
|
|
||||||
|
/* Important: speech route must be added before the upload limiters */
|
||||||
|
router.use('/speech', speech);
|
||||||
|
|
||||||
const upload = await createMulterInstance();
|
const upload = await createMulterInstance();
|
||||||
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
|
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
|
||||||
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);
|
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);
|
||||||
|
|||||||
10
api/server/routes/files/speech/customConfigSpeech.js
Normal file
10
api/server/routes/files/speech/customConfigSpeech.js
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
|
||||||
|
const { getCustomConfigSpeech } = require('~/server/services/Files/Audio');
|
||||||
|
|
||||||
|
router.get('/get', async (req, res) => {
|
||||||
|
await getCustomConfigSpeech(req, res);
|
||||||
|
});
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
17
api/server/routes/files/speech/index.js
Normal file
17
api/server/routes/files/speech/index.js
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
const express = require('express');
|
||||||
|
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware');
|
||||||
|
|
||||||
|
const stt = require('./stt');
|
||||||
|
const tts = require('./tts');
|
||||||
|
const customConfigSpeech = require('./customConfigSpeech');
|
||||||
|
|
||||||
|
const router = express.Router();
|
||||||
|
|
||||||
|
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||||
|
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||||
|
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||||
|
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||||
|
|
||||||
|
router.use('/config', customConfigSpeech);
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
13
api/server/routes/files/speech/stt.js
Normal file
13
api/server/routes/files/speech/stt.js
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
const express = require('express');
|
||||||
|
const router = express.Router();
|
||||||
|
const multer = require('multer');
|
||||||
|
const { requireJwtAuth } = require('~/server/middleware/');
|
||||||
|
const { speechToText } = require('~/server/services/Files/Audio');
|
||||||
|
|
||||||
|
const upload = multer();
|
||||||
|
|
||||||
|
router.post('/', requireJwtAuth, upload.single('audio'), async (req, res) => {
|
||||||
|
await speechToText(req, res);
|
||||||
|
});
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
42
api/server/routes/files/speech/tts.js
Normal file
42
api/server/routes/files/speech/tts.js
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
const multer = require('multer');
|
||||||
|
const express = require('express');
|
||||||
|
const { CacheKeys } = require('librechat-data-provider');
|
||||||
|
const { getVoices, streamAudio, textToSpeech } = require('~/server/services/Files/Audio');
|
||||||
|
const { getLogStores } = require('~/cache');
|
||||||
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
const router = express.Router();
|
||||||
|
const upload = multer();
|
||||||
|
|
||||||
|
router.post('/manual', upload.none(), async (req, res) => {
|
||||||
|
await textToSpeech(req, res);
|
||||||
|
});
|
||||||
|
|
||||||
|
const logDebugMessage = (req, message) =>
|
||||||
|
logger.debug(`[streamAudio] user: ${req?.user?.id ?? 'UNDEFINED_USER'} | ${message}`);
|
||||||
|
|
||||||
|
// TODO: test caching
|
||||||
|
router.post('/', async (req, res) => {
|
||||||
|
try {
|
||||||
|
const audioRunsCache = getLogStores(CacheKeys.AUDIO_RUNS);
|
||||||
|
const audioRun = await audioRunsCache.get(req.body.runId);
|
||||||
|
logDebugMessage(req, 'start stream audio');
|
||||||
|
if (audioRun) {
|
||||||
|
logDebugMessage(req, 'stream audio already running');
|
||||||
|
return res.status(401).json({ error: 'Audio stream already running' });
|
||||||
|
}
|
||||||
|
audioRunsCache.set(req.body.runId, true);
|
||||||
|
await streamAudio(req, res);
|
||||||
|
logDebugMessage(req, 'end stream audio');
|
||||||
|
res.status(200).end();
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(`[streamAudio] user: ${req.user.id} | Failed to stream audio: ${error}`);
|
||||||
|
res.status(500).json({ error: 'Failed to stream audio' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
router.get('/voices', async (req, res) => {
|
||||||
|
await getVoices(req, res);
|
||||||
|
});
|
||||||
|
|
||||||
|
module.exports = router;
|
||||||
@@ -19,6 +19,8 @@ const assistants = require('./assistants');
|
|||||||
const files = require('./files');
|
const files = require('./files');
|
||||||
const staticRoute = require('./static');
|
const staticRoute = require('./static');
|
||||||
const share = require('./share');
|
const share = require('./share');
|
||||||
|
const categories = require('./categories');
|
||||||
|
const roles = require('./roles');
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
search,
|
search,
|
||||||
@@ -42,4 +44,6 @@ module.exports = {
|
|||||||
files,
|
files,
|
||||||
staticRoute,
|
staticRoute,
|
||||||
share,
|
share,
|
||||||
|
categories,
|
||||||
|
roles,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ router.use(requireJwtAuth);
|
|||||||
|
|
||||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||||
const { conversationId } = req.params;
|
const { conversationId } = req.params;
|
||||||
res.status(200).send(await getMessages({ conversationId }));
|
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
|
||||||
});
|
});
|
||||||
|
|
||||||
// CREATE
|
// CREATE
|
||||||
@@ -28,7 +28,7 @@ router.post('/:conversationId', validateMessageReq, async (req, res) => {
|
|||||||
// READ
|
// READ
|
||||||
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||||
const { conversationId, messageId } = req.params;
|
const { conversationId, messageId } = req.params;
|
||||||
res.status(200).send(await getMessages({ conversationId, messageId }));
|
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
|
||||||
});
|
});
|
||||||
|
|
||||||
// UPDATE
|
// UPDATE
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
|
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
|
||||||
|
|
||||||
const passport = require('passport');
|
|
||||||
const express = require('express');
|
const express = require('express');
|
||||||
const router = express.Router();
|
const passport = require('passport');
|
||||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
|
||||||
const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware');
|
const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware');
|
||||||
|
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||||
const { logger } = require('~/config');
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
|
const router = express.Router();
|
||||||
|
|
||||||
const domains = {
|
const domains = {
|
||||||
client: process.env.DOMAIN_CLIENT,
|
client: process.env.DOMAIN_CLIENT,
|
||||||
server: process.env.DOMAIN_SERVER,
|
server: process.env.DOMAIN_SERVER,
|
||||||
|
|||||||
@@ -1,14 +1,235 @@
|
|||||||
const express = require('express');
|
const express = require('express');
|
||||||
|
const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
|
||||||
|
const {
|
||||||
|
getPrompt,
|
||||||
|
getPrompts,
|
||||||
|
savePrompt,
|
||||||
|
deletePrompt,
|
||||||
|
getPromptGroup,
|
||||||
|
getPromptGroups,
|
||||||
|
updatePromptGroup,
|
||||||
|
deletePromptGroup,
|
||||||
|
createPromptGroup,
|
||||||
|
getAllPromptGroups,
|
||||||
|
// updatePromptLabels,
|
||||||
|
makePromptProduction,
|
||||||
|
} = require('~/models/Prompt');
|
||||||
|
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||||
|
const { logger } = require('~/config');
|
||||||
|
|
||||||
const router = express.Router();
|
const router = express.Router();
|
||||||
const { getPrompts } = require('../../models/Prompt');
|
|
||||||
|
const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
|
||||||
|
const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
|
||||||
|
Permissions.USE,
|
||||||
|
Permissions.CREATE,
|
||||||
|
]);
|
||||||
|
const checkGlobalPromptShare = generateCheckAccess(
|
||||||
|
PermissionTypes.PROMPTS,
|
||||||
|
[Permissions.USE, Permissions.CREATE],
|
||||||
|
{
|
||||||
|
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
router.use(requireJwtAuth);
|
||||||
|
router.use(checkPromptAccess);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Route to get single prompt group by its ID
|
||||||
|
* GET /groups/:groupId
|
||||||
|
*/
|
||||||
|
router.get('/groups/:groupId', async (req, res) => {
|
||||||
|
let groupId = req.params.groupId;
|
||||||
|
const author = req.user.id;
|
||||||
|
|
||||||
|
const query = {
|
||||||
|
_id: groupId,
|
||||||
|
$or: [{ projectIds: { $exists: true, $ne: [], $not: { $size: 0 } } }, { author }],
|
||||||
|
};
|
||||||
|
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
delete query.$or;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const group = await getPromptGroup(query);
|
||||||
|
|
||||||
|
if (!group) {
|
||||||
|
return res.status(404).send({ message: 'Prompt group not found' });
|
||||||
|
}
|
||||||
|
|
||||||
|
res.status(200).send(group);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error getting prompt group', error);
|
||||||
|
res.status(500).send({ message: 'Error getting prompt group' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Route to fetch all prompt groups
|
||||||
|
* GET /groups
|
||||||
|
*/
|
||||||
|
router.get('/all', async (req, res) => {
|
||||||
|
try {
|
||||||
|
const groups = await getAllPromptGroups(req, {
|
||||||
|
author: req.user._id,
|
||||||
|
});
|
||||||
|
res.status(200).send(groups);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Route to fetch paginated prompt groups with filters
|
||||||
|
* GET /groups
|
||||||
|
*/
|
||||||
|
router.get('/groups', async (req, res) => {
|
||||||
|
try {
|
||||||
|
const filter = req.query;
|
||||||
|
/* Note: The aggregation requires an ObjectId */
|
||||||
|
filter.author = req.user._id;
|
||||||
|
const groups = await getPromptGroups(req, filter);
|
||||||
|
res.status(200).send(groups);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Updates or creates a prompt + promptGroup
|
||||||
|
* @param {object} req
|
||||||
|
* @param {TCreatePrompt} req.body
|
||||||
|
* @param {Express.Response} res
|
||||||
|
*/
|
||||||
|
const createPrompt = async (req, res) => {
|
||||||
|
try {
|
||||||
|
const { prompt, group } = req.body;
|
||||||
|
if (!prompt) {
|
||||||
|
return res.status(400).send({ error: 'Prompt is required' });
|
||||||
|
}
|
||||||
|
|
||||||
|
const saveData = {
|
||||||
|
prompt,
|
||||||
|
group,
|
||||||
|
author: req.user.id,
|
||||||
|
authorName: req.user.name,
|
||||||
|
};
|
||||||
|
|
||||||
|
/** @type {TCreatePromptResponse} */
|
||||||
|
let result;
|
||||||
|
if (group && group.name) {
|
||||||
|
result = await createPromptGroup(saveData);
|
||||||
|
} else {
|
||||||
|
result = await savePrompt(saveData);
|
||||||
|
}
|
||||||
|
res.status(200).send(result);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error saving prompt' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
router.post('/', createPrompt);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Updates a prompt group
|
||||||
|
* @param {object} req
|
||||||
|
* @param {object} req.params - The request parameters
|
||||||
|
* @param {string} req.params.groupId - The group ID
|
||||||
|
* @param {TUpdatePromptGroupPayload} req.body - The request body
|
||||||
|
* @param {Express.Response} res
|
||||||
|
*/
|
||||||
|
const patchPromptGroup = async (req, res) => {
|
||||||
|
try {
|
||||||
|
const { groupId } = req.params;
|
||||||
|
const author = req.user.id;
|
||||||
|
const filter = { _id: groupId, author };
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
delete filter.author;
|
||||||
|
}
|
||||||
|
const promptGroup = await updatePromptGroup(filter, req.body);
|
||||||
|
res.status(200).send(promptGroup);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error updating prompt group' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
router.patch('/groups/:groupId', checkGlobalPromptShare, patchPromptGroup);
|
||||||
|
|
||||||
|
router.patch('/:promptId/tags/production', checkPromptCreate, async (req, res) => {
|
||||||
|
try {
|
||||||
|
const { promptId } = req.params;
|
||||||
|
const result = await makePromptProduction(promptId);
|
||||||
|
res.status(200).send(result);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error updating prompt production' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
router.get('/:promptId', async (req, res) => {
|
||||||
|
const { promptId } = req.params;
|
||||||
|
const author = req.user.id;
|
||||||
|
const query = { _id: promptId, author };
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
delete query.author;
|
||||||
|
}
|
||||||
|
const prompt = await getPrompt(query);
|
||||||
|
res.status(200).send(prompt);
|
||||||
|
});
|
||||||
|
|
||||||
router.get('/', async (req, res) => {
|
router.get('/', async (req, res) => {
|
||||||
let filter = {};
|
try {
|
||||||
// const { search } = req.body.arg;
|
const author = req.user.id;
|
||||||
// if (!!search) {
|
const { groupId } = req.query;
|
||||||
// filter = { conversationId };
|
const query = { groupId, author };
|
||||||
// }
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
res.status(200).send(await getPrompts(filter));
|
delete query.author;
|
||||||
|
}
|
||||||
|
const prompts = await getPrompts(query);
|
||||||
|
res.status(200).send(prompts);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error getting prompts' });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deletes a prompt
|
||||||
|
*
|
||||||
|
* @param {Express.Request} req - The request object.
|
||||||
|
* @param {TDeletePromptVariables} req.params - The request parameters
|
||||||
|
* @param {import('mongoose').ObjectId} req.params.promptId - The prompt ID
|
||||||
|
* @param {Express.Response} res - The response object.
|
||||||
|
* @return {TDeletePromptResponse} A promise that resolves when the prompt is deleted.
|
||||||
|
*/
|
||||||
|
const deletePromptController = async (req, res) => {
|
||||||
|
try {
|
||||||
|
const { promptId } = req.params;
|
||||||
|
const { groupId } = req.query;
|
||||||
|
const author = req.user.id;
|
||||||
|
const query = { promptId, groupId, author, role: req.user.role };
|
||||||
|
if (req.user.role === SystemRoles.ADMIN) {
|
||||||
|
delete query.author;
|
||||||
|
}
|
||||||
|
const result = await deletePrompt(query);
|
||||||
|
res.status(200).send(result);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(error);
|
||||||
|
res.status(500).send({ error: 'Error deleting prompt' });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
router.delete('/:promptId', checkPromptCreate, deletePromptController);
|
||||||
|
|
||||||
|
router.delete('/groups/:groupId', checkPromptCreate, async (req, res) => {
|
||||||
|
const { groupId } = req.params;
|
||||||
|
res.status(200).send(await deletePromptGroup(groupId));
|
||||||
});
|
});
|
||||||
|
|
||||||
module.exports = router;
|
module.exports = router;
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user