Compare commits
177 Commits
v0.7.2
...
re-add-dow
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c27e8566fb | ||
|
|
4ffdefc2a8 | ||
|
|
2d5f704695 | ||
|
|
3fd25920d4 | ||
|
|
d03f8285db | ||
|
|
e565e0faab | ||
|
|
d4d56281e3 | ||
|
|
14eff23b57 | ||
|
|
18fd8f1416 | ||
|
|
ba9cb71245 | ||
|
|
e5dfa06e6c | ||
|
|
0bd59c0efe | ||
|
|
344297021f | ||
|
|
620973436c | ||
|
|
422d1a2c91 | ||
|
|
2ad097647c | ||
|
|
9e7615f832 | ||
|
|
ee4dd1b2e9 | ||
|
|
f6125ccd59 | ||
|
|
1acd47a0f6 | ||
|
|
511f1336da | ||
|
|
5d40d0a37a | ||
|
|
1c282d1517 | ||
|
|
73dbf3eb20 | ||
|
|
237a0de8b6 | ||
|
|
d5782ac66c | ||
|
|
d5d188eebf | ||
|
|
25ea3b8e98 | ||
|
|
b1ec67ea42 | ||
|
|
f1bb5fa4c5 | ||
|
|
44d8596872 | ||
|
|
32d84c85ea | ||
|
|
0a1d38e318 | ||
|
|
5ef71a7a36 | ||
|
|
326069d7a6 | ||
|
|
785430daf5 | ||
|
|
03fe361917 | ||
|
|
b34a4ddac1 | ||
|
|
7d5b03dd98 | ||
|
|
f959ee302c | ||
|
|
cd00df69bb | ||
|
|
a05e2c1dcc | ||
|
|
87bdbda10a | ||
|
|
605a8ae8c9 | ||
|
|
a724635998 | ||
|
|
6c306a662c | ||
|
|
55f8d9910e | ||
|
|
7edb54889b | ||
|
|
71d9e841b1 | ||
|
|
e76777d298 | ||
|
|
1edbfdbce2 | ||
|
|
1aad315de6 | ||
|
|
5d985746cb | ||
|
|
04654014b2 | ||
|
|
456793772b | ||
|
|
a87d4e0b75 | ||
|
|
a2fd975cd5 | ||
|
|
83619de158 | ||
|
|
b8f2bee3fc | ||
|
|
81292bb4dd | ||
|
|
ed5ee1f86f | ||
|
|
791b0139bc | ||
|
|
156c52e293 | ||
|
|
eef894e608 | ||
|
|
e2867eecc9 | ||
|
|
dd563e0796 | ||
|
|
c99cf1b4b1 | ||
|
|
b5081bfe86 | ||
|
|
aac01df80c | ||
|
|
24467dd626 | ||
|
|
b2b469bd3d | ||
|
|
cec2e57ee9 | ||
|
|
a8c874267f | ||
|
|
a53312bbd4 | ||
|
|
ab74685476 | ||
|
|
015215b790 | ||
|
|
4e4de88faa | ||
|
|
3172381bad | ||
|
|
54b1095239 | ||
|
|
0424f8fe55 | ||
|
|
4319c62e66 | ||
|
|
d3a0b862db | ||
|
|
5d8793c5d1 | ||
|
|
54db67449a | ||
|
|
0cd3c83328 | ||
|
|
d839e4661c | ||
|
|
302b28fc9b | ||
|
|
dad25bd297 | ||
|
|
a338decf90 | ||
|
|
2cf5228021 | ||
|
|
0294cfc881 | ||
|
|
8d8b17e7ed | ||
|
|
04502e9525 | ||
|
|
bcaa7d5d29 | ||
|
|
c288b458b6 | ||
|
|
447bbcb8ca | ||
|
|
68bf7ac7c0 | ||
|
|
97d12d03d1 | ||
|
|
4416f69a9b | ||
|
|
29e71e98ad | ||
|
|
e9bbf39618 | ||
|
|
08b8ae120e | ||
|
|
803fd63121 | ||
|
|
ef76cc195e | ||
|
|
2e559137ae | ||
|
|
92232afaca | ||
|
|
084cf266a2 | ||
|
|
baf0848021 | ||
|
|
1da92111aa | ||
|
|
35f8053f45 | ||
|
|
ee673d682e | ||
|
|
b7fef6958b | ||
|
|
5452d4c20c | ||
|
|
a7f5b57272 | ||
|
|
f69b317171 | ||
|
|
4469ba72fc | ||
|
|
0e3e45e77d | ||
|
|
9f0c1914a5 | ||
|
|
37ae484fbc | ||
|
|
8939d8af37 | ||
|
|
f9a0166352 | ||
|
|
248dfb8b5b | ||
|
|
b8e35002f4 | ||
|
|
8318f26d66 | ||
|
|
08d6bea359 | ||
|
|
a6058c5669 | ||
|
|
e0402b71f0 | ||
|
|
a618266905 | ||
|
|
d5a7806e32 | ||
|
|
e2cb2905e7 | ||
|
|
3f600f0d3f | ||
|
|
c9e7d4ac18 | ||
|
|
40685f6eb4 | ||
|
|
0ee060d730 | ||
|
|
5dc5d875ba | ||
|
|
9f2538fcd9 | ||
|
|
2b7a973a33 | ||
|
|
c704a23749 | ||
|
|
eb5733083e | ||
|
|
b80f38e49e | ||
|
|
4369e75ca7 | ||
|
|
35ba4ba1a4 | ||
|
|
dcd2e3e62d | ||
|
|
514a502b9c | ||
|
|
8e66683577 | ||
|
|
dc1778b11f | ||
|
|
795bb9c568 | ||
|
|
a937650df6 | ||
|
|
6cf1c85363 | ||
|
|
b3e03b75d0 | ||
|
|
9d8fd92dd3 | ||
|
|
f00a8f87f7 | ||
|
|
79840763e7 | ||
|
|
1a452121fa | ||
|
|
af8bcb08d6 | ||
|
|
f0e8cca5df | ||
|
|
38ad36c1c5 | ||
|
|
53fe2f6453 | ||
|
|
31479d6a48 | ||
|
|
612a58737d | ||
|
|
8a7f36f581 | ||
|
|
4a5d06a774 | ||
|
|
fc9368e0e7 | ||
|
|
64bf0800a0 | ||
|
|
94eeec354e | ||
|
|
e42709bd1f | ||
|
|
638ac5bba6 | ||
|
|
5920672a8c | ||
|
|
a0d1e2a5f8 | ||
|
|
4ffc1414a8 | ||
|
|
df6183db0f | ||
|
|
3816219936 | ||
|
|
6fc664e4a3 | ||
|
|
89899164ed | ||
|
|
c83d9d61d4 | ||
|
|
bcdddaed72 | ||
|
|
a4de635719 |
70
.env.example
70
.env.example
@@ -2,11 +2,9 @@
|
||||
# LibreChat Configuration #
|
||||
#=====================================================================#
|
||||
# Please refer to the reference documentation for assistance #
|
||||
# with configuring your LibreChat environment. The guide is #
|
||||
# available both online and within your local LibreChat #
|
||||
# directory: #
|
||||
# Online: https://docs.librechat.ai/install/configuration/dotenv.html #
|
||||
# Locally: ./docs/install/configuration/dotenv.md #
|
||||
# with configuring your LibreChat environment. #
|
||||
# #
|
||||
# https://www.librechat.ai/docs/configuration/dotenv #
|
||||
#=====================================================================#
|
||||
|
||||
#==================================================#
|
||||
@@ -62,10 +60,12 @@ PROXY=
|
||||
#===================================#
|
||||
# Known Endpoints - librechat.yaml #
|
||||
#===================================#
|
||||
# https://docs.librechat.ai/install/configuration/ai_endpoints.html
|
||||
# https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints
|
||||
|
||||
# ANYSCALE_API_KEY=
|
||||
# APIPIE_API_KEY=
|
||||
# COHERE_API_KEY=
|
||||
# DATABRICKS_API_KEY=
|
||||
# FIREWORKS_API_KEY=
|
||||
# GROQ_API_KEY=
|
||||
# HUGGINGFACE_TOKEN=
|
||||
@@ -80,7 +80,7 @@ PROXY=
|
||||
#============#
|
||||
|
||||
ANTHROPIC_API_KEY=user_provided
|
||||
# ANTHROPIC_MODELS=claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||
# ANTHROPIC_MODELS=claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||
# ANTHROPIC_REVERSE_PROXY=
|
||||
|
||||
#============#
|
||||
@@ -118,10 +118,12 @@ GOOGLE_KEY=user_provided
|
||||
# GOOGLE_REVERSE_PROXY=
|
||||
|
||||
# Gemini API
|
||||
# GOOGLE_MODELS=gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-1.5-pro-preview-0409,gemini-1.0-pro-vision-001,gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||
|
||||
# GOOGLE_TITLE_MODEL=gemini-pro
|
||||
|
||||
# Google Gemini Safety Settings
|
||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
||||
@@ -142,7 +144,7 @@ GOOGLE_KEY=user_provided
|
||||
#============#
|
||||
|
||||
OPENAI_API_KEY=user_provided
|
||||
# OPENAI_MODELS=gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k
|
||||
# OPENAI_MODELS=gpt-4o,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k
|
||||
|
||||
DEBUG_OPENAI=false
|
||||
|
||||
@@ -164,7 +166,17 @@ DEBUG_OPENAI=false
|
||||
|
||||
ASSISTANTS_API_KEY=user_provided
|
||||
# ASSISTANTS_BASE_URL=
|
||||
# ASSISTANTS_MODELS=gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
|
||||
# ASSISTANTS_MODELS=gpt-4o,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
|
||||
|
||||
#==========================#
|
||||
# Azure Assistants API #
|
||||
#==========================#
|
||||
|
||||
# Note: You should map your credentials with custom variables according to your Azure OpenAI Configuration
|
||||
# The models for Azure Assistants are also determined by your Azure OpenAI configuration.
|
||||
|
||||
# More info, including how to enable use of Assistants with Azure here:
|
||||
# https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints/azure#using-assistants-with-azure
|
||||
|
||||
#============#
|
||||
# OpenRouter #
|
||||
@@ -176,7 +188,7 @@ ASSISTANTS_API_KEY=user_provided
|
||||
# Plugins #
|
||||
#============#
|
||||
|
||||
# PLUGIN_MODELS=gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613
|
||||
# PLUGIN_MODELS=gpt-4o,gpt-4o-mini,gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613
|
||||
|
||||
DEBUG_PLUGINS=true
|
||||
|
||||
@@ -249,6 +261,14 @@ MEILI_NO_ANALYTICS=true
|
||||
MEILI_HOST=http://0.0.0.0:7700
|
||||
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
|
||||
|
||||
|
||||
#==================================================#
|
||||
# Speech to Text & Text to Speech #
|
||||
#==================================================#
|
||||
|
||||
STT_API_KEY=
|
||||
TTS_API_KEY=
|
||||
|
||||
#===================================================#
|
||||
# User System #
|
||||
#===================================================#
|
||||
@@ -303,6 +323,9 @@ ALLOW_EMAIL_LOGIN=true
|
||||
ALLOW_REGISTRATION=true
|
||||
ALLOW_SOCIAL_LOGIN=false
|
||||
ALLOW_SOCIAL_REGISTRATION=false
|
||||
ALLOW_PASSWORD_RESET=false
|
||||
# ALLOW_ACCOUNT_DELETION=true # note: enabled by default if omitted/commented out
|
||||
ALLOW_UNVERIFIED_EMAIL_LOGIN=true
|
||||
|
||||
SESSION_EXPIRY=1000 * 60 * 15
|
||||
REFRESH_TOKEN_EXPIRY=(1000 * 60 * 60 * 24) * 7
|
||||
@@ -344,6 +367,19 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
|
||||
OPENID_BUTTON_LABEL=
|
||||
OPENID_IMAGE_URL=
|
||||
|
||||
# LDAP
|
||||
LDAP_URL=
|
||||
LDAP_BIND_DN=
|
||||
LDAP_BIND_CREDENTIALS=
|
||||
LDAP_USER_SEARCH_BASE=
|
||||
LDAP_SEARCH_FILTER=mail={{username}}
|
||||
LDAP_CA_CERT_PATH=
|
||||
# LDAP_TLS_REJECT_UNAUTHORIZED=
|
||||
# LDAP_LOGIN_USES_USERNAME=true
|
||||
# LDAP_ID=
|
||||
# LDAP_USERNAME=
|
||||
# LDAP_FULL_NAME=
|
||||
|
||||
#========================#
|
||||
# Email Password Reset #
|
||||
#========================#
|
||||
@@ -370,6 +406,13 @@ FIREBASE_STORAGE_BUCKET=
|
||||
FIREBASE_MESSAGING_SENDER_ID=
|
||||
FIREBASE_APP_ID=
|
||||
|
||||
#========================#
|
||||
# Shared Links #
|
||||
#========================#
|
||||
|
||||
ALLOW_SHARED_LINKS=true
|
||||
ALLOW_SHARED_LINKS_PUBLIC=true
|
||||
|
||||
#===================================================#
|
||||
# UI #
|
||||
#===================================================#
|
||||
@@ -380,6 +423,9 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
|
||||
# SHOW_BIRTHDAY_ICON=true
|
||||
|
||||
# Google tag manager id
|
||||
#ANALYTICS_GTM_ID=user provided google tag manager id
|
||||
|
||||
#==================================================#
|
||||
# Others #
|
||||
#==================================================#
|
||||
|
||||
@@ -132,6 +132,13 @@ module.exports = {
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
files: './config/translations/**/*.ts',
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: {
|
||||
project: './config/translations/tsconfig.json',
|
||||
},
|
||||
},
|
||||
{
|
||||
files: ['./packages/data-provider/specs/**/*.ts'],
|
||||
parserOptions: {
|
||||
|
||||
12
.github/CONTRIBUTING.md
vendored
12
.github/CONTRIBUTING.md
vendored
@@ -126,6 +126,18 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||
|
||||
- **Current Stance**: At present, this backend transition is of lower priority and might not be pursued.
|
||||
|
||||
## 7. Module Import Conventions
|
||||
|
||||
- `npm` packages first,
|
||||
- from shortest line (top) to longest (bottom)
|
||||
|
||||
- Followed by typescript types (pertains to data-provider and client workspaces)
|
||||
- longest line (top) to shortest (bottom)
|
||||
- types from package come first
|
||||
|
||||
- Lastly, local imports
|
||||
- longest line (top) to shortest (bottom)
|
||||
- imports with alias `~` treated the same as relative import with respect to line length
|
||||
|
||||
---
|
||||
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml
vendored
2
.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml
vendored
@@ -43,7 +43,7 @@ body:
|
||||
id: terms
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/CODE_OF_CONDUCT.md)
|
||||
description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/.github/CODE_OF_CONDUCT.md)
|
||||
options:
|
||||
- label: I agree to follow this project's Code of Conduct
|
||||
required: true
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/QUESTION.yml
vendored
2
.github/ISSUE_TEMPLATE/QUESTION.yml
vendored
@@ -44,7 +44,7 @@ body:
|
||||
id: terms
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/CODE_OF_CONDUCT.md)
|
||||
description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/.github/CODE_OF_CONDUCT.md)
|
||||
options:
|
||||
- label: I agree to follow this project's Code of Conduct
|
||||
required: true
|
||||
|
||||
9
.github/pull_request_template.md
vendored
9
.github/pull_request_template.md
vendored
@@ -1,7 +1,10 @@
|
||||
# Pull Request Template
|
||||
|
||||
⚠️ Before Submitting a PR, Please Review:
|
||||
- Please ensure that you have thoroughly read and understood the [Contributing Docs](https://github.com/danny-avila/LibreChat/blob/main/.github/CONTRIBUTING.md) before submitting your Pull Request.
|
||||
|
||||
### ⚠️ Before Submitting a PR, read the [Contributing Docs](https://github.com/danny-avila/LibreChat/blob/main/.github/CONTRIBUTING.md) in full!
|
||||
⚠️ Documentation Updates Notice:
|
||||
- Kindly note that documentation updates are managed in this repository: [librechat.ai](https://github.com/LibreChat-AI/librechat.ai)
|
||||
|
||||
## Summary
|
||||
|
||||
@@ -16,8 +19,6 @@ Please delete any irrelevant options.
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update
|
||||
- [ ] Translation update
|
||||
- [ ] Documentation update
|
||||
|
||||
|
||||
## Testing
|
||||
|
||||
@@ -37,4 +38,4 @@ Please delete any irrelevant options.
|
||||
- [ ] I have written tests demonstrating that my changes are effective or that my feature works
|
||||
- [ ] Local unit tests pass with my changes
|
||||
- [ ] Any changes dependent on mine have been merged and published in downstream modules.
|
||||
- [ ] New documents have been locally validated with mkdocs
|
||||
- [ ] A pull request for updating the documentation has been submitted.
|
||||
|
||||
27
.github/workflows/mkdocs.yaml
vendored
27
.github/workflows/mkdocs.yaml
vendored
@@ -1,27 +0,0 @@
|
||||
name: mkdocs
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
permissions:
|
||||
contents: write
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
- run: pip install mkdocs-material
|
||||
- run: pip install mkdocs-nav-weight
|
||||
- run: pip install mkdocs-publisher
|
||||
- run: pip install mkdocs-exclude
|
||||
- run: mkdocs gh-deploy --force
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,6 +11,7 @@ logs
|
||||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
.git
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
@@ -45,6 +46,7 @@ api/node_modules/
|
||||
client/node_modules/
|
||||
bower_components/
|
||||
*.d.ts
|
||||
!vite-env.d.ts
|
||||
|
||||
# Floobits
|
||||
.floo
|
||||
|
||||
32
Dockerfile
32
Dockerfile
@@ -1,10 +1,8 @@
|
||||
# v0.7.2
|
||||
# v0.7.3
|
||||
|
||||
# Base node image
|
||||
FROM node:18-alpine3.18 AS node
|
||||
FROM node:20-alpine AS node
|
||||
|
||||
RUN apk add g++ make py3-pip
|
||||
RUN npm install -g node-gyp
|
||||
RUN apk --no-cache add curl
|
||||
|
||||
RUN mkdir -p /app && chown node:node /app
|
||||
@@ -14,20 +12,20 @@ USER node
|
||||
|
||||
COPY --chown=node:node . .
|
||||
|
||||
# Allow mounting of these files, which have no default
|
||||
# values.
|
||||
RUN touch .env
|
||||
RUN npm config set fetch-retry-maxtimeout 600000
|
||||
RUN npm config set fetch-retries 5
|
||||
RUN npm config set fetch-retry-mintimeout 15000
|
||||
RUN npm install --no-audit
|
||||
RUN \
|
||||
# Allow mounting of these files, which have no default
|
||||
touch .env ; \
|
||||
# Create directories for the volumes to inherit the correct permissions
|
||||
mkdir -p /app/client/public/images /app/api/logs ; \
|
||||
npm config set fetch-retry-maxtimeout 600000 ; \
|
||||
npm config set fetch-retries 5 ; \
|
||||
npm config set fetch-retry-mintimeout 15000 ; \
|
||||
npm install --no-audit; \
|
||||
# React client build
|
||||
NODE_OPTIONS="--max-old-space-size=2048" npm run frontend; \
|
||||
npm prune --production; \
|
||||
npm cache clean --force
|
||||
|
||||
# React client build
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
RUN npm run frontend
|
||||
|
||||
# Create directories for the volumes to inherit
|
||||
# the correct permissions
|
||||
RUN mkdir -p /app/client/public/images /app/api/logs
|
||||
|
||||
# Node API setup
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# v0.7.2
|
||||
# v0.7.3
|
||||
|
||||
# Build API, Client and Data Provider
|
||||
FROM node:20-alpine AS base
|
||||
@@ -7,32 +7,31 @@ FROM node:20-alpine AS base
|
||||
FROM base AS data-provider-build
|
||||
WORKDIR /app/packages/data-provider
|
||||
COPY ./packages/data-provider ./
|
||||
RUN npm install
|
||||
RUN npm install; npm cache clean --force
|
||||
RUN npm run build
|
||||
RUN npm prune --production
|
||||
|
||||
# React client build
|
||||
FROM data-provider-build AS client-build
|
||||
FROM base AS client-build
|
||||
WORKDIR /app/client
|
||||
COPY ./client/package*.json ./
|
||||
# Copy data-provider to client's node_modules
|
||||
RUN mkdir -p /app/client/node_modules/librechat-data-provider/
|
||||
RUN cp -R /app/packages/data-provider/* /app/client/node_modules/librechat-data-provider/
|
||||
RUN npm install
|
||||
COPY --from=data-provider-build /app/packages/data-provider/ /app/client/node_modules/librechat-data-provider/
|
||||
RUN npm install; npm cache clean --force
|
||||
COPY ./client/ ./
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
RUN npm run build
|
||||
|
||||
# Node API setup
|
||||
FROM data-provider-build AS api-build
|
||||
FROM base AS api-build
|
||||
WORKDIR /app/api
|
||||
COPY api/package*.json ./
|
||||
COPY api/ ./
|
||||
# Copy helper scripts
|
||||
COPY config/ ./
|
||||
# Copy data-provider to API's node_modules
|
||||
RUN mkdir -p /app/api/node_modules/librechat-data-provider/
|
||||
RUN cp -R /app/packages/data-provider/* /app/api/node_modules/librechat-data-provider/
|
||||
RUN npm install
|
||||
COPY --from=data-provider-build /app/packages/data-provider/ /app/api/node_modules/librechat-data-provider/
|
||||
RUN npm install --include prod; npm cache clean --force
|
||||
COPY --from=client-build /app/client/dist /app/client/dist
|
||||
EXPOSE 3080
|
||||
ENV HOST=0.0.0.0
|
||||
|
||||
68
README.md
68
README.md
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
<a href="https://librechat.ai">
|
||||
<img src="docs/assets/LibreChat.svg" height="256">
|
||||
<img src="client/public/assets/logo.svg" height="256">
|
||||
</a>
|
||||
<h1 align="center">
|
||||
<a href="https://librechat.ai">LibreChat</a>
|
||||
@@ -27,7 +27,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://railway.app/template/b5k2mn?referralCode=myKrVZ">
|
||||
<a href="https://railway.app/template/b5k2mn?referralCode=HI9hWz">
|
||||
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
|
||||
</a>
|
||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||
@@ -43,14 +43,14 @@
|
||||
- 🖥️ UI matching ChatGPT, including Dark mode, Streaming, and latest updates
|
||||
- 🤖 AI model selection:
|
||||
- OpenAI, Azure OpenAI, BingAI, ChatGPT, Google Vertex AI, Anthropic (Claude), Plugins, Assistants API (including Azure Assistants)
|
||||
- ✅ Compatible across both **[Remote & Local AI services](https://docs.librechat.ai/install/configuration/ai_endpoints.html#intro):**
|
||||
- ✅ Compatible across both **[Remote & Local AI services](https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints):**
|
||||
- groq, Ollama, Cohere, Mistral AI, Apple MLX, koboldcpp, OpenRouter, together.ai, Perplexity, ShuttleAI, and more
|
||||
- 💾 Create, Save, & Share Custom Presets
|
||||
- 🔀 Switch between AI Endpoints and Presets, mid-chat
|
||||
- 🔄 Edit, Resubmit, and Continue Messages with Conversation branching
|
||||
- 🌿 Fork Messages & Conversations for Advanced Context control
|
||||
- 💬 Multimodal Chat:
|
||||
- Upload and analyze images with Claude 3, GPT-4, and Gemini Vision 📸
|
||||
- Upload and analyze images with Claude 3, GPT-4 (including `gpt-4o` and `gpt-4o-mini`), and Gemini Vision 📸
|
||||
- Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, & Google. 🗃️
|
||||
- Advanced Agents with Files, Code Interpreter, Tools, and API Actions 🔦
|
||||
- Available through the [OpenAI Assistants API](https://platform.openai.com/docs/assistants/overview) 🌤️
|
||||
@@ -58,9 +58,13 @@
|
||||
- 🌎 Multilingual UI:
|
||||
- English, 中文, Deutsch, Español, Français, Italiano, Polski, Português Brasileiro,
|
||||
- Русский, 日本語, Svenska, 한국어, Tiếng Việt, 繁體中文, العربية, Türkçe, Nederlands, עברית
|
||||
- 🎨 Customizable Dropdown & Interface: Adapts to both power users and newcomers.
|
||||
- 🎨 Customizable Dropdown & Interface: Adapts to both power users and newcomers
|
||||
- 📧 Verify your email to ensure secure access
|
||||
- 🗣️ Chat hands-free with Speech-to-Text and Text-to-Speech magic
|
||||
- Automatically send and play Audio
|
||||
- Supports OpenAI, Azure OpenAI, and Elevenlabs
|
||||
- 📥 Import Conversations from LibreChat, ChatGPT, Chatbot UI
|
||||
- 📤 Export conversations as screenshots, markdown, text, json.
|
||||
- 📤 Export conversations as screenshots, markdown, text, json
|
||||
- 🔍 Search all messages/conversations
|
||||
- 🔌 Plugins, including web access, image generation with DALL-E-3 and more
|
||||
- 👥 Multi-User, Secure Authentication with Moderation and Token spend tools
|
||||
@@ -69,7 +73,7 @@
|
||||
- 📖 Completely Open-Source & Built in Public
|
||||
- 🧑🤝🧑 Community-driven development, support, and feedback
|
||||
|
||||
[For a thorough review of our features, see our docs here](https://docs.librechat.ai/features/plugins/introduction.html) 📚
|
||||
[For a thorough review of our features, see our docs here](https://docs.librechat.ai/) 📚
|
||||
|
||||
## 🪶 All-In-One AI Conversations with LibreChat
|
||||
|
||||
@@ -77,48 +81,48 @@ LibreChat brings together the future of assistant AIs with the revolutionary tec
|
||||
|
||||
With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform.
|
||||
|
||||
<!-- https://github.com/danny-avila/LibreChat/assets/110412045/c1eb0c0f-41f6-4335-b982-84b278b53d59 -->
|
||||
|
||||
[](https://youtu.be/pNIOs1ovsXw)
|
||||
[](https://www.youtube.com/watch?v=bSVHEbVPNl4)
|
||||
Click on the thumbnail to open the video☝️
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
## 🌐 Resources
|
||||
|
||||
For more information on how to use our advanced features, install and configure our software, and access our guidelines and tutorials, please check out our documentation at [docs.librechat.ai](https://docs.librechat.ai)
|
||||
**GitHub Repo:**
|
||||
- **RAG API:** [github.com/danny-avila/rag_api](https://github.com/danny-avila/rag_api)
|
||||
- **Website:** [github.com/LibreChat-AI/librechat.ai](https://github.com/LibreChat-AI/librechat.ai)
|
||||
|
||||
**Other:**
|
||||
- **Website:** [librechat.ai](https://librechat.ai)
|
||||
- **Documentation:** [docs.librechat.ai](https://docs.librechat.ai)
|
||||
- **Blog:** [blog.librechat.ai](https://docs.librechat.ai)
|
||||
|
||||
---
|
||||
|
||||
## 📝 Changelog
|
||||
|
||||
Keep up with the latest updates by visiting the releases page - [Releases](https://github.com/danny-avila/LibreChat/releases)
|
||||
Keep up with the latest updates by visiting the releases page and notes:
|
||||
- [Releases](https://github.com/danny-avila/LibreChat/releases)
|
||||
- [Changelog](https://www.librechat.ai/changelog)
|
||||
|
||||
**⚠️ [Breaking Changes](docs/general_info/breaking_changes.md)**
|
||||
Please consult the breaking changes before updating.
|
||||
**⚠️ Please consult the [changelog](https://www.librechat.ai/changelog) for breaking changes before updating.**
|
||||
|
||||
---
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4685" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4685" alt="danny-avila%2FLibreChat | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a
|
||||
href="https://runacap.com/ross-index/q1-24/"
|
||||
target="_blank"
|
||||
rel="noopener"
|
||||
>
|
||||
<img
|
||||
style="width: 260px; height: 56px"
|
||||
src="https://runacap.com/wp-content/uploads/2024/04/ROSS_badge_white_Q1_2024.svg"
|
||||
alt="ROSS Index - Fastest Growing Open-Source Startups in Q1 2024 | Runa Capital"
|
||||
width="260"
|
||||
height="56"
|
||||
/>
|
||||
</a>
|
||||
<a href="https://star-history.com/#danny-avila/LibreChat&Date">
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date&theme=dark" onerror="this.src='https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date'" />
|
||||
</a>
|
||||
<a href="https://star-history.com/#danny-avila/LibreChat&Date">
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date&theme=dark" onerror="this.src='https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date'" />
|
||||
</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4685" target="_blank" style="padding: 10px;">
|
||||
<img src="https://trendshift.io/api/badge/repositories/4685" alt="danny-avila%2FLibreChat | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<a href="https://runacap.com/ross-index/q1-24/" target="_blank" rel="noopener" style="margin-left: 20px;">
|
||||
<img style="width: 260px; height: 56px" src="https://runacap.com/wp-content/uploads/2024/04/ROSS_badge_white_Q1_2024.svg" alt="ROSS Index - Fastest Growing Open-Source Startups in Q1 2024 | Runa Capital" width="260" height="56"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
const Anthropic = require('@anthropic-ai/sdk');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
getResponseSender,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
anthropicSettings,
|
||||
getResponseSender,
|
||||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
@@ -15,6 +18,7 @@ const {
|
||||
} = require('./prompts');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -28,6 +32,8 @@ function delayBeforeRetry(attempts, baseDelay = 1000) {
|
||||
return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts));
|
||||
}
|
||||
|
||||
const { legacy } = anthropicSettings;
|
||||
|
||||
class AnthropicClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
super(apiKey, options);
|
||||
@@ -60,15 +66,20 @@ class AnthropicClient extends BaseClient {
|
||||
const modelOptions = this.options.modelOptions || {};
|
||||
this.modelOptions = {
|
||||
...modelOptions,
|
||||
// set some good defaults (check for undefined in some cases because they may be 0)
|
||||
model: modelOptions.model || 'claude-1',
|
||||
temperature: typeof modelOptions.temperature === 'undefined' ? 1 : modelOptions.temperature, // 0 - 1, 1 is default
|
||||
topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7
|
||||
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
|
||||
stop: modelOptions.stop, // no stop method for now
|
||||
model: modelOptions.model || anthropicSettings.model.default,
|
||||
};
|
||||
|
||||
this.isClaude3 = this.modelOptions.model.includes('claude-3');
|
||||
this.isLegacyOutput = !this.modelOptions.model.includes('claude-3-5-sonnet');
|
||||
|
||||
if (
|
||||
this.isLegacyOutput &&
|
||||
this.modelOptions.maxOutputTokens &&
|
||||
this.modelOptions.maxOutputTokens > legacy.maxOutputTokens.default
|
||||
) {
|
||||
this.modelOptions.maxOutputTokens = legacy.maxOutputTokens.default;
|
||||
}
|
||||
|
||||
this.useMessages = this.isClaude3 || !!this.options.attachments;
|
||||
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229';
|
||||
@@ -118,18 +129,30 @@ class AnthropicClient extends BaseClient {
|
||||
|
||||
/**
|
||||
* Get the initialized Anthropic client.
|
||||
* @param {Partial<Anthropic.ClientOptions>} requestOptions - The options for the client.
|
||||
* @returns {Anthropic} The Anthropic client instance.
|
||||
*/
|
||||
getClient() {
|
||||
/** @type {Anthropic.default.RequestOptions} */
|
||||
getClient(requestOptions) {
|
||||
/** @type {Anthropic.ClientOptions} */
|
||||
const options = {
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
};
|
||||
|
||||
if (this.options.proxy) {
|
||||
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
options.baseURL = this.options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
if (requestOptions?.model && requestOptions.model.includes('claude-3-5-sonnet')) {
|
||||
options.defaultHeaders = {
|
||||
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
|
||||
};
|
||||
}
|
||||
|
||||
return new Anthropic(options);
|
||||
}
|
||||
|
||||
@@ -550,8 +573,6 @@ class AnthropicClient extends BaseClient {
|
||||
}
|
||||
|
||||
logger.debug('modelOptions', { modelOptions });
|
||||
|
||||
const client = this.getClient();
|
||||
const metadata = {
|
||||
user_id: this.user,
|
||||
};
|
||||
@@ -579,7 +600,7 @@ class AnthropicClient extends BaseClient {
|
||||
|
||||
if (this.useMessages) {
|
||||
requestOptions.messages = payload;
|
||||
requestOptions.max_tokens = maxOutputTokens || 1500;
|
||||
requestOptions.max_tokens = maxOutputTokens || legacy.maxOutputTokens.default;
|
||||
} else {
|
||||
requestOptions.prompt = payload;
|
||||
requestOptions.max_tokens_to_sample = maxOutputTokens || 1500;
|
||||
@@ -599,12 +620,14 @@ class AnthropicClient extends BaseClient {
|
||||
};
|
||||
|
||||
const maxRetries = 3;
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
async function processResponse() {
|
||||
let attempts = 0;
|
||||
|
||||
while (attempts < maxRetries) {
|
||||
let response;
|
||||
try {
|
||||
const client = this.getClient(requestOptions);
|
||||
response = await this.createResponse(client, requestOptions);
|
||||
|
||||
signal.addEventListener('abort', () => {
|
||||
@@ -621,6 +644,8 @@ class AnthropicClient extends BaseClient {
|
||||
} else if (completion.completion) {
|
||||
handleChunk(completion.completion);
|
||||
}
|
||||
|
||||
await sleep(streamRate);
|
||||
}
|
||||
|
||||
// Successful processing, exit loop
|
||||
@@ -731,7 +756,11 @@ class AnthropicClient extends BaseClient {
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await this.createResponse(this.getClient(), requestOptions, true);
|
||||
const response = await this.createResponse(
|
||||
this.getClient(requestOptions),
|
||||
requestOptions,
|
||||
true,
|
||||
);
|
||||
let promptTokens = response?.usage?.input_tokens;
|
||||
let completionTokens = response?.usage?.output_tokens;
|
||||
if (!promptTokens) {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
const crypto = require('crypto');
|
||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const fetch = require('node-fetch');
|
||||
const { supportsBalanceCheck, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -17,6 +19,15 @@ class BaseClient {
|
||||
month: 'long',
|
||||
day: 'numeric',
|
||||
});
|
||||
this.fetch = this.fetch.bind(this);
|
||||
/** @type {boolean} */
|
||||
this.skipSaveConvo = false;
|
||||
/** @type {boolean} */
|
||||
this.skipSaveUserMessage = false;
|
||||
/** @type {ClientDatabaseSavePromise} */
|
||||
this.userMessagePromise;
|
||||
/** @type {ClientDatabaseSavePromise} */
|
||||
this.responsePromise;
|
||||
}
|
||||
|
||||
setOptions() {
|
||||
@@ -54,6 +65,25 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes an HTTP request and logs the process.
|
||||
*
|
||||
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
|
||||
* @param {RequestInit} [init] - Optional init options for the request.
|
||||
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
async fetch(_url, init) {
|
||||
let url = _url;
|
||||
if (this.options.directEndpoint) {
|
||||
url = this.options.reverseProxyUrl;
|
||||
}
|
||||
logger.debug(`Making request to ${url}`);
|
||||
if (typeof Bun !== 'undefined') {
|
||||
return await fetch(url, init);
|
||||
}
|
||||
return await fetch(url, init);
|
||||
}
|
||||
|
||||
getBuildMessagesOptions() {
|
||||
throw new Error('Subclasses must implement getBuildMessagesOptions');
|
||||
}
|
||||
@@ -63,19 +93,45 @@ class BaseClient {
|
||||
await stream.processTextStream(onProgress);
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {[string|undefined, string|undefined]}
|
||||
*/
|
||||
processOverideIds() {
|
||||
/** @type {Record<string, string | undefined>} */
|
||||
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
|
||||
if (overrideConvoId) {
|
||||
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
|
||||
overrideConvoId = conversationId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveConvo = true;
|
||||
}
|
||||
}
|
||||
if (overrideUserMessageId) {
|
||||
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
|
||||
overrideUserMessageId = userMessageId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveUserMessage = true;
|
||||
}
|
||||
}
|
||||
|
||||
return [overrideConvoId, overrideUserMessageId];
|
||||
}
|
||||
|
||||
async setMessageOptions(opts = {}) {
|
||||
if (opts && opts.replaceOptions) {
|
||||
this.setOptions(opts);
|
||||
}
|
||||
|
||||
const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
|
||||
const { isEdited, isContinued } = opts;
|
||||
const user = opts.user ?? null;
|
||||
this.user = user;
|
||||
const saveOptions = this.getSaveOptions();
|
||||
this.abortController = opts.abortController ?? new AbortController();
|
||||
const conversationId = opts.conversationId ?? crypto.randomUUID();
|
||||
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
const userMessageId =
|
||||
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||
let head = isEdited ? responseMessageId : parentMessageId;
|
||||
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||
@@ -139,7 +195,7 @@ class BaseClient {
|
||||
}
|
||||
|
||||
if (typeof opts?.onStart === 'function') {
|
||||
opts.onStart(userMessage);
|
||||
opts.onStart(userMessage, responseMessageId);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -373,6 +429,14 @@ class BaseClient {
|
||||
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
||||
await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
const { generation = '' } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
@@ -421,8 +485,13 @@ class BaseClient {
|
||||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
if (!isEdited) {
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (!isEdited && !this.skipSaveUserMessage) {
|
||||
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise: this.userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -471,15 +540,23 @@ class BaseClient {
|
||||
const completionTokens = this.getTokenCount(completion);
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||
}
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
if (this.userMessagePromise) {
|
||||
await this.userMessagePromise;
|
||||
}
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
delete responseMessage.tokenCount;
|
||||
return responseMessage;
|
||||
}
|
||||
|
||||
async getConversation(conversationId, user = null) {
|
||||
return await getConvo(user, conversationId);
|
||||
}
|
||||
|
||||
async loadHistory(conversationId, parentMessageId = null) {
|
||||
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
||||
|
||||
@@ -534,22 +611,41 @@ class BaseClient {
|
||||
* @param {string | null} user
|
||||
*/
|
||||
async saveMessageToDatabase(message, endpointOptions, user = null) {
|
||||
await saveMessage({
|
||||
...message,
|
||||
endpoint: this.options.endpoint,
|
||||
unfinished: false,
|
||||
user,
|
||||
});
|
||||
await saveConvo(user, {
|
||||
conversationId: message.conversationId,
|
||||
endpoint: this.options.endpoint,
|
||||
endpointType: this.options.endpointType,
|
||||
...endpointOptions,
|
||||
});
|
||||
if (this.user && user !== this.user) {
|
||||
throw new Error('User mismatch.');
|
||||
}
|
||||
|
||||
const savedMessage = await saveMessage(
|
||||
this.options.req,
|
||||
{
|
||||
...message,
|
||||
endpoint: this.options.endpoint,
|
||||
unfinished: false,
|
||||
user,
|
||||
},
|
||||
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' },
|
||||
);
|
||||
|
||||
if (this.skipSaveConvo) {
|
||||
return { message: savedMessage };
|
||||
}
|
||||
|
||||
const conversation = await saveConvo(
|
||||
this.options.req,
|
||||
{
|
||||
conversationId: message.conversationId,
|
||||
endpoint: this.options.endpoint,
|
||||
endpointType: this.options.endpointType,
|
||||
...endpointOptions,
|
||||
},
|
||||
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo' },
|
||||
);
|
||||
|
||||
return { message: savedMessage, conversation };
|
||||
}
|
||||
|
||||
async updateMessageInDatabase(message) {
|
||||
await updateMessage(message);
|
||||
await updateMessage(this.options.req, message);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -438,9 +438,17 @@ class ChatGPTClient extends BaseClient {
|
||||
|
||||
if (message.eventType === 'text-generation' && message.text) {
|
||||
onTokenProgress(message.text);
|
||||
} else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply += message.text;
|
||||
}
|
||||
/*
|
||||
Cohere API Chinese Unicode character replacement hotfix.
|
||||
Should be un-commented when the following issue is resolved:
|
||||
https://github.com/cohere-ai/cohere-typescript/issues/151
|
||||
|
||||
else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply = message.response.text;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
return reply;
|
||||
|
||||
@@ -13,13 +13,20 @@ const {
|
||||
endpointSettings,
|
||||
EModelEndpoint,
|
||||
VisionModes,
|
||||
Constants,
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { formatMessage, createContextHandlers } = require('./prompts');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const {
|
||||
formatMessage,
|
||||
createContextHandlers,
|
||||
titleInstruction,
|
||||
truncateText,
|
||||
} = require('./prompts');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
||||
const loc = 'us-central1';
|
||||
const publisher = 'google';
|
||||
@@ -591,12 +598,16 @@ class GoogleClient extends BaseClient {
|
||||
createLLM(clientOptions) {
|
||||
const model = clientOptions.modelName ?? clientOptions.model;
|
||||
if (this.project_id && this.isTextModel) {
|
||||
logger.debug('Creating Google VertexAI client');
|
||||
return new GoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id && this.isChatModel) {
|
||||
logger.debug('Creating Chat Google VertexAI client');
|
||||
return new ChatGoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id) {
|
||||
logger.debug('Creating VertexAI client');
|
||||
return new ChatVertexAI(clientOptions);
|
||||
} else if (model.includes('1.5')) {
|
||||
logger.debug('Creating GenAI client');
|
||||
return new GenAI(this.apiKey).getGenerativeModel(
|
||||
{
|
||||
...clientOptions,
|
||||
@@ -606,12 +617,14 @@ class GoogleClient extends BaseClient {
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug('Creating Chat Google Generative AI client');
|
||||
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
||||
}
|
||||
|
||||
async getCompletion(_payload, options = {}) {
|
||||
const { onProgress, abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { onProgress, abortController } = options;
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let examples;
|
||||
@@ -683,13 +696,15 @@ class GoogleClient extends BaseClient {
|
||||
const safetySettings = _payload.safetySettings;
|
||||
requestOptions.safetySettings = safetySettings;
|
||||
|
||||
const delay = modelName.includes('flash') ? 8 : 14;
|
||||
const result = await client.generateContentStream(requestOptions);
|
||||
for await (const chunk of result.stream) {
|
||||
const chunkText = chunk.text();
|
||||
this.generateTextStream(chunkText, onProgress, {
|
||||
delay: 12,
|
||||
await this.generateTextStream(chunkText, onProgress, {
|
||||
delay,
|
||||
});
|
||||
reply += chunkText;
|
||||
await sleep(streamRate);
|
||||
}
|
||||
return reply;
|
||||
}
|
||||
@@ -701,10 +716,21 @@ class GoogleClient extends BaseClient {
|
||||
safetySettings: safetySettings,
|
||||
});
|
||||
|
||||
let delay = this.options.streamRate || 8;
|
||||
|
||||
if (!this.options.streamRate) {
|
||||
if (this.isGenerativeModel) {
|
||||
delay = 12;
|
||||
}
|
||||
if (modelName.includes('flash')) {
|
||||
delay = 5;
|
||||
}
|
||||
}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const chunkText = chunk?.content ?? chunk;
|
||||
this.generateTextStream(chunkText, onProgress, {
|
||||
delay: this.isGenerativeModel ? 12 : 8,
|
||||
await this.generateTextStream(chunkText, onProgress, {
|
||||
delay,
|
||||
});
|
||||
reply += chunkText;
|
||||
}
|
||||
@@ -712,6 +738,123 @@ class GoogleClient extends BaseClient {
|
||||
return reply;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
|
||||
*/
|
||||
async titleChatCompletion(_payload, options = {}) {
|
||||
const { abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
|
||||
logger.debug('Initialized title client options');
|
||||
|
||||
if (this.project_id) {
|
||||
clientOptions['authOptions'] = {
|
||||
credentials: {
|
||||
...this.serviceKey,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
};
|
||||
}
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (this.isGenerativeModel && !this.project_id) {
|
||||
clientOptions.modelName = clientOptions.model;
|
||||
delete clientOptions.model;
|
||||
}
|
||||
|
||||
const model = this.createLLM(clientOptions);
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||
|
||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||
if (modelName?.includes('1.5') && !this.project_id) {
|
||||
logger.debug('Identified titling model as 1.5 version');
|
||||
/** @type {GenerativeModel} */
|
||||
const client = model;
|
||||
const requestOptions = {
|
||||
contents: _payload,
|
||||
};
|
||||
|
||||
if (this.options?.promptPrefix?.length) {
|
||||
requestOptions.systemInstruction = {
|
||||
parts: [
|
||||
{
|
||||
text: this.options.promptPrefix,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const safetySettings = _payload.safetySettings;
|
||||
requestOptions.safetySettings = safetySettings;
|
||||
|
||||
const result = await client.generateContent(requestOptions);
|
||||
|
||||
reply = result.response?.text();
|
||||
|
||||
return reply;
|
||||
} else {
|
||||
logger.debug('Beginning titling');
|
||||
const safetySettings = _payload.safetySettings;
|
||||
|
||||
const titleResponse = await model.invoke(messages, {
|
||||
signal: abortController.signal,
|
||||
timeout: 7000,
|
||||
safetySettings: safetySettings,
|
||||
});
|
||||
|
||||
reply = titleResponse.content;
|
||||
|
||||
return reply;
|
||||
}
|
||||
}
|
||||
|
||||
async titleConvo({ text, responseText = '' }) {
|
||||
let title = 'New Chat';
|
||||
const convo = `||>User:
|
||||
"${truncateText(text)}"
|
||||
||>Response:
|
||||
"${JSON.stringify(truncateText(responseText))}"`;
|
||||
|
||||
let { prompt: payload } = await this.buildMessages([
|
||||
{
|
||||
text: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
|
||||
||>Title:`,
|
||||
isCreatedByUser: true,
|
||||
author: this.userLabel,
|
||||
},
|
||||
]);
|
||||
|
||||
if (this.isVisionModel) {
|
||||
logger.warn(
|
||||
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
|
||||
);
|
||||
|
||||
payload.parameters = { ...payload.parameters, model: settings.model.default };
|
||||
}
|
||||
|
||||
try {
|
||||
title = await this.titleChatCompletion(payload, {
|
||||
abortController: new AbortController(),
|
||||
onProgress: () => {},
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error('[GoogleClient] There was an issue generating the title', e);
|
||||
}
|
||||
logger.debug(`Title response: ${title}`);
|
||||
return title;
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { Ollama } = require('ollama');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { deriveBaseURL } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const ollamaPayloadSchema = z.object({
|
||||
@@ -40,6 +42,7 @@ const getValidBase64 = (imageUrl) => {
|
||||
class OllamaClient {
|
||||
constructor(options = {}) {
|
||||
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
|
||||
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
/** @type {Ollama} */
|
||||
this.client = new Ollama({ host });
|
||||
}
|
||||
@@ -136,6 +139,8 @@ class OllamaClient {
|
||||
stream.controller.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
await sleep(this.streamRate);
|
||||
}
|
||||
}
|
||||
// TODO: regular completion
|
||||
|
||||
@@ -308,7 +308,7 @@ class OpenAIClient extends BaseClient {
|
||||
let tokenizer;
|
||||
this.encoding = 'text-davinci-003';
|
||||
if (this.isChatCompletion) {
|
||||
this.encoding = 'cl100k_base';
|
||||
this.encoding = this.modelOptions.model.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
|
||||
tokenizer = this.constructor.getTokenizer(this.encoding);
|
||||
} else if (this.isUnofficialChatGptModel) {
|
||||
const extendSpecialTokens = {
|
||||
@@ -588,7 +588,7 @@ class OpenAIClient extends BaseClient {
|
||||
let streamResult = null;
|
||||
this.modelOptions.user = this.user;
|
||||
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
|
||||
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
||||
const completionResult = await this.getCompletion(
|
||||
payload,
|
||||
@@ -756,6 +756,8 @@ class OpenAIClient extends BaseClient {
|
||||
* In case of failure, it will return the default title, "New Chat".
|
||||
*/
|
||||
async titleConvo({ text, conversationId, responseText = '' }) {
|
||||
this.conversationId = conversationId;
|
||||
|
||||
if (this.options.attachments) {
|
||||
delete this.options.attachments;
|
||||
}
|
||||
@@ -825,7 +827,7 @@ class OpenAIClient extends BaseClient {
|
||||
|
||||
const instructionsPayload = [
|
||||
{
|
||||
role: 'system',
|
||||
role: this.options.titleMessageRole ?? 'system',
|
||||
content: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
@@ -838,13 +840,17 @@ ${convo}
|
||||
|
||||
try {
|
||||
let useChatCompletion = true;
|
||||
|
||||
if (this.options.reverseProxyUrl === CohereConstants.API_URL) {
|
||||
useChatCompletion = false;
|
||||
}
|
||||
|
||||
title = (
|
||||
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion })
|
||||
).replaceAll('"', '');
|
||||
|
||||
const completionTokens = this.getTokenCount(title);
|
||||
|
||||
this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' });
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
@@ -868,6 +874,7 @@ ${convo}
|
||||
context: 'title',
|
||||
tokenBuffer: 150,
|
||||
});
|
||||
|
||||
title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal });
|
||||
} catch (e) {
|
||||
if (e?.message?.toLowerCase()?.includes('abort')) {
|
||||
@@ -1005,9 +1012,9 @@ ${convo}
|
||||
await spendTokens(
|
||||
{
|
||||
context,
|
||||
user: this.user,
|
||||
model: this.modelOptions.model,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
@@ -1099,7 +1106,12 @@ ${convo}
|
||||
}
|
||||
|
||||
if (this.azure || this.options.azure) {
|
||||
// Azure does not accept `model` in the body, so we need to remove it.
|
||||
/* Azure Bug, extremely short default `max_tokens` response */
|
||||
if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
|
||||
modelOptions.max_tokens = 4000;
|
||||
}
|
||||
|
||||
/* Azure does not accept `model` in the body, so we need to remove it. */
|
||||
delete modelOptions.model;
|
||||
|
||||
opts.baseURL = this.langchainProxy
|
||||
@@ -1120,6 +1132,7 @@ ${convo}
|
||||
let chatCompletion;
|
||||
/** @type {OpenAI} */
|
||||
const openai = new OpenAI({
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
...opts,
|
||||
});
|
||||
@@ -1169,8 +1182,10 @@ ${convo}
|
||||
});
|
||||
}
|
||||
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
|
||||
if (this.message_file_map && this.isOllama) {
|
||||
const ollamaClient = new OllamaClient({ baseURL });
|
||||
const ollamaClient = new OllamaClient({ baseURL, streamRate });
|
||||
return await ollamaClient.chatCompletion({
|
||||
payload: modelOptions,
|
||||
onProgress,
|
||||
@@ -1208,7 +1223,6 @@ ${convo}
|
||||
}
|
||||
});
|
||||
|
||||
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
||||
for await (const chunk of stream) {
|
||||
const token = chunk.choices[0]?.delta?.content || '';
|
||||
intermediateReply += token;
|
||||
@@ -1218,9 +1232,7 @@ ${convo}
|
||||
break;
|
||||
}
|
||||
|
||||
if (this.azure) {
|
||||
await sleep(azureDelay);
|
||||
}
|
||||
await sleep(streamRate);
|
||||
}
|
||||
|
||||
if (!UnexpectedRoleError) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const OpenAIClient = require('./OpenAIClient');
|
||||
const { CallbackManager } = require('langchain/callbacks');
|
||||
const { CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
||||
@@ -11,6 +12,7 @@ const { SelfReflectionTool } = require('./tools');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { loadTools } = require('./tools/util');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class PluginsClient extends OpenAIClient {
|
||||
@@ -220,6 +222,13 @@ class PluginsClient extends OpenAIClient {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} responseMessage
|
||||
* @param {Partial<TMessage>} saveOptions
|
||||
* @param {string} user
|
||||
* @returns
|
||||
*/
|
||||
async handleResponseMessage(responseMessage, saveOptions, user) {
|
||||
const { output, errorMessage, ...result } = this.result;
|
||||
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
|
||||
@@ -238,18 +247,39 @@ class PluginsClient extends OpenAIClient {
|
||||
await this.recordTokenUsage(responseMessage);
|
||||
}
|
||||
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessage.messageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
delete responseMessage.tokenCount;
|
||||
return { ...responseMessage, ...result };
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
} else {
|
||||
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
}
|
||||
|
||||
// If a message is edited, no tools can be used.
|
||||
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
||||
if (completionMode) {
|
||||
this.setOptions(opts);
|
||||
return super.sendMessage(message, opts);
|
||||
}
|
||||
|
||||
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
||||
const {
|
||||
user,
|
||||
@@ -264,6 +294,14 @@ class PluginsClient extends OpenAIClient {
|
||||
onToolEnd,
|
||||
} = await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
this.currentMessages.push(userMessage);
|
||||
|
||||
let {
|
||||
@@ -292,7 +330,15 @@ class PluginsClient extends OpenAIClient {
|
||||
if (payload) {
|
||||
this.currentMessages = payload;
|
||||
}
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
|
||||
if (!this.skipSaveUserMessage) {
|
||||
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise: this.userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
await checkBalance({
|
||||
|
||||
@@ -1,44 +1,3 @@
|
||||
/*
|
||||
module.exports = `You are ChatGPT, a Large Language model with useful tools.
|
||||
|
||||
Talk to the human and provide meaningful answers when questions are asked.
|
||||
|
||||
Use the tools when you need them, but use your own knowledge if you are confident of the answer. Keep answers short and concise.
|
||||
|
||||
A tool is not usually needed for creative requests, so do your best to answer them without tools.
|
||||
|
||||
Avoid repeating identical answers if it appears before. Only fulfill the human's requests, do not create extra steps beyond what the human has asked for.
|
||||
|
||||
Your input for 'Action' should be the name of tool used only.
|
||||
|
||||
Be honest. If you can't answer something, or a tool is not appropriate, say you don't know or answer to the best of your ability.
|
||||
|
||||
Attempt to fulfill the human's requests in as few actions as possible`;
|
||||
*/
|
||||
|
||||
// module.exports = `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
||||
|
||||
// Engage with the Human conversationally, providing concise and meaningful answers to questions. Utilize built-in tools when necessary, except for creative requests, where relying on your own knowledge is preferred. Aim for variety and avoid repetitive answers.
|
||||
|
||||
// For your 'Action' input, state the name of the tool used only, and honor user requests without adding extra steps. Always be honest; if you cannot provide an appropriate answer or tool, admit that or do your best.
|
||||
|
||||
// Strive to meet the user's needs efficiently with minimal actions.`;
|
||||
|
||||
// import {
|
||||
// BasePromptTemplate,
|
||||
// BaseStringPromptTemplate,
|
||||
// SerializedBasePromptTemplate,
|
||||
// renderTemplate,
|
||||
// } from "langchain/prompts";
|
||||
|
||||
// prefix: `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
||||
// Your objective is to help users by understanding their intent and choosing the best action. Prioritize direct, specific responses. Use concise, varied answers and rely on your knowledge for creative tasks. Utilize tools when needed, and structure results for machine compatibility.
|
||||
// prefix: `Objective: to comprehend human intentions based on user input and available tools. Goal: identify the best action to directly address the human's query. In your subsequent steps, you will utilize the chosen action. You may select multiple actions and list them in a meaningful order. Prioritize actions that directly relate to the user's query over general ones. Ensure that the generated thought is highly specific and explicit to best match the user's expectations. Construct the result in a manner that an online open-API would most likely expect. Provide concise and meaningful answers to human queries. Utilize tools when necessary. Relying on your own knowledge is preferred for creative requests. Aim for variety and avoid repetitive answers.
|
||||
|
||||
// # Available Actions & Tools:
|
||||
// N/A: no suitable action, use your own knowledge.`,
|
||||
// suffix: `Remember, all your responses MUST adhere to the described format and only respond if the format is followed. Output exactly with the requested format, avoiding any other text as this will be parsed by a machine. Following 'Action:', provide only one of the actions listed above. If a tool is not necessary, deduce this quickly and finish your response. Honor the human's requests without adding extra steps. Carry out tasks in the sequence written by the human. Always be honest; if you cannot provide an appropriate answer or tool, do your best with your own knowledge. Strive to meet the user's needs efficiently with minimal actions.`;
|
||||
|
||||
module.exports = {
|
||||
'gpt3-v1': {
|
||||
prefix: `Objective: Understand human intentions using user input and available tools. Goal: Identify the most suitable actions to directly address user queries.
|
||||
|
||||
@@ -8,8 +8,6 @@ In your response, remember to follow these guidelines:
|
||||
- If you don't know the answer, simply say that you don't know.
|
||||
- If you are unsure how to answer, ask for clarification.
|
||||
- Avoid mentioning that you obtained the information from the context.
|
||||
|
||||
Answer appropriately in the user's language.
|
||||
`;
|
||||
|
||||
function createContextHandlers(req, userMessageContent) {
|
||||
@@ -94,37 +92,40 @@ function createContextHandlers(req, userMessageContent) {
|
||||
|
||||
const resolvedQueries = await Promise.all(queryPromises);
|
||||
|
||||
const context = resolvedQueries
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
let contextItems = queryResult.data;
|
||||
const context =
|
||||
resolvedQueries.length === 0
|
||||
? '\n\tThe semantic search did not return any results.'
|
||||
: resolvedQueries
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
let contextItems = queryResult.data;
|
||||
|
||||
const generateContext = (currentContext) =>
|
||||
`
|
||||
const generateContext = (currentContext) =>
|
||||
`
|
||||
<file>
|
||||
<filename>${file.filename}</filename>
|
||||
<context>${currentContext}
|
||||
</context>
|
||||
</file>`;
|
||||
|
||||
if (useFullContext) {
|
||||
return generateContext(`\n${contextItems}`);
|
||||
}
|
||||
if (useFullContext) {
|
||||
return generateContext(`\n${contextItems}`);
|
||||
}
|
||||
|
||||
contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
<contextItem>
|
||||
<![CDATA[${pageContent?.trim()}]]>
|
||||
</contextItem>`;
|
||||
})
|
||||
.join('');
|
||||
|
||||
return generateContext(contextItems);
|
||||
})
|
||||
.join('');
|
||||
|
||||
return generateContext(contextItems);
|
||||
})
|
||||
.join('');
|
||||
|
||||
if (useFullContext) {
|
||||
const prompt = `${header}
|
||||
${context}
|
||||
|
||||
@@ -28,7 +28,7 @@ ${convo}`,
|
||||
};
|
||||
|
||||
const titleInstruction =
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. Never directly mention the language name or the word "title"';
|
||||
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
|
||||
|
||||
You may call them like this:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
const AnthropicClient = require('../AnthropicClient');
|
||||
const { anthropicSettings } = require('librechat-data-provider');
|
||||
const AnthropicClient = require('~/app/clients/AnthropicClient');
|
||||
|
||||
const HUMAN_PROMPT = '\n\nHuman:';
|
||||
const AI_PROMPT = '\n\nAssistant:';
|
||||
|
||||
@@ -22,7 +24,7 @@ describe('AnthropicClient', () => {
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model,
|
||||
temperature: 0.7,
|
||||
temperature: anthropicSettings.temperature.default,
|
||||
},
|
||||
};
|
||||
client = new AnthropicClient('test-api-key');
|
||||
@@ -33,7 +35,42 @@ describe('AnthropicClient', () => {
|
||||
it('should set the options correctly', () => {
|
||||
expect(client.apiKey).toBe('test-api-key');
|
||||
expect(client.modelOptions.model).toBe(model);
|
||||
expect(client.modelOptions.temperature).toBe(0.7);
|
||||
expect(client.modelOptions.temperature).toBe(anthropicSettings.temperature.default);
|
||||
});
|
||||
|
||||
it('should set legacy maxOutputTokens for non-Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-2',
|
||||
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
it('should not set maxOutputTokens if not provided', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-3',
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not set legacy maxOutputTokens for Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -136,4 +173,57 @@ describe('AnthropicClient', () => {
|
||||
expect(prompt).toContain('You are Claude-2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getClient', () => {
|
||||
it('should set legacy maxOutputTokens for non-Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-2',
|
||||
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not set legacy maxOutputTokens for Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add beta header for claude-3-5-sonnet model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
model: 'claude-3-5-sonnet-20240307',
|
||||
};
|
||||
client.setOptions({ modelOptions });
|
||||
const anthropicClient = client.getClient(modelOptions);
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'max-tokens-3-5-sonnet-2024-07-15',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not add beta header for other models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-2',
|
||||
},
|
||||
});
|
||||
const anthropicClient = client.getClient();
|
||||
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { initializeFakeClient } = require('./FakeClient');
|
||||
|
||||
jest.mock('../../../lib/db/connectDb');
|
||||
jest.mock('~/lib/db/connectDb');
|
||||
jest.mock('~/models', () => ({
|
||||
User: jest.fn(),
|
||||
Key: jest.fn(),
|
||||
@@ -576,7 +576,11 @@ describe('BaseClient', () => {
|
||||
const onStart = jest.fn();
|
||||
const opts = { onStart };
|
||||
await TestClient.sendMessage('Hello, world!', opts);
|
||||
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
|
||||
|
||||
expect(onStart).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ text: 'Hello, world!' }),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
||||
@@ -627,5 +631,32 @@ describe('BaseClient', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('userMessagePromise is awaited before saving response message', async () => {
|
||||
// Mock the saveMessageToDatabase method
|
||||
TestClient.saveMessageToDatabase = jest.fn().mockImplementation(() => {
|
||||
return new Promise((resolve) => setTimeout(resolve, 100)); // Simulate a delay
|
||||
});
|
||||
|
||||
// Send a message
|
||||
const messagePromise = TestClient.sendMessage('Hello, world!');
|
||||
|
||||
// Wait a short time to ensure the user message save has started
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Check that saveMessageToDatabase has been called once (for the user message)
|
||||
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Wait for the message to be fully processed
|
||||
await messagePromise;
|
||||
|
||||
// Check that saveMessageToDatabase has been called twice (once for user message, once for response)
|
||||
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Check the order of calls
|
||||
const calls = TestClient.saveMessageToDatabase.mock.calls;
|
||||
expect(calls[0][0].isCreatedByUser).toBe(true); // First call should be for user message
|
||||
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -144,6 +144,7 @@ describe('OpenAIClient', () => {
|
||||
|
||||
const defaultOptions = {
|
||||
// debug: true,
|
||||
req: {},
|
||||
openaiApiKey: 'new-api-key',
|
||||
modelOptions: {
|
||||
model,
|
||||
|
||||
@@ -194,6 +194,7 @@ describe('PluginsClient', () => {
|
||||
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Azure OpenAI tests specific to Plugins', () => {
|
||||
// TODO: add more tests for Azure OpenAI integration with Plugins
|
||||
// let client;
|
||||
@@ -220,4 +221,94 @@ describe('PluginsClient', () => {
|
||||
spy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage with filtered tools', () => {
|
||||
let TestAgent;
|
||||
const apiKey = 'fake-api-key';
|
||||
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
|
||||
|
||||
beforeEach(() => {
|
||||
TestAgent = new PluginsClient(apiKey, {
|
||||
tools: mockTools,
|
||||
modelOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
temperature: 0,
|
||||
max_tokens: 2,
|
||||
},
|
||||
agentOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
},
|
||||
});
|
||||
|
||||
TestAgent.options.req = {
|
||||
app: {
|
||||
locals: {},
|
||||
},
|
||||
};
|
||||
|
||||
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
|
||||
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = TestAgent.options.tools.filter((plugin) =>
|
||||
includedTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
} else {
|
||||
const tools = TestAgent.options.tools.filter(
|
||||
(plugin) => !filteredTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
}
|
||||
|
||||
return {
|
||||
text: 'Mocked response',
|
||||
tools: TestAgent.options.tools,
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
test('should filter out tools when filteredTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should only include specified tools when includedTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should prioritize includedTools over filteredTools', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool1' }),
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should not modify tools when no filters are provided', async () => {
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(4);
|
||||
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -80,13 +80,18 @@ class StableDiffusionAPI extends StructuredTool {
|
||||
const payload = {
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sampler_index: 'DPM++ 2M Karras',
|
||||
cfg_scale: 4.5,
|
||||
steps: 22,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
};
|
||||
const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
let generationResponse;
|
||||
try {
|
||||
generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
} catch (error) {
|
||||
logger.error('[StableDiffusion] Error while generating image:', error);
|
||||
return 'Error making API request.';
|
||||
}
|
||||
const image = generationResponse.data.images[0];
|
||||
|
||||
/** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */
|
||||
|
||||
34
api/cache/getLogStores.js
vendored
34
api/cache/getLogStores.js
vendored
@@ -1,12 +1,11 @@
|
||||
const Keyv = require('keyv');
|
||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const { math, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||
const THIRTY_MINUTES = 1800000;
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
|
||||
@@ -24,13 +23,25 @@ const config = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const roles = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const messages = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES });
|
||||
|
||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: 120000 })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 });
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
@@ -38,9 +49,10 @@ const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
|
||||
const abortKeys = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const namespaces = {
|
||||
[CacheKeys.ROLES]: roles,
|
||||
[CacheKeys.CONFIG_STORE]: config,
|
||||
pending_req,
|
||||
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
||||
@@ -55,7 +67,13 @@ const namespaces = {
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.VERIFY_EMAIL_LIMIT]: createViolationInstance(ViolationTypes.VERIFY_EMAIL_LIMIT),
|
||||
[ViolationTypes.RESET_PASSWORD_LIMIT]: createViolationInstance(
|
||||
ViolationTypes.RESET_PASSWORD_LIMIT,
|
||||
),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
),
|
||||
@@ -64,6 +82,8 @@ const namespaces = {
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
2
api/cache/logViolation.js
vendored
2
api/cache/logViolation.js
vendored
@@ -1,6 +1,6 @@
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const getLogStores = require('./getLogStores');
|
||||
const banViolation = require('./banViolation');
|
||||
const { isEnabled } = require('../server/utils');
|
||||
|
||||
/**
|
||||
* Logs the violation.
|
||||
|
||||
@@ -27,26 +27,25 @@ function getMatchingSensitivePatterns(valueStr) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Redacts sensitive information from a console message.
|
||||
*
|
||||
* Redacts sensitive information from a console message and trims it to a specified length if provided.
|
||||
* @param {string} str - The console message to be redacted.
|
||||
* @returns {string} - The redacted console message.
|
||||
* @param {number} [trimLength] - The optional length at which to trim the redacted message.
|
||||
* @returns {string} - The redacted and optionally trimmed console message.
|
||||
*/
|
||||
function redactMessage(str) {
|
||||
function redactMessage(str, trimLength) {
|
||||
if (!str) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const patterns = getMatchingSensitivePatterns(str);
|
||||
|
||||
if (patterns.length === 0) {
|
||||
return str;
|
||||
}
|
||||
|
||||
patterns.forEach((pattern) => {
|
||||
str = str.replace(pattern, '$1[REDACTED]');
|
||||
});
|
||||
|
||||
if (trimLength !== undefined && str.length > trimLength) {
|
||||
return `${str.substring(0, trimLength)}...`;
|
||||
}
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
@@ -110,6 +109,14 @@ const condenseArray = (item) => {
|
||||
* @returns {string} - The formatted log message.
|
||||
*/
|
||||
const debugTraverse = winston.format.printf(({ level, message, timestamp, ...metadata }) => {
|
||||
if (!message) {
|
||||
return `${timestamp} ${level}`;
|
||||
}
|
||||
|
||||
if (!message?.trim || typeof message !== 'string') {
|
||||
return `${timestamp} ${level}: ${JSON.stringify(message)}`;
|
||||
}
|
||||
|
||||
let msg = `${timestamp} ${level}: ${truncateLongStrings(message?.trim(), 150)}`;
|
||||
try {
|
||||
if (level !== 'debug') {
|
||||
|
||||
@@ -62,8 +62,24 @@ const deleteAction = async (searchParams, session = null) => {
|
||||
return await Action.findOneAndDelete(searchParams, options).lean();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAction,
|
||||
getActions,
|
||||
deleteAction,
|
||||
/**
|
||||
* Deletes actions by params, within a transaction session if provided.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the actions to delete.
|
||||
* @param {string} searchParams.action_id - The ID of the action(s) to delete.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
||||
* @returns {Promise<Number>} A promise that resolves to the number of deleted action documents.
|
||||
*/
|
||||
const deleteActions = async (searchParams, session = null) => {
|
||||
const options = session ? { session } : {};
|
||||
const result = await Action.deleteMany(searchParams, options);
|
||||
return result.deletedCount;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getActions,
|
||||
updateAction,
|
||||
deleteAction,
|
||||
deleteActions,
|
||||
};
|
||||
|
||||
@@ -14,7 +14,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
|
||||
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
||||
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
|
||||
*/
|
||||
const updateAssistant = async (searchParams, updateData, session = null) => {
|
||||
const updateAssistantDoc = async (searchParams, updateData, session = null) => {
|
||||
const options = { new: true, upsert: true, session };
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
@@ -39,8 +39,21 @@ const getAssistants = async (searchParams) => {
|
||||
return await Assistant.find(searchParams).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an assistant based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the assistant to delete.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to delete.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<void>} Resolves when the assistant has been successfully deleted.
|
||||
*/
|
||||
const deleteAssistant = async (searchParams) => {
|
||||
return await Assistant.findOneAndDelete(searchParams);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAssistant,
|
||||
updateAssistantDoc,
|
||||
deleteAssistant,
|
||||
getAssistants,
|
||||
getAssistant,
|
||||
};
|
||||
|
||||
61
api/models/Categories.js
Normal file
61
api/models/Categories.js
Normal file
@@ -0,0 +1,61 @@
|
||||
const { logger } = require('~/config');
|
||||
// const { Categories } = require('./schema/categories');
|
||||
const options = [
|
||||
{
|
||||
label: '',
|
||||
value: '',
|
||||
},
|
||||
{
|
||||
label: 'idea',
|
||||
value: 'idea',
|
||||
},
|
||||
{
|
||||
label: 'travel',
|
||||
value: 'travel',
|
||||
},
|
||||
{
|
||||
label: 'teach_or_explain',
|
||||
value: 'teach_or_explain',
|
||||
},
|
||||
{
|
||||
label: 'write',
|
||||
value: 'write',
|
||||
},
|
||||
{
|
||||
label: 'shop',
|
||||
value: 'shop',
|
||||
},
|
||||
{
|
||||
label: 'code',
|
||||
value: 'code',
|
||||
},
|
||||
{
|
||||
label: 'misc',
|
||||
value: 'misc',
|
||||
},
|
||||
{
|
||||
label: 'roleplay',
|
||||
value: 'roleplay',
|
||||
},
|
||||
{
|
||||
label: 'finance',
|
||||
value: 'finance',
|
||||
},
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
/**
|
||||
* Retrieves the categories asynchronously.
|
||||
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
|
||||
* @throws {Error} If there is an error retrieving the categories.
|
||||
*/
|
||||
getCategories: async () => {
|
||||
try {
|
||||
// const categories = await Categories.find();
|
||||
return options;
|
||||
} catch (error) {
|
||||
logger.error('Error getting categories', error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -19,20 +19,39 @@ const getConvo = async (user, conversationId) => {
|
||||
|
||||
module.exports = {
|
||||
Conversation,
|
||||
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
||||
/**
|
||||
* Saves a conversation to the database.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @param {Object} metadata - Additional metadata to log for operation.
|
||||
* @returns {Promise<TConversation>} The conversation object.
|
||||
*/
|
||||
saveConvo: async (req, { conversationId, newConversationId, ...convo }, metadata) => {
|
||||
try {
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...convo, messages, user };
|
||||
if (metadata && metadata?.context) {
|
||||
logger.debug(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update = { ...convo, messages, user: req.user.id };
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
});
|
||||
const conversation = await Conversation.findOneAndUpdate(
|
||||
{ conversationId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
upsert: true,
|
||||
},
|
||||
);
|
||||
|
||||
return conversation.toObject();
|
||||
} catch (error) {
|
||||
logger.error('[saveConvo] Error saving conversation', error);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
return { message: 'Error saving conversation' };
|
||||
}
|
||||
},
|
||||
@@ -54,13 +73,16 @@ module.exports = {
|
||||
throw new Error('Failed to save conversations in bulk.');
|
||||
}
|
||||
},
|
||||
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false) => {
|
||||
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false, tags) => {
|
||||
const query = { user };
|
||||
if (isArchived) {
|
||||
query.isArchived = true;
|
||||
} else {
|
||||
query.$or = [{ isArchived: false }, { isArchived: { $exists: false } }];
|
||||
}
|
||||
if (Array.isArray(tags) && tags.length > 0) {
|
||||
query.tags = { $in: tags };
|
||||
}
|
||||
try {
|
||||
const totalConvos = (await Conversation.countDocuments(query)) || 1;
|
||||
const totalPages = Math.ceil(totalConvos / pageSize);
|
||||
|
||||
268
api/models/ConversationTag.js
Normal file
268
api/models/ConversationTag.js
Normal file
@@ -0,0 +1,268 @@
|
||||
//const crypto = require('crypto');
|
||||
|
||||
const logger = require('~/config/winston');
|
||||
const Conversation = require('./schema/convoSchema');
|
||||
const ConversationTag = require('./schema/conversationTagSchema');
|
||||
|
||||
const SAVED_TAG = 'Saved';
|
||||
|
||||
const updateTagsForConversation = async (user, conversationId, tags) => {
|
||||
try {
|
||||
const conversation = await Conversation.findOne({ user, conversationId });
|
||||
if (!conversation) {
|
||||
return { message: 'Conversation not found' };
|
||||
}
|
||||
|
||||
const addedTags = tags.tags.filter((tag) => !conversation.tags.includes(tag));
|
||||
const removedTags = conversation.tags.filter((tag) => !tags.tags.includes(tag));
|
||||
for (const tag of addedTags) {
|
||||
await ConversationTag.updateOne({ tag, user }, { $inc: { count: 1 } }, { upsert: true });
|
||||
}
|
||||
for (const tag of removedTags) {
|
||||
await ConversationTag.updateOne({ tag, user }, { $inc: { count: -1 } });
|
||||
}
|
||||
conversation.tags = tags.tags;
|
||||
await conversation.save({ timestamps: { updatedAt: false } });
|
||||
return conversation.tags;
|
||||
} catch (error) {
|
||||
logger.error('[updateTagsToConversation] Error updating tags', error);
|
||||
return { message: 'Error updating tags' };
|
||||
}
|
||||
};
|
||||
|
||||
const createConversationTag = async (user, data) => {
|
||||
try {
|
||||
const cTag = await ConversationTag.findOne({ user, tag: data.tag });
|
||||
if (cTag) {
|
||||
return cTag;
|
||||
}
|
||||
|
||||
const addToConversation = data.addToConversation && data.conversationId;
|
||||
const newTag = await ConversationTag.create({
|
||||
user,
|
||||
tag: data.tag,
|
||||
count: 0,
|
||||
description: data.description,
|
||||
position: 1,
|
||||
});
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{ user, position: { $gte: 1 }, _id: { $ne: newTag._id } },
|
||||
{ $inc: { position: 1 } },
|
||||
);
|
||||
|
||||
if (addToConversation) {
|
||||
const conversation = await Conversation.findOne({
|
||||
user,
|
||||
conversationId: data.conversationId,
|
||||
});
|
||||
if (conversation) {
|
||||
const tags = [...(conversation.tags || []), data.tag];
|
||||
await updateTagsForConversation(user, data.conversationId, { tags });
|
||||
} else {
|
||||
logger.warn('[updateTagsForConversation] Conversation not found', data.conversationId);
|
||||
}
|
||||
}
|
||||
|
||||
return await ConversationTag.findOne({ user, tag: data.tag });
|
||||
} catch (error) {
|
||||
logger.error('[createConversationTag] Error updating conversation tag', error);
|
||||
return { message: 'Error updating conversation tag' };
|
||||
}
|
||||
};
|
||||
|
||||
const replaceOrRemoveTagInConversations = async (user, oldtag, newtag) => {
|
||||
try {
|
||||
const conversations = await Conversation.find({ user, tags: { $in: [oldtag] } });
|
||||
for (const conversation of conversations) {
|
||||
if (newtag && newtag !== '') {
|
||||
conversation.tags = conversation.tags.map((tag) => (tag === oldtag ? newtag : tag));
|
||||
} else {
|
||||
conversation.tags = conversation.tags.filter((tag) => tag !== oldtag);
|
||||
}
|
||||
await conversation.save({ timestamps: { updatedAt: false } });
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[replaceOrRemoveTagInConversations] Error updating conversation tags', error);
|
||||
return { message: 'Error updating conversation tags' };
|
||||
}
|
||||
};
|
||||
|
||||
const updateTagPosition = async (user, tag, newPosition) => {
|
||||
try {
|
||||
const cTag = await ConversationTag.findOne({ user, tag });
|
||||
if (!cTag) {
|
||||
return { message: 'Tag not found' };
|
||||
}
|
||||
|
||||
const oldPosition = cTag.position;
|
||||
|
||||
if (newPosition === oldPosition) {
|
||||
return cTag;
|
||||
}
|
||||
|
||||
const updateOperations = [];
|
||||
|
||||
if (newPosition > oldPosition) {
|
||||
// Move other tags up
|
||||
updateOperations.push({
|
||||
updateMany: {
|
||||
filter: {
|
||||
user,
|
||||
position: { $gt: oldPosition, $lte: newPosition },
|
||||
tag: { $ne: SAVED_TAG },
|
||||
},
|
||||
update: { $inc: { position: -1 } },
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// Move other tags down
|
||||
updateOperations.push({
|
||||
updateMany: {
|
||||
filter: {
|
||||
user,
|
||||
position: { $gte: newPosition, $lt: oldPosition },
|
||||
tag: { $ne: SAVED_TAG },
|
||||
},
|
||||
update: { $inc: { position: 1 } },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Update the target tag's position
|
||||
updateOperations.push({
|
||||
updateOne: {
|
||||
filter: { _id: cTag._id },
|
||||
update: { $set: { position: newPosition } },
|
||||
},
|
||||
});
|
||||
|
||||
await ConversationTag.bulkWrite(updateOperations);
|
||||
|
||||
return await ConversationTag.findById(cTag._id);
|
||||
} catch (error) {
|
||||
logger.error('[updateTagPosition] Error updating tag position', error);
|
||||
return { message: 'Error updating tag position' };
|
||||
}
|
||||
};
|
||||
module.exports = {
|
||||
SAVED_TAG,
|
||||
ConversationTag,
|
||||
getConversationTags: async (user) => {
|
||||
try {
|
||||
const cTags = await ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
cTags.sort((a, b) => (a.tag === SAVED_TAG ? -1 : b.tag === SAVED_TAG ? 1 : 0));
|
||||
|
||||
return cTags;
|
||||
} catch (error) {
|
||||
logger.error('[getShare] Error getting share link', error);
|
||||
return { message: 'Error getting share link' };
|
||||
}
|
||||
},
|
||||
|
||||
createConversationTag,
|
||||
updateConversationTag: async (user, tag, data) => {
|
||||
try {
|
||||
const cTag = await ConversationTag.findOne({ user, tag });
|
||||
if (!cTag) {
|
||||
return createConversationTag(user, data);
|
||||
}
|
||||
|
||||
if (cTag.tag !== data.tag || cTag.description !== data.description) {
|
||||
cTag.tag = data.tag;
|
||||
cTag.description = data.description === undefined ? cTag.description : data.description;
|
||||
await cTag.save();
|
||||
}
|
||||
|
||||
if (data.position !== undefined && cTag.position !== data.position) {
|
||||
await updateTagPosition(user, tag, data.position);
|
||||
}
|
||||
|
||||
// update conversation tags properties
|
||||
replaceOrRemoveTagInConversations(user, tag, data.tag);
|
||||
return await ConversationTag.findOne({ user, tag: data.tag });
|
||||
} catch (error) {
|
||||
logger.error('[updateConversationTag] Error updating conversation tag', error);
|
||||
return { message: 'Error updating conversation tag' };
|
||||
}
|
||||
},
|
||||
|
||||
deleteConversationTag: async (user, tag) => {
|
||||
try {
|
||||
const currentTag = await ConversationTag.findOne({ user, tag });
|
||||
if (!currentTag) {
|
||||
return;
|
||||
}
|
||||
|
||||
await currentTag.deleteOne({ user, tag });
|
||||
|
||||
await replaceOrRemoveTagInConversations(user, tag, null);
|
||||
return currentTag;
|
||||
} catch (error) {
|
||||
logger.error('[deleteConversationTag] Error deleting conversation tag', error);
|
||||
return { message: 'Error deleting conversation tag' };
|
||||
}
|
||||
},
|
||||
|
||||
updateTagsForConversation,
|
||||
rebuildConversationTags: async (user) => {
|
||||
try {
|
||||
const conversations = await Conversation.find({ user }).select('tags');
|
||||
const tagCountMap = {};
|
||||
|
||||
// Count the occurrences of each tag
|
||||
conversations.forEach((conversation) => {
|
||||
conversation.tags.forEach((tag) => {
|
||||
if (tagCountMap[tag]) {
|
||||
tagCountMap[tag]++;
|
||||
} else {
|
||||
tagCountMap[tag] = 1;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const tags = await ConversationTag.find({ user }).sort({ position: -1 });
|
||||
|
||||
// Update existing tags and add new tags
|
||||
for (const [tag, count] of Object.entries(tagCountMap)) {
|
||||
const existingTag = tags.find((t) => t.tag === tag);
|
||||
if (existingTag) {
|
||||
existingTag.count = count;
|
||||
await existingTag.save();
|
||||
} else {
|
||||
const newTag = new ConversationTag({ user, tag, count });
|
||||
tags.push(newTag);
|
||||
await newTag.save();
|
||||
}
|
||||
}
|
||||
|
||||
// Set count to 0 for tags that are not in the grouped tags
|
||||
for (const tag of tags) {
|
||||
if (!tagCountMap[tag.tag]) {
|
||||
tag.count = 0;
|
||||
await tag.save();
|
||||
}
|
||||
}
|
||||
|
||||
// Sort tags by position in descending order
|
||||
tags.sort((a, b) => a.position - b.position);
|
||||
|
||||
// Move the tag with name "saved" to the first position
|
||||
const savedTagIndex = tags.findIndex((tag) => tag.tag === SAVED_TAG);
|
||||
if (savedTagIndex !== -1) {
|
||||
const [savedTag] = tags.splice(savedTagIndex, 1);
|
||||
tags.unshift(savedTag);
|
||||
}
|
||||
|
||||
// Reassign positions starting from 0
|
||||
tags.forEach((tag, index) => {
|
||||
tag.position = index;
|
||||
tag.save();
|
||||
});
|
||||
return tags;
|
||||
} catch (error) {
|
||||
logger.error('[rearrangeTags] Error rearranging tags', error);
|
||||
return { message: 'Error rearranging tags' };
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -97,8 +97,12 @@ const deleteFileByFilter = async (filter) => {
|
||||
* @param {Array<string>} file_ids - The unique identifiers of the files to delete.
|
||||
* @returns {Promise<Object>} A promise that resolves to the result of the deletion operation.
|
||||
*/
|
||||
const deleteFiles = async (file_ids) => {
|
||||
return await File.deleteMany({ file_id: { $in: file_ids } });
|
||||
const deleteFiles = async (file_ids, user) => {
|
||||
let deleteQuery = { file_id: { $in: file_ids } };
|
||||
if (user) {
|
||||
deleteQuery = { user: user };
|
||||
}
|
||||
return await File.deleteMany(deleteQuery);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -1,191 +1,342 @@
|
||||
const { z } = require('zod');
|
||||
const Message = require('./schema/messageSchema');
|
||||
const logger = require('~/config/winston');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
/**
|
||||
* Saves a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function saveMessage
|
||||
* @param {Express.Request} req - The request object containing user information.
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.iconURL - The URL of the sender's icon.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.newMessageId - The new unique identifier for the message (if applicable).
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {string} params.sender - The identifier of the sender.
|
||||
* @param {string} params.text - The text content of the message.
|
||||
* @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user.
|
||||
* @param {string} [params.error] - Any error associated with the message.
|
||||
* @param {boolean} [params.unfinished] - Indicates if the message is unfinished.
|
||||
* @param {Object[]} [params.files] - An array of files associated with the message.
|
||||
* @param {boolean} [params.isEdited] - Indicates if the message was edited.
|
||||
* @param {string} [params.finish_reason] - Reason for finishing the message.
|
||||
* @param {number} [params.tokenCount] - The number of tokens in the message.
|
||||
* @param {string} [params.plugin] - Plugin associated with the message.
|
||||
* @param {string[]} [params.plugins] - An array of plugins associated with the message.
|
||||
* @param {string} [params.model] - The model used to generate the message.
|
||||
* @param {Object} [metadata] - Additional metadata for this operation
|
||||
* @param {string} [metadata.context] - The context of the operation
|
||||
* @returns {Promise<TMessage>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function saveMessage(req, params, metadata) {
|
||||
try {
|
||||
if (!req || !req.user || !req.user.id) {
|
||||
throw new Error('User not authenticated');
|
||||
}
|
||||
|
||||
const {
|
||||
text,
|
||||
error,
|
||||
model,
|
||||
files,
|
||||
plugin,
|
||||
sender,
|
||||
plugins,
|
||||
iconURL,
|
||||
endpoint,
|
||||
isEdited,
|
||||
messageId,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
newMessageId,
|
||||
finish_reason,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
isCreatedByUser,
|
||||
} = params;
|
||||
|
||||
const validConvoId = idSchema.safeParse(conversationId);
|
||||
if (!validConvoId.success) {
|
||||
logger.warn(`Invalid conversation ID: ${conversationId}`);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`saveMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
|
||||
logger.info(`---Invalid conversation ID Params:
|
||||
|
||||
${JSON.stringify(params, null, 2)}
|
||||
|
||||
`);
|
||||
return;
|
||||
}
|
||||
|
||||
const update = {
|
||||
user: req.user.id,
|
||||
iconURL,
|
||||
endpoint,
|
||||
messageId: newMessageId || messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
error,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
};
|
||||
|
||||
if (files) {
|
||||
update.files = files;
|
||||
}
|
||||
|
||||
const message = await Message.findOneAndUpdate({ messageId, user: req.user.id }, update, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
|
||||
return message.toObject();
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`saveMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves multiple messages in the database in bulk.
|
||||
*
|
||||
* @async
|
||||
* @function bulkSaveMessages
|
||||
* @param {Object[]} messages - An array of message objects to save.
|
||||
* @returns {Promise<Object>} The result of the bulk write operation.
|
||||
* @throws {Error} If there is an error in saving messages in bulk.
|
||||
*/
|
||||
async function bulkSaveMessages(messages) {
|
||||
try {
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function recordMessage({
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest
|
||||
}) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error recording message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the text of a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessageText
|
||||
* @param {Object} params - The update data object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.text - The new text content of the message.
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} If there is an error in updating the message text.
|
||||
*/
|
||||
async function updateMessageText(req, { messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId, user: req.user.id }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessage
|
||||
* @param {Object} message - The message object containing update data.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} message.messageId - The unique identifier for the message.
|
||||
* @param {string} [message.text] - The new text content of the message.
|
||||
* @param {Object[]} [message.files] - The files associated with the message.
|
||||
* @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user.
|
||||
* @param {string} [message.sender] - The identifier of the sender.
|
||||
* @param {number} [message.tokenCount] - The number of tokens in the message.
|
||||
* @param {Object} [metadata] - The operation metadata
|
||||
* @param {string} [metadata.context] - The operation metadata
|
||||
* @returns {Promise<TMessage>} The updated message document.
|
||||
* @throws {Error} If there is an error in updating the message or if the message is not found.
|
||||
*/
|
||||
async function updateMessage(req, message, metadata) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
update.isEdited = true;
|
||||
const updatedMessage = await Message.findOneAndUpdate(
|
||||
{ messageId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found or user not authorized.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
isEdited: true,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`updateMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages in a conversation since a specific message.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessagesSince
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @returns {Promise<Number>} The number of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessagesSince(req, { messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
|
||||
|
||||
if (message) {
|
||||
const query = Message.find({ conversationId, user: req.user.id });
|
||||
return await query.deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
return undefined;
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @async
|
||||
* @function getMessages
|
||||
* @param {Record<string, unknown>} filter - The filter criteria.
|
||||
* @param {string | undefined} [select] - The fields to select.
|
||||
* @returns {Promise<TMessage[]>} The messages that match the filter criteria.
|
||||
* @throws {Error} If there is an error in retrieving messages.
|
||||
*/
|
||||
async function getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages from the database.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessages
|
||||
* @param {Object} filter - The filter criteria to find messages to delete.
|
||||
* @returns {Promise<Number>} The number of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Message,
|
||||
|
||||
async saveMessage({
|
||||
user,
|
||||
endpoint,
|
||||
iconURL,
|
||||
messageId,
|
||||
newMessageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
error,
|
||||
unfinished,
|
||||
files,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
}) {
|
||||
try {
|
||||
const validConvoId = idSchema.safeParse(conversationId);
|
||||
if (!validConvoId.success) {
|
||||
return;
|
||||
}
|
||||
|
||||
const update = {
|
||||
user,
|
||||
iconURL,
|
||||
endpoint,
|
||||
messageId: newMessageId || messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
error,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
};
|
||||
|
||||
if (files) {
|
||||
update.files = files;
|
||||
}
|
||||
// may also need to update the conversation here
|
||||
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
|
||||
|
||||
return {
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
tokenCount,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
|
||||
async bulkSaveMessages(messages) {
|
||||
try {
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw new Error('Failed to save messages in bulk.');
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
async updateMessage(message) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
update.isEdited = true;
|
||||
const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, {
|
||||
new: true,
|
||||
});
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
isEdited: true,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
throw new Error('Failed to update message.');
|
||||
}
|
||||
},
|
||||
async deleteMessagesSince({ messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId }).lean();
|
||||
|
||||
if (message) {
|
||||
return await Message.find({ conversationId }).deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw new Error('Failed to delete messages.');
|
||||
}
|
||||
},
|
||||
|
||||
async getMessages(filter) {
|
||||
try {
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw new Error('Failed to get messages.');
|
||||
}
|
||||
},
|
||||
|
||||
async deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw new Error('Failed to delete messages.');
|
||||
}
|
||||
},
|
||||
saveMessage,
|
||||
bulkSaveMessages,
|
||||
recordMessage,
|
||||
updateMessageText,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
getMessages,
|
||||
deleteMessages,
|
||||
};
|
||||
|
||||
239
api/models/Message.spec.js
Normal file
239
api/models/Message.spec.js
Normal file
@@ -0,0 +1,239 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
|
||||
jest.mock('mongoose');
|
||||
|
||||
const mockFindQuery = {
|
||||
select: jest.fn().mockReturnThis(),
|
||||
sort: jest.fn().mockReturnThis(),
|
||||
lean: jest.fn().mockReturnThis(),
|
||||
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
|
||||
};
|
||||
|
||||
const mockSchema = {
|
||||
findOneAndUpdate: jest.fn(),
|
||||
updateOne: jest.fn(),
|
||||
findOne: jest.fn(() => ({
|
||||
lean: jest.fn(),
|
||||
})),
|
||||
find: jest.fn(() => mockFindQuery),
|
||||
deleteMany: jest.fn(),
|
||||
};
|
||||
|
||||
mongoose.model.mockReturnValue(mockSchema);
|
||||
|
||||
jest.mock('~/models/schema/messageSchema', () => mockSchema);
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
}));
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
updateMessage,
|
||||
deleteMessages,
|
||||
updateMessageText,
|
||||
deleteMessagesSince,
|
||||
} = require('~/models/Message');
|
||||
|
||||
describe('Message Operations', () => {
|
||||
let mockReq;
|
||||
let mockMessage;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
};
|
||||
|
||||
mockMessage = {
|
||||
messageId: 'msg123',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Hello, world!',
|
||||
user: 'user123',
|
||||
};
|
||||
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue({
|
||||
toObject: () => mockMessage,
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveMessage', () => {
|
||||
it('should save a message for an authenticated user', async () => {
|
||||
const result = await saveMessage(mockReq, mockMessage);
|
||||
expect(result).toEqual(mockMessage);
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: 'msg123', user: 'user123' },
|
||||
expect.objectContaining({ user: 'user123' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error for unauthenticated user', async () => {
|
||||
mockReq.user = null;
|
||||
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
|
||||
});
|
||||
|
||||
it('should throw an error for invalid conversation ID', async () => {
|
||||
mockMessage.conversationId = 'invalid-id';
|
||||
await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessageText', () => {
|
||||
it('should update message text for the authenticated user', async () => {
|
||||
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
|
||||
expect(mockSchema.updateOne).toHaveBeenCalledWith(
|
||||
{ messageId: 'msg123', user: 'user123' },
|
||||
{ text: 'Updated text' },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessage', () => {
|
||||
it('should update a message for the authenticated user', async () => {
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
|
||||
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({
|
||||
messageId: 'msg123',
|
||||
text: 'Hello, world!',
|
||||
isEdited: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if message is not found', async () => {
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(null);
|
||||
await expect(
|
||||
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessagesSince', () => {
|
||||
it('should delete messages only for the authenticated user', async () => {
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
|
||||
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
|
||||
const result = await deleteMessagesSince(mockReq, {
|
||||
messageId: 'msg123',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
|
||||
expect(mockSchema.find).not.toHaveBeenCalled();
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if no message is found', async () => {
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce(null);
|
||||
const result = await deleteMessagesSince(mockReq, {
|
||||
messageId: 'nonexistent',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessages', () => {
|
||||
it('should retrieve messages with the correct filter', async () => {
|
||||
const filter = { conversationId: 'convo123' };
|
||||
await getMessages(filter);
|
||||
expect(mockSchema.find).toHaveBeenCalledWith(filter);
|
||||
expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
|
||||
expect(mockFindQuery.lean).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessages', () => {
|
||||
it('should delete messages with the correct filter', async () => {
|
||||
await deleteMessages({ user: 'user123' });
|
||||
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('Conversation Hijacking Prevention', () => {
|
||||
it('should not allow editing a message in another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
updateMessage(attackerReq, {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
text: 'Hacked message',
|
||||
}),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: victimMessageId, user: 'attacker123' },
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow deleting messages from another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
|
||||
const result = await deleteMessagesSince(attackerReq, {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(mockSchema.findOne).toHaveBeenCalledWith({
|
||||
messageId: victimMessageId,
|
||||
user: 'attacker123',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not allow inserting a new message into another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = uuidv4(); // Use a valid UUID
|
||||
|
||||
await expect(
|
||||
saveMessage(attackerReq, {
|
||||
conversationId: victimConversationId,
|
||||
text: 'Inserted malicious message',
|
||||
messageId: 'new-msg-123',
|
||||
}),
|
||||
).resolves.not.toThrow(); // It should not throw an error
|
||||
|
||||
// Check that the message was saved with the attacker's user ID
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: 'new-msg-123', user: 'attacker123' },
|
||||
expect.objectContaining({
|
||||
user: 'attacker123',
|
||||
conversationId: victimConversationId,
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow retrieving messages from any conversation', async () => {
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
|
||||
await getMessages({ conversationId: victimConversationId });
|
||||
|
||||
expect(mockSchema.find).toHaveBeenCalledWith({
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
mockSchema.find.mockReturnValueOnce({
|
||||
select: jest.fn().mockReturnThis(),
|
||||
sort: jest.fn().mockReturnThis(),
|
||||
lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
|
||||
});
|
||||
|
||||
const result = await getMessages({ conversationId: victimConversationId });
|
||||
expect(result).toEqual([{ text: 'Test message' }]);
|
||||
});
|
||||
});
|
||||
});
|
||||
90
api/models/Project.js
Normal file
90
api/models/Project.js
Normal file
@@ -0,0 +1,90 @@
|
||||
const { model } = require('mongoose');
|
||||
const projectSchema = require('~/models/schema/projectSchema');
|
||||
|
||||
const Project = model('Project', projectSchema);
|
||||
|
||||
/**
|
||||
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to find and return as a plain object.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||
*/
|
||||
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||
const query = Project.findById(projectId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a project by name and convert the found project document to a plain object.
|
||||
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
|
||||
*
|
||||
* @param {string} projectName - The name of the project to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document.
|
||||
*/
|
||||
const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
||||
const query = { name: projectName };
|
||||
const update = { $setOnInsert: { name: projectName } };
|
||||
const options = {
|
||||
new: true,
|
||||
upsert: projectName === 'instance',
|
||||
lean: true,
|
||||
select: fieldsToSelect,
|
||||
};
|
||||
|
||||
return await Project.findOneAndUpdate(query, update, options);
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of prompt group IDs from a project's promptGroupIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove a prompt group ID from all projects.
|
||||
*
|
||||
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getProjectById,
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
};
|
||||
@@ -1,52 +1,528 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { SystemRoles, SystemCategories } = require('librechat-data-provider');
|
||||
const {
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
} = require('./Project');
|
||||
const { Prompt, PromptGroup } = require('./schema/promptSchema');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const promptSchema = mongoose.Schema(
|
||||
{
|
||||
title: {
|
||||
type: String,
|
||||
required: true,
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
* @param {Object} query
|
||||
* @param {number} skip
|
||||
* @param {number} limit
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createGroupPipeline = (query, skip, limit) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{ $skip: skip },
|
||||
{ $limit: limit },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
// 'productionPrompt._id': 1,
|
||||
// 'productionPrompt.type': 1,
|
||||
},
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
];
|
||||
};
|
||||
|
||||
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get all prompt groups
|
||||
* @param {Object} query
|
||||
* @param {Partial<MongoPromptGroup>} $project
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createAllGroupsPipeline = (
|
||||
query,
|
||||
$project = {
|
||||
name: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
command: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project,
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all prompt groups with filters
|
||||
* @param {Object} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getAllPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { name, ...query } = filter;
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(name, 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||
if (project && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {Object} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
|
||||
|
||||
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(name, 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
// const projects = req.user.projects || []; // TODO: handle multiple projects
|
||||
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||
if (project && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||
const limit = validatedPageSize;
|
||||
|
||||
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
|
||||
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
|
||||
|
||||
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
|
||||
PromptGroup.aggregate(promptGroupsPipeline).exec(),
|
||||
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
|
||||
]);
|
||||
|
||||
const promptGroups = promptGroupsResults;
|
||||
const totalPromptGroups =
|
||||
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: validatedPageNumber.toString(),
|
||||
pageSize: validatedPageSize.toString(),
|
||||
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
savePrompt: async ({ title, prompt }) => {
|
||||
getPromptGroups,
|
||||
getAllPromptGroups,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
createPromptGroup: async (saveData) => {
|
||||
try {
|
||||
await Prompt.create({
|
||||
title,
|
||||
prompt,
|
||||
});
|
||||
return { title, prompt };
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
const newPrompt = await Prompt.findOneAndUpdate(
|
||||
{ ...prompt, author, groupId: newPromptGroup._id },
|
||||
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
newPromptGroup = await PromptGroup.findByIdAndUpdate(
|
||||
newPromptGroup._id,
|
||||
{ productionId: newPrompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
group: {
|
||||
...newPromptGroup,
|
||||
productionPrompt: { prompt: newPrompt.prompt },
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt group', error);
|
||||
throw new Error('Error saving prompt group');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Save a prompt
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
savePrompt: async (saveData) => {
|
||||
try {
|
||||
const { prompt, author } = saveData;
|
||||
const newPromptData = {
|
||||
...prompt,
|
||||
author,
|
||||
};
|
||||
|
||||
/** @type {TPrompt} */
|
||||
let newPrompt;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error) {
|
||||
if (error?.message?.includes('groupId_1_version_1')) {
|
||||
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
}
|
||||
|
||||
return { prompt: newPrompt };
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt', error);
|
||||
return { prompt: 'Error saving prompt' };
|
||||
return { message: 'Error saving prompt' };
|
||||
}
|
||||
},
|
||||
getPrompts: async (filter) => {
|
||||
try {
|
||||
return await Prompt.find(filter).lean();
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { prompt: 'Error getting prompts' };
|
||||
return { message: 'Error getting prompts' };
|
||||
}
|
||||
},
|
||||
deletePrompts: async (filter) => {
|
||||
getPrompt: async (filter) => {
|
||||
try {
|
||||
return await Prompt.deleteMany(filter);
|
||||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompts', error);
|
||||
return { prompt: 'Error deleting prompts' };
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {TGetRandomPromptsRequest} filter
|
||||
* @returns {Promise<TGetRandomPromptsResponse>}
|
||||
*/
|
||||
getRandomPromptGroups: async (filter) => {
|
||||
try {
|
||||
const result = await PromptGroup.aggregate([
|
||||
{
|
||||
$match: {
|
||||
category: { $ne: '' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$group: {
|
||||
_id: '$category',
|
||||
promptGroup: { $first: '$$ROOT' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$replaceRoot: { newRoot: '$promptGroup' },
|
||||
},
|
||||
{
|
||||
$sample: { size: +filter.limit + +filter.skip },
|
||||
},
|
||||
{
|
||||
$skip: +filter.skip,
|
||||
},
|
||||
{
|
||||
$limit: +filter.limit,
|
||||
},
|
||||
]);
|
||||
return { prompts: result };
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroupsWithPrompts: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroup: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
|
||||
*
|
||||
* @param {Object} options - The options for deleting the prompt.
|
||||
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
|
||||
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
|
||||
* @param {ObjectId|string} options.author - The ID of the prompt's author.
|
||||
* @param {string} options.role - The role of the prompt's author.
|
||||
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
|
||||
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
|
||||
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
|
||||
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
|
||||
*/
|
||||
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||
const query = { _id: promptId, groupId, author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const { deletedCount } = await Prompt.deleteOne(query);
|
||||
if (deletedCount === 0) {
|
||||
throw new Error('Failed to delete the prompt');
|
||||
}
|
||||
|
||||
const remainingPrompts = await Prompt.find({ groupId })
|
||||
.select('_id')
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
|
||||
return {
|
||||
prompt: 'Prompt deleted successfully',
|
||||
promptGroup: {
|
||||
message: 'Prompt group deleted successfully',
|
||||
id: groupId,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const promptGroup = await PromptGroup.findById(groupId).lean();
|
||||
if (promptGroup.productionId.toString() === promptId.toString()) {
|
||||
await PromptGroup.updateOne(
|
||||
{ _id: groupId },
|
||||
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||
);
|
||||
}
|
||||
|
||||
return { prompt: 'Prompt deleted successfully' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Update prompt group
|
||||
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
|
||||
* @param {Partial<MongoPromptGroup>} data - Data to update
|
||||
* @returns {Promise<TUpdatePromptGroupResponse>}
|
||||
*/
|
||||
updatePromptGroup: async (filter, data) => {
|
||||
try {
|
||||
const updateOps = {};
|
||||
if (data.removeProjectIds) {
|
||||
for (const projectId of data.removeProjectIds) {
|
||||
await removeGroupIdsFromProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
|
||||
delete data.removeProjectIds;
|
||||
}
|
||||
|
||||
if (data.projectIds) {
|
||||
for (const projectId of data.projectIds) {
|
||||
await addGroupIdsToProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
|
||||
delete data.projectIds;
|
||||
}
|
||||
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
||||
if (!updatedDoc) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
return updatedDoc;
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt group', error);
|
||||
return { message: 'Error updating prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Function to make a prompt production based on its ID.
|
||||
* @param {String} promptId - The ID of the prompt to make production.
|
||||
* @returns {Object} The result of the production operation.
|
||||
*/
|
||||
makePromptProduction: async (promptId) => {
|
||||
try {
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
|
||||
await PromptGroup.findByIdAndUpdate(
|
||||
prompt.groupId,
|
||||
{ productionId: prompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.exec();
|
||||
|
||||
return {
|
||||
message: 'Prompt production made successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error making prompt production', error);
|
||||
return { message: 'Error making prompt production' };
|
||||
}
|
||||
},
|
||||
updatePromptLabels: async (_id, labels) => {
|
||||
try {
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
return { message: 'Prompt labels updated successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt labels', error);
|
||||
return { message: 'Error updating prompt labels' };
|
||||
}
|
||||
},
|
||||
deletePromptGroup: async (_id) => {
|
||||
try {
|
||||
const response = await PromptGroup.deleteOne({ _id });
|
||||
|
||||
if (response.deletedCount === 0) {
|
||||
return { promptGroup: 'Prompt group not found' };
|
||||
}
|
||||
|
||||
await Prompt.deleteMany({ groupId: new ObjectId(_id) });
|
||||
await removeGroupFromAllProjects(_id);
|
||||
return { promptGroup: 'Prompt group deleted successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompt group', error);
|
||||
return { message: 'Error deleting prompt group' };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
86
api/models/Role.js
Normal file
86
api/models/Role.js
Normal file
@@ -0,0 +1,86 @@
|
||||
const { SystemRoles, CacheKeys, roleDefaults } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const Role = require('~/models/schema/roleSchema');
|
||||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
}
|
||||
let query = Role.findOne({ name: roleName });
|
||||
if (fieldsToSelect) {
|
||||
query = query.select(fieldsToSelect);
|
||||
}
|
||||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = roleDefaults[roleName];
|
||||
role = await new Role(role).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to retrieve or create role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Update role values by name.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to update.
|
||||
* @param {Partial<TRole>} updates - The fields to update.
|
||||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
)
|
||||
.select('-__v')
|
||||
.lean()
|
||||
.exec();
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to update role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize default roles in the system.
|
||||
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||
|
||||
for (const roleName of defaultRoles) {
|
||||
let role = await Role.findOne({ name: roleName }).select('name').lean();
|
||||
if (!role) {
|
||||
role = new Role(roleDefaults[roleName]);
|
||||
await role.save();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getRoleByName,
|
||||
initializeRoles,
|
||||
updateRoleByName,
|
||||
};
|
||||
117
api/models/Share.js
Normal file
117
api/models/Share.js
Normal file
@@ -0,0 +1,117 @@
|
||||
const crypto = require('crypto');
|
||||
const { getMessages } = require('./Message');
|
||||
const SharedLink = require('./schema/shareSchema');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
module.exports = {
|
||||
SharedLink,
|
||||
getSharedMessages: async (shareId) => {
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId })
|
||||
.populate({
|
||||
path: 'messages',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
|
||||
if (!share || !share.conversationId || !share.isPublic) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return share;
|
||||
} catch (error) {
|
||||
logger.error('[getShare] Error getting share link', error);
|
||||
throw new Error('Error getting share link');
|
||||
}
|
||||
},
|
||||
|
||||
getSharedLinks: async (user, pageNumber = 1, pageSize = 25, isPublic = true) => {
|
||||
const query = { user, isPublic };
|
||||
try {
|
||||
const totalConvos = (await SharedLink.countDocuments(query)) || 1;
|
||||
const totalPages = Math.ceil(totalConvos / pageSize);
|
||||
const shares = await SharedLink.find(query)
|
||||
.sort({ updatedAt: -1 })
|
||||
.skip((pageNumber - 1) * pageSize)
|
||||
.limit(pageSize)
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
|
||||
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
|
||||
} catch (error) {
|
||||
logger.error('[getShareByPage] Error getting shares', error);
|
||||
throw new Error('Error getting shares');
|
||||
}
|
||||
},
|
||||
|
||||
createSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (share) {
|
||||
return share;
|
||||
}
|
||||
|
||||
const shareId = crypto.randomUUID();
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, shareId, messages, user };
|
||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[createSharedLink] Error creating shared link', error);
|
||||
throw new Error('Error creating shared link');
|
||||
}
|
||||
},
|
||||
|
||||
updateSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
}
|
||||
|
||||
// update messages to the latest
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, messages, user };
|
||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[updateSharedLink] Error updating shared link', error);
|
||||
throw new Error('Error updating shared link');
|
||||
}
|
||||
},
|
||||
|
||||
deleteSharedLink: async (user, { shareId }) => {
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId, user });
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
}
|
||||
return await SharedLink.findOneAndDelete({ shareId, user });
|
||||
} catch (error) {
|
||||
logger.error('[deleteSharedLink] Error deleting shared link', error);
|
||||
throw new Error('Error deleting shared link');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Deletes all shared links for a specific user.
|
||||
* @param {string} user - The user ID.
|
||||
* @returns {Promise<{ message: string, deletedCount?: number }>} A result object indicating success or error message.
|
||||
*/
|
||||
deleteAllSharedLinks: async (user) => {
|
||||
try {
|
||||
const result = await SharedLink.deleteMany({ user });
|
||||
return {
|
||||
message: 'All shared links have been deleted successfully',
|
||||
deletedCount: result.deletedCount,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
|
||||
throw new Error('Error deleting shared links');
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -1,61 +1,5 @@
|
||||
const mongoose = require('mongoose');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const signPayload = require('../server/services/signPayload');
|
||||
const userSchema = require('./schema/userSchema.js');
|
||||
const { SESSION_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||
|
||||
userSchema.methods.toJSON = function () {
|
||||
return {
|
||||
id: this._id,
|
||||
provider: this.provider,
|
||||
email: this.email,
|
||||
name: this.name,
|
||||
username: this.username,
|
||||
avatar: this.avatar,
|
||||
role: this.role,
|
||||
emailVerified: this.emailVerified,
|
||||
plugins: this.plugins,
|
||||
createdAt: this.createdAt,
|
||||
updatedAt: this.updatedAt,
|
||||
};
|
||||
};
|
||||
|
||||
userSchema.methods.generateToken = async function () {
|
||||
return await signPayload({
|
||||
payload: {
|
||||
id: this._id,
|
||||
username: this.username,
|
||||
provider: this.provider,
|
||||
email: this.email,
|
||||
},
|
||||
secret: process.env.JWT_SECRET,
|
||||
expirationTime: expires / 1000,
|
||||
});
|
||||
};
|
||||
|
||||
userSchema.methods.comparePassword = function (candidatePassword, callback) {
|
||||
bcrypt.compare(candidatePassword, this.password, (err, isMatch) => {
|
||||
if (err) {
|
||||
return callback(err);
|
||||
}
|
||||
callback(null, isMatch);
|
||||
});
|
||||
};
|
||||
|
||||
module.exports.hashPassword = async (password) => {
|
||||
const hashedPassword = await new Promise((resolve, reject) => {
|
||||
bcrypt.hash(password, 10, function (err, hash) {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve(hash);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return hashedPassword;
|
||||
};
|
||||
const userSchema = require('~/models/schema/userSchema');
|
||||
|
||||
const User = mongoose.model('User', userSchema);
|
||||
|
||||
|
||||
@@ -6,9 +6,18 @@ const {
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const {
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
updateUser,
|
||||
createUser,
|
||||
countUsers,
|
||||
findUser,
|
||||
} = require('./userMethods');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { hashPassword, getUser, updateUser } = require('./userMethods');
|
||||
const {
|
||||
findFileById,
|
||||
createFile,
|
||||
@@ -29,9 +38,14 @@ module.exports = {
|
||||
Session,
|
||||
Balance,
|
||||
|
||||
hashPassword,
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
countUsers,
|
||||
createUser,
|
||||
updateUser,
|
||||
getUser,
|
||||
findUser,
|
||||
|
||||
getMessages,
|
||||
saveMessage,
|
||||
|
||||
@@ -155,7 +155,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
|
||||
function (results, value, key) {
|
||||
return { ...results, [key]: 1 };
|
||||
},
|
||||
{ _id: 1 },
|
||||
{ _id: 1, __v: 1 },
|
||||
),
|
||||
).lean();
|
||||
|
||||
|
||||
19
api/models/schema/categories.js
Normal file
19
api/models/schema/categories.js
Normal file
@@ -0,0 +1,19 @@
|
||||
const mongoose = require('mongoose');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
const categoriesSchema = new Schema({
|
||||
label: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
value: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
});
|
||||
|
||||
const categories = mongoose.model('categories', categoriesSchema);
|
||||
|
||||
module.exports = { Categories: categories };
|
||||
31
api/models/schema/conversationTagSchema.js
Normal file
31
api/models/schema/conversationTagSchema.js
Normal file
@@ -0,0 +1,31 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const conversationTagSchema = mongoose.Schema(
|
||||
{
|
||||
tag: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
user: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
description: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
count: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
position: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
conversationTagSchema.index({ tag: 1, user: 1 }, { unique: true });
|
||||
|
||||
module.exports = mongoose.model('ConversationTag', conversationTagSchema);
|
||||
@@ -42,6 +42,11 @@ const convoSchema = mongoose.Schema(
|
||||
invocationId: {
|
||||
type: Number,
|
||||
},
|
||||
tags: {
|
||||
type: [String],
|
||||
default: [],
|
||||
meiliIndex: true,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
@@ -103,6 +103,10 @@ const conversationPreset = {
|
||||
spec: {
|
||||
type: String,
|
||||
},
|
||||
tags: {
|
||||
type: [String],
|
||||
default: [],
|
||||
},
|
||||
tools: { type: [{ type: String }], default: undefined },
|
||||
maxContextTokens: {
|
||||
type: Number,
|
||||
|
||||
@@ -3,9 +3,9 @@ const mongoose = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoFile
|
||||
* @property {mongoose.Schema.Types.ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {number} [__v] - MongoDB Version Key
|
||||
* @property {mongoose.Schema.Types.ObjectId} user - User ID
|
||||
* @property {ObjectId} user - User ID
|
||||
* @property {string} [conversationId] - Optional conversation ID
|
||||
* @property {string} file_id - File identifier
|
||||
* @property {string} [temp_file_id] - Temporary File identifier
|
||||
@@ -14,17 +14,19 @@ const mongoose = require('mongoose');
|
||||
* @property {string} filepath - Location of the file
|
||||
* @property {'file'} object - Type of object, always 'file'
|
||||
* @property {string} type - Type of file
|
||||
* @property {number} usage - Number of uses of the file
|
||||
* @property {number} [usage=0] - Number of uses of the file
|
||||
* @property {string} [context] - Context of the file origin
|
||||
* @property {boolean} [embedded] - Whether or not the file is embedded in vector db
|
||||
* @property {boolean} [embedded=false] - Whether or not the file is embedded in vector db
|
||||
* @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting)
|
||||
* @property {string} [source] - The source of the file
|
||||
* @property {string} [source] - The source of the file (e.g., from FileSources)
|
||||
* @property {number} [width] - Optional width of the file
|
||||
* @property {number} [height] - Optional height of the file
|
||||
* @property {Date} [expiresAt] - Optional height of the file
|
||||
* @property {Date} [expiresAt] - Optional expiration date of the file
|
||||
* @property {Date} [createdAt] - Date when the file was created
|
||||
* @property {Date} [updatedAt] - Date when the file was updated
|
||||
*/
|
||||
|
||||
/** @type {MongooseSchema<MongoFile>} */
|
||||
const fileSchema = mongoose.Schema(
|
||||
{
|
||||
user: {
|
||||
@@ -91,7 +93,7 @@ const fileSchema = mongoose.Schema(
|
||||
height: Number,
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
expires: 3600,
|
||||
expires: 3600, // 1 hour in seconds
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -11,6 +11,7 @@ const messageSchema = mongoose.Schema(
|
||||
},
|
||||
conversationId: {
|
||||
type: String,
|
||||
index: true,
|
||||
required: true,
|
||||
meiliIndex: true,
|
||||
},
|
||||
@@ -128,6 +129,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
|
||||
}
|
||||
|
||||
messageSchema.index({ createdAt: 1 });
|
||||
messageSchema.index({ messageId: 1, user: 1 }, { unique: true });
|
||||
|
||||
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
|
||||
|
||||
|
||||
30
api/models/schema/projectSchema.js
Normal file
30
api/models/schema/projectSchema.js
Normal file
@@ -0,0 +1,30 @@
|
||||
const { Schema } = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoProject
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the project
|
||||
* @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project
|
||||
* @property {Date} [createdAt] - Date when the project was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const projectSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
promptGroupIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'PromptGroup',
|
||||
default: [],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = projectSchema;
|
||||
118
api/models/schema/promptSchema.js
Normal file
118
api/models/schema/promptSchema.js
Normal file
@@ -0,0 +1,118 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoPromptGroup
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the prompt group
|
||||
* @property {ObjectId} author - The author of the prompt group
|
||||
* @property {ObjectId} [projectId=null] - The project ID of the prompt group
|
||||
* @property {ObjectId} [productionId=null] - The project ID of the prompt group
|
||||
* @property {string} authorName - The name of the author of the prompt group
|
||||
* @property {number} [numberOfGenerations=0] - Number of generations the prompt group has
|
||||
* @property {string} [oneliner=''] - Oneliner description of the prompt group
|
||||
* @property {string} [category=''] - Category of the prompt group
|
||||
* @property {string} [command] - Command for the prompt group
|
||||
* @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const promptGroupSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
numberOfGenerations: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
oneliner: {
|
||||
type: String,
|
||||
default: '',
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
default: '',
|
||||
index: true,
|
||||
},
|
||||
projectIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'Project',
|
||||
index: true,
|
||||
},
|
||||
productionId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'Prompt',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
authorName: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
command: {
|
||||
type: String,
|
||||
index: true,
|
||||
validate: {
|
||||
validator: function (v) {
|
||||
return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v);
|
||||
},
|
||||
message: (props) =>
|
||||
`${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`,
|
||||
},
|
||||
maxlength: [
|
||||
Constants.COMMANDS_MAX_LENGTH,
|
||||
`Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`,
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
|
||||
|
||||
const promptSchema = new Schema(
|
||||
{
|
||||
groupId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'PromptGroup',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
type: {
|
||||
type: String,
|
||||
enum: ['text', 'chat'],
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const Prompt = mongoose.model('Prompt', promptSchema);
|
||||
|
||||
promptSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
promptGroupSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
|
||||
module.exports = { Prompt, PromptGroup };
|
||||
29
api/models/schema/roleSchema.js
Normal file
29
api/models/schema/roleSchema.js
Normal file
@@ -0,0 +1,29 @@
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const roleSchema = new mongoose.Schema({
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
index: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
[Permissions.CREATE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const Role = mongoose.model('Role', roleSchema);
|
||||
|
||||
module.exports = Role;
|
||||
38
api/models/schema/shareSchema.js
Normal file
38
api/models/schema/shareSchema.js
Normal file
@@ -0,0 +1,38 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const shareSchema = mongoose.Schema(
|
||||
{
|
||||
conversationId: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
title: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
user: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }],
|
||||
shareId: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
isPublic: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
isVisible: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
isAnonymous: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
module.exports = mongoose.model('SharedLink', shareSchema);
|
||||
@@ -7,6 +7,9 @@ const tokenSchema = new Schema({
|
||||
required: true,
|
||||
ref: 'user',
|
||||
},
|
||||
email: {
|
||||
type: String,
|
||||
},
|
||||
token: {
|
||||
type: String,
|
||||
required: true,
|
||||
|
||||
@@ -1,5 +1,35 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoSession
|
||||
* @property {string} [refreshToken] - The refresh token
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoUser
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} [name] - The user's name
|
||||
* @property {string} [username] - The user's username, in lowercase
|
||||
* @property {string} email - The user's email address
|
||||
* @property {boolean} emailVerified - Whether the user's email is verified
|
||||
* @property {string} [password] - The user's password, trimmed with 8-128 characters
|
||||
* @property {string} [avatar] - The URL of the user's avatar
|
||||
* @property {string} provider - The provider of the user's account (e.g., 'local', 'google')
|
||||
* @property {string} [role='USER'] - The role of the user
|
||||
* @property {string} [googleId] - Optional Google ID for the user
|
||||
* @property {string} [facebookId] - Optional Facebook ID for the user
|
||||
* @property {string} [openidId] - Optional OpenID ID for the user
|
||||
* @property {string} [ldapId] - Optional LDAP ID for the user
|
||||
* @property {string} [githubId] - Optional GitHub ID for the user
|
||||
* @property {string} [discordId] - Optional Discord ID for the user
|
||||
* @property {Array} [plugins=[]] - List of plugins used by the user
|
||||
* @property {Array.<MongoSession>} [refreshToken] - List of sessions with refresh tokens
|
||||
* @property {Date} [expiresAt] - Optional expiration date of the file
|
||||
* @property {Date} [createdAt] - Date when the user was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
/** @type {MongooseSchema<MongoSession>} */
|
||||
const Session = mongoose.Schema({
|
||||
refreshToken: {
|
||||
type: String,
|
||||
@@ -7,6 +37,7 @@ const Session = mongoose.Schema({
|
||||
},
|
||||
});
|
||||
|
||||
/** @type {MongooseSchema<MongoUser>} */
|
||||
const userSchema = mongoose.Schema(
|
||||
{
|
||||
name: {
|
||||
@@ -64,6 +95,11 @@ const userSchema = mongoose.Schema(
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
ldapId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
githubId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
@@ -81,6 +117,10 @@ const userSchema = mongoose.Schema(
|
||||
refreshToken: {
|
||||
type: [Session],
|
||||
},
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
expires: 604800, // 7 days in seconds
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
@@ -40,7 +40,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
||||
});
|
||||
}
|
||||
|
||||
if (!completionTokens) {
|
||||
if (!completionTokens && isNaN(completionTokens)) {
|
||||
logger.debug('[spendTokens] !completionTokens', { prompt, completion });
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -12,10 +12,13 @@ const tokenValues = {
|
||||
'4k': { prompt: 1.5, completion: 2 },
|
||||
'16k': { prompt: 3, completion: 4 },
|
||||
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-4o': { prompt: 5, completion: 15 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
@@ -52,6 +55,10 @@ const getValueKey = (model, endpoint) => {
|
||||
return 'gpt-3.5-turbo-1106';
|
||||
} else if (modelName.includes('gpt-3.5')) {
|
||||
return '4k';
|
||||
} else if (modelName.includes('gpt-4o-mini')) {
|
||||
return 'gpt-4o-mini';
|
||||
} else if (modelName.includes('gpt-4o')) {
|
||||
return 'gpt-4o';
|
||||
} else if (modelName.includes('gpt-4-vision')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-1106')) {
|
||||
|
||||
@@ -41,6 +41,26 @@ describe('getValueKey', () => {
|
||||
expect(getValueKey('gpt-4-turbo')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-0125')).toBe('gpt-4-1106');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o" for model type of "gpt-4o"', () => {
|
||||
expect(getValueKey('gpt-4o-2024-05-13')).toBe('gpt-4o');
|
||||
expect(getValueKey('openai/gpt-4o')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o-mini" for model type of "gpt-4o-mini"', () => {
|
||||
expect(getValueKey('gpt-4o-mini-2024-07-18')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('openai/gpt-4o-mini')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('gpt-4o-mini-0718')).toBe('gpt-4o-mini');
|
||||
});
|
||||
|
||||
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
|
||||
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-turbo')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-0125')).toBe('claude-3-5-sonnet');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMultiplier', () => {
|
||||
@@ -84,6 +104,30 @@ describe('getMultiplier', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o', () => {
|
||||
const valueKey = getValueKey('gpt-4o-2024-05-13');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4o'].prompt);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4o'].completion,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o-mini', () => {
|
||||
const valueKey = getValueKey('gpt-4o-mini-2024-07-18');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-4o-mini'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4o-mini'].completion,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should derive the valueKey from the model if not provided for new models', () => {
|
||||
expect(
|
||||
getMultiplier({ tokenType: 'prompt', model: 'gpt-3.5-turbo-1106-some-other-info' }),
|
||||
|
||||
@@ -1,28 +1,37 @@
|
||||
const bcrypt = require('bcryptjs');
|
||||
const signPayload = require('~/server/services/signPayload');
|
||||
const User = require('./User');
|
||||
|
||||
const hashPassword = async (password) => {
|
||||
const hashedPassword = await new Promise((resolve, reject) => {
|
||||
bcrypt.hash(password, 10, function (err, hash) {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve(hash);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return hashedPassword;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a user by ID and convert the found user document to a plain object.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to find and return as a plain object.
|
||||
* @returns {Promise<Object>} A plain object representing the user document, or `null` if no user is found.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||
*/
|
||||
const getUser = async function (userId) {
|
||||
return await User.findById(userId).lean();
|
||||
const getUserById = async function (userId, fieldsToSelect = null) {
|
||||
const query = User.findById(userId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for a single user based on partial data and return matching user document as plain object.
|
||||
* @param {Partial<MongoUser>} searchCriteria - The partial data to use for searching the user.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoUser>} A plain object representing the user document, or `null` if no user is found.
|
||||
*/
|
||||
const findUser = async function (searchCriteria, fieldsToSelect = null) {
|
||||
const query = User.findOne(searchCriteria);
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -30,17 +39,127 @@ const getUser = async function (userId) {
|
||||
*
|
||||
* @param {string} userId - The ID of the user to update.
|
||||
* @param {Object} updateData - An object containing the properties to update.
|
||||
* @returns {Promise<Object>} The updated user document as a plain object, or `null` if no user is found.
|
||||
* @returns {Promise<MongoUser>} The updated user document as a plain object, or `null` if no user is found.
|
||||
*/
|
||||
const updateUser = async function (userId, updateData) {
|
||||
return await User.findByIdAndUpdate(userId, updateData, {
|
||||
const updateOperation = {
|
||||
$set: updateData,
|
||||
$unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL
|
||||
};
|
||||
return await User.findByIdAndUpdate(userId, updateOperation, {
|
||||
new: true,
|
||||
runValidators: true,
|
||||
}).lean();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
hashPassword,
|
||||
updateUser,
|
||||
getUser,
|
||||
/**
|
||||
* Creates a new user, optionally with a TTL of 1 week.
|
||||
* @param {MongoUser} data - The user data to be created, must contain user_id.
|
||||
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @returns {Promise<ObjectId>} A promise that resolves to the created user document ID.
|
||||
* @throws {Error} If a user with the same user_id already exists.
|
||||
*/
|
||||
const createUser = async (data, disableTTL = true, returnUser = false) => {
|
||||
const userData = {
|
||||
...data,
|
||||
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
|
||||
};
|
||||
|
||||
if (disableTTL) {
|
||||
delete userData.expiresAt;
|
||||
}
|
||||
|
||||
const user = await User.create(userData);
|
||||
if (returnUser) {
|
||||
return user.toObject();
|
||||
}
|
||||
return user._id;
|
||||
};
|
||||
|
||||
/**
|
||||
* Count the number of user documents in the collection based on the provided filter.
|
||||
*
|
||||
* @param {Object} [filter={}] - The filter to apply when counting the documents.
|
||||
* @returns {Promise<number>} The count of documents that match the filter.
|
||||
*/
|
||||
const countUsers = async function (filter = {}) {
|
||||
return await User.countDocuments(filter);
|
||||
};
|
||||
|
||||
/**
|
||||
* Delete a user by their unique ID.
|
||||
*
|
||||
* @param {string} userId - The ID of the user to delete.
|
||||
* @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents.
|
||||
*/
|
||||
const deleteUserById = async function (userId) {
|
||||
try {
|
||||
const result = await User.deleteOne({ _id: userId });
|
||||
if (result.deletedCount === 0) {
|
||||
return { deletedCount: 0, message: 'No user found with that ID.' };
|
||||
}
|
||||
return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' };
|
||||
} catch (error) {
|
||||
throw new Error('Error deleting user: ' + error.message);
|
||||
}
|
||||
};
|
||||
|
||||
const { SESSION_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||
|
||||
/**
|
||||
* Generates a JWT token for a given user.
|
||||
*
|
||||
* @param {MongoUser} user - ID of the user for whom the token is being generated.
|
||||
* @returns {Promise<string>} A promise that resolves to a JWT token.
|
||||
*/
|
||||
const generateToken = async (user) => {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
return await signPayload({
|
||||
payload: {
|
||||
id: user._id,
|
||||
username: user.username,
|
||||
provider: user.provider,
|
||||
email: user.email,
|
||||
},
|
||||
secret: process.env.JWT_SECRET,
|
||||
expirationTime: expires / 1000,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Compares the provided password with the user's password.
|
||||
*
|
||||
* @param {MongoUser} user - the user to compare password for.
|
||||
* @param {string} candidatePassword - The password to test against the user's password.
|
||||
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
|
||||
*/
|
||||
const comparePassword = async (user, candidatePassword) => {
|
||||
if (!user) {
|
||||
throw new Error('No user provided');
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
bcrypt.compare(candidatePassword, user.password, (err, isMatch) => {
|
||||
if (err) {
|
||||
reject(err);
|
||||
}
|
||||
resolve(isMatch);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
comparePassword,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
getUserById,
|
||||
countUsers,
|
||||
createUser,
|
||||
updateUser,
|
||||
findUser,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "0.7.2",
|
||||
"version": "0.7.4-rc1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -40,8 +40,7 @@
|
||||
"@keyv/redis": "^2.8.1",
|
||||
"@langchain/community": "^0.0.46",
|
||||
"@langchain/google-genai": "^0.0.11",
|
||||
"@langchain/google-vertexai": "^0.0.5",
|
||||
"agenda": "^5.0.0",
|
||||
"@langchain/google-vertexai": "^0.0.17",
|
||||
"axios": "^1.3.4",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"cheerio": "^1.0.0-rc.12",
|
||||
@@ -76,7 +75,7 @@
|
||||
"nodejs-gpt": "^1.37.4",
|
||||
"nodemailer": "^6.9.4",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "4.36.0",
|
||||
"openai": "^4.47.1",
|
||||
"openai-chat-tokens": "^0.2.8",
|
||||
"openid-client": "^5.4.2",
|
||||
"passport": "^0.6.0",
|
||||
@@ -86,14 +85,16 @@
|
||||
"passport-github2": "^0.1.12",
|
||||
"passport-google-oauth20": "^2.0.0",
|
||||
"passport-jwt": "^4.0.1",
|
||||
"passport-ldapauth": "^3.0.1",
|
||||
"passport-local": "^1.0.0",
|
||||
"pino": "^8.12.1",
|
||||
"sharp": "^0.32.6",
|
||||
"tiktoken": "^1.0.10",
|
||||
"tiktoken": "^1.0.15",
|
||||
"traverse": "^0.6.7",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^4.7.1",
|
||||
"ws": "^8.17.0",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
@@ -18,6 +19,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
@@ -34,6 +36,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
@@ -48,11 +52,13 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
|
||||
try {
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
onProgress: throttle(
|
||||
({ text: partialText }) => {
|
||||
saveMessage({
|
||||
/*
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
messageCache.set(responseMessageId, {
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
@@ -62,7 +68,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
unfinished,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}, Time.FIVE_MINUTES);
|
||||
*/
|
||||
|
||||
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||
},
|
||||
3000,
|
||||
{ trailing: false },
|
||||
@@ -74,6 +83,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
@@ -81,7 +91,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[AskController] Request closed');
|
||||
@@ -105,11 +115,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
getReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, messageOptions);
|
||||
@@ -120,7 +130,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
|
||||
response.endpoint = endpointOption.endpoint;
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
@@ -140,10 +150,18 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/controllers/AskController.js - response end' },
|
||||
);
|
||||
}
|
||||
|
||||
await saveMessage(userMessage);
|
||||
if (!client.skipSaveUserMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
context: 'api/server/controllers/AskController.js - don\'t skip saving user message',
|
||||
});
|
||||
}
|
||||
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
|
||||
@@ -1,45 +1,29 @@
|
||||
const crypto = require('crypto');
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { Session, User } = require('~/models');
|
||||
const {
|
||||
registerUser,
|
||||
resetPassword,
|
||||
setAuthTokens,
|
||||
requestPasswordReset,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { Session, getUserById } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
const response = await registerUser(req.body);
|
||||
if (response.status === 200) {
|
||||
const { status, user } = response;
|
||||
let newUser = await User.findOne({ _id: user._id });
|
||||
if (!newUser) {
|
||||
newUser = new User(user);
|
||||
await newUser.save();
|
||||
}
|
||||
const token = await setAuthTokens(user._id, res);
|
||||
res.setHeader('Authorization', `Bearer ${token}`);
|
||||
res.status(status).send({ user });
|
||||
} else {
|
||||
const { status, message } = response;
|
||||
res.status(status).send({ message });
|
||||
}
|
||||
const { status, message } = response;
|
||||
res.status(status).send({ message });
|
||||
} catch (err) {
|
||||
logger.error('[registrationController]', err);
|
||||
return res.status(500).json({ message: err.message });
|
||||
}
|
||||
};
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
return res.status(200).send(req.user);
|
||||
};
|
||||
|
||||
const resetPasswordRequestController = async (req, res) => {
|
||||
try {
|
||||
const resetService = await requestPasswordReset(req.body.email);
|
||||
const resetService = await requestPasswordReset(req);
|
||||
if (resetService instanceof Error) {
|
||||
return res.status(400).json(resetService);
|
||||
} else {
|
||||
@@ -77,7 +61,7 @@ const refreshController = async (req, res) => {
|
||||
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const user = await User.findOne({ _id: payload.id });
|
||||
const user = await getUserById(payload.id, '-password -__v');
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
@@ -86,8 +70,7 @@ const refreshController = async (req, res) => {
|
||||
|
||||
if (process.env.NODE_ENV === 'CI') {
|
||||
const token = await setAuthTokens(userId, res);
|
||||
const userObj = user.toJSON();
|
||||
return res.status(200).send({ token, user: userObj });
|
||||
return res.status(200).send({ token, user });
|
||||
}
|
||||
|
||||
// Hash the refresh token
|
||||
@@ -98,8 +81,7 @@ const refreshController = async (req, res) => {
|
||||
const session = await Session.findOne({ user: userId, refreshTokenHash: hashedToken });
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
const userObj = user.toJSON();
|
||||
res.status(200).send({ token, user: userObj });
|
||||
res.status(200).send({ token, user });
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
res.status(403).send('No session found');
|
||||
@@ -115,7 +97,6 @@ const refreshController = async (req, res) => {
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getUserController,
|
||||
refreshController,
|
||||
registrationController,
|
||||
resetPasswordController,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const EditController = async (req, res, next, initializeClient) => {
|
||||
@@ -27,6 +28,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
});
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
@@ -40,6 +42,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
@@ -48,12 +52,14 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
}
|
||||
};
|
||||
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
generation,
|
||||
onProgress: throttle(
|
||||
({ text: partialText }) => {
|
||||
saveMessage({
|
||||
/*
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
{
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
@@ -64,7 +70,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
isEdited: true,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
} */
|
||||
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||
},
|
||||
3000,
|
||||
{ trailing: false },
|
||||
@@ -73,6 +80,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
@@ -81,7 +89,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[EditController] Request closed');
|
||||
@@ -112,14 +120,14 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
getReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
});
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
@@ -137,7 +145,11 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/controllers/EditController.js - response end' },
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
|
||||
@@ -16,10 +16,28 @@ async function endpointController(req, res) {
|
||||
/** @type {TEndpointsConfig} */
|
||||
const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints };
|
||||
if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) {
|
||||
const { disableBuilder, retrievalModels, capabilities, ..._rest } =
|
||||
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
|
||||
req.app.locals[EModelEndpoint.assistants];
|
||||
|
||||
mergedConfig[EModelEndpoint.assistants] = {
|
||||
...mergedConfig[EModelEndpoint.assistants],
|
||||
version,
|
||||
retrievalModels,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
mergedConfig[EModelEndpoint.azureAssistants] &&
|
||||
req.app.locals?.[EModelEndpoint.azureAssistants]
|
||||
) {
|
||||
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
|
||||
req.app.locals[EModelEndpoint.azureAssistants];
|
||||
|
||||
mergedConfig[EModelEndpoint.azureAssistants] = {
|
||||
...mergedConfig[EModelEndpoint.azureAssistants],
|
||||
version,
|
||||
retrievalModels,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
|
||||
@@ -55,24 +55,26 @@ const getAvailablePluginsController = async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {{ filteredTools: string[] }} */
|
||||
const { filteredTools = [] } = req.app.locals;
|
||||
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = req.app.locals;
|
||||
const pluginManifest = await fs.readFile(req.app.locals.paths.pluginManifest, 'utf8');
|
||||
|
||||
const jsonData = JSON.parse(pluginManifest);
|
||||
/** @type {TPlugin[]} */
|
||||
|
||||
const uniquePlugins = filterUniquePlugins(jsonData);
|
||||
const authenticatedPlugins = uniquePlugins.map((plugin) => {
|
||||
if (isPluginAuthenticated(plugin)) {
|
||||
return { ...plugin, authenticated: true };
|
||||
} else {
|
||||
return plugin;
|
||||
}
|
||||
});
|
||||
let authenticatedPlugins = [];
|
||||
for (const plugin of uniquePlugins) {
|
||||
authenticatedPlugins.push(
|
||||
isPluginAuthenticated(plugin) ? { ...plugin, authenticated: true } : plugin,
|
||||
);
|
||||
}
|
||||
|
||||
let plugins = await addOpenAPISpecs(authenticatedPlugins);
|
||||
plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey));
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
plugins = plugins.filter((plugin) => includedTools.includes(plugin.pluginKey));
|
||||
} else {
|
||||
plugins = plugins.filter((plugin) => !filteredTools.includes(plugin.pluginKey));
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.PLUGINS, plugins);
|
||||
res.status(200).json(plugins);
|
||||
|
||||
@@ -1,11 +1,37 @@
|
||||
const { updateUserPluginsService } = require('~/server/services/UserService');
|
||||
const {
|
||||
Session,
|
||||
Balance,
|
||||
getFiles,
|
||||
deleteFiles,
|
||||
deleteConvos,
|
||||
deletePresets,
|
||||
deleteMessages,
|
||||
deleteUserById,
|
||||
} = require('~/models');
|
||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
|
||||
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
const { deleteAllSharedLinks } = require('~/models/Share');
|
||||
const { Transaction } = require('~/models/Transaction');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const getUserController = async (req, res) => {
|
||||
res.status(200).send(req.user);
|
||||
};
|
||||
|
||||
const deleteUserFiles = async (req) => {
|
||||
try {
|
||||
const userFiles = await getFiles({ user: req.user.id });
|
||||
await processDeleteRequest({
|
||||
req,
|
||||
files: userFiles,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[deleteUserFiles]', error);
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserPluginsController = async (req, res) => {
|
||||
const { user } = req;
|
||||
const { pluginKey, action, auth, isAssistantTool } = req.body;
|
||||
@@ -49,11 +75,68 @@ const updateUserPluginsController = async (req, res) => {
|
||||
res.status(200).send();
|
||||
} catch (err) {
|
||||
logger.error('[updateUserPluginsController]', err);
|
||||
res.status(500).json({ message: err.message });
|
||||
return res.status(500).json({ message: 'Something went wrong.' });
|
||||
}
|
||||
};
|
||||
|
||||
const deleteUserController = async (req, res) => {
|
||||
const { user } = req;
|
||||
|
||||
try {
|
||||
await deleteMessages({ user: user.id }); // delete user messages
|
||||
await Session.deleteMany({ user: user.id }); // delete user sessions
|
||||
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
||||
await deleteUserKey({ userId: user.id, all: true }); // delete user keys
|
||||
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||
await deletePresets(user.id); // delete user presets
|
||||
/* TODO: Delete Assistant Threads */
|
||||
await deleteConvos(user.id); // delete user convos
|
||||
await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth
|
||||
await deleteUserById(user.id); // delete user
|
||||
await deleteAllSharedLinks(user.id); // delete user shared links
|
||||
await deleteUserFiles(req); // delete user files
|
||||
await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps
|
||||
/* TODO: queue job for cleaning actions and assistants of non-existant users */
|
||||
logger.info(`User deleted account. Email: ${user.email} ID: ${user.id}`);
|
||||
res.status(200).send({ message: 'User deleted' });
|
||||
} catch (err) {
|
||||
logger.error('[deleteUserController]', err);
|
||||
return res.status(500).json({ message: 'Something went wrong.' });
|
||||
}
|
||||
};
|
||||
|
||||
const verifyEmailController = async (req, res) => {
|
||||
try {
|
||||
const verifyEmailService = await verifyEmail(req);
|
||||
if (verifyEmailService instanceof Error) {
|
||||
return res.status(400).json(verifyEmailService);
|
||||
} else {
|
||||
return res.status(200).json(verifyEmailService);
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('[verifyEmailController]', e);
|
||||
return res.status(500).json({ message: 'Something went wrong.' });
|
||||
}
|
||||
};
|
||||
|
||||
const resendVerificationController = async (req, res) => {
|
||||
try {
|
||||
const result = await resendVerificationEmail(req);
|
||||
if (result instanceof Error) {
|
||||
return res.status(400).json(result);
|
||||
} else {
|
||||
return res.status(200).json(result);
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('[verifyEmailController]', e);
|
||||
return res.status(500).json({ message: 'Something went wrong.' });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getUserController,
|
||||
deleteUserController,
|
||||
verifyEmailController,
|
||||
updateUserPluginsController,
|
||||
resendVerificationController,
|
||||
};
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
const { v4 } = require('uuid');
|
||||
const express = require('express');
|
||||
const {
|
||||
Constants,
|
||||
RunStatus,
|
||||
CacheKeys,
|
||||
FileSources,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
ViolationTypes,
|
||||
ImageVisionTool,
|
||||
checkOpenAIStorage,
|
||||
AssistantStreamEvents,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
@@ -21,44 +20,36 @@ const {
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
handleAbortError,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @param {object} req - The request object, containing the request data.
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
const chatV1 = async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
|
||||
const {
|
||||
text,
|
||||
model,
|
||||
endpoint,
|
||||
files = [],
|
||||
promptPrefix,
|
||||
assistant_id,
|
||||
@@ -69,30 +60,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[EModelEndpoint.assistants];
|
||||
|
||||
if (assistantsConfig) {
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
@@ -138,7 +105,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: false,
|
||||
messageId: responseMessageId,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
if (error.message === 'Run cancelled') {
|
||||
@@ -149,25 +116,26 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
logger.debug('[/assistants/chat/] Request aborted on close');
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
req.app.locals?.[EModelEndpoint.azureOpenAI].assistants
|
||||
endpoint === EModelEndpoint.azureAssistants
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
req,
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
return sendResponse(req, res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(res, messageData, defaultErrorMessage);
|
||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
@@ -205,6 +173,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
const runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId: responseMessageId,
|
||||
@@ -253,10 +222,10 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
||||
return sendResponse(res, messageData, 'The Assistant run failed');
|
||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(res, finalEvent);
|
||||
return sendResponse(req, res, finalEvent);
|
||||
};
|
||||
|
||||
try {
|
||||
@@ -311,8 +280,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
});
|
||||
};
|
||||
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai: _openai, client } = await initializeClient({
|
||||
const { openai: _openai, client } = await getOpenAIClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption: req.body.endpointOption,
|
||||
@@ -320,6 +288,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
@@ -370,10 +339,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
|
||||
/** @type {MongoFile[]} */
|
||||
const attachments = await req.body.endpointOption.attachments;
|
||||
if (
|
||||
attachments &&
|
||||
attachments.every((attachment) => attachment.source === FileSources.openai)
|
||||
) {
|
||||
if (attachments && attachments.every((attachment) => checkOpenAIStorage(attachment.source))) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -417,6 +383,9 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
return files;
|
||||
};
|
||||
|
||||
/** @type {Promise<Run>|undefined} */
|
||||
let userMessagePromise;
|
||||
|
||||
const initializeThread = async () => {
|
||||
/** @type {[ undefined | MongoFile[]]}*/
|
||||
const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]);
|
||||
@@ -431,7 +400,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
|
||||
if (processedFiles) {
|
||||
for (const file of processedFiles) {
|
||||
if (file.source !== FileSources.openai) {
|
||||
if (!checkOpenAIStorage(file.source)) {
|
||||
attachedFileIds.delete(file.file_id);
|
||||
const index = file_ids.indexOf(file.file_id);
|
||||
if (index > -1) {
|
||||
@@ -467,16 +436,17 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
previousMessages.push(requestMessage);
|
||||
|
||||
/* asynchronous */
|
||||
saveUserMessage({ ...requestMessage, model });
|
||||
userMessagePromise = saveUserMessage(req, { ...requestMessage, model });
|
||||
|
||||
conversation = {
|
||||
conversationId,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint,
|
||||
promptPrefix: promptPrefix,
|
||||
instructions: instructions,
|
||||
assistant_id,
|
||||
@@ -513,7 +483,8 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
let response;
|
||||
|
||||
const processRun = async (retry = false) => {
|
||||
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
body.model = openai._options.model;
|
||||
openai.attachedFileIds = attachedFileIds;
|
||||
openai.visionPromise = visionPromise;
|
||||
if (retry) {
|
||||
@@ -602,6 +573,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
sendMessage(res, {
|
||||
@@ -614,7 +586,10 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveAssistantMessage({ ...responseMessage, model });
|
||||
if (userMessagePromise) {
|
||||
await userMessagePromise;
|
||||
}
|
||||
await saveAssistantMessage(req, { ...responseMessage, model });
|
||||
|
||||
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
|
||||
addTitle(req, {
|
||||
@@ -654,6 +629,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
} catch (error) {
|
||||
await handleError(error);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = router;
|
||||
module.exports = chatV1;
|
||||
500
api/server/controllers/assistants/chatV2.js
Normal file
500
api/server/controllers/assistants/chatV2.js
Normal file
@@ -0,0 +1,500 @@
|
||||
const { v4 } = require('uuid');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
RunStatus,
|
||||
CacheKeys,
|
||||
ContentTypes,
|
||||
ToolCallTypes,
|
||||
EModelEndpoint,
|
||||
retrievalMimeTypes,
|
||||
AssistantStreamEvents,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initThread,
|
||||
recordUsage,
|
||||
saveUserMessage,
|
||||
addThreadMetadata,
|
||||
saveAssistantMessage,
|
||||
} = require('~/server/services/Threads');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {Express.Request} req - The request object, containing the request data.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
const chatV2 = async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
|
||||
/** @type {{files: MongoFile[]}} */
|
||||
const {
|
||||
text,
|
||||
model,
|
||||
endpoint,
|
||||
files = [],
|
||||
promptPrefix,
|
||||
assistant_id,
|
||||
instructions,
|
||||
thread_id: _thread_id,
|
||||
messageId: _messageId,
|
||||
conversationId: convoId,
|
||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
let thread_id = _thread_id;
|
||||
/** @type {string|undefined} - the current run id */
|
||||
let run_id;
|
||||
/** @type {string|undefined} - the parent messageId */
|
||||
let parentMessageId = _parentId;
|
||||
/** @type {TMessage[]} */
|
||||
let previousMessages = [];
|
||||
/** @type {import('librechat-data-provider').TConversation | null} */
|
||||
let conversation = null;
|
||||
/** @type {string[]} */
|
||||
let file_ids = [];
|
||||
/** @type {Set<string>} */
|
||||
let attachedFileIds = new Set();
|
||||
/** @type {TMessage | null} */
|
||||
let requestMessage = null;
|
||||
|
||||
const userMessageId = v4();
|
||||
const responseMessageId = v4();
|
||||
|
||||
/** @type {string} - The conversation UUID - created if undefined */
|
||||
const conversationId = convoId ?? v4();
|
||||
|
||||
const cache = getLogStores(CacheKeys.ABORT_KEYS);
|
||||
const cacheKey = `${req.user.id}:${conversationId}`;
|
||||
|
||||
/** @type {Run | undefined} - The completed run, undefined if incomplete */
|
||||
let completedRun;
|
||||
|
||||
const getContext = () => ({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
cacheKey,
|
||||
thread_id,
|
||||
completedRun,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
responseMessageId,
|
||||
});
|
||||
|
||||
const handleError = createErrorHandler({ req, res, getContext });
|
||||
|
||||
try {
|
||||
res.on('close', async () => {
|
||||
if (!completedRun) {
|
||||
await handleError(new Error('Request closed'));
|
||||
}
|
||||
});
|
||||
|
||||
if (convoId && !_thread_id) {
|
||||
completedRun = true;
|
||||
throw new Error('Missing thread_id for existing conversation');
|
||||
}
|
||||
|
||||
if (!assistant_id) {
|
||||
completedRun = true;
|
||||
throw new Error('Missing assistant_id');
|
||||
}
|
||||
|
||||
const checkBalanceBeforeRun = async () => {
|
||||
if (!isEnabled(process.env.CHECK_BALANCE)) {
|
||||
return;
|
||||
}
|
||||
const transactions =
|
||||
(await getTransactions({
|
||||
user: req.user.id,
|
||||
context: 'message',
|
||||
conversationId,
|
||||
})) ?? [];
|
||||
|
||||
const totalPreviousTokens = Math.abs(
|
||||
transactions.reduce((acc, curr) => acc + curr.rawAmount, 0),
|
||||
);
|
||||
|
||||
// TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions
|
||||
const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0;
|
||||
// 5 is added for labels
|
||||
let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5;
|
||||
promptTokens += totalPreviousTokens + promptBuffer;
|
||||
// Count tokens up to the current context window
|
||||
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
|
||||
|
||||
await checkBalance({
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const { openai: _openai, client } = await getOpenAIClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption: req.body.endpointOption,
|
||||
initAppClient: true,
|
||||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
}
|
||||
|
||||
let userMessage = {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text,
|
||||
},
|
||||
],
|
||||
metadata: {
|
||||
messageId: userMessageId,
|
||||
},
|
||||
};
|
||||
|
||||
/** @type {CreateRunBody | undefined} */
|
||||
const body = {
|
||||
assistant_id,
|
||||
model,
|
||||
};
|
||||
|
||||
if (promptPrefix) {
|
||||
body.additional_instructions = promptPrefix;
|
||||
}
|
||||
|
||||
if (instructions) {
|
||||
body.instructions = instructions;
|
||||
}
|
||||
|
||||
const getRequestFileIds = async () => {
|
||||
let thread_file_ids = [];
|
||||
if (convoId) {
|
||||
const convo = await getConvo(req.user.id, convoId);
|
||||
if (convo && convo.file_ids) {
|
||||
thread_file_ids = convo.file_ids;
|
||||
}
|
||||
}
|
||||
|
||||
if (files.length || thread_file_ids.length) {
|
||||
attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
|
||||
|
||||
let attachmentIndex = 0;
|
||||
for (const file of files) {
|
||||
file_ids.push(file.file_id);
|
||||
if (file.type.startsWith('image')) {
|
||||
userMessage.content.push({
|
||||
type: ContentTypes.IMAGE_FILE,
|
||||
[ContentTypes.IMAGE_FILE]: { file_id: file.file_id },
|
||||
});
|
||||
}
|
||||
|
||||
if (!userMessage.attachments) {
|
||||
userMessage.attachments = [];
|
||||
}
|
||||
|
||||
userMessage.attachments.push({
|
||||
file_id: file.file_id,
|
||||
tools: [{ type: ToolCallTypes.CODE_INTERPRETER }],
|
||||
});
|
||||
|
||||
if (file.type.startsWith('image')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const mimeType = file.type;
|
||||
const isSupportedByRetrieval = retrievalMimeTypes.some((regex) => regex.test(mimeType));
|
||||
if (isSupportedByRetrieval) {
|
||||
userMessage.attachments[attachmentIndex].tools.push({
|
||||
type: ToolCallTypes.FILE_SEARCH,
|
||||
});
|
||||
}
|
||||
|
||||
attachmentIndex++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/** @type {Promise<Run>|undefined} */
|
||||
let userMessagePromise;
|
||||
|
||||
const initializeThread = async () => {
|
||||
await getRequestFileIds();
|
||||
|
||||
// TODO: may allow multiple messages to be created beforehand in a future update
|
||||
const initThreadBody = {
|
||||
messages: [userMessage],
|
||||
metadata: {
|
||||
user: req.user.id,
|
||||
conversationId,
|
||||
},
|
||||
};
|
||||
|
||||
const result = await initThread({ openai, body: initThreadBody, thread_id });
|
||||
thread_id = result.thread_id;
|
||||
|
||||
createOnTextProgress({
|
||||
openai,
|
||||
conversationId,
|
||||
userMessageId,
|
||||
messageId: responseMessageId,
|
||||
thread_id,
|
||||
});
|
||||
|
||||
requestMessage = {
|
||||
user: req.user.id,
|
||||
text,
|
||||
messageId: userMessageId,
|
||||
parentMessageId,
|
||||
// TODO: make sure client sends correct format for `files`, use zod
|
||||
files,
|
||||
file_ids,
|
||||
conversationId,
|
||||
isCreatedByUser: true,
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
previousMessages.push(requestMessage);
|
||||
|
||||
/* asynchronous */
|
||||
userMessagePromise = saveUserMessage(req, { ...requestMessage, model });
|
||||
|
||||
conversation = {
|
||||
conversationId,
|
||||
endpoint,
|
||||
promptPrefix: promptPrefix,
|
||||
instructions: instructions,
|
||||
assistant_id,
|
||||
// model,
|
||||
};
|
||||
|
||||
if (file_ids.length) {
|
||||
conversation.file_ids = file_ids;
|
||||
}
|
||||
};
|
||||
|
||||
const promises = [initializeThread(), checkBalanceBeforeRun()];
|
||||
await Promise.all(promises);
|
||||
|
||||
const sendInitialResponse = () => {
|
||||
sendMessage(res, {
|
||||
sync: true,
|
||||
conversationId,
|
||||
// messages: previousMessages,
|
||||
requestMessage,
|
||||
responseMessage: {
|
||||
user: req.user.id,
|
||||
messageId: openai.responseMessage.messageId,
|
||||
parentMessageId: userMessageId,
|
||||
conversationId,
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
/** @type {RunResponse | typeof StreamRunManager | undefined} */
|
||||
let response;
|
||||
|
||||
const processRun = async (retry = false) => {
|
||||
if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
body.model = openai._options.model;
|
||||
openai.attachedFileIds = attachedFileIds;
|
||||
if (retry) {
|
||||
response = await runAssistant({
|
||||
openai,
|
||||
thread_id,
|
||||
run_id,
|
||||
in_progress: openai.in_progress,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
/* NOTE:
|
||||
* By default, a Run will use the model and tools configuration specified in Assistant object,
|
||||
* but you can override most of these when creating the Run for added flexibility:
|
||||
*/
|
||||
const run = await createRun({
|
||||
openai,
|
||||
thread_id,
|
||||
body,
|
||||
});
|
||||
|
||||
run_id = run.id;
|
||||
await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes);
|
||||
sendInitialResponse();
|
||||
|
||||
// todo: retry logic
|
||||
response = await runAssistant({ openai, thread_id, run_id });
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise<void>}} */
|
||||
const handlers = {
|
||||
[AssistantStreamEvents.ThreadRunCreated]: async (event) => {
|
||||
await cache.set(cacheKey, `${thread_id}:${event.data.id}`, ten_minutes);
|
||||
run_id = event.data.id;
|
||||
sendInitialResponse();
|
||||
},
|
||||
};
|
||||
|
||||
/** @type {undefined | TAssistantEndpoint} */
|
||||
const config = req.app.locals[endpoint] ?? {};
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
|
||||
const streamRunManager = new StreamRunManager({
|
||||
req,
|
||||
res,
|
||||
openai,
|
||||
handlers,
|
||||
thread_id,
|
||||
attachedFileIds,
|
||||
parentMessageId: userMessageId,
|
||||
responseMessage: openai.responseMessage,
|
||||
streamRate: allConfig?.streamRate ?? config.streamRate,
|
||||
// streamOptions: {
|
||||
|
||||
// },
|
||||
});
|
||||
|
||||
await streamRunManager.runAssistant({
|
||||
thread_id,
|
||||
body,
|
||||
});
|
||||
|
||||
response = streamRunManager;
|
||||
response.text = streamRunManager.intermediateText;
|
||||
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
complete: true,
|
||||
text: response.text,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
};
|
||||
|
||||
await processRun();
|
||||
logger.debug('[/assistants/chat/] response', {
|
||||
run: response.run,
|
||||
steps: response.steps,
|
||||
});
|
||||
|
||||
if (response.run.status === RunStatus.CANCELLED) {
|
||||
logger.debug('[/assistants/chat/] Run cancelled, handled by `abortRun`');
|
||||
return res.end();
|
||||
}
|
||||
|
||||
if (response.run.status === RunStatus.IN_PROGRESS) {
|
||||
processRun(true);
|
||||
}
|
||||
|
||||
completedRun = response.run;
|
||||
|
||||
/** @type {ResponseMessage} */
|
||||
const responseMessage = {
|
||||
...(response.responseMessage ?? response.finalMessage),
|
||||
text: response.text,
|
||||
parentMessageId: userMessageId,
|
||||
conversationId,
|
||||
user: req.user.id,
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: {
|
||||
parentMessageId,
|
||||
thread_id,
|
||||
},
|
||||
});
|
||||
res.end();
|
||||
|
||||
if (userMessagePromise) {
|
||||
await userMessagePromise;
|
||||
}
|
||||
await saveAssistantMessage(req, { ...responseMessage, model });
|
||||
|
||||
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
responseText: response.text,
|
||||
conversationId,
|
||||
client,
|
||||
});
|
||||
}
|
||||
|
||||
await addThreadMetadata({
|
||||
openai,
|
||||
thread_id,
|
||||
messageId: responseMessage.messageId,
|
||||
messages: response.messages,
|
||||
});
|
||||
|
||||
if (!response.run.usage) {
|
||||
await sleep(3000);
|
||||
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
|
||||
if (completedRun.usage) {
|
||||
await recordUsage({
|
||||
...completedRun.usage,
|
||||
user: req.user.id,
|
||||
model: completedRun.model ?? model,
|
||||
conversationId,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
await recordUsage({
|
||||
...response.run.usage,
|
||||
user: req.user.id,
|
||||
model: response.run.model ?? model,
|
||||
conversationId,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
await handleError(error);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = chatV2;
|
||||
193
api/server/controllers/assistants/errors.js
Normal file
193
api/server/controllers/assistants/errors.js
Normal file
@@ -0,0 +1,193 @@
|
||||
// errorHandler.js
|
||||
const { sendResponse } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
* @property {OpenAIClient} openai - The OpenAI client
|
||||
* @property {string} thread_id - The thread ID
|
||||
* @property {string} run_id - The run ID
|
||||
* @property {boolean} completedRun - Whether the run has completed
|
||||
* @property {string} assistant_id - The assistant ID
|
||||
* @property {string} conversationId - The conversation ID
|
||||
* @property {string} parentMessageId - The parent message ID
|
||||
* @property {string} responseMessageId - The response message ID
|
||||
* @property {string} endpoint - The endpoint being used
|
||||
* @property {string} cacheKey - The cache key for the current request
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerDependencies
|
||||
* @property {Express.Request} req - The Express request object
|
||||
* @property {Express.Response} res - The Express response object
|
||||
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
|
||||
* @property {string} [originPath] - The origin path for the error handler
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates an error handler function with the given dependencies
|
||||
* @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler
|
||||
* @returns {(error: Error) => Promise<void>} The error handler function
|
||||
*/
|
||||
const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => {
|
||||
const cache = getLogStores(CacheKeys.ABORT_KEYS);
|
||||
|
||||
/**
|
||||
* Handles errors that occur during the chat process
|
||||
* @param {Error} error - The error that occurred
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
return async (error) => {
|
||||
const {
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
cacheKey,
|
||||
thread_id,
|
||||
completedRun,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
responseMessageId,
|
||||
} = getContext();
|
||||
|
||||
const defaultErrorMessage =
|
||||
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
|
||||
const messageData = {
|
||||
thread_id,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender: 'System',
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: false,
|
||||
messageId: responseMessageId,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
if (error.message === 'Run cancelled') {
|
||||
return res.end();
|
||||
} else if (error.message === 'Request closed' && completedRun) {
|
||||
return;
|
||||
} else if (error.message === 'Request closed') {
|
||||
logger.debug(`[${originPath}] Request aborted on close`);
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === 'azureAssistants'
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
req,
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(req, res, messageData, error.message);
|
||||
} else {
|
||||
logger.error(`[${originPath}]`, error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
|
||||
try {
|
||||
const status = await cache.get(cacheKey);
|
||||
if (status === 'cancelled') {
|
||||
logger.debug(`[${originPath}] Run already cancelled`);
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
user: req.user.id,
|
||||
conversationId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error fetching or processing run`, error);
|
||||
}
|
||||
|
||||
let finalEvent;
|
||||
try {
|
||||
const runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId: responseMessageId,
|
||||
});
|
||||
|
||||
const errorContentPart = {
|
||||
text: {
|
||||
value:
|
||||
error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||
},
|
||||
type: ContentTypes.ERROR,
|
||||
};
|
||||
|
||||
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
|
||||
runMessages[runMessages.length - 1].content = [errorContentPart];
|
||||
} else {
|
||||
const contentParts = runMessages[runMessages.length - 1].content;
|
||||
for (let i = 0; i < contentParts.length; i++) {
|
||||
const currentPart = contentParts[i];
|
||||
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
|
||||
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
|
||||
if (
|
||||
toolCall &&
|
||||
toolCall?.function &&
|
||||
!(toolCall?.function?.output || toolCall?.function?.output?.length)
|
||||
) {
|
||||
contentParts[i] = {
|
||||
...currentPart,
|
||||
[ContentTypes.TOOL_CALL]: {
|
||||
...toolCall,
|
||||
function: {
|
||||
...toolCall.function,
|
||||
output: 'error processing tool',
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
runMessages[runMessages.length - 1].content.push(errorContentPart);
|
||||
}
|
||||
|
||||
finalEvent = {
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
runMessages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error finalizing error process`, error);
|
||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(req, res, finalEvent);
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = { createErrorHandler };
|
||||
270
api/server/controllers/assistants/helpers.js
Normal file
270
api/server/controllers/assistants/helpers.js
Normal file
@@ -0,0 +1,270 @@
|
||||
const {
|
||||
CacheKeys,
|
||||
SystemRoles,
|
||||
EModelEndpoint,
|
||||
defaultOrderQuery,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
} = require('~/server/services/Endpoints/azureAssistants');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* @param {Express.Request} req
|
||||
* @param {string} [endpoint]
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
const getCurrentVersion = async (req, endpoint) => {
|
||||
const index = req.baseUrl.lastIndexOf('/v');
|
||||
let version = index !== -1 ? req.baseUrl.substring(index + 1, index + 3) : null;
|
||||
if (!version && req.body.version) {
|
||||
version = `v${req.body.version}`;
|
||||
}
|
||||
if (!version && endpoint) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
version = `v${
|
||||
cachedEndpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint]
|
||||
}`;
|
||||
}
|
||||
if (!version?.startsWith('v') && version.length !== 2) {
|
||||
throw new Error(`[${req.baseUrl}] Invalid version: ${version}`);
|
||||
}
|
||||
return version;
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants based on provided query parameters.
|
||||
*
|
||||
* Initializes the client with the current request and response objects and lists assistants
|
||||
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
||||
*
|
||||
* @deprecated
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const _listAssistants = async ({ req, res, version, query }) => {
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
return openai.beta.assistants.list(query);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches all assistants based on provided query params, until `has_more` is `false`.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAllAssistants = async ({ req, res, version, query }) => {
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
const allAssistants = [];
|
||||
|
||||
let first_id;
|
||||
let last_id;
|
||||
let afterToken = query.after;
|
||||
let hasMore = true;
|
||||
|
||||
while (hasMore) {
|
||||
const response = await openai.beta.assistants.list({
|
||||
...query,
|
||||
after: afterToken,
|
||||
});
|
||||
|
||||
const { body } = response;
|
||||
|
||||
allAssistants.push(...body.data);
|
||||
hasMore = body.has_more;
|
||||
|
||||
if (!first_id) {
|
||||
first_id = body.first_id;
|
||||
}
|
||||
|
||||
if (hasMore) {
|
||||
afterToken = body.last_id;
|
||||
} else {
|
||||
last_id = body.last_id;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
data: allAssistants,
|
||||
body: {
|
||||
data: allAssistants,
|
||||
has_more: false,
|
||||
first_id,
|
||||
last_id,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants for Azure configured groups.
|
||||
*
|
||||
* Iterates through Azure configured assistant groups, initializes the client with the current request and response objects,
|
||||
* lists assistants based on the provided query parameters, and merges their data alongside the model information into a single array.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client and manipulating the request body.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {TAzureConfig} params.azureConfig - The Azure configuration object containing assistantGroups and groupMap.
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<AssistantListResponse>} A promise that resolves to an array of assistant data merged with their respective model information.
|
||||
*/
|
||||
const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, query }) => {
|
||||
/** @type {Array<[string, TAzureModelConfig]>} */
|
||||
const groupModelTuples = [];
|
||||
const promises = [];
|
||||
/** @type {Array<TAzureGroup>} */
|
||||
const groups = [];
|
||||
|
||||
const { groupMap, assistantGroups } = azureConfig;
|
||||
|
||||
for (const groupName of assistantGroups) {
|
||||
const group = groupMap[groupName];
|
||||
groups.push(group);
|
||||
|
||||
const currentModelTuples = Object.entries(group?.models);
|
||||
groupModelTuples.push(currentModelTuples);
|
||||
|
||||
/* The specified model is only necessary to
|
||||
fetch assistants for the shared instance */
|
||||
req.body.model = currentModelTuples[0][0];
|
||||
promises.push(listAllAssistants({ req, res, version, query }));
|
||||
}
|
||||
|
||||
const resolvedQueries = await Promise.all(promises);
|
||||
const data = resolvedQueries.flatMap((res, i) =>
|
||||
res.data.map((assistant) => {
|
||||
const deploymentName = assistant.model;
|
||||
const currentGroup = groups[i];
|
||||
const currentModelTuples = groupModelTuples[i];
|
||||
const firstModel = currentModelTuples[0][0];
|
||||
|
||||
if (currentGroup.deploymentName === deploymentName) {
|
||||
return { ...assistant, model: firstModel };
|
||||
}
|
||||
|
||||
for (const [model, modelConfig] of currentModelTuples) {
|
||||
if (modelConfig.deploymentName === deploymentName) {
|
||||
return { ...assistant, model };
|
||||
}
|
||||
}
|
||||
|
||||
return { ...assistant, model: firstModel };
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
first_id: data[0]?.id,
|
||||
last_id: data[data.length - 1]?.id,
|
||||
object: 'list',
|
||||
has_more: false,
|
||||
data,
|
||||
};
|
||||
};
|
||||
|
||||
async function getOpenAIClient({ req, res, endpointOption, initAppClient, overrideEndpoint }) {
|
||||
let endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
if (!endpoint) {
|
||||
throw new Error(`[${req.baseUrl}] Endpoint is required`);
|
||||
}
|
||||
|
||||
let result;
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
result = await initializeClient({ req, res, version, endpointOption, initAppClient });
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
result = await initAzureClient({ req, res, version, endpointOption, initAppClient });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
|
||||
* @param {object} params.res - Express Response
|
||||
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
|
||||
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
|
||||
*/
|
||||
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
const {
|
||||
limit = 100,
|
||||
order = 'desc',
|
||||
after,
|
||||
before,
|
||||
endpoint,
|
||||
} = req.query ?? {
|
||||
endpoint: overrideEndpoint,
|
||||
...defaultOrderQuery,
|
||||
};
|
||||
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
const query = { limit, order, after, before };
|
||||
|
||||
/** @type {AssistantListResponse} */
|
||||
let body;
|
||||
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
({ body } = await listAllAssistants({ req, res, version, query }));
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
return body;
|
||||
}
|
||||
|
||||
body.data = filterAssistants({
|
||||
userId: req.user.id,
|
||||
assistants: body.data,
|
||||
assistantsConfig: req.app.locals[endpoint],
|
||||
});
|
||||
return body;
|
||||
};
|
||||
|
||||
/**
|
||||
* Filter assistants based on configuration.
|
||||
*
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {string} params.userId - The user ID to filter private assistants.
|
||||
* @param {Assistant[]} params.assistants - The list of assistants to filter.
|
||||
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
|
||||
* @returns {Assistant[]} - The filtered list of assistants.
|
||||
*/
|
||||
function filterAssistants({ assistants, userId, assistantsConfig }) {
|
||||
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
|
||||
if (privateAssistants) {
|
||||
return assistants.filter((assistant) => userId === assistant.metadata?.author);
|
||||
} else if (supportedIds?.length) {
|
||||
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
return assistants;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIClient,
|
||||
fetchAssistants,
|
||||
getCurrentVersion,
|
||||
};
|
||||
@@ -1,34 +1,12 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const { FileContext, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient,
|
||||
listAssistantsForAzure,
|
||||
listAssistants,
|
||||
} = require('~/server/services/Endpoints/assistants');
|
||||
const { FileContext } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||
const { uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { updateAssistant, getAssistants } = require('~/models/Assistant');
|
||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
const actions = require('./actions');
|
||||
const tools = require('./tools');
|
||||
|
||||
const upload = multer();
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
* Assistant actions route.
|
||||
* @route GET|POST /assistants/actions
|
||||
*/
|
||||
router.use('/actions', actions);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route GET /assistants/tools
|
||||
* @returns {TPlugin[]} 200 - application/json
|
||||
*/
|
||||
router.use('/tools', tools);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
@@ -36,12 +14,11 @@ router.use('/tools', tools);
|
||||
* @param {AssistantCreateParams} req.body - The assistant creation parameters.
|
||||
* @returns {Assistant} 201 - success response - application/json
|
||||
*/
|
||||
router.post('/', async (req, res) => {
|
||||
const createAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const { tools = [], ...assistantData } = req.body;
|
||||
const { tools = [], endpoint, ...assistantData } = req.body;
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
@@ -52,18 +29,30 @@ router.post('/', async (req, res) => {
|
||||
})
|
||||
.filter((tool) => tool);
|
||||
|
||||
let azureModelIdentifier = null;
|
||||
if (openai.locals?.azureOptions) {
|
||||
azureModelIdentifier = assistantData.model;
|
||||
assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
assistantData.metadata = {
|
||||
author: req.user.id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error creating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves an assistant.
|
||||
@@ -71,11 +60,10 @@ router.post('/', async (req, res) => {
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/:id', async (req, res) => {
|
||||
const retrieveAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
|
||||
/* NOTE: not actually being used right now */
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
const assistant_id = req.params.id;
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
res.json(assistant);
|
||||
@@ -83,22 +71,24 @@ router.get('/:id', async (req, res) => {
|
||||
logger.error('[/assistants/:id] Error retrieving assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @route PATCH /assistants/:id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @param {AssistantUpdateParams} req.body - The assistant update parameters.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.patch('/:id', async (req, res) => {
|
||||
const patchAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const updateData = req.body;
|
||||
const { endpoint: _e, ...updateData } = req.body;
|
||||
updateData.tools = (updateData.tools ?? [])
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
@@ -119,90 +109,76 @@ router.patch('/:id', async (req, res) => {
|
||||
logger.error('[/assistants/:id] Error updating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an assistant.
|
||||
* @route DELETE /assistants/:id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:id', async (req, res) => {
|
||||
const deleteAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||
if (deletionStatus?.deleted) {
|
||||
await deleteAssistantActions({ req, assistant_id });
|
||||
}
|
||||
res.json(deletionStatus);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/:id] Error deleting assistant', error);
|
||||
res.status(500).json({ error: 'Error deleting assistant' });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @route GET /assistants
|
||||
* @param {object} req - Express Request
|
||||
* @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting.
|
||||
* @returns {AssistantListResponse} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/', async (req, res) => {
|
||||
const listAssistants = async (req, res) => {
|
||||
try {
|
||||
const { limit = 100, order = 'desc', after, before } = req.query;
|
||||
const query = { limit, order, after, before };
|
||||
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
/** @type {AssistantListResponse} */
|
||||
let body;
|
||||
|
||||
if (azureConfig?.assistants) {
|
||||
body = await listAssistantsForAzure({ req, res, azureConfig, query });
|
||||
} else {
|
||||
({ body } = await listAssistants({ req, res, query }));
|
||||
}
|
||||
|
||||
if (req.app.locals?.[EModelEndpoint.assistants]) {
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals[EModelEndpoint.assistants];
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
if (supportedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
}
|
||||
|
||||
const body = await fetchAssistants({ req, res });
|
||||
res.json(body);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error listing assistants', error);
|
||||
res.status(500).json({ message: 'Error listing assistants' });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a list of the user's assistant documents (metadata saved to database).
|
||||
* @route GET /assistants/documents
|
||||
* @returns {AssistantDocument[]} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/documents', async (req, res) => {
|
||||
const getAssistantDocuments = async (req, res) => {
|
||||
try {
|
||||
res.json(await getAssistants({ user: req.user.id }));
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/documents] Error listing assistant documents', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Uploads and updates an avatar for a specific assistant.
|
||||
* @route POST /avatar/:assistant_id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.assistant_id - The ID of the assistant.
|
||||
* @param {Express.Multer.File} req.file - The avatar image file.
|
||||
* @param {object} req.body - Request body
|
||||
* @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => {
|
||||
const uploadAssistantAvatar = async (req, res) => {
|
||||
try {
|
||||
const { assistant_id } = req.params;
|
||||
if (!assistant_id) {
|
||||
@@ -210,8 +186,8 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
|
||||
}
|
||||
|
||||
let { metadata: _metadata = '{}' } = req.body;
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const image = await uploadImageBuffer({
|
||||
req,
|
||||
@@ -246,7 +222,7 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
|
||||
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
avatar: {
|
||||
@@ -266,6 +242,14 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = router;
|
||||
module.exports = {
|
||||
createAssistant,
|
||||
retrieveAssistant,
|
||||
patchAssistant,
|
||||
deleteAssistant,
|
||||
listAssistants,
|
||||
getAssistantDocuments,
|
||||
uploadAssistantAvatar,
|
||||
};
|
||||
213
api/server/controllers/assistants/v2.js
Normal file
213
api/server/controllers/assistants/v2.js
Normal file
@@ -0,0 +1,213 @@
|
||||
const { ToolCallTypes } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route POST /assistants
|
||||
* @param {AssistantCreateParams} req.body - The assistant creation parameters.
|
||||
* @returns {Assistant} 201 - success response - application/json
|
||||
*/
|
||||
const createAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const { tools = [], endpoint, ...assistantData } = req.body;
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
return tool;
|
||||
}
|
||||
|
||||
return req.app.locals.availableTools[tool];
|
||||
})
|
||||
.filter((tool) => tool);
|
||||
|
||||
let azureModelIdentifier = null;
|
||||
if (openai.locals?.azureOptions) {
|
||||
azureModelIdentifier = assistantData.model;
|
||||
assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
assistantData.metadata = {
|
||||
author: req.user.id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error creating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||
await validateAuthor({ req, openai });
|
||||
const tools = [];
|
||||
|
||||
let hasFileSearch = false;
|
||||
for (const tool of updateData.tools ?? []) {
|
||||
let actualTool = typeof tool === 'string' ? req.app.locals.availableTools[tool] : tool;
|
||||
|
||||
if (!actualTool) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (actualTool.type === ToolCallTypes.FILE_SEARCH) {
|
||||
hasFileSearch = true;
|
||||
}
|
||||
|
||||
if (!actualTool.function) {
|
||||
tools.push(actualTool);
|
||||
continue;
|
||||
}
|
||||
|
||||
const updatedTool = await validateAndUpdateTool({ req, tool: actualTool, assistant_id });
|
||||
if (updatedTool) {
|
||||
tools.push(updatedTool);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFileSearch && !updateData.tool_resources) {
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
updateData.tool_resources = assistant.tool_resources ?? null;
|
||||
}
|
||||
|
||||
if (hasFileSearch && !updateData.tool_resources?.file_search) {
|
||||
updateData.tool_resources = {
|
||||
...(updateData.tool_resources ?? {}),
|
||||
file_search: {
|
||||
vector_store_ids: [],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
updateData.tools = tools;
|
||||
|
||||
if (openai.locals?.azureOptions && updateData.model) {
|
||||
updateData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
return await openai.beta.assistants.update(assistant_id, updateData);
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant with the resource file id.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {string} params.tool_resource
|
||||
* @param {string} params.file_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const addResourceFileId = async ({ req, openai, assistant_id, tool_resource, file_id }) => {
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
const { tool_resources = {} } = assistant;
|
||||
if (tool_resources[tool_resource]) {
|
||||
tool_resources[tool_resource].file_ids.push(file_id);
|
||||
} else {
|
||||
tool_resources[tool_resource] = { file_ids: [file_id] };
|
||||
}
|
||||
|
||||
delete assistant.id;
|
||||
return await updateAssistant({
|
||||
req,
|
||||
openai,
|
||||
assistant_id,
|
||||
updateData: { tools: assistant.tools, tool_resources },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a file ID from an assistant's resource.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {string} [params.tool_resource]
|
||||
* @param {string} params.file_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const deleteResourceFileId = async ({ req, openai, assistant_id, tool_resource, file_id }) => {
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
const { tool_resources = {} } = assistant;
|
||||
|
||||
if (tool_resource && tool_resources[tool_resource]) {
|
||||
const resource = tool_resources[tool_resource];
|
||||
const index = resource.file_ids.indexOf(file_id);
|
||||
if (index !== -1) {
|
||||
resource.file_ids.splice(index, 1);
|
||||
}
|
||||
} else {
|
||||
for (const resourceKey in tool_resources) {
|
||||
const resource = tool_resources[resourceKey];
|
||||
const index = resource.file_ids.indexOf(file_id);
|
||||
if (index !== -1) {
|
||||
resource.file_ids.splice(index, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete assistant.id;
|
||||
return await updateAssistant({
|
||||
req,
|
||||
openai,
|
||||
assistant_id,
|
||||
updateData: { tools: assistant.tools, tool_resources },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @route PATCH /assistants/:id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @param {AssistantUpdateParams} req.body - The assistant update parameters.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
const patchAssistant = async (req, res) => {
|
||||
try {
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
const assistant_id = req.params.id;
|
||||
const { endpoint: _e, ...updateData } = req.body;
|
||||
updateData.tools = updateData.tools ?? [];
|
||||
const updatedAssistant = await updateAssistant({ req, openai, assistant_id, updateData });
|
||||
res.json(updatedAssistant);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/:id] Error updating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
patchAssistant,
|
||||
createAssistant,
|
||||
updateAssistant,
|
||||
addResourceFileId,
|
||||
deleteResourceFileId,
|
||||
};
|
||||
@@ -1,26 +1,22 @@
|
||||
const User = require('~/models/User');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const loginController = async (req, res) => {
|
||||
try {
|
||||
const user = await User.findById(req.user._id);
|
||||
|
||||
// If user doesn't exist, return error
|
||||
if (!user) {
|
||||
// typeof user !== User) { // this doesn't seem to resolve the User type ??
|
||||
if (!req.user) {
|
||||
return res.status(400).json({ message: 'Invalid credentials' });
|
||||
}
|
||||
|
||||
const token = await setAuthTokens(user._id, res);
|
||||
const { password: _, __v, ...user } = req.user;
|
||||
user.id = user._id.toString();
|
||||
|
||||
const token = await setAuthTokens(req.user._id, res);
|
||||
|
||||
return res.status(200).send({ token, user });
|
||||
} catch (err) {
|
||||
logger.error('[loginController]', err);
|
||||
return res.status(500).json({ message: 'Something went wrong' });
|
||||
}
|
||||
|
||||
// Generic error messages are safer
|
||||
return res.status(500).json({ message: 'Something went wrong' });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -6,16 +6,16 @@ const axios = require('axios');
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { jwtLogin, passportLogin } = require('~/strategies');
|
||||
const { connectDb, indexSync } = require('~/lib/db');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { ldapLogin } = require('~/strategies');
|
||||
const { logger } = require('~/config');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const errorController = require('./controllers/ErrorController');
|
||||
const { jwtLogin, passportLogin } = require('~/strategies');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const { connectDb, indexSync } = require('~/lib/db');
|
||||
const AppService = require('./services/AppService');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
|
||||
@@ -60,6 +60,11 @@ const startServer = async () => {
|
||||
passport.use(await jwtLogin());
|
||||
passport.use(passportLogin());
|
||||
|
||||
// LDAP Auth
|
||||
if (process.env.LDAP_URL && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
passport.use(ldapLogin);
|
||||
}
|
||||
|
||||
if (isEnabled(ALLOW_SOCIAL_LOGIN)) {
|
||||
configureSocialLogins(app);
|
||||
}
|
||||
@@ -76,6 +81,7 @@ const startServer = async () => {
|
||||
app.use('/api/convos', routes.convos);
|
||||
app.use('/api/presets', routes.presets);
|
||||
app.use('/api/prompts', routes.prompts);
|
||||
app.use('/api/categories', routes.categories);
|
||||
app.use('/api/tokenizer', routes.tokenizer);
|
||||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
@@ -85,9 +91,12 @@ const startServer = async () => {
|
||||
app.use('/api/assistants', routes.assistants);
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use((req, res) => {
|
||||
res.status(404).sendFile(path.join(app.locals.paths.dist, 'index.html'));
|
||||
res.sendFile(path.join(app.locals.paths.dist, 'index.html'));
|
||||
});
|
||||
|
||||
app.listen(port, host, () => {
|
||||
|
||||
@@ -1,31 +1,39 @@
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { abortRun } = require('./abortRun');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, conversationId, endpoint } = req.body;
|
||||
let { abortKey, endpoint } = req.body;
|
||||
|
||||
if (!abortKey && conversationId) {
|
||||
abortKey = conversationId;
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
return await abortRun(req, res);
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const finalEvent = await abortController.abortCompletion();
|
||||
logger.debug('[abortMessage] Aborted request', { abortKey });
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
|
||||
JSON.stringify({ abortKey }),
|
||||
);
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
@@ -50,12 +58,35 @@ const handleAbort = () => {
|
||||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData) => {
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
const onStart = (userMessage) => {
|
||||
|
||||
abortController.getAbortData = function () {
|
||||
return getAbortData();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
|
||||
if (prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
|
||||
res.on('finish', function () {
|
||||
abortControllers.delete(addedAbortKey);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...endpointOption });
|
||||
|
||||
res.on('finish', function () {
|
||||
@@ -65,7 +96,8 @@ const createAbortController = (req, res, getAbortData) => {
|
||||
|
||||
abortController.abortCompletion = async function () {
|
||||
abortController.abort();
|
||||
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
|
||||
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||
getAbortData();
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = req.user.id;
|
||||
|
||||
@@ -87,12 +119,26 @@ const createAbortController = (req, res, getAbortData) => {
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
const resolved = await userMessagePromise;
|
||||
conversation = resolved?.conversation;
|
||||
}
|
||||
|
||||
if (!conversation) {
|
||||
conversation = await getConvo(req.user.id, conversationId);
|
||||
}
|
||||
|
||||
return {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
@@ -151,7 +197,7 @@ const handleAbortError = async (res, req, error, data) => {
|
||||
}
|
||||
};
|
||||
|
||||
await sendError(res, options, callback);
|
||||
await sendError(req, res, options, callback);
|
||||
};
|
||||
|
||||
if (partialText && partialText.length > 5) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||
const { deleteMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
@@ -10,7 +11,7 @@ const three_minutes = 1000 * 60 * 3;
|
||||
|
||||
async function abortRun(req, res) {
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
const { abortKey } = req.body;
|
||||
const { abortKey, endpoint } = req.body;
|
||||
const [conversationId, latestMessageId] = abortKey.split(':');
|
||||
const conversation = await getConvo(req.user.id, conversationId);
|
||||
|
||||
@@ -66,12 +67,19 @@ async function abortRun(req, res) {
|
||||
logger.error('[abortRun] Error fetching or processing run', error);
|
||||
}
|
||||
|
||||
/* TODO: a reconciling strategy between the existing intermediate message would be more optimal than deleting it */
|
||||
await deleteMessages({
|
||||
user: req.user.id,
|
||||
unfinished: true,
|
||||
conversationId,
|
||||
});
|
||||
runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
latestMessageId,
|
||||
thread_id,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId,
|
||||
});
|
||||
|
||||
const finalEvent = {
|
||||
|
||||
44
api/server/middleware/assistants/validate.js
Normal file
44
api/server/middleware/assistants/validate.js
Normal file
@@ -0,0 +1,44 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { handleAbortError } = require('~/server/middleware/abortMiddleware');
|
||||
|
||||
/**
|
||||
* Checks if the assistant is supported or excluded
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {object} res - Express Response
|
||||
* @param {function} next - Express next middleware function.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAssistant = async (req, res, next) => {
|
||||
const { endpoint, conversationId, assistant_id, messageId } = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'validateAssistant: Assistant not supported' };
|
||||
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
parentMessageId: messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
parentMessageId: messageId,
|
||||
});
|
||||
}
|
||||
|
||||
return next();
|
||||
};
|
||||
|
||||
module.exports = validateAssistant;
|
||||
43
api/server/middleware/assistants/validateAuthor.js
Normal file
43
api/server/middleware/assistants/validateAuthor.js
Normal file
@@ -0,0 +1,43 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
/**
|
||||
* Checks if the assistant is supported or excluded
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {object} params.req.body - The request payload.
|
||||
* @param {string} params.overrideEndpoint - The override endpoint
|
||||
* @param {string} params.overrideAssistantId - The override assistant ID
|
||||
* @param {OpenAIClient} params.openai - OpenAI API Client
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return;
|
||||
}
|
||||
|
||||
const endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const assistant_id =
|
||||
overrideAssistantId ?? req.params.id ?? req.body.assistant_id ?? req.query.assistant_id;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!assistantsConfig.privateAssistants) {
|
||||
return;
|
||||
}
|
||||
|
||||
const assistantDoc = await getAssistant({ assistant_id, user: req.user.id });
|
||||
if (assistantDoc) {
|
||||
return;
|
||||
}
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
if (req.user.id !== assistant?.metadata?.author) {
|
||||
throw new Error(`Assistant ${assistant_id} is not authored by the user.`);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = validateAuthor;
|
||||
@@ -1,5 +1,6 @@
|
||||
const { parseConvo, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
||||
const assistants = require('~/server/services/Endpoints/assistants');
|
||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
@@ -18,6 +19,7 @@ const buildFunction = {
|
||||
[EModelEndpoint.anthropic]: anthropic.buildOptions,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
|
||||
[EModelEndpoint.assistants]: assistants.buildOptions,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
|
||||
};
|
||||
|
||||
async function buildEndpointOption(req, res, next) {
|
||||
@@ -42,6 +44,15 @@ async function buildEndpointOption(req, res, next) {
|
||||
return handleError(res, { text: 'Model spec mismatch' });
|
||||
}
|
||||
|
||||
if (
|
||||
currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins &&
|
||||
currentModelSpec.preset.tools
|
||||
) {
|
||||
return handleError(res, {
|
||||
text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`,
|
||||
});
|
||||
}
|
||||
|
||||
const isValidModelSpec = enforceModelSpec(currentModelSpec, parsedBody);
|
||||
if (!isValidModelSpec) {
|
||||
return handleError(res, { text: 'Model spec mismatch' });
|
||||
|
||||
28
api/server/middleware/canDeleteAccount.js
Normal file
28
api/server/middleware/canDeleteAccount.js
Normal file
@@ -0,0 +1,28 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Checks if the user can delete their account
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} req - Express request object
|
||||
* @param {Object} res - Express response object
|
||||
* @param {Function} next - Next middleware function
|
||||
*
|
||||
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if the user can delete their account
|
||||
*/
|
||||
|
||||
const canDeleteAccount = async (req, res, next = () => {}) => {
|
||||
const { user } = req;
|
||||
const { ALLOW_ACCOUNT_DELETION = true } = process.env;
|
||||
if (user?.role === SystemRoles.ADMIN || isEnabled(ALLOW_ACCOUNT_DELETION)) {
|
||||
return next();
|
||||
} else {
|
||||
logger.error(`[User] [Delete Account] [User cannot delete account] [User: ${user?.id}]`);
|
||||
return res.status(403).send({ message: 'You do not have permission to delete this account' });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = canDeleteAccount;
|
||||
@@ -1,15 +1,13 @@
|
||||
const Keyv = require('keyv');
|
||||
const uap = require('ua-parser-js');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, removePorts } = require('../utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const { isEnabled, removePorts } = require('~/server/utils');
|
||||
const keyvMongo = require('~/cache/keyvMongo');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const User = require('~/models/User');
|
||||
const { findUser } = require('~/models');
|
||||
|
||||
const banCache = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const message = 'Your account has been temporarily banned due to violations of our service.';
|
||||
|
||||
/**
|
||||
@@ -57,7 +55,7 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
let userId = req.user?.id ?? req.user?._id ?? null;
|
||||
|
||||
if (!userId && req?.body?.email) {
|
||||
const user = await User.findOne({ email: req.body.email }, '_id').lean();
|
||||
const user = await findUser({ email: req.body.email }, '_id');
|
||||
userId = user?._id ? user._id.toString() : userId;
|
||||
}
|
||||
|
||||
|
||||
@@ -41,10 +41,14 @@ const denyRequest = async (req, res, errorMessage) => {
|
||||
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
||||
|
||||
if (shouldSaveMessage) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...userMessage, user: req.user.id },
|
||||
{ context: `api/server/middleware/denyRequest.js - ${responseText}` },
|
||||
);
|
||||
}
|
||||
|
||||
return await sendError(res, {
|
||||
return await sendError(req, res, {
|
||||
sender: getResponseSender(req.body),
|
||||
messageId: crypto.randomUUID(),
|
||||
conversationId,
|
||||
|
||||
@@ -24,21 +24,32 @@ const enforceModelSpec = (modelSpec, parsedBody) => {
|
||||
|
||||
/**
|
||||
* Checks if there is a match for the given key and value in the parsed body
|
||||
* or any of its interchangeable keys.
|
||||
* or any of its interchangeable keys, including deep comparison for objects and arrays.
|
||||
* @param {string} key
|
||||
* @param {any} value
|
||||
* @param {TConversation} parsedBody
|
||||
* @param {object} parsedBody
|
||||
* @returns {boolean}
|
||||
*/
|
||||
const checkMatch = (key, value, parsedBody) => {
|
||||
if (parsedBody[key] === value) {
|
||||
const isEqual = (a, b) => {
|
||||
if (Array.isArray(a) && Array.isArray(b)) {
|
||||
return a.length === b.length && a.every((val, index) => isEqual(val, b[index]));
|
||||
} else if (typeof a === 'object' && typeof b === 'object' && a !== null && b !== null) {
|
||||
const keysA = Object.keys(a);
|
||||
const keysB = Object.keys(b);
|
||||
return keysA.length === keysB.length && keysA.every((k) => isEqual(a[k], b[k]));
|
||||
}
|
||||
return a === b;
|
||||
};
|
||||
|
||||
if (isEqual(parsedBody[key], value)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (interchangeableKeys.has(key)) {
|
||||
return interchangeableKeys
|
||||
.get(key)
|
||||
.some((interchangeableKey) => parsedBody[interchangeableKey] === value);
|
||||
.some((interchangeableKey) => isEqual(parsedBody[interchangeableKey], value));
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
47
api/server/middleware/enforceModelSpec.spec.js
Normal file
47
api/server/middleware/enforceModelSpec.spec.js
Normal file
@@ -0,0 +1,47 @@
|
||||
// enforceModelSpec.test.js
|
||||
|
||||
const enforceModelSpec = require('./enforceModelSpec');
|
||||
|
||||
describe('enforceModelSpec function', () => {
|
||||
test('returns true when all model specs match parsed body directly', () => {
|
||||
const modelSpec = { preset: { title: 'Dialog', status: 'Active' } };
|
||||
const parsedBody = { title: 'Dialog', status: 'Active' };
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns true when model specs match via interchangeable keys', () => {
|
||||
const modelSpec = { preset: { chatGptLabel: 'GPT-4' } };
|
||||
const parsedBody = { modelLabel: 'GPT-4' };
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(true);
|
||||
});
|
||||
|
||||
test('returns false if any key value does not match', () => {
|
||||
const modelSpec = { preset: { language: 'English', level: 'Advanced' } };
|
||||
const parsedBody = { language: 'Spanish', level: 'Advanced' };
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(false);
|
||||
});
|
||||
|
||||
test('ignores the \'endpoint\' key in model spec', () => {
|
||||
const modelSpec = { preset: { endpoint: 'ignored', feature: 'Special' } };
|
||||
const parsedBody = { feature: 'Special' };
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(true);
|
||||
});
|
||||
|
||||
test('handles nested objects correctly', () => {
|
||||
const modelSpec = { preset: { details: { time: 'noon', location: 'park' } } };
|
||||
const parsedBody = { details: { time: 'noon', location: 'park' } };
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(true);
|
||||
});
|
||||
|
||||
test('handles arrays within objects', () => {
|
||||
const modelSpec = { preset: { tags: ['urgent', 'important'] } };
|
||||
const parsedBody = { tags: ['urgent', 'important'] };
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(true);
|
||||
});
|
||||
|
||||
test('fails when arrays in objects do not match', () => {
|
||||
const modelSpec = { preset: { tags: ['urgent', 'important'] } };
|
||||
const parsedBody = { tags: ['important', 'urgent'] }; // Different order
|
||||
expect(enforceModelSpec(modelSpec, parsedBody)).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,45 +1,45 @@
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
const checkBan = require('./checkBan');
|
||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||
const uaParser = require('./uaParser');
|
||||
const setHeaders = require('./setHeaders');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const validateEndpoint = require('./validateEndpoint');
|
||||
const concurrentLimiter = require('./concurrentLimiter');
|
||||
const validateMessageReq = require('./validateMessageReq');
|
||||
const buildEndpointOption = require('./buildEndpointOption');
|
||||
const validatePasswordReset = require('./validatePasswordReset');
|
||||
const validateRegistration = require('./validateRegistration');
|
||||
const validateImageRequest = require('./validateImageRequest');
|
||||
const buildEndpointOption = require('./buildEndpointOption');
|
||||
const validateMessageReq = require('./validateMessageReq');
|
||||
const checkDomainAllowed = require('./checkDomainAllowed');
|
||||
const concurrentLimiter = require('./concurrentLimiter');
|
||||
const validateEndpoint = require('./validateEndpoint');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const validateModel = require('./validateModel');
|
||||
const moderateText = require('./moderateText');
|
||||
const setHeaders = require('./setHeaders');
|
||||
const limiters = require('./limiters');
|
||||
const uaParser = require('./uaParser');
|
||||
const checkBan = require('./checkBan');
|
||||
const noIndex = require('./noIndex');
|
||||
const importLimiters = require('./importLimiters');
|
||||
const roles = require('./roles');
|
||||
|
||||
module.exports = {
|
||||
...uploadLimiters,
|
||||
...abortMiddleware,
|
||||
...messageLimiters,
|
||||
...limiters,
|
||||
...roles,
|
||||
noIndex,
|
||||
checkBan,
|
||||
uaParser,
|
||||
setHeaders,
|
||||
loginLimiter,
|
||||
moderateText,
|
||||
validateModel,
|
||||
requireJwtAuth,
|
||||
registerLimiter,
|
||||
requireLdapAuth,
|
||||
requireLocalAuth,
|
||||
canDeleteAccount,
|
||||
validateEndpoint,
|
||||
concurrentLimiter,
|
||||
checkDomainAllowed,
|
||||
validateMessageReq,
|
||||
buildEndpointOption,
|
||||
validateRegistration,
|
||||
validateImageRequest,
|
||||
validateModel,
|
||||
moderateText,
|
||||
noIndex,
|
||||
...importLimiters,
|
||||
checkDomainAllowed,
|
||||
validatePasswordReset,
|
||||
};
|
||||
|
||||
22
api/server/middleware/limiters/index.js
Normal file
22
api/server/middleware/limiters/index.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const createTTSLimiters = require('./ttsLimiters');
|
||||
const createSTTLimiters = require('./sttLimiters');
|
||||
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const importLimiters = require('./importLimiters');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
const verifyEmailLimiter = require('./verifyEmailLimiter');
|
||||
const resetPasswordLimiter = require('./resetPasswordLimiter');
|
||||
|
||||
module.exports = {
|
||||
...uploadLimiters,
|
||||
...importLimiters,
|
||||
...messageLimiters,
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
createTTSLimiters,
|
||||
createSTTLimiters,
|
||||
verifyEmailLimiter,
|
||||
resetPasswordLimiter,
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { logViolation } = require('../../cache');
|
||||
const { removePorts } = require('../utils');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = LOGIN_WINDOW * 60 * 1000;
|
||||
@@ -1,6 +1,6 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { logViolation } = require('../../cache');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const denyRequest = require('~/server/middleware/denyRequest');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
const {
|
||||
MESSAGE_IP_MAX = 40,
|
||||
@@ -1,6 +1,6 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { logViolation } = require('../../cache');
|
||||
const { removePorts } = require('../utils');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
|
||||
const windowMs = REGISTER_WINDOW * 60 * 1000;
|
||||
35
api/server/middleware/limiters/resetPasswordLimiter.js
Normal file
35
api/server/middleware/limiters/resetPasswordLimiter.js
Normal file
@@ -0,0 +1,35 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
const {
|
||||
RESET_PASSWORD_WINDOW = 2,
|
||||
RESET_PASSWORD_MAX = 2,
|
||||
RESET_PASSWORD_VIOLATION_SCORE: score,
|
||||
} = process.env;
|
||||
const windowMs = RESET_PASSWORD_WINDOW * 60 * 1000;
|
||||
const max = RESET_PASSWORD_MAX;
|
||||
const windowInMinutes = windowMs / 60000;
|
||||
const message = `Too many attempts, please try again after ${windowInMinutes} minute(s)`;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.RESET_PASSWORD_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max,
|
||||
windowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return res.status(429).json({ message });
|
||||
};
|
||||
|
||||
const resetPasswordLimiter = rateLimit({
|
||||
windowMs,
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
});
|
||||
|
||||
module.exports = resetPasswordLimiter;
|
||||
68
api/server/middleware/limiters/sttLimiters.js
Normal file
68
api/server/middleware/limiters/sttLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
const sttIpWindowInMinutes = sttIpWindowMs / 60000;
|
||||
|
||||
const sttUserWindowMs = STT_USER_WINDOW * 60 * 1000;
|
||||
const sttUserMax = STT_USER_MAX;
|
||||
const sttUserWindowInMinutes = sttUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
sttIpWindowMs,
|
||||
sttIpMax,
|
||||
sttIpWindowInMinutes,
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.STT_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? sttIpMax : sttUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTLimiters = () => {
|
||||
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
|
||||
|
||||
const sttIpLimiter = rateLimit({
|
||||
windowMs: sttIpWindowMs,
|
||||
max: sttIpMax,
|
||||
handler: createSTTHandler(),
|
||||
});
|
||||
|
||||
const sttUserLimiter = rateLimit({
|
||||
windowMs: sttUserWindowMs,
|
||||
max: sttUserMax,
|
||||
handler: createSTTHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
|
||||
return { sttIpLimiter, sttUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = createSTTLimiters;
|
||||
68
api/server/middleware/limiters/ttsLimiters.js
Normal file
68
api/server/middleware/limiters/ttsLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
const ttsIpWindowInMinutes = ttsIpWindowMs / 60000;
|
||||
|
||||
const ttsUserWindowMs = TTS_USER_WINDOW * 60 * 1000;
|
||||
const ttsUserMax = TTS_USER_MAX;
|
||||
const ttsUserWindowInMinutes = ttsUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
ttsIpWindowMs,
|
||||
ttsIpMax,
|
||||
ttsIpWindowInMinutes,
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.TTS_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? ttsIpMax : ttsUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSLimiters = () => {
|
||||
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ttsIpLimiter = rateLimit({
|
||||
windowMs: ttsIpWindowMs,
|
||||
max: ttsIpMax,
|
||||
handler: createTTSHandler(),
|
||||
});
|
||||
|
||||
const ttsUserLimiter = rateLimit({
|
||||
windowMs: ttsUserWindowMs,
|
||||
max: ttsUserMax,
|
||||
handler: createTTSHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
|
||||
return { ttsIpLimiter, ttsUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = createTTSLimiters;
|
||||
35
api/server/middleware/limiters/verifyEmailLimiter.js
Normal file
35
api/server/middleware/limiters/verifyEmailLimiter.js
Normal file
@@ -0,0 +1,35 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { removePorts } = require('~/server/utils');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
const {
|
||||
VERIFY_EMAIL_WINDOW = 2,
|
||||
VERIFY_EMAIL_MAX = 2,
|
||||
VERIFY_EMAIL_VIOLATION_SCORE: score,
|
||||
} = process.env;
|
||||
const windowMs = VERIFY_EMAIL_WINDOW * 60 * 1000;
|
||||
const max = VERIFY_EMAIL_MAX;
|
||||
const windowInMinutes = windowMs / 60000;
|
||||
const message = `Too many attempts, please try again after ${windowInMinutes} minute(s)`;
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const type = ViolationTypes.VERIFY_EMAIL_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max,
|
||||
windowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return res.status(429).json({ message });
|
||||
};
|
||||
|
||||
const verifyEmailLimiter = rateLimit({
|
||||
windowMs,
|
||||
max,
|
||||
handler,
|
||||
keyGenerator: removePorts,
|
||||
});
|
||||
|
||||
module.exports = verifyEmailLimiter;
|
||||
22
api/server/middleware/requireLdapAuth.js
Normal file
22
api/server/middleware/requireLdapAuth.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const passport = require('passport');
|
||||
|
||||
const requireLdapAuth = (req, res, next) => {
|
||||
passport.authenticate('ldapauth', (err, user, info) => {
|
||||
if (err) {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error at passport.authenticate',
|
||||
parameters: [{ name: 'error', value: err }],
|
||||
});
|
||||
return next(err);
|
||||
}
|
||||
if (!user) {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error: No user',
|
||||
});
|
||||
return res.status(404).send(info);
|
||||
}
|
||||
req.user = user;
|
||||
next();
|
||||
})(req, res, next);
|
||||
};
|
||||
module.exports = requireLdapAuth;
|
||||
@@ -21,7 +21,13 @@ const requireLocalAuth = (req, res, next) => {
|
||||
log({
|
||||
title: '(requireLocalAuth) Error: No user',
|
||||
});
|
||||
return res.status(422).send(info);
|
||||
return res.status(404).send(info);
|
||||
}
|
||||
if (info && info.message) {
|
||||
log({
|
||||
title: '(requireLocalAuth) Error: ' + info.message,
|
||||
});
|
||||
return res.status(422).send({ message: info.message });
|
||||
}
|
||||
req.user = user;
|
||||
next();
|
||||
|
||||
14
api/server/middleware/roles/checkAdmin.js
Normal file
14
api/server/middleware/roles/checkAdmin.js
Normal file
@@ -0,0 +1,14 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
function checkAdmin(req, res, next) {
|
||||
try {
|
||||
if (req.user.role !== SystemRoles.ADMIN) {
|
||||
return res.status(403).json({ message: 'Forbidden' });
|
||||
}
|
||||
next();
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Internal Server Error' });
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = checkAdmin;
|
||||
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
@@ -0,0 +1,52 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
*
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||
* @returns {Function} Express middleware function.
|
||||
*/
|
||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { user } = req;
|
||||
if (!user) {
|
||||
return res.status(401).json({ message: 'Authorization required' });
|
||||
}
|
||||
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && req.body) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(req.body, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
if (hasAnyPermission) {
|
||||
return next();
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
} catch (error) {
|
||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = generateCheckAccess;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user