Compare commits
101 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
adbeb46399 | ||
|
|
9ad47b6660 | ||
|
|
8b28fdf240 | ||
|
|
1ec8e53db8 | ||
|
|
405be4b408 | ||
|
|
b171369aa6 | ||
|
|
ddb42b23cb | ||
|
|
037ea8cc0b | ||
|
|
e383ecba85 | ||
|
|
c7205c9bb2 | ||
|
|
25402fd208 | ||
|
|
216f6da79e | ||
|
|
cbfe47a9d5 | ||
|
|
e5e04c1cb8 | ||
|
|
5d95433c83 | ||
|
|
9ca84edb9a | ||
|
|
d5259e1525 | ||
|
|
9d100ec0fc | ||
|
|
efe057e0d8 | ||
|
|
5ab9802aa9 | ||
|
|
ed3d7c9f80 | ||
|
|
9d565ec8a5 | ||
|
|
43d7a751d6 | ||
|
|
4f3b66756a | ||
|
|
3a38b4b842 | ||
|
|
48c087cc06 | ||
|
|
4b63eb5a2c | ||
|
|
5f3ecef575 | ||
|
|
a2ee57568a | ||
|
|
0886441461 | ||
|
|
a7b5639da1 | ||
|
|
34148885b7 | ||
|
|
c11fbde9a7 | ||
|
|
9a31df026d | ||
|
|
b031dea127 | ||
|
|
9f5d77eeb0 | ||
|
|
8f328ec6a3 | ||
|
|
af69763103 | ||
|
|
5c1e44eff7 | ||
|
|
7b30ab3a41 | ||
|
|
2017ec5693 | ||
|
|
c878289adc | ||
|
|
5cafe0900c | ||
|
|
81a90d245b | ||
|
|
ba5ab86037 | ||
|
|
11dd3b487f | ||
|
|
bc39bd12a5 | ||
|
|
05c4c7e551 | ||
|
|
4ce585f77d | ||
|
|
c7bfb2ab40 | ||
|
|
3d4a8778d5 | ||
|
|
00e0091f7a | ||
|
|
70590251d1 | ||
|
|
4073b7d05d | ||
|
|
7d6a1d260f | ||
|
|
6cb561abcf | ||
|
|
abbc57a49a | ||
|
|
fd99bac121 | ||
|
|
ddf56db316 | ||
|
|
377f2c7c19 | ||
|
|
6d8aed7ef8 | ||
|
|
352e01f9d0 | ||
|
|
b23166d6be | ||
|
|
9f201577ef | ||
|
|
0450c34e3b | ||
|
|
a53ccf0d72 | ||
|
|
b1a96ecedc | ||
|
|
cff45df0ef | ||
|
|
494ab01cb4 | ||
|
|
241bc68d0f | ||
|
|
e7e473d335 | ||
|
|
909cbb8529 | ||
|
|
5145121eb7 | ||
|
|
4ac0c04e83 | ||
|
|
bc7a079208 | ||
|
|
f63fe4b4e0 | ||
|
|
495ac1b36d | ||
|
|
b3aac97710 | ||
|
|
2dd545eaa4 | ||
|
|
d61e44742d | ||
|
|
e7ca40b5ab | ||
|
|
c0e2c58c03 | ||
|
|
09c03b9df0 | ||
|
|
599d70f1de | ||
|
|
ce966419f7 | ||
|
|
365c39c405 | ||
|
|
be71a1947b | ||
|
|
3137f467a8 | ||
|
|
317a1bd8da | ||
|
|
be73deddcc | ||
|
|
6c16e910e7 | ||
|
|
7abc5bc670 | ||
|
|
1bf6c259b9 | ||
|
|
7c0379ba51 | ||
|
|
5d4b168df5 | ||
|
|
33b0154602 | ||
|
|
d87754c43d | ||
|
|
1a77fb4fd5 | ||
|
|
1be6c4830a | ||
|
|
1d3e336e1c | ||
|
|
d13a7b1a74 |
106
.env.example
106
.env.example
@@ -4,6 +4,10 @@
|
||||
|
||||
APP_TITLE=LibreChat
|
||||
|
||||
# Uncomment to add a custom footer.
|
||||
# Uncomment and make empty "" to remove the footer.
|
||||
# CUSTOM_FOOTER="My custom footer"
|
||||
|
||||
# The server will listen to localhost:3080 by default. You can change the target IP as you want.
|
||||
# If you want to make this server available externally, for example to share the server with others
|
||||
# or expose this from a Docker container, set host to 0.0.0.0 or your external IP interface.
|
||||
@@ -13,6 +17,21 @@ APP_TITLE=LibreChat
|
||||
HOST=localhost
|
||||
PORT=3080
|
||||
|
||||
# Note: the following enables user balances, which you can add manually
|
||||
# or you will need to build out a balance accruing system for users.
|
||||
# For more info, see https://docs.librechat.ai/features/token_usage.html
|
||||
|
||||
# To manually add balances, run the following command:
|
||||
# `npm run add-balance`
|
||||
|
||||
# You can also specify the email and token credit amount to add, e.g.:
|
||||
# `npm run add-balance example@example.com 1000`
|
||||
|
||||
# This works well to track your own usage for personal use; 1000 credits = $0.001 (1 mill USD)
|
||||
|
||||
# Set to true to enable token credit balances for the OpenAI/Plugins endpoints
|
||||
CHECK_BALANCE=false
|
||||
|
||||
# Automated Moderation System
|
||||
# The Automated Moderation System uses a scoring mechanism to track user violations. As users commit actions
|
||||
# like excessive logins, registrations, or messaging, they accumulate violation scores. Upon reaching
|
||||
@@ -52,6 +71,11 @@ LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can sen
|
||||
MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW
|
||||
MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages
|
||||
|
||||
# If you have permission problems, set here the UID and GID of the user running
|
||||
# the docker compose command. The applications in the container will run with these uid/gid.
|
||||
# UID=1000
|
||||
# GID=1000
|
||||
|
||||
# Change this to proxy any API request.
|
||||
# It's useful if your machine has difficulty calling the original API server.
|
||||
# PROXY=
|
||||
@@ -68,15 +92,72 @@ MONGO_URI=mongodb://127.0.0.1:27018/LibreChat
|
||||
# Set to "user_provided" to allow the user to provide their API key from the UI.
|
||||
OPENAI_API_KEY=user_provided
|
||||
|
||||
DEBUG_OPENAI=false # Set to true to enable debug mode for the OpenAI endpoint
|
||||
|
||||
# Identify the available models, separated by commas *without spaces*.
|
||||
# The first will be default.
|
||||
# Leave it blank to use internal settings.
|
||||
# OPENAI_MODELS=gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-0301,text-davinci-003,gpt-4,gpt-4-0314,gpt-4-0613
|
||||
# OPENAI_MODELS=gpt-3.5-turbo-1106,gpt-4-1106-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-0301,text-davinci-003,gpt-4,gpt-4-0314,gpt-4-0613
|
||||
|
||||
# Titling is enabled by default when initiating a conversation.
|
||||
# Uncomment the following variable to disable this feature.
|
||||
# TITLE_CONVO=false
|
||||
|
||||
# (Optional) The default model used for titling by is gpt-3.5-turbo-0613
|
||||
# You can change it by uncommenting the following and setting the desired model
|
||||
# Must be compatible with the OpenAI Endpoint.
|
||||
# OPENAI_TITLE_MODEL=gpt-3.5-turbo
|
||||
|
||||
# (Optional/Experimental) Enable message summarization by uncommenting the following:
|
||||
# Note: this may affect response time when a summary is being generated.
|
||||
# OPENAI_SUMMARIZE=true
|
||||
|
||||
# Not yet implemented: this will be a conversation option enabled by default to save users on tokens
|
||||
# We are using the ConversationSummaryBufferMemory method to summarize messages.
|
||||
# To learn more about this, see this article:
|
||||
# https://www.pinecone.io/learn/series/langchain/langchain-conversational-memory/
|
||||
|
||||
# (Optional) The default model used for summarizing is gpt-3.5-turbo
|
||||
# You can change it by uncommenting the following and setting the desired model
|
||||
# Must be compatible with the OpenAI Endpoint.
|
||||
# OPENAI_SUMMARY_MODEL=gpt-3.5-turbo
|
||||
|
||||
# Reverse proxy settings for OpenAI:
|
||||
# https://github.com/waylaidwanderer/node-chatgpt-api#using-a-reverse-proxy
|
||||
# https://github.com/waylaidwanderer/node-chatgpt-api#using-a-reverse-proxy
|
||||
# The URL must match the "url/v1," pattern, the "openai" suffix is also allowed.
|
||||
# Examples:
|
||||
# - https://open.ai/v1
|
||||
# - https://open.ai/v1/ACCOUNT/GATEWAY/openai
|
||||
# - https://open.ai/v1/hi/openai
|
||||
|
||||
# OPENAI_REVERSE_PROXY=
|
||||
|
||||
# (Advanced) Sometimes when using Local LLM APIs, you may need to force the API
|
||||
# to be called with a `prompt` payload instead of a `messages` payload; to mimic the
|
||||
# a `/v1/completions` request instead of `/v1/chat/completions`
|
||||
# This may be the case for LocalAI with some models. To do so, uncomment the following:
|
||||
# OPENAI_FORCE_PROMPT=true
|
||||
|
||||
# (Advanced) For customization of the DALL-E-3 System prompt,
|
||||
# uncomment the following, and provide your own prompt:
|
||||
# See official prompt for reference:
|
||||
# https://github.com/spdustin/ChatGPT-AutoExpert/blob/main/_system-prompts/dall-e.md
|
||||
# DALLE3_SYSTEM_PROMPT="Your System Prompt here"
|
||||
|
||||
# (Advanced) DALL-E Proxy settings
|
||||
# This is separate from its OpenAI counterpart for customization purposes
|
||||
|
||||
# Reverse proxy settings, changes the baseURL for the DALL-E-3 API Calls
|
||||
# The URL must match the "url/v1," pattern, the "openai" suffix is also allowed.
|
||||
# Examples:
|
||||
# - https://open.ai/v1
|
||||
# - https://open.ai/v1/ACCOUNT/GATEWAY/openai
|
||||
# - https://open.ai/v1/hi/openai
|
||||
|
||||
# DALLE_REVERSE_PROXY=
|
||||
|
||||
# Note: if you have PROXY set, it will be used for DALLE calls also, which is universal for the app
|
||||
|
||||
##########################
|
||||
# OpenRouter (overrides OpenAI and Plugins Endpoints):
|
||||
##########################
|
||||
@@ -108,11 +189,23 @@ OPENAI_API_KEY=user_provided
|
||||
# AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME=
|
||||
# AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
# Identify the available models, separated by commas *without spaces*.
|
||||
# The first will be default.
|
||||
# Leave it blank to use internal settings.
|
||||
# NOTE: As of 2023-11-10, the Azure API only allows one model per deployment,
|
||||
# It's recommended to name your deployments after the model name, e.g. "gpt-35-turbo,"
|
||||
# which allows for fast deployment switching and AZURE_USE_MODEL_AS_DEPLOYMENT_NAME enabled.
|
||||
# However, you can use non-model deployment names and setting the AZURE_OPENAI_DEFAULT_MODEL to ensure it works as expected.
|
||||
|
||||
# Identify the available models, separated by commas *without spaces*.
|
||||
# The first will be default. Leave it blank or as is to use internal settings.
|
||||
# NOTE: as deployment names can't have periods, they will be removed when the endpoint is generated.
|
||||
AZURE_OPENAI_MODELS=gpt-3.5-turbo,gpt-4
|
||||
|
||||
# (Advanced) this enables the use of the model name as the deployment name, e.g. "gpt-3.5-turbo" as the deployment name
|
||||
AZURE_USE_MODEL_AS_DEPLOYMENT_NAME=TRUE
|
||||
|
||||
# (Advanced) this overrides the model setting for Azure, in case you want to use your custom deployment names
|
||||
# as the values for AZURE_OPENAI_MODELS
|
||||
# AZURE_OPENAI_DEFAULT_MODEL=gpt-3.5-turbo
|
||||
|
||||
# To use Azure with the Plugins endpoint, you need the variables above, and uncomment the following variable:
|
||||
# NOTE: This may not work as expected and Azure OpenAI may not support OpenAI Functions yet
|
||||
# Omit/leave it commented to use the default OpenAI API
|
||||
@@ -347,7 +440,8 @@ DOMAIN_SERVER=http://localhost:3080
|
||||
###########################
|
||||
|
||||
# Email is used for password reset. Note that all 4 values must be set for email to work.
|
||||
# Failing to set the 4 values will result in LibreChat using the unsecured password reset!
|
||||
EMAIL_SERVICE= # eg. gmail
|
||||
EMAIL_USERNAME= # eg. your email address if using gmail
|
||||
EMAIL_PASSWORD= # eg. this is the "app password" if using gmail
|
||||
EMAIL_FROM= # eg. email address for from field like noreply@librechat.ai
|
||||
EMAIL_FROM=noreply@librechat.ai # email address for from field, it is required to set a value here even in the cases where it's not porperly working.
|
||||
|
||||
@@ -61,6 +61,7 @@ module.exports = {
|
||||
'no-restricted-syntax': 'off',
|
||||
'react/prop-types': ['off'],
|
||||
'react/display-name': ['off'],
|
||||
'no-unused-vars': ['error', { varsIgnorePattern: '^_' }],
|
||||
quotes: ['error', 'single'],
|
||||
},
|
||||
overrides: [
|
||||
|
||||
2
.github/CONTRIBUTING.md
vendored
2
.github/CONTRIBUTING.md
vendored
@@ -33,7 +33,7 @@ Project maintainers have the right and responsibility to remove, edit, or reject
|
||||
4. Clear web app localStorage and cookies before and after changes.
|
||||
5. For frontend changes:
|
||||
- Install typescript globally: `npm i -g typescript`.
|
||||
- Compile typescript before and after changes to check for introduced errors: `tsc --noEmit`.
|
||||
- Compile typescript before and after changes to check for introduced errors: `cd client && tsc --noEmit`.
|
||||
6. Run tests locally:
|
||||
- Backend unit tests: `npm run test:api`
|
||||
- Frontend unit tests: `npm run test:client`
|
||||
|
||||
2
.github/workflows/backend-review.yml
vendored
2
.github/workflows/backend-review.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
BAN_VIOLATIONS: ${{ secrets.BAN_VIOLATIONS }}
|
||||
BAN_DURATION: ${{ secrets.BAN_DURATION }}
|
||||
BAN_INTERVAL: ${{ secrets.BAN_INTERVAL }}
|
||||
NODE_ENV: ci
|
||||
NODE_ENV: CI
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Use Node.js 20.x
|
||||
|
||||
3
.github/workflows/frontend-review.yml
vendored
3
.github/workflows/frontend-review.yml
vendored
@@ -34,4 +34,5 @@ jobs:
|
||||
run: npm run frontend:ci
|
||||
|
||||
- name: Run unit tests
|
||||
run: cd client && npm run test:ci
|
||||
run: npm run test:ci --verbose
|
||||
working-directory: client
|
||||
40
.github/workflows/latest-images-main.yml
vendored
Normal file
40
.github/workflows/latest-images-main.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Docker Compose Build on Main Branch
|
||||
|
||||
on:
|
||||
workflow_dispatch: # This line allows manual triggering
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Check out the repository
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
# Set up Docker
|
||||
- name: Set up Docker
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
# Log in to GitHub Container Registry
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Run docker-compose build
|
||||
- name: Build Docker images
|
||||
run: |
|
||||
cp .env.example .env
|
||||
docker-compose build
|
||||
docker build -f Dockerfile.multi --target api-build -t librechat-api .
|
||||
|
||||
# Tag and push the images with the 'latest' tag
|
||||
- name: Tag image and push
|
||||
run: |
|
||||
docker tag librechat:latest ghcr.io/${{ github.repository_owner }}/librechat:latest
|
||||
docker push ghcr.io/${{ github.repository_owner }}/librechat:latest
|
||||
docker tag librechat-api:latest ghcr.io/${{ github.repository_owner }}/librechat-api:latest
|
||||
docker push ghcr.io/${{ github.repository_owner }}/librechat-api:latest
|
||||
3
.github/workflows/playwright.yml
vendored
3
.github/workflows/playwright.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
timeout-minutes: 60
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
NODE_ENV: ci
|
||||
NODE_ENV: CI
|
||||
CI: true
|
||||
SEARCH: false
|
||||
BINGAI_TOKEN: user_provided
|
||||
@@ -34,6 +34,7 @@ jobs:
|
||||
DOMAIN_SERVER: ${{ secrets.DOMAIN_SERVER }}
|
||||
PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: 1 # Skip downloading during npm install
|
||||
PLAYWRIGHT_BROWSERS_PATH: 0 # Places binaries to node_modules/@playwright/test
|
||||
TITLE_CONVO: false
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-node@v3
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# MIT License
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Danny Avila
|
||||
|
||||
---
|
||||
Copyright (c) 2023 LibreChat
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
@@ -14,8 +12,6 @@ furnished to do so, subject to the following conditions:
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
##
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
@@ -23,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
---
|
||||
|
||||
## [Go Back to ReadMe](../README.md)
|
||||
@@ -100,6 +100,7 @@ Keep up with the latest updates by visiting the releases page - [Releases](https
|
||||
* [Third-Party Tools](docs/features/third_party.md)
|
||||
* [Proxy](docs/features/proxy.md)
|
||||
* [Bing Jailbreak](docs/features/bing_jailbreak.md)
|
||||
* [Token Usage](docs/features/token_usage.md)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -112,6 +113,7 @@ Keep up with the latest updates by visiting the releases page - [Releases](https
|
||||
* [Ngrok](docs/deployment/ngrok.md)
|
||||
* [HuggingFace](docs/deployment/huggingface.md)
|
||||
* [Render](docs/deployment/render.md)
|
||||
* [Meilisearch in Render](docs/deployment/meilisearch_in_render.md)
|
||||
* [Hetzner](docs/deployment/hetzner_ubuntu.md)
|
||||
* [Heroku](docs/deployment/heroku.md)
|
||||
</details>
|
||||
@@ -133,7 +135,9 @@ Keep up with the latest updates by visiting the releases page - [Releases](https
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#danny-avila/LibreChat&Date)
|
||||
<a href="https://star-history.com/#danny-avila/LibreChat&Date">
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date&theme=dark" onerror="this.src='https://api.star-history.com/svg?repos=danny-avila/LibreChat&type=Date'" />
|
||||
</a>
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ const askBing = async ({
|
||||
key = await getUserKey({ userId, name: 'bingAI' });
|
||||
}
|
||||
|
||||
const { BingAIClient } = await import('@waylaidwanderer/chatgpt-api');
|
||||
const { BingAIClient } = await import('nodejs-gpt');
|
||||
const store = {
|
||||
store: new KeyvFile({ filename: './data/cache.json' }),
|
||||
};
|
||||
@@ -94,7 +94,7 @@ const askBing = async ({
|
||||
// don't give those parameters for new conversation
|
||||
// for new conversation, conversationSignature always is null
|
||||
if (conversationSignature) {
|
||||
options.conversationSignature = conversationSignature;
|
||||
options.encryptedConversationSignature = conversationSignature;
|
||||
options.clientId = clientId;
|
||||
options.invocationId = invocationId;
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ const browserClient = async ({
|
||||
key = await getUserKey({ userId, name: 'chatGPTBrowser' });
|
||||
}
|
||||
|
||||
const { ChatGPTBrowserClient } = await import('@waylaidwanderer/chatgpt-api');
|
||||
const { ChatGPTBrowserClient } = await import('nodejs-gpt');
|
||||
const store = {
|
||||
store: new KeyvFile({ filename: './data/cache.json' }),
|
||||
};
|
||||
|
||||
@@ -91,7 +91,10 @@ class AnthropicClient extends BaseClient {
|
||||
}
|
||||
|
||||
async buildMessages(messages, parentMessageId) {
|
||||
const orderedMessages = this.constructor.getMessagesForConversation(messages, parentMessageId);
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
});
|
||||
if (this.options.debug) {
|
||||
console.debug('AnthropicClient: orderedMessages', orderedMessages, parentMessageId);
|
||||
}
|
||||
@@ -239,7 +242,6 @@ class AnthropicClient extends BaseClient {
|
||||
console.log('AnthropicClient doesn\'t use getCompletion (all handled in sendCompletion)');
|
||||
}
|
||||
|
||||
// TODO: implement abortController usage
|
||||
async sendCompletion(payload, { onProgress, abortController }) {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
@@ -317,14 +319,6 @@ class AnthropicClient extends BaseClient {
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
// I commented this out because I will need to refactor this for the BaseClient/all clients
|
||||
// getMessageMapMethod() {
|
||||
// return ((message) => ({
|
||||
// author: message.isCreatedByUser ? this.userLabel : this.assistantLabel,
|
||||
// content: message?.content ?? message.text
|
||||
// })).bind(this);
|
||||
// }
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
const crypto = require('crypto');
|
||||
const TextStream = require('./TextStream');
|
||||
const { RecursiveCharacterTextSplitter } = require('langchain/text_splitter');
|
||||
const { ChatOpenAI } = require('langchain/chat_models/openai');
|
||||
const { loadSummarizationChain } = require('langchain/chains');
|
||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models');
|
||||
const { addSpaceIfNeeded } = require('../../server/utils');
|
||||
const { refinePrompt } = require('./prompts');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('../../server/utils');
|
||||
const checkBalance = require('../../models/checkBalance');
|
||||
|
||||
class BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
@@ -39,6 +36,22 @@ class BaseClient {
|
||||
throw new Error('Subclasses must implement buildMessages');
|
||||
}
|
||||
|
||||
async summarizeMessages() {
|
||||
throw new Error('Subclasses attempted to call summarizeMessages without implementing it');
|
||||
}
|
||||
|
||||
async getTokenCountForResponse(response) {
|
||||
if (this.options.debug) {
|
||||
console.debug('`recordTokenUsage` not implemented.', response);
|
||||
}
|
||||
}
|
||||
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
if (this.options.debug) {
|
||||
console.debug('`recordTokenUsage` not implemented.', { promptTokens, completionTokens });
|
||||
}
|
||||
}
|
||||
|
||||
getBuildMessagesOptions() {
|
||||
throw new Error('Subclasses must implement getBuildMessagesOptions');
|
||||
}
|
||||
@@ -64,6 +77,7 @@ class BaseClient {
|
||||
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||
let head = isEdited ? responseMessageId : parentMessageId;
|
||||
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||
this.conversationId = conversationId;
|
||||
|
||||
if (isEdited && !isContinued) {
|
||||
responseMessageId = crypto.randomUUID();
|
||||
@@ -114,8 +128,8 @@ class BaseClient {
|
||||
text: message,
|
||||
});
|
||||
|
||||
if (typeof opts?.getIds === 'function') {
|
||||
opts.getIds({
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessage,
|
||||
conversationId,
|
||||
responseMessageId,
|
||||
@@ -137,9 +151,18 @@ class BaseClient {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds instructions to the messages array. If the instructions object is empty or undefined,
|
||||
* the original messages array is returned. Otherwise, the instructions are added to the messages
|
||||
* array, preserving the last message at the end.
|
||||
*
|
||||
* @param {Array} messages - An array of messages.
|
||||
* @param {Object} instructions - An object containing instructions to be added to the messages.
|
||||
* @returns {Array} An array containing messages and instructions, or the original messages if instructions are empty.
|
||||
*/
|
||||
addInstructions(messages, instructions) {
|
||||
const payload = [];
|
||||
if (!instructions) {
|
||||
if (!instructions || Object.keys(instructions).length === 0) {
|
||||
return messages;
|
||||
}
|
||||
if (messages.length > 1) {
|
||||
@@ -170,19 +193,15 @@ class BaseClient {
|
||||
const { messageId } = message;
|
||||
const update = {};
|
||||
|
||||
if (messageId === tokenCountMap.refined?.messageId) {
|
||||
if (this.options.debug) {
|
||||
console.debug(`Adding refined props to ${messageId}.`);
|
||||
}
|
||||
if (messageId === tokenCountMap.summaryMessage?.messageId) {
|
||||
this.options.debug && console.debug(`Adding summary props to ${messageId}.`);
|
||||
|
||||
update.refinedMessageText = tokenCountMap.refined.content;
|
||||
update.refinedTokenCount = tokenCountMap.refined.tokenCount;
|
||||
update.summary = tokenCountMap.summaryMessage.content;
|
||||
update.summaryTokenCount = tokenCountMap.summaryMessage.tokenCount;
|
||||
}
|
||||
|
||||
if (message.tokenCount && !update.refinedTokenCount) {
|
||||
if (this.options.debug) {
|
||||
console.debug(`Skipping ${messageId}: already had a token count.`);
|
||||
}
|
||||
if (message.tokenCount && !update.summaryTokenCount) {
|
||||
this.options.debug && console.debug(`Skipping ${messageId}: already had a token count.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -202,193 +221,141 @@ class BaseClient {
|
||||
}, '');
|
||||
}
|
||||
|
||||
async refineMessages(messagesToRefine, remainingContextTokens) {
|
||||
const model = new ChatOpenAI({ temperature: 0 });
|
||||
const chain = loadSummarizationChain(model, {
|
||||
type: 'refine',
|
||||
verbose: this.options.debug,
|
||||
refinePrompt,
|
||||
});
|
||||
const splitter = new RecursiveCharacterTextSplitter({
|
||||
chunkSize: 1500,
|
||||
chunkOverlap: 100,
|
||||
});
|
||||
const userMessages = this.concatenateMessages(
|
||||
messagesToRefine.filter((m) => m.role === 'user'),
|
||||
);
|
||||
const assistantMessages = this.concatenateMessages(
|
||||
messagesToRefine.filter((m) => m.role !== 'user'),
|
||||
);
|
||||
const userDocs = await splitter.createDocuments([userMessages], [], {
|
||||
chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n',
|
||||
appendChunkOverlapHeader: true,
|
||||
});
|
||||
const assistantDocs = await splitter.createDocuments([assistantMessages], [], {
|
||||
chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n',
|
||||
appendChunkOverlapHeader: true,
|
||||
});
|
||||
// const chunkSize = Math.round(concatenatedMessages.length / 512);
|
||||
const input_documents = userDocs.concat(assistantDocs);
|
||||
if (this.options.debug) {
|
||||
console.debug('Refining messages...');
|
||||
}
|
||||
try {
|
||||
const res = await chain.call({
|
||||
input_documents,
|
||||
signal: this.abortController.signal,
|
||||
});
|
||||
|
||||
const refinedMessage = {
|
||||
role: 'assistant',
|
||||
content: res.output_text,
|
||||
tokenCount: this.getTokenCount(res.output_text),
|
||||
};
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('Refined messages', refinedMessage);
|
||||
console.debug(
|
||||
`remainingContextTokens: ${remainingContextTokens}, after refining: ${
|
||||
remainingContextTokens - refinedMessage.tokenCount
|
||||
}`,
|
||||
);
|
||||
}
|
||||
|
||||
return refinedMessage;
|
||||
} catch (e) {
|
||||
console.error('Error refining messages');
|
||||
console.error(e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method processes an array of messages and returns a context of messages that fit within a token limit.
|
||||
* This method processes an array of messages and returns a context of messages that fit within a specified token limit.
|
||||
* It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached.
|
||||
* If the token limit would be exceeded by adding a message, that message and possibly the previous one are added to a separate array of messages to refine.
|
||||
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the arrays at the end to maintain the original order of the messages.
|
||||
* The method also includes a mechanism to avoid blocking the event loop by waiting for the next tick after each iteration.
|
||||
* If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array.
|
||||
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages.
|
||||
*
|
||||
* @param {Array} messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
|
||||
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
||||
* @param {Array} _messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
|
||||
* @param {number} [maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
|
||||
* @returns {Object} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
|
||||
* `context` is an array of messages that fit within the token limit.
|
||||
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
|
||||
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
|
||||
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
|
||||
*/
|
||||
async getMessagesWithinTokenLimit(messages) {
|
||||
async getMessagesWithinTokenLimit(_messages, maxContextTokens) {
|
||||
// Every reply is primed with <|start|>assistant<|message|>, so we
|
||||
// start with 3 tokens for the label after all messages have been counted.
|
||||
let currentTokenCount = 3;
|
||||
let context = [];
|
||||
let messagesToRefine = [];
|
||||
let refineIndex = -1;
|
||||
let remainingContextTokens = this.maxContextTokens;
|
||||
let summaryIndex = -1;
|
||||
let remainingContextTokens = maxContextTokens ?? this.maxContextTokens;
|
||||
const messages = [..._messages];
|
||||
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const message = messages[i];
|
||||
const newTokenCount = currentTokenCount + message.tokenCount;
|
||||
const exceededLimit = newTokenCount > this.maxContextTokens;
|
||||
let shouldRefine = exceededLimit && this.shouldRefineContext;
|
||||
let refineNextMessage = i !== 0 && i !== 1 && context.length > 0;
|
||||
const context = [];
|
||||
if (currentTokenCount < remainingContextTokens) {
|
||||
while (messages.length > 0 && currentTokenCount < remainingContextTokens) {
|
||||
const poppedMessage = messages.pop();
|
||||
const { tokenCount } = poppedMessage;
|
||||
|
||||
if (shouldRefine) {
|
||||
messagesToRefine.push(message);
|
||||
|
||||
if (refineIndex === -1) {
|
||||
refineIndex = i;
|
||||
if (poppedMessage && currentTokenCount + tokenCount <= remainingContextTokens) {
|
||||
context.push(poppedMessage);
|
||||
currentTokenCount += tokenCount;
|
||||
} else {
|
||||
messages.push(poppedMessage);
|
||||
break;
|
||||
}
|
||||
|
||||
if (refineNextMessage) {
|
||||
refineIndex = i + 1;
|
||||
const removedMessage = context.pop();
|
||||
messagesToRefine.push(removedMessage);
|
||||
currentTokenCount -= removedMessage.tokenCount;
|
||||
remainingContextTokens = this.maxContextTokens - currentTokenCount;
|
||||
refineNextMessage = false;
|
||||
}
|
||||
|
||||
continue;
|
||||
} else if (exceededLimit) {
|
||||
break;
|
||||
}
|
||||
|
||||
context.push(message);
|
||||
currentTokenCount = newTokenCount;
|
||||
remainingContextTokens = this.maxContextTokens - currentTokenCount;
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
}
|
||||
|
||||
const prunedMemory = messages;
|
||||
summaryIndex = prunedMemory.length - 1;
|
||||
remainingContextTokens -= currentTokenCount;
|
||||
|
||||
return {
|
||||
context: context.reverse(),
|
||||
remainingContextTokens,
|
||||
messagesToRefine: messagesToRefine.reverse(),
|
||||
refineIndex,
|
||||
messagesToRefine: prunedMemory,
|
||||
summaryIndex,
|
||||
};
|
||||
}
|
||||
|
||||
async handleContextStrategy({ instructions, orderedMessages, formattedMessages }) {
|
||||
let payload = this.addInstructions(formattedMessages, instructions);
|
||||
let _instructions;
|
||||
let tokenCount;
|
||||
|
||||
if (instructions) {
|
||||
({ tokenCount, ..._instructions } = instructions);
|
||||
}
|
||||
this.options.debug && _instructions && console.debug('instructions tokenCount', tokenCount);
|
||||
let payload = this.addInstructions(formattedMessages, _instructions);
|
||||
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
|
||||
let { context, remainingContextTokens, messagesToRefine, refineIndex } =
|
||||
await this.getMessagesWithinTokenLimit(payload);
|
||||
|
||||
payload = context;
|
||||
let refinedMessage;
|
||||
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
|
||||
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
|
||||
|
||||
// if (messagesToRefine.length > 0) {
|
||||
// refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens);
|
||||
// payload.unshift(refinedMessage);
|
||||
// remainingContextTokens -= refinedMessage.tokenCount;
|
||||
// }
|
||||
// if (remainingContextTokens <= instructions?.tokenCount) {
|
||||
// if (this.options.debug) {
|
||||
// console.debug(`Remaining context (${remainingContextTokens}) is less than instructions token count: ${instructions.tokenCount}`);
|
||||
// }
|
||||
|
||||
// ({ context, remainingContextTokens, messagesToRefine, refineIndex } = await this.getMessagesWithinTokenLimit(payload));
|
||||
// payload = context;
|
||||
// }
|
||||
|
||||
// Calculate the difference in length to determine how many messages were discarded if any
|
||||
let diff = orderedWithInstructions.length - payload.length;
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('<---------------------------------DIFF--------------------------------->');
|
||||
console.debug(
|
||||
`Difference between payload (${payload.length}) and orderedWithInstructions (${orderedWithInstructions.length}): ${diff}`,
|
||||
);
|
||||
this.options.debug &&
|
||||
console.debug(
|
||||
'remainingContextTokens, this.maxContextTokens (1/2)',
|
||||
remainingContextTokens,
|
||||
this.maxContextTokens,
|
||||
);
|
||||
}
|
||||
|
||||
// If the difference is positive, slice the orderedWithInstructions array
|
||||
let summaryMessage;
|
||||
let summaryTokenCount;
|
||||
let { shouldSummarize } = this;
|
||||
|
||||
// Calculate the difference in length to determine how many messages were discarded if any
|
||||
const { length } = payload;
|
||||
const diff = length - context.length;
|
||||
const firstMessage = orderedWithInstructions[0];
|
||||
const usePrevSummary =
|
||||
shouldSummarize &&
|
||||
diff === 1 &&
|
||||
firstMessage?.summary &&
|
||||
this.previous_summary.messageId === firstMessage.messageId;
|
||||
|
||||
if (diff > 0) {
|
||||
orderedWithInstructions = orderedWithInstructions.slice(diff);
|
||||
payload = payload.slice(diff);
|
||||
this.options.debug &&
|
||||
console.debug(
|
||||
`Difference between original payload (${length}) and context (${context.length}): ${diff}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (messagesToRefine.length > 0) {
|
||||
refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens);
|
||||
payload.unshift(refinedMessage);
|
||||
remainingContextTokens -= refinedMessage.tokenCount;
|
||||
const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
|
||||
if (payload.length === 0 && !shouldSummarize && latestMessage) {
|
||||
throw new Error(
|
||||
`Prompt token count of ${latestMessage.tokenCount} exceeds max token count of ${this.maxContextTokens}.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (this.options.debug) {
|
||||
if (usePrevSummary) {
|
||||
summaryMessage = { role: 'system', content: firstMessage.summary };
|
||||
summaryTokenCount = firstMessage.summaryTokenCount;
|
||||
payload.unshift(summaryMessage);
|
||||
remainingContextTokens -= summaryTokenCount;
|
||||
} else if (shouldSummarize && messagesToRefine.length > 0) {
|
||||
({ summaryMessage, summaryTokenCount } = await this.summarizeMessages({
|
||||
messagesToRefine,
|
||||
remainingContextTokens,
|
||||
}));
|
||||
summaryMessage && payload.unshift(summaryMessage);
|
||||
remainingContextTokens -= summaryTokenCount;
|
||||
}
|
||||
|
||||
// Make sure to only continue summarization logic if the summary message was generated
|
||||
shouldSummarize = summaryMessage && shouldSummarize;
|
||||
|
||||
this.options.debug &&
|
||||
console.debug(
|
||||
'remainingContextTokens, this.maxContextTokens (2/2)',
|
||||
remainingContextTokens,
|
||||
this.maxContextTokens,
|
||||
);
|
||||
}
|
||||
|
||||
let tokenCountMap = orderedWithInstructions.reduce((map, message, index) => {
|
||||
if (!message.messageId) {
|
||||
const { messageId } = message;
|
||||
if (!messageId) {
|
||||
return map;
|
||||
}
|
||||
|
||||
if (index === refineIndex) {
|
||||
map.refined = { ...refinedMessage, messageId: message.messageId };
|
||||
if (shouldSummarize && index === summaryIndex && !usePrevSummary) {
|
||||
map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount };
|
||||
}
|
||||
|
||||
map[message.messageId] = payload[index].tokenCount;
|
||||
map[messageId] = orderedWithInstructions[index].tokenCount;
|
||||
return map;
|
||||
}, {});
|
||||
|
||||
@@ -396,9 +363,16 @@ class BaseClient {
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->');
|
||||
// console.debug('Payload:', payload);
|
||||
console.debug('Payload:', payload);
|
||||
console.debug('Token Count Map:', tokenCountMap);
|
||||
console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens);
|
||||
console.debug(
|
||||
'Prompt Tokens',
|
||||
promptTokens,
|
||||
'remainingContextTokens',
|
||||
remainingContextTokens,
|
||||
'this.maxContextTokens',
|
||||
this.maxContextTokens,
|
||||
);
|
||||
}
|
||||
|
||||
return { payload, tokenCountMap, promptTokens, messages: orderedWithInstructions };
|
||||
@@ -445,11 +419,6 @@ class BaseClient {
|
||||
this.getBuildMessagesOptions(opts),
|
||||
);
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('payload');
|
||||
console.debug(payload);
|
||||
}
|
||||
|
||||
if (tokenCountMap) {
|
||||
console.dir(tokenCountMap, { depth: null });
|
||||
if (tokenCountMap[userMessage.messageId]) {
|
||||
@@ -458,11 +427,6 @@ class BaseClient {
|
||||
console.log('userMessage', userMessage);
|
||||
}
|
||||
|
||||
payload = payload.map((message) => {
|
||||
const messageWithoutTokenCount = message;
|
||||
delete messageWithoutTokenCount.tokenCount;
|
||||
return messageWithoutTokenCount;
|
||||
});
|
||||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
@@ -470,6 +434,21 @@ class BaseClient {
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
await checkBalance({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
txData: {
|
||||
user: this.user,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
debug: this.options.debug,
|
||||
model: this.modelOptions.model,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const completion = await this.sendCompletion(payload, opts);
|
||||
const responseMessage = {
|
||||
messageId: responseMessageId,
|
||||
conversationId,
|
||||
@@ -478,13 +457,19 @@ class BaseClient {
|
||||
isEdited,
|
||||
model: this.modelOptions.model,
|
||||
sender: this.sender,
|
||||
text: addSpaceIfNeeded(generation) + (await this.sendCompletion(payload, opts)),
|
||||
text: addSpaceIfNeeded(generation) + completion,
|
||||
promptTokens,
|
||||
};
|
||||
|
||||
if (tokenCountMap && this.getTokenCountForResponse) {
|
||||
if (
|
||||
tokenCountMap &&
|
||||
this.recordTokenUsage &&
|
||||
this.getTokenCountForResponse &&
|
||||
this.getTokenCount
|
||||
) {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
responseMessage.completionTokens = responseMessage.tokenCount;
|
||||
const completionTokens = this.getTokenCount(completion);
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||
}
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
delete responseMessage.tokenCount;
|
||||
@@ -511,7 +496,30 @@ class BaseClient {
|
||||
mapMethod = this.getMessageMapMethod();
|
||||
}
|
||||
|
||||
return this.constructor.getMessagesForConversation(messages, parentMessageId, mapMethod);
|
||||
const orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
mapMethod,
|
||||
});
|
||||
|
||||
if (!this.shouldSummarize) {
|
||||
return orderedMessages;
|
||||
}
|
||||
|
||||
// Find the latest message with a 'summary' property
|
||||
for (let i = orderedMessages.length - 1; i >= 0; i--) {
|
||||
if (orderedMessages[i]?.summary) {
|
||||
this.previous_summary = orderedMessages[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.options.debug && this.previous_summary) {
|
||||
const { messageId, summary, tokenCount, summaryTokenCount } = this.previous_summary;
|
||||
console.debug('Previous summary:', { messageId, summary, tokenCount, summaryTokenCount });
|
||||
}
|
||||
|
||||
return orderedMessages;
|
||||
}
|
||||
|
||||
async saveMessageToDatabase(message, endpointOptions, user = null) {
|
||||
@@ -529,30 +537,79 @@ class BaseClient {
|
||||
|
||||
/**
|
||||
* Iterate through messages, building an array based on the parentMessageId.
|
||||
* Each message has an id and a parentMessageId. The parentMessageId is the id of the message that this message is a reply to.
|
||||
* @param messages
|
||||
* @param parentMessageId
|
||||
* @returns {*[]} An array containing the messages in the order they should be displayed, starting with the root message.
|
||||
*
|
||||
* This function constructs a conversation thread by traversing messages from a given parentMessageId up to the root message.
|
||||
* It handles cyclic references by ensuring that a message is not processed more than once.
|
||||
* If the 'summary' option is set to true and a message has a 'summary' property:
|
||||
* - The message's 'role' is set to 'system'.
|
||||
* - The message's 'text' is set to its 'summary'.
|
||||
* - If the message has a 'summaryTokenCount', the message's 'tokenCount' is set to 'summaryTokenCount'.
|
||||
* The traversal stops at the message with the 'summary' property.
|
||||
*
|
||||
* Each message object should have an 'id' or 'messageId' property and may have a 'parentMessageId' property.
|
||||
* The 'parentMessageId' is the ID of the message that the current message is a reply to.
|
||||
* If 'parentMessageId' is not present, null, or is '00000000-0000-0000-0000-000000000000',
|
||||
* the message is considered a root message.
|
||||
*
|
||||
* @param {Object} options - The options for the function.
|
||||
* @param {Array} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property.
|
||||
* @param {string} options.parentMessageId - The ID of the parent message to start the traversal from.
|
||||
* @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. If provided, it will be applied to each message in the resulting array.
|
||||
* @param {boolean} [options.summary=false] - If set to true, the traversal modifies messages with 'summary' and 'summaryTokenCount' properties and stops at the message with a 'summary' property.
|
||||
* @returns {Array} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'.
|
||||
*/
|
||||
static getMessagesForConversation(messages, parentMessageId, mapMethod = null) {
|
||||
static getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
mapMethod = null,
|
||||
summary = false,
|
||||
}) {
|
||||
if (!messages || messages.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const orderedMessages = [];
|
||||
let currentMessageId = parentMessageId;
|
||||
const visitedMessageIds = new Set();
|
||||
|
||||
while (currentMessageId) {
|
||||
if (visitedMessageIds.has(currentMessageId)) {
|
||||
break;
|
||||
}
|
||||
const message = messages.find((msg) => {
|
||||
const messageId = msg.messageId ?? msg.id;
|
||||
return messageId === currentMessageId;
|
||||
});
|
||||
|
||||
visitedMessageIds.add(currentMessageId);
|
||||
|
||||
if (!message) {
|
||||
break;
|
||||
}
|
||||
orderedMessages.unshift(message);
|
||||
currentMessageId = message.parentMessageId;
|
||||
|
||||
if (summary && message.summary) {
|
||||
message.role = 'system';
|
||||
message.text = message.summary;
|
||||
}
|
||||
|
||||
if (summary && message.summaryTokenCount) {
|
||||
message.tokenCount = message.summaryTokenCount;
|
||||
}
|
||||
|
||||
orderedMessages.push(message);
|
||||
|
||||
if (summary && message.summary) {
|
||||
break;
|
||||
}
|
||||
|
||||
currentMessageId =
|
||||
message.parentMessageId === '00000000-0000-0000-0000-000000000000'
|
||||
? null
|
||||
: message.parentMessageId;
|
||||
}
|
||||
|
||||
orderedMessages.reverse();
|
||||
|
||||
if (mapMethod) {
|
||||
return orderedMessages.map(mapMethod);
|
||||
}
|
||||
@@ -565,6 +622,7 @@ class BaseClient {
|
||||
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
*
|
||||
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
|
||||
* In our implementation, this is accounted for in the getMessagesWithinTokenLimit method.
|
||||
*
|
||||
* @param {Object} message
|
||||
*/
|
||||
|
||||
@@ -50,7 +50,7 @@ class ChatGPTClient extends BaseClient {
|
||||
stop: modelOptions.stop,
|
||||
};
|
||||
|
||||
this.isChatGptModel = this.modelOptions.model.startsWith('gpt-');
|
||||
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
|
||||
const { isChatGptModel } = this;
|
||||
this.isUnofficialChatGptModel =
|
||||
this.modelOptions.model.startsWith('text-chat') ||
|
||||
@@ -153,6 +153,11 @@ class ChatGPTClient extends BaseClient {
|
||||
} else {
|
||||
modelOptions.prompt = input;
|
||||
}
|
||||
|
||||
if (this.useOpenRouter && modelOptions.prompt) {
|
||||
delete modelOptions.stop;
|
||||
}
|
||||
|
||||
const { debug } = this.options;
|
||||
const url = this.completionsUrl;
|
||||
if (debug) {
|
||||
@@ -432,9 +437,7 @@ ${botMessage.message}
|
||||
return returnData;
|
||||
}
|
||||
|
||||
async buildPrompt(messages, parentMessageId, { isChatGptModel = false, promptPrefix = null }) {
|
||||
const orderedMessages = this.constructor.getMessagesForConversation(messages, parentMessageId);
|
||||
|
||||
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
|
||||
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
|
||||
if (promptPrefix) {
|
||||
// If the prompt prefix doesn't end with the end token, add it.
|
||||
@@ -480,8 +483,8 @@ ${botMessage.message}
|
||||
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
|
||||
// Do this within a recursive async function so that it doesn't block the event loop for too long.
|
||||
const buildPromptBody = async () => {
|
||||
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
|
||||
const message = orderedMessages.pop();
|
||||
if (currentTokenCount < maxTokenCount && messages.length > 0) {
|
||||
const message = messages.pop();
|
||||
const roleLabel =
|
||||
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
|
||||
? this.userLabel
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
const BaseClient = require('./BaseClient');
|
||||
const ChatGPTClient = require('./ChatGPTClient');
|
||||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
|
||||
const { getModelMaxTokens, genAzureChatCompletion, extractBaseURL } = require('../../utils');
|
||||
const { truncateText, formatMessage, CUT_OFF_PROMPT } = require('./prompts');
|
||||
const spendTokens = require('../../models/spendTokens');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const { isEnabled } = require('../../server/utils');
|
||||
const { createLLM, RunManager } = require('./llm');
|
||||
const ChatGPTClient = require('./ChatGPTClient');
|
||||
const { summaryBuffer } = require('./memory');
|
||||
const { runTitleChain } = require('./chains');
|
||||
const { createLLM } = require('./llm');
|
||||
const { tokenSplit } = require('./document');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
||||
// Cache to store Tiktoken instances
|
||||
const tokenizersCache = {};
|
||||
@@ -20,11 +28,8 @@ class OpenAIClient extends BaseClient {
|
||||
this.contextStrategy = options.contextStrategy
|
||||
? options.contextStrategy.toLowerCase()
|
||||
: 'discard';
|
||||
this.shouldRefineContext = this.contextStrategy === 'refine';
|
||||
this.shouldSummarize = this.contextStrategy === 'summarize';
|
||||
this.azure = options.azure || false;
|
||||
if (this.azure) {
|
||||
this.azureEndpoint = genAzureChatCompletion(this.azure);
|
||||
}
|
||||
this.setOptions(options);
|
||||
}
|
||||
|
||||
@@ -59,28 +64,57 @@ class OpenAIClient extends BaseClient {
|
||||
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
|
||||
stop: modelOptions.stop,
|
||||
};
|
||||
} else {
|
||||
// Update the modelOptions if it already exists
|
||||
this.modelOptions = {
|
||||
...this.modelOptions,
|
||||
...modelOptions,
|
||||
};
|
||||
}
|
||||
|
||||
if (process.env.OPENROUTER_API_KEY) {
|
||||
this.apiKey = process.env.OPENROUTER_API_KEY;
|
||||
const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {};
|
||||
if (OPENROUTER_API_KEY && !this.azure) {
|
||||
this.apiKey = OPENROUTER_API_KEY;
|
||||
this.useOpenRouter = true;
|
||||
}
|
||||
|
||||
this.isChatCompletion =
|
||||
this.useOpenRouter ||
|
||||
this.options.reverseProxyUrl ||
|
||||
this.options.localAI ||
|
||||
this.modelOptions.model.startsWith('gpt-');
|
||||
const { reverseProxyUrl: reverseProxy } = this.options;
|
||||
this.FORCE_PROMPT =
|
||||
isEnabled(OPENAI_FORCE_PROMPT) ||
|
||||
(reverseProxy && reverseProxy.includes('completions') && !reverseProxy.includes('chat'));
|
||||
|
||||
if (this.azure && process.env.AZURE_OPENAI_DEFAULT_MODEL) {
|
||||
this.azureEndpoint = genAzureChatCompletion(this.azure, this.modelOptions.model);
|
||||
this.modelOptions.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
} else if (this.azure) {
|
||||
this.azureEndpoint = genAzureChatCompletion(this.azure, this.modelOptions.model);
|
||||
}
|
||||
|
||||
const { model } = this.modelOptions;
|
||||
|
||||
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-');
|
||||
this.isChatGptModel = this.isChatCompletion;
|
||||
if (this.modelOptions.model === 'text-davinci-003') {
|
||||
if (
|
||||
model.includes('text-davinci') ||
|
||||
model.includes('gpt-3.5-turbo-instruct') ||
|
||||
this.FORCE_PROMPT
|
||||
) {
|
||||
this.isChatCompletion = false;
|
||||
this.isChatGptModel = false;
|
||||
}
|
||||
const { isChatGptModel } = this;
|
||||
this.isUnofficialChatGptModel =
|
||||
this.modelOptions.model.startsWith('text-chat') ||
|
||||
this.modelOptions.model.startsWith('text-davinci-002-render');
|
||||
this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum
|
||||
model.startsWith('text-chat') || model.startsWith('text-davinci-002-render');
|
||||
this.maxContextTokens = getModelMaxTokens(model) ?? 4095; // 1 less than maximum
|
||||
|
||||
if (this.shouldSummarize) {
|
||||
this.maxContextTokens = Math.floor(this.maxContextTokens / 2);
|
||||
}
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('maxContextTokens', this.maxContextTokens);
|
||||
}
|
||||
|
||||
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
|
||||
this.maxPromptTokens =
|
||||
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
|
||||
@@ -108,9 +142,13 @@ class OpenAIClient extends BaseClient {
|
||||
this.modelOptions.stop = stopTokens;
|
||||
}
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
this.completionsUrl = this.options.reverseProxyUrl;
|
||||
this.langchainProxy = this.options.reverseProxyUrl.match(/.*v1/)[0];
|
||||
if (reverseProxy) {
|
||||
this.completionsUrl = reverseProxy;
|
||||
this.langchainProxy = extractBaseURL(reverseProxy);
|
||||
!this.langchainProxy &&
|
||||
console.warn(`The reverse proxy URL ${reverseProxy} is not valid for Plugins.
|
||||
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions
|
||||
If your reverse proxy is compatible to OpenAI specs in every other way, it may still work without plugins enabled.`);
|
||||
} else if (isChatGptModel) {
|
||||
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
|
||||
} else {
|
||||
@@ -161,10 +199,11 @@ class OpenAIClient extends BaseClient {
|
||||
tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens);
|
||||
} else {
|
||||
try {
|
||||
this.encoding = this.modelOptions.model;
|
||||
tokenizer = this.constructor.getTokenizer(this.modelOptions.model, true);
|
||||
} catch {
|
||||
const { model } = this.modelOptions;
|
||||
this.encoding = model.includes('instruct') ? 'text-davinci-003' : model;
|
||||
tokenizer = this.constructor.getTokenizer(this.encoding, true);
|
||||
} catch {
|
||||
tokenizer = this.constructor.getTokenizer('text-davinci-003', true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,8 +289,13 @@ class OpenAIClient extends BaseClient {
|
||||
parentMessageId,
|
||||
{ isChatCompletion = false, promptPrefix = null },
|
||||
) {
|
||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
summary: this.shouldSummarize,
|
||||
});
|
||||
if (!isChatCompletion) {
|
||||
return await this.buildPrompt(messages, parentMessageId, {
|
||||
return await this.buildPrompt(orderedMessages, {
|
||||
isChatGptModel: isChatCompletion,
|
||||
promptPrefix,
|
||||
});
|
||||
@@ -261,7 +305,6 @@ class OpenAIClient extends BaseClient {
|
||||
let instructions;
|
||||
let tokenCountMap;
|
||||
let promptTokens;
|
||||
let orderedMessages = this.constructor.getMessagesForConversation(messages, parentMessageId);
|
||||
|
||||
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
|
||||
if (promptPrefix) {
|
||||
@@ -277,22 +320,15 @@ class OpenAIClient extends BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
const formattedMessages = orderedMessages.map((message) => {
|
||||
let { role: _role, sender, text } = message;
|
||||
const role = _role ?? sender;
|
||||
const content = text ?? '';
|
||||
const formattedMessage = {
|
||||
role: role?.toLowerCase() === 'user' ? 'user' : 'assistant',
|
||||
content,
|
||||
};
|
||||
const formattedMessages = orderedMessages.map((message, i) => {
|
||||
const formattedMessage = formatMessage({
|
||||
message,
|
||||
userName: this.options?.name,
|
||||
assistantName: this.options?.chatGptLabel,
|
||||
});
|
||||
|
||||
if (this.options?.name && formattedMessage.role === 'user') {
|
||||
formattedMessage.name = this.options.name;
|
||||
}
|
||||
|
||||
if (this.contextStrategy) {
|
||||
formattedMessage.tokenCount =
|
||||
message.tokenCount ?? this.getTokenCountForMessage(formattedMessage);
|
||||
if (this.contextStrategy && !orderedMessages[i].tokenCount) {
|
||||
orderedMessages[i].tokenCount = this.getTokenCountForMessage(formattedMessage);
|
||||
}
|
||||
|
||||
return formattedMessage;
|
||||
@@ -318,6 +354,10 @@ class OpenAIClient extends BaseClient {
|
||||
result.tokenCountMap = tokenCountMap;
|
||||
}
|
||||
|
||||
if (promptTokens >= 0 && typeof this.options.getReqData === 'function') {
|
||||
this.options.getReqData({ promptTokens });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -326,7 +366,9 @@ class OpenAIClient extends BaseClient {
|
||||
let result = null;
|
||||
let streamResult = null;
|
||||
this.modelOptions.user = this.user;
|
||||
if (typeof opts.onProgress === 'function') {
|
||||
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
||||
const useOldMethod = !!(this.azure || invalidBaseUrl || !this.isChatCompletion);
|
||||
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
||||
await this.getCompletion(
|
||||
payload,
|
||||
(progressMessage) => {
|
||||
@@ -347,6 +389,8 @@ class OpenAIClient extends BaseClient {
|
||||
if (this.isChatCompletion) {
|
||||
token =
|
||||
progressMessage.choices?.[0]?.delta?.content ?? progressMessage.choices?.[0]?.text;
|
||||
} else {
|
||||
token = progressMessage.choices?.[0]?.text;
|
||||
}
|
||||
|
||||
if (!token && this.useOpenRouter) {
|
||||
@@ -367,6 +411,13 @@ class OpenAIClient extends BaseClient {
|
||||
},
|
||||
opts.abortController || new AbortController(),
|
||||
);
|
||||
} else if (typeof opts.onProgress === 'function') {
|
||||
reply = await this.chatCompletion({
|
||||
payload,
|
||||
clientOptions: opts,
|
||||
onProgress: opts.onProgress,
|
||||
abortController: opts.abortController,
|
||||
});
|
||||
} else {
|
||||
result = await this.getCompletion(
|
||||
payload,
|
||||
@@ -390,28 +441,30 @@ class OpenAIClient extends BaseClient {
|
||||
return reply.trim();
|
||||
}
|
||||
|
||||
getTokenCountForResponse(response) {
|
||||
return this.getTokenCountForMessage({
|
||||
role: 'assistant',
|
||||
content: response.text,
|
||||
});
|
||||
}
|
||||
|
||||
async titleConvo({ text, responseText = '' }) {
|
||||
let title = 'New Chat';
|
||||
const convo = `||>User:
|
||||
"${text}"
|
||||
||>Response:
|
||||
"${JSON.stringify(responseText)}"`;
|
||||
|
||||
initializeLLM({
|
||||
model = 'gpt-3.5-turbo',
|
||||
modelName,
|
||||
temperature = 0.2,
|
||||
presence_penalty = 0,
|
||||
frequency_penalty = 0,
|
||||
max_tokens,
|
||||
streaming,
|
||||
context,
|
||||
tokenBuffer,
|
||||
initialMessageCount,
|
||||
}) {
|
||||
const modelOptions = {
|
||||
model: 'gpt-3.5-turbo-0613',
|
||||
temperature: 0.2,
|
||||
presence_penalty: 0,
|
||||
frequency_penalty: 0,
|
||||
max_tokens: 16,
|
||||
modelName: modelName ?? model,
|
||||
temperature,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
user: this.user,
|
||||
};
|
||||
|
||||
if (max_tokens) {
|
||||
modelOptions.max_tokens = max_tokens;
|
||||
}
|
||||
|
||||
const configOptions = {};
|
||||
|
||||
if (this.langchainProxy) {
|
||||
@@ -428,19 +481,65 @@ class OpenAIClient extends BaseClient {
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const llm = createLLM({
|
||||
modelOptions,
|
||||
configOptions,
|
||||
openAIApiKey: this.apiKey,
|
||||
azure: this.azure,
|
||||
});
|
||||
if (this.options.proxy) {
|
||||
configOptions.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
configOptions.httpsAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
title = await runTitleChain({ llm, text, convo });
|
||||
const { req, res, debug } = this.options;
|
||||
const runManager = new RunManager({ req, res, debug, abortController: this.abortController });
|
||||
this.runManager = runManager;
|
||||
|
||||
const llm = createLLM({
|
||||
modelOptions,
|
||||
configOptions,
|
||||
openAIApiKey: this.apiKey,
|
||||
azure: this.azure,
|
||||
streaming,
|
||||
callbacks: runManager.createCallbacks({
|
||||
context,
|
||||
tokenBuffer,
|
||||
conversationId: this.conversationId,
|
||||
initialMessageCount,
|
||||
}),
|
||||
});
|
||||
|
||||
return llm;
|
||||
}
|
||||
|
||||
async titleConvo({ text, responseText = '' }) {
|
||||
let title = 'New Chat';
|
||||
const convo = `||>User:
|
||||
"${truncateText(text)}"
|
||||
||>Response:
|
||||
"${JSON.stringify(truncateText(responseText))}"`;
|
||||
|
||||
const { OPENAI_TITLE_MODEL } = process.env ?? {};
|
||||
|
||||
const modelOptions = {
|
||||
model: OPENAI_TITLE_MODEL ?? 'gpt-3.5-turbo',
|
||||
temperature: 0.2,
|
||||
presence_penalty: 0,
|
||||
frequency_penalty: 0,
|
||||
max_tokens: 16,
|
||||
};
|
||||
|
||||
try {
|
||||
this.abortController = new AbortController();
|
||||
const llm = this.initializeLLM({ ...modelOptions, context: 'title', tokenBuffer: 150 });
|
||||
title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal });
|
||||
} catch (e) {
|
||||
console.error(e.message);
|
||||
if (e?.message?.toLowerCase()?.includes('abort')) {
|
||||
this.options.debug && console.debug('Aborted title generation');
|
||||
return;
|
||||
}
|
||||
console.log('There was an issue generating title with LangChain, trying the old method...');
|
||||
modelOptions.model = 'gpt-3.5-turbo';
|
||||
this.options.debug && console.error(e.message, e);
|
||||
modelOptions.model = OPENAI_TITLE_MODEL ?? 'gpt-3.5-turbo';
|
||||
if (this.azure) {
|
||||
modelOptions.model = process.env.AZURE_OPENAI_DEFAULT_MODEL ?? modelOptions.model;
|
||||
this.azureEndpoint = genAzureChatCompletion(this.azure, modelOptions.model);
|
||||
}
|
||||
const instructionsPayload = [
|
||||
{
|
||||
role: 'system',
|
||||
@@ -464,6 +563,262 @@ ${convo}
|
||||
console.log('CONVERSATION TITLE', title);
|
||||
return title;
|
||||
}
|
||||
|
||||
async summarizeMessages({ messagesToRefine, remainingContextTokens }) {
|
||||
this.options.debug && console.debug('Summarizing messages...');
|
||||
let context = messagesToRefine;
|
||||
let prompt;
|
||||
|
||||
const { OPENAI_SUMMARY_MODEL = 'gpt-3.5-turbo' } = process.env ?? {};
|
||||
const maxContextTokens = getModelMaxTokens(OPENAI_SUMMARY_MODEL) ?? 4095;
|
||||
// 3 tokens for the assistant label, and 98 for the summarizer prompt (101)
|
||||
let promptBuffer = 101;
|
||||
|
||||
/*
|
||||
* Note: token counting here is to block summarization if it exceeds the spend; complete
|
||||
* accuracy is not important. Actual spend will happen after successful summarization.
|
||||
*/
|
||||
const excessTokenCount = context.reduce(
|
||||
(acc, message) => acc + message.tokenCount,
|
||||
promptBuffer,
|
||||
);
|
||||
|
||||
if (excessTokenCount > maxContextTokens) {
|
||||
({ context } = await this.getMessagesWithinTokenLimit(context, maxContextTokens));
|
||||
}
|
||||
|
||||
if (context.length === 0) {
|
||||
this.options.debug &&
|
||||
console.debug('Summary context is empty, using latest message within token limit');
|
||||
|
||||
promptBuffer = 32;
|
||||
const { text, ...latestMessage } = messagesToRefine[messagesToRefine.length - 1];
|
||||
const splitText = await tokenSplit({
|
||||
text,
|
||||
chunkSize: Math.floor((maxContextTokens - promptBuffer) / 3),
|
||||
});
|
||||
|
||||
const newText = `${splitText[0]}\n...[truncated]...\n${splitText[splitText.length - 1]}`;
|
||||
prompt = CUT_OFF_PROMPT;
|
||||
|
||||
context = [
|
||||
formatMessage({
|
||||
message: {
|
||||
...latestMessage,
|
||||
text: newText,
|
||||
},
|
||||
userName: this.options?.name,
|
||||
assistantName: this.options?.chatGptLabel,
|
||||
}),
|
||||
];
|
||||
}
|
||||
// TODO: We can accurately count the tokens here before handleChatModelStart
|
||||
// by recreating the summary prompt (single message) to avoid LangChain handling
|
||||
|
||||
const initialPromptTokens = this.maxContextTokens - remainingContextTokens;
|
||||
this.options.debug && console.debug(`initialPromptTokens: ${initialPromptTokens}`);
|
||||
|
||||
const llm = this.initializeLLM({
|
||||
model: OPENAI_SUMMARY_MODEL,
|
||||
temperature: 0.2,
|
||||
context: 'summary',
|
||||
tokenBuffer: initialPromptTokens,
|
||||
});
|
||||
|
||||
try {
|
||||
const summaryMessage = await summaryBuffer({
|
||||
llm,
|
||||
debug: this.options.debug,
|
||||
prompt,
|
||||
context,
|
||||
formatOptions: {
|
||||
userName: this.options?.name,
|
||||
assistantName: this.options?.chatGptLabel ?? this.options?.modelLabel,
|
||||
},
|
||||
previous_summary: this.previous_summary?.summary,
|
||||
signal: this.abortController.signal,
|
||||
});
|
||||
|
||||
const summaryTokenCount = this.getTokenCountForMessage(summaryMessage);
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('summaryMessage:', summaryMessage);
|
||||
console.debug(
|
||||
`remainingContextTokens: ${remainingContextTokens}, after refining: ${
|
||||
remainingContextTokens - summaryTokenCount
|
||||
}`,
|
||||
);
|
||||
}
|
||||
|
||||
return { summaryMessage, summaryTokenCount };
|
||||
} catch (e) {
|
||||
if (e?.message?.toLowerCase()?.includes('abort')) {
|
||||
this.options.debug && console.debug('Aborted summarization');
|
||||
const { run, runId } = this.runManager.getRunByConversationId(this.conversationId);
|
||||
if (run && run.error) {
|
||||
const { error } = run;
|
||||
this.runManager.removeRun(runId);
|
||||
throw new Error(error);
|
||||
}
|
||||
}
|
||||
console.error('Error summarizing messages');
|
||||
this.options.debug && console.error(e);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
if (this.options.debug) {
|
||||
console.debug('promptTokens', promptTokens);
|
||||
console.debug('completionTokens', completionTokens);
|
||||
}
|
||||
await spendTokens(
|
||||
{
|
||||
user: this.user,
|
||||
model: this.modelOptions.model,
|
||||
context: 'message',
|
||||
conversationId: this.conversationId,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
}
|
||||
|
||||
getTokenCountForResponse(response) {
|
||||
return this.getTokenCountForMessage({
|
||||
role: 'assistant',
|
||||
content: response.text,
|
||||
});
|
||||
}
|
||||
|
||||
async chatCompletion({ payload, onProgress, clientOptions, abortController = null }) {
|
||||
let error = null;
|
||||
const errorCallback = (err) => (error = err);
|
||||
let intermediateReply = '';
|
||||
try {
|
||||
if (!abortController) {
|
||||
abortController = new AbortController();
|
||||
}
|
||||
const modelOptions = { ...this.modelOptions };
|
||||
if (typeof onProgress === 'function') {
|
||||
modelOptions.stream = true;
|
||||
}
|
||||
if (this.isChatCompletion) {
|
||||
modelOptions.messages = payload;
|
||||
} else {
|
||||
// TODO: unreachable code. Need to implement completions call for non-chat models
|
||||
modelOptions.prompt = payload;
|
||||
}
|
||||
|
||||
const { debug } = this.options;
|
||||
const url = extractBaseURL(this.completionsUrl);
|
||||
if (debug) {
|
||||
console.debug('baseURL', url);
|
||||
console.debug('modelOptions', modelOptions);
|
||||
}
|
||||
const opts = {
|
||||
baseURL: url,
|
||||
};
|
||||
|
||||
if (this.useOpenRouter) {
|
||||
opts.defaultHeaders = {
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
};
|
||||
}
|
||||
|
||||
if (this.options.headers) {
|
||||
opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers };
|
||||
}
|
||||
|
||||
if (this.options.proxy) {
|
||||
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
let chatCompletion;
|
||||
const openai = new OpenAI({
|
||||
apiKey: this.apiKey,
|
||||
...opts,
|
||||
});
|
||||
|
||||
if (modelOptions.stream) {
|
||||
const stream = await openai.beta.chat.completions
|
||||
.stream({
|
||||
...modelOptions,
|
||||
stream: true,
|
||||
})
|
||||
.on('abort', () => {
|
||||
/* Do nothing here */
|
||||
})
|
||||
.on('error', (err) => {
|
||||
handleOpenAIErrors(err, errorCallback, 'stream');
|
||||
});
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const token = chunk.choices[0]?.delta?.content || '';
|
||||
intermediateReply += token;
|
||||
onProgress(token);
|
||||
if (abortController.signal.aborted) {
|
||||
stream.controller.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
chatCompletion = await stream.finalChatCompletion().catch((err) => {
|
||||
handleOpenAIErrors(err, errorCallback, 'finalChatCompletion');
|
||||
});
|
||||
}
|
||||
// regular completion
|
||||
else {
|
||||
chatCompletion = await openai.chat.completions
|
||||
.create({
|
||||
...modelOptions,
|
||||
})
|
||||
.catch((err) => {
|
||||
handleOpenAIErrors(err, errorCallback, 'create');
|
||||
});
|
||||
}
|
||||
|
||||
if (!chatCompletion && error) {
|
||||
throw new Error(error);
|
||||
} else if (!chatCompletion) {
|
||||
throw new Error('Chat completion failed');
|
||||
}
|
||||
|
||||
const { message, finish_reason } = chatCompletion.choices[0];
|
||||
if (chatCompletion && typeof clientOptions.addMetadata === 'function') {
|
||||
clientOptions.addMetadata({ finish_reason });
|
||||
}
|
||||
|
||||
return message.content;
|
||||
} catch (err) {
|
||||
if (
|
||||
err?.message?.includes('abort') ||
|
||||
(err instanceof OpenAI.APIError && err?.message?.includes('abort'))
|
||||
) {
|
||||
return '';
|
||||
}
|
||||
if (
|
||||
err?.message?.includes('missing finish_reason') ||
|
||||
(err instanceof OpenAI.OpenAIError && err?.message?.includes('missing finish_reason'))
|
||||
) {
|
||||
await abortController.abortCompletion();
|
||||
return intermediateReply;
|
||||
} else if (err instanceof OpenAI.APIError) {
|
||||
console.log(err.name);
|
||||
console.log(err.status);
|
||||
console.log(err.headers);
|
||||
if (intermediateReply) {
|
||||
return intermediateReply;
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
} else {
|
||||
console.warn('[OpenAIClient.chatCompletion] Unhandled error type');
|
||||
console.error(err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = OpenAIClient;
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
const OpenAIClient = require('./OpenAIClient');
|
||||
const { CallbackManager } = require('langchain/callbacks');
|
||||
const { HumanChatMessage, AIChatMessage } = require('langchain/schema');
|
||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
||||
const checkBalance = require('../../models/checkBalance');
|
||||
const { formatLangChainMessages } = require('./prompts');
|
||||
const { isEnabled } = require('../../server/utils');
|
||||
const { extractBaseURL } = require('../../utils');
|
||||
const { SelfReflectionTool } = require('./tools');
|
||||
const { loadTools } = require('./tools/util');
|
||||
const { createLLM } = require('./llm');
|
||||
|
||||
class PluginsClient extends OpenAIClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
@@ -32,7 +35,11 @@ class PluginsClient extends OpenAIClient {
|
||||
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
this.langchainProxy = this.options.reverseProxyUrl.match(/.*v1/)[0];
|
||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
||||
!this.langchainProxy &&
|
||||
console.warn(`The reverse proxy URL ${this.options.reverseProxyUrl} is not valid for Plugins.
|
||||
The url must follow OpenAI specs, for example: https://localhost:8080/v1/chat/completions
|
||||
If your reverse proxy is compatible to OpenAI specs in every other way, it may still work without plugins enabled.`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,9 +57,9 @@ class PluginsClient extends OpenAIClient {
|
||||
}
|
||||
|
||||
getFunctionModelName(input) {
|
||||
if (input.startsWith('gpt-3.5-turbo')) {
|
||||
if (input.includes('gpt-3.5-turbo')) {
|
||||
return 'gpt-3.5-turbo';
|
||||
} else if (input.startsWith('gpt-4')) {
|
||||
} else if (input.includes('gpt-4')) {
|
||||
return 'gpt-4';
|
||||
} else {
|
||||
return 'gpt-3.5-turbo';
|
||||
@@ -73,27 +80,10 @@ class PluginsClient extends OpenAIClient {
|
||||
temperature: this.agentOptions.temperature,
|
||||
};
|
||||
|
||||
const configOptions = {};
|
||||
|
||||
if (this.langchainProxy) {
|
||||
configOptions.basePath = this.langchainProxy;
|
||||
}
|
||||
|
||||
if (this.useOpenRouter) {
|
||||
configOptions.basePath = 'https://openrouter.ai/api/v1';
|
||||
configOptions.baseOptions = {
|
||||
headers: {
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const model = createLLM({
|
||||
modelOptions,
|
||||
configOptions,
|
||||
openAIApiKey: this.openAIApiKey,
|
||||
azure: this.azure,
|
||||
const model = this.initializeLLM({
|
||||
...modelOptions,
|
||||
context: 'plugins',
|
||||
initialMessageCount: this.currentMessages.length + 1,
|
||||
});
|
||||
|
||||
if (this.options.debug) {
|
||||
@@ -102,12 +92,26 @@ class PluginsClient extends OpenAIClient {
|
||||
);
|
||||
}
|
||||
|
||||
// Map Messages to Langchain format
|
||||
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
|
||||
userName: this.options?.name,
|
||||
});
|
||||
this.options.debug && console.debug('pastMessages: ', pastMessages);
|
||||
|
||||
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
|
||||
const memory = new BufferMemory({
|
||||
llm: model,
|
||||
chatHistory: new ChatMessageHistory(pastMessages),
|
||||
});
|
||||
|
||||
this.tools = await loadTools({
|
||||
user,
|
||||
model,
|
||||
tools: this.options.tools,
|
||||
functions: this.functionsAgent,
|
||||
options: {
|
||||
memory,
|
||||
signal: this.abortController.signal,
|
||||
openAIApiKey: this.openAIApiKey,
|
||||
conversationId: this.conversationId,
|
||||
debug: this.options?.debug,
|
||||
@@ -140,15 +144,6 @@ class PluginsClient extends OpenAIClient {
|
||||
}
|
||||
};
|
||||
|
||||
// Map Messages to Langchain format
|
||||
const pastMessages = this.currentMessages
|
||||
.slice(0, -1)
|
||||
.map((msg) =>
|
||||
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
|
||||
? new HumanChatMessage(msg.text)
|
||||
: new AIChatMessage(msg.text),
|
||||
);
|
||||
|
||||
// initialize agent
|
||||
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
|
||||
this.executor = await initializer({
|
||||
@@ -218,16 +213,12 @@ class PluginsClient extends OpenAIClient {
|
||||
break; // Exit the loop if the function call is successful
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
errorMessage = err.message;
|
||||
let content = '';
|
||||
if (content) {
|
||||
errorMessage = content;
|
||||
break;
|
||||
}
|
||||
if (attempts === maxAttempts) {
|
||||
this.result.output = `Encountered an error while attempting to respond. Error: ${err.message}`;
|
||||
const { run } = this.runManager.getRunByConversationId(this.conversationId);
|
||||
const defaultOutput = `Encountered an error while attempting to respond. Error: ${err.message}`;
|
||||
this.result.output = run && run.error ? run.error : defaultOutput;
|
||||
this.result.errorMessage = run && run.error ? run.error : err.message;
|
||||
this.result.intermediateSteps = this.actions;
|
||||
this.result.errorMessage = errorMessage;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -235,11 +226,23 @@ class PluginsClient extends OpenAIClient {
|
||||
}
|
||||
|
||||
async handleResponseMessage(responseMessage, saveOptions, user) {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
responseMessage.completionTokens = responseMessage.tokenCount;
|
||||
const { output, errorMessage, ...result } = this.result;
|
||||
this.options.debug &&
|
||||
console.debug('[handleResponseMessage] Output:', { output, errorMessage, ...result });
|
||||
const { error } = responseMessage;
|
||||
if (!error) {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
|
||||
}
|
||||
|
||||
// Record usage only when completion is skipped as it is already recorded in the agent phase.
|
||||
if (!this.agentOptions.skipCompletion && !error) {
|
||||
await this.recordTokenUsage(responseMessage);
|
||||
}
|
||||
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
delete responseMessage.tokenCount;
|
||||
return { ...responseMessage, ...this.result };
|
||||
return { ...responseMessage, ...result };
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
@@ -249,9 +252,7 @@ class PluginsClient extends OpenAIClient {
|
||||
this.setOptions(opts);
|
||||
return super.sendMessage(message, opts);
|
||||
}
|
||||
if (this.options.debug) {
|
||||
console.log('Plugins sendMessage', message, opts);
|
||||
}
|
||||
this.options.debug && console.log('Plugins sendMessage', message, opts);
|
||||
const {
|
||||
user,
|
||||
isEdited,
|
||||
@@ -265,14 +266,12 @@ class PluginsClient extends OpenAIClient {
|
||||
onToolEnd,
|
||||
} = await this.handleStartMethods(message, opts);
|
||||
|
||||
this.conversationId = conversationId;
|
||||
this.currentMessages.push(userMessage);
|
||||
|
||||
let {
|
||||
prompt: payload,
|
||||
tokenCountMap,
|
||||
promptTokens,
|
||||
messages,
|
||||
} = await this.buildMessages(
|
||||
this.currentMessages,
|
||||
userMessage.messageId,
|
||||
@@ -288,19 +287,29 @@ class PluginsClient extends OpenAIClient {
|
||||
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
|
||||
console.log('userMessage.tokenCount', userMessage.tokenCount);
|
||||
}
|
||||
payload = payload.map((message) => {
|
||||
const messageWithoutTokenCount = message;
|
||||
delete messageWithoutTokenCount.tokenCount;
|
||||
return messageWithoutTokenCount;
|
||||
});
|
||||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
this.result = {};
|
||||
if (messages) {
|
||||
this.currentMessages = messages;
|
||||
if (payload) {
|
||||
this.currentMessages = payload;
|
||||
}
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
await checkBalance({
|
||||
req: this.options.req,
|
||||
res: this.options.res,
|
||||
txData: {
|
||||
user: this.user,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
debug: this.options.debug,
|
||||
model: this.modelOptions.model,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const responseMessage = {
|
||||
messageId: responseMessageId,
|
||||
conversationId,
|
||||
@@ -337,11 +346,19 @@ class PluginsClient extends OpenAIClient {
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
// If error occurred during generation (likely token_balance)
|
||||
if (this.result?.errorMessage?.length > 0) {
|
||||
responseMessage.error = true;
|
||||
responseMessage.text = this.result.output;
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
|
||||
const partialText = opts.getPartialText();
|
||||
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
|
||||
responseMessage.text =
|
||||
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
|
||||
addImages(this.result.intermediateSteps, responseMessage);
|
||||
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
|
||||
return await this.handleResponseMessage(responseMessage, saveOptions, user);
|
||||
}
|
||||
@@ -431,7 +448,9 @@ class PluginsClient extends OpenAIClient {
|
||||
const message = orderedMessages.pop();
|
||||
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
|
||||
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
|
||||
let messageString = `${this.startToken}${roleLabel}:\n${message.text}${this.endToken}\n`;
|
||||
let messageString = `${this.startToken}${roleLabel}:\n${
|
||||
message.text ?? message.content ?? ''
|
||||
}${this.endToken}\n`;
|
||||
let newPromptBody = `${messageString}${promptBody}`;
|
||||
|
||||
const tokenCountForMessage = this.getTokenCount(messageString);
|
||||
|
||||
@@ -16,11 +16,11 @@ class CustomAgent extends ZeroShotAgent {
|
||||
const inputVariables = ['input', 'chat_history', 'agent_scratchpad'];
|
||||
|
||||
let prefix, instructions, suffix;
|
||||
if (model.startsWith('gpt-3')) {
|
||||
if (model.includes('gpt-3')) {
|
||||
prefix = gpt3.prefix;
|
||||
instructions = gpt3.instructions;
|
||||
suffix = gpt3.suffix;
|
||||
} else if (model.startsWith('gpt-4')) {
|
||||
} else if (model.includes('gpt-4')) {
|
||||
prefix = gpt4.prefix;
|
||||
instructions = gpt4.instructions;
|
||||
suffix = gpt4.suffix;
|
||||
|
||||
@@ -18,7 +18,7 @@ const initializeCustomAgent = async ({
|
||||
}) => {
|
||||
let prompt = CustomAgent.createPrompt(tools, { currentDateString, model: model.modelName });
|
||||
|
||||
const chatPrompt = ChatPromptTemplate.fromPromptMessages([
|
||||
const chatPrompt = ChatPromptTemplate.fromMessages([
|
||||
new SystemMessagePromptTemplate(prompt),
|
||||
HumanMessagePromptTemplate.fromTemplate(`{chat_history}
|
||||
Query: {input}
|
||||
@@ -28,6 +28,7 @@ Query: {input}
|
||||
const outputParser = new CustomOutputParser({ tools });
|
||||
|
||||
const memory = new BufferMemory({
|
||||
llm: model,
|
||||
chatHistory: new ChatMessageHistory(pastMessages),
|
||||
// returnMessages: true, // commenting this out retains memory
|
||||
memoryKey: 'chat_history',
|
||||
|
||||
@@ -49,7 +49,7 @@ class FunctionsAgent extends Agent {
|
||||
static createPrompt(_tools, fields) {
|
||||
const { prefix = PREFIX, currentDateString } = fields || {};
|
||||
|
||||
return ChatPromptTemplate.fromPromptMessages([
|
||||
return ChatPromptTemplate.fromMessages([
|
||||
SystemMessagePromptTemplate.fromTemplate(`Date: ${currentDateString}\n${prefix}`),
|
||||
new MessagesPlaceholder('chat_history'),
|
||||
HumanMessagePromptTemplate.fromTemplate('Query: {input}'),
|
||||
|
||||
@@ -14,6 +14,7 @@ const initializeFunctionsAgent = async ({
|
||||
...rest
|
||||
}) => {
|
||||
const memory = new BufferMemory({
|
||||
llm: model,
|
||||
chatHistory: new ChatMessageHistory(pastMessages),
|
||||
memoryKey: 'chat_history',
|
||||
humanPrefix: 'User',
|
||||
|
||||
84
api/app/clients/callbacks/createStartHandler.js
Normal file
84
api/app/clients/callbacks/createStartHandler.js
Normal file
@@ -0,0 +1,84 @@
|
||||
const { promptTokensEstimate } = require('openai-chat-tokens');
|
||||
const checkBalance = require('../../../models/checkBalance');
|
||||
const { isEnabled } = require('../../../server/utils');
|
||||
const { formatFromLangChain } = require('../prompts');
|
||||
|
||||
const createStartHandler = ({
|
||||
context,
|
||||
conversationId,
|
||||
tokenBuffer = 0,
|
||||
initialMessageCount,
|
||||
manager,
|
||||
}) => {
|
||||
return async (_llm, _messages, runId, parentRunId, extraParams) => {
|
||||
const { invocation_params } = extraParams;
|
||||
const { model, functions, function_call } = invocation_params;
|
||||
const messages = _messages[0].map(formatFromLangChain);
|
||||
|
||||
if (manager.debug) {
|
||||
console.log(`handleChatModelStart: ${context}`);
|
||||
console.dir({ model, functions, function_call }, { depth: null });
|
||||
}
|
||||
|
||||
const payload = { messages };
|
||||
let prelimPromptTokens = 1;
|
||||
|
||||
if (functions) {
|
||||
payload.functions = functions;
|
||||
prelimPromptTokens += 2;
|
||||
}
|
||||
|
||||
if (function_call) {
|
||||
payload.function_call = function_call;
|
||||
prelimPromptTokens -= 5;
|
||||
}
|
||||
|
||||
prelimPromptTokens += promptTokensEstimate(payload);
|
||||
if (manager.debug) {
|
||||
console.log('Prelim Prompt Tokens & Token Buffer', prelimPromptTokens, tokenBuffer);
|
||||
}
|
||||
prelimPromptTokens += tokenBuffer;
|
||||
|
||||
try {
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
const generations =
|
||||
initialMessageCount && messages.length > initialMessageCount
|
||||
? messages.slice(initialMessageCount)
|
||||
: null;
|
||||
await checkBalance({
|
||||
req: manager.req,
|
||||
res: manager.res,
|
||||
txData: {
|
||||
user: manager.user,
|
||||
tokenType: 'prompt',
|
||||
amount: prelimPromptTokens,
|
||||
debug: manager.debug,
|
||||
generations,
|
||||
model,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`[${context}] checkBalance error`, err);
|
||||
manager.abortController.abort();
|
||||
if (context === 'summary' || context === 'plugins') {
|
||||
manager.addRun(runId, { conversationId, error: err.message });
|
||||
throw new Error(err);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
manager.addRun(runId, {
|
||||
model,
|
||||
messages,
|
||||
functions,
|
||||
function_call,
|
||||
runId,
|
||||
parentRunId,
|
||||
conversationId,
|
||||
prelimPromptTokens,
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = createStartHandler;
|
||||
5
api/app/clients/callbacks/index.js
Normal file
5
api/app/clients/callbacks/index.js
Normal file
@@ -0,0 +1,5 @@
|
||||
const createStartHandler = require('./createStartHandler');
|
||||
|
||||
module.exports = {
|
||||
createStartHandler,
|
||||
};
|
||||
@@ -1,5 +1,7 @@
|
||||
const runTitleChain = require('./runTitleChain');
|
||||
const predictNewSummary = require('./predictNewSummary');
|
||||
|
||||
module.exports = {
|
||||
runTitleChain,
|
||||
predictNewSummary,
|
||||
};
|
||||
|
||||
25
api/app/clients/chains/predictNewSummary.js
Normal file
25
api/app/clients/chains/predictNewSummary.js
Normal file
@@ -0,0 +1,25 @@
|
||||
const { LLMChain } = require('langchain/chains');
|
||||
const { getBufferString } = require('langchain/memory');
|
||||
|
||||
/**
|
||||
* Predicts a new summary for the conversation given the existing messages
|
||||
* and summary.
|
||||
* @param {Object} options - The prediction options.
|
||||
* @param {Array<string>} options.messages - Existing messages in the conversation.
|
||||
* @param {string} options.previous_summary - Current summary of the conversation.
|
||||
* @param {Object} options.memory - Memory Class.
|
||||
* @param {string} options.signal - Signal for the prediction.
|
||||
* @returns {Promise<string>} A promise that resolves to a new summary string.
|
||||
*/
|
||||
async function predictNewSummary({ messages, previous_summary, memory, signal }) {
|
||||
const newLines = getBufferString(messages, memory.humanPrefix, memory.aiPrefix);
|
||||
const chain = new LLMChain({ llm: memory.llm, prompt: memory.prompt });
|
||||
const result = await chain.call({
|
||||
summary: previous_summary,
|
||||
new_lines: newLines,
|
||||
signal,
|
||||
});
|
||||
return result.text;
|
||||
}
|
||||
|
||||
module.exports = predictNewSummary;
|
||||
@@ -1,32 +1,31 @@
|
||||
const { z } = require('zod');
|
||||
const { langPrompt, createTitlePrompt } = require('../prompts');
|
||||
const { escapeBraces, getSnippet } = require('../output_parsers');
|
||||
const { langPrompt, createTitlePrompt, escapeBraces, getSnippet } = require('../prompts');
|
||||
const { createStructuredOutputChainFromZod } = require('langchain/chains/openai_functions');
|
||||
|
||||
const langSchema = z.object({
|
||||
language: z.string().describe('The language of the input text (full noun, no abbreviations).'),
|
||||
});
|
||||
|
||||
const createLanguageChain = ({ llm }) =>
|
||||
const createLanguageChain = (config) =>
|
||||
createStructuredOutputChainFromZod(langSchema, {
|
||||
prompt: langPrompt,
|
||||
llm,
|
||||
...config,
|
||||
// verbose: true,
|
||||
});
|
||||
|
||||
const titleSchema = z.object({
|
||||
title: z.string().describe('The title-cased title of the conversation in the given language.'),
|
||||
title: z.string().describe('The conversation title in title-case, in the given language.'),
|
||||
});
|
||||
const createTitleChain = ({ llm, convo }) => {
|
||||
const createTitleChain = ({ convo, ...config }) => {
|
||||
const titlePrompt = createTitlePrompt({ convo });
|
||||
return createStructuredOutputChainFromZod(titleSchema, {
|
||||
prompt: titlePrompt,
|
||||
llm,
|
||||
...config,
|
||||
// verbose: true,
|
||||
});
|
||||
};
|
||||
|
||||
const runTitleChain = async ({ llm, text, convo }) => {
|
||||
const runTitleChain = async ({ llm, text, convo, signal, callbacks }) => {
|
||||
let snippet = text;
|
||||
try {
|
||||
snippet = getSnippet(text);
|
||||
@@ -34,10 +33,10 @@ const runTitleChain = async ({ llm, text, convo }) => {
|
||||
console.log('Error getting snippet of text for titleChain');
|
||||
console.log(e);
|
||||
}
|
||||
const languageChain = createLanguageChain({ llm });
|
||||
const titleChain = createTitleChain({ llm, convo: escapeBraces(convo) });
|
||||
const { language } = await languageChain.run(snippet);
|
||||
return (await titleChain.run(language)).title;
|
||||
const languageChain = createLanguageChain({ llm, callbacks });
|
||||
const titleChain = createTitleChain({ llm, callbacks, convo: escapeBraces(convo) });
|
||||
const { language } = (await languageChain.call({ inputText: snippet, signal })).output;
|
||||
return (await titleChain.call({ language, signal })).output.title;
|
||||
};
|
||||
|
||||
module.exports = runTitleChain;
|
||||
|
||||
5
api/app/clients/document/index.js
Normal file
5
api/app/clients/document/index.js
Normal file
@@ -0,0 +1,5 @@
|
||||
const tokenSplit = require('./tokenSplit');
|
||||
|
||||
module.exports = {
|
||||
tokenSplit,
|
||||
};
|
||||
51
api/app/clients/document/tokenSplit.js
Normal file
51
api/app/clients/document/tokenSplit.js
Normal file
@@ -0,0 +1,51 @@
|
||||
const { TokenTextSplitter } = require('langchain/text_splitter');
|
||||
|
||||
/**
|
||||
* Splits a given text by token chunks, based on the provided parameters for the TokenTextSplitter.
|
||||
* Note: limit or memoize use of this function as its calculation is expensive.
|
||||
*
|
||||
* @param {Object} obj - Configuration object for the text splitting operation.
|
||||
* @param {string} obj.text - The text to be split.
|
||||
* @param {string} [obj.encodingName='cl100k_base'] - Encoding name. Defaults to 'cl100k_base'.
|
||||
* @param {number} [obj.chunkSize=1] - The token size of each chunk. Defaults to 1.
|
||||
* @param {number} [obj.chunkOverlap=0] - The number of chunk elements to be overlapped between adjacent chunks. Defaults to 0.
|
||||
* @param {number} [obj.returnSize] - If specified and not 0, slices the return array from the end by this amount.
|
||||
*
|
||||
* @returns {Promise<Array>} Returns a promise that resolves to an array of text chunks.
|
||||
* If no text is provided, an empty array is returned.
|
||||
* If returnSize is specified and not 0, slices the return array from the end by returnSize.
|
||||
*
|
||||
* @async
|
||||
* @function tokenSplit
|
||||
*/
|
||||
async function tokenSplit({
|
||||
text,
|
||||
encodingName = 'cl100k_base',
|
||||
chunkSize = 1,
|
||||
chunkOverlap = 0,
|
||||
returnSize,
|
||||
}) {
|
||||
if (!text) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const splitter = new TokenTextSplitter({
|
||||
encodingName,
|
||||
chunkSize,
|
||||
chunkOverlap,
|
||||
});
|
||||
|
||||
if (!returnSize) {
|
||||
return await splitter.splitText(text);
|
||||
}
|
||||
|
||||
const splitText = await splitter.splitText(text);
|
||||
|
||||
if (returnSize && returnSize > 0 && splitText.length > 0) {
|
||||
return splitText.slice(-Math.abs(returnSize));
|
||||
}
|
||||
|
||||
return splitText;
|
||||
}
|
||||
|
||||
module.exports = tokenSplit;
|
||||
56
api/app/clients/document/tokenSplit.spec.js
Normal file
56
api/app/clients/document/tokenSplit.spec.js
Normal file
@@ -0,0 +1,56 @@
|
||||
const tokenSplit = require('./tokenSplit');
|
||||
|
||||
describe('tokenSplit', () => {
|
||||
const text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam id.';
|
||||
|
||||
it('returns correct text chunks with provided parameters', async () => {
|
||||
const result = await tokenSplit({
|
||||
text: text,
|
||||
encodingName: 'gpt2',
|
||||
chunkSize: 2,
|
||||
chunkOverlap: 1,
|
||||
returnSize: 5,
|
||||
});
|
||||
|
||||
expect(result).toEqual(['. Null', ' Nullam', 'am id', ' id.', '.']);
|
||||
});
|
||||
|
||||
it('returns correct text chunks with default parameters', async () => {
|
||||
const result = await tokenSplit({ text });
|
||||
expect(result).toEqual([
|
||||
'Lorem',
|
||||
' ipsum',
|
||||
' dolor',
|
||||
' sit',
|
||||
' amet',
|
||||
',',
|
||||
' consectetur',
|
||||
' adipiscing',
|
||||
' elit',
|
||||
'.',
|
||||
' Null',
|
||||
'am',
|
||||
' id',
|
||||
'.',
|
||||
]);
|
||||
});
|
||||
|
||||
it('returns correct text chunks with specific return size', async () => {
|
||||
const result = await tokenSplit({ text, returnSize: 2 });
|
||||
expect(result.length).toEqual(2);
|
||||
expect(result).toEqual([' id', '.']);
|
||||
});
|
||||
|
||||
it('returns correct text chunks with specified chunk size', async () => {
|
||||
const result = await tokenSplit({ text, chunkSize: 10 });
|
||||
expect(result).toEqual([
|
||||
'Lorem ipsum dolor sit amet, consectetur adipiscing elit.',
|
||||
' Nullam id.',
|
||||
]);
|
||||
});
|
||||
|
||||
it('returns empty array with no text', async () => {
|
||||
const result = await tokenSplit({ text: '' });
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
96
api/app/clients/llm/RunManager.js
Normal file
96
api/app/clients/llm/RunManager.js
Normal file
@@ -0,0 +1,96 @@
|
||||
const { createStartHandler } = require('../callbacks');
|
||||
const spendTokens = require('../../../models/spendTokens');
|
||||
|
||||
class RunManager {
|
||||
constructor(fields) {
|
||||
const { req, res, abortController, debug } = fields;
|
||||
this.abortController = abortController;
|
||||
this.user = req.user.id;
|
||||
this.req = req;
|
||||
this.res = res;
|
||||
this.debug = debug;
|
||||
this.runs = new Map();
|
||||
this.convos = new Map();
|
||||
}
|
||||
|
||||
addRun(runId, runData) {
|
||||
if (!this.runs.has(runId)) {
|
||||
this.runs.set(runId, runData);
|
||||
if (runData.conversationId) {
|
||||
this.convos.set(runData.conversationId, runId);
|
||||
}
|
||||
return runData;
|
||||
} else {
|
||||
const existingData = this.runs.get(runId);
|
||||
const update = { ...existingData, ...runData };
|
||||
this.runs.set(runId, update);
|
||||
if (update.conversationId) {
|
||||
this.convos.set(update.conversationId, runId);
|
||||
}
|
||||
return update;
|
||||
}
|
||||
}
|
||||
|
||||
removeRun(runId) {
|
||||
if (this.runs.has(runId)) {
|
||||
this.runs.delete(runId);
|
||||
} else {
|
||||
console.error(`Run with ID ${runId} does not exist.`);
|
||||
}
|
||||
}
|
||||
|
||||
getAllRuns() {
|
||||
return Array.from(this.runs.values());
|
||||
}
|
||||
|
||||
getRunById(runId) {
|
||||
return this.runs.get(runId);
|
||||
}
|
||||
|
||||
getRunByConversationId(conversationId) {
|
||||
const runId = this.convos.get(conversationId);
|
||||
return { run: this.runs.get(runId), runId };
|
||||
}
|
||||
|
||||
createCallbacks(metadata) {
|
||||
return [
|
||||
{
|
||||
handleChatModelStart: createStartHandler({ ...metadata, manager: this }),
|
||||
handleLLMEnd: async (output, runId, _parentRunId) => {
|
||||
if (this.debug) {
|
||||
console.log(`handleLLMEnd: ${JSON.stringify(metadata)}`);
|
||||
console.dir({ output, runId, _parentRunId }, { depth: null });
|
||||
}
|
||||
const { tokenUsage } = output.llmOutput;
|
||||
const run = this.getRunById(runId);
|
||||
this.removeRun(runId);
|
||||
|
||||
const txData = {
|
||||
user: this.user,
|
||||
model: run?.model ?? 'gpt-3.5-turbo',
|
||||
...metadata,
|
||||
};
|
||||
|
||||
await spendTokens(txData, tokenUsage);
|
||||
},
|
||||
handleLLMError: async (err) => {
|
||||
this.debug && console.log(`handleLLMError: ${JSON.stringify(metadata)}`);
|
||||
this.debug && console.error(err);
|
||||
if (metadata.context === 'title') {
|
||||
return;
|
||||
} else if (metadata.context === 'plugins') {
|
||||
throw new Error(err);
|
||||
}
|
||||
const { conversationId } = metadata;
|
||||
const { run } = this.getRunByConversationId(conversationId);
|
||||
if (run && run.error) {
|
||||
const { error } = run;
|
||||
throw new Error(error);
|
||||
}
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = RunManager;
|
||||
@@ -1,15 +1,88 @@
|
||||
const { ChatOpenAI } = require('langchain/chat_models/openai');
|
||||
const { CallbackManager } = require('langchain/callbacks');
|
||||
const { sanitizeModelName } = require('../../../utils');
|
||||
const { isEnabled } = require('../../../server/utils');
|
||||
|
||||
function createLLM({ modelOptions, configOptions, handlers, openAIApiKey, azure = {} }) {
|
||||
/**
|
||||
* @typedef {Object} ModelOptions
|
||||
* @property {string} modelName - The name of the model.
|
||||
* @property {number} [temperature] - The temperature setting for the model.
|
||||
* @property {number} [presence_penalty] - The presence penalty setting.
|
||||
* @property {number} [frequency_penalty] - The frequency penalty setting.
|
||||
* @property {number} [max_tokens] - The maximum number of tokens to generate.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} ConfigOptions
|
||||
* @property {string} [basePath] - The base path for the API requests.
|
||||
* @property {Object} [baseOptions] - Base options for the API requests, including headers.
|
||||
* @property {Object} [httpAgent] - The HTTP agent for the request.
|
||||
* @property {Object} [httpsAgent] - The HTTPS agent for the request.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} Callbacks
|
||||
* @property {Function} [handleChatModelStart] - A callback function for handleChatModelStart
|
||||
* @property {Function} [handleLLMEnd] - A callback function for handleLLMEnd
|
||||
* @property {Function} [handleLLMError] - A callback function for handleLLMError
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} AzureOptions
|
||||
* @property {string} [azureOpenAIApiKey] - The Azure OpenAI API key.
|
||||
* @property {string} [azureOpenAIApiInstanceName] - The Azure OpenAI API instance name.
|
||||
* @property {string} [azureOpenAIApiDeploymentName] - The Azure OpenAI API deployment name.
|
||||
* @property {string} [azureOpenAIApiVersion] - The Azure OpenAI API version.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a new instance of a language model (LLM) for chat interactions.
|
||||
*
|
||||
* @param {Object} options - The options for creating the LLM.
|
||||
* @param {ModelOptions} options.modelOptions - The options specific to the model, including modelName, temperature, presence_penalty, frequency_penalty, and other model-related settings.
|
||||
* @param {ConfigOptions} options.configOptions - Configuration options for the API requests, including proxy settings and custom headers.
|
||||
* @param {Callbacks} options.callbacks - Callback functions for managing the lifecycle of the LLM, including token buffers, context, and initial message count.
|
||||
* @param {boolean} [options.streaming=false] - Determines if the LLM should operate in streaming mode.
|
||||
* @param {string} options.openAIApiKey - The API key for OpenAI, used for authentication.
|
||||
* @param {AzureOptions} [options.azure={}] - Optional Azure-specific configurations. If provided, Azure configurations take precedence over OpenAI configurations.
|
||||
*
|
||||
* @returns {ChatOpenAI} An instance of the ChatOpenAI class, configured with the provided options.
|
||||
*
|
||||
* @example
|
||||
* const llm = createLLM({
|
||||
* modelOptions: { modelName: 'gpt-3.5-turbo', temperature: 0.2 },
|
||||
* configOptions: { basePath: 'https://example.api/path' },
|
||||
* callbacks: { onMessage: handleMessage },
|
||||
* openAIApiKey: 'your-api-key'
|
||||
* });
|
||||
*/
|
||||
function createLLM({
|
||||
modelOptions,
|
||||
configOptions,
|
||||
callbacks,
|
||||
streaming = false,
|
||||
openAIApiKey,
|
||||
azure = {},
|
||||
}) {
|
||||
let credentials = { openAIApiKey };
|
||||
let configuration = {
|
||||
apiKey: openAIApiKey,
|
||||
};
|
||||
|
||||
let azureOptions = {};
|
||||
if (azure) {
|
||||
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
|
||||
|
||||
credentials = {};
|
||||
configuration = {};
|
||||
azureOptions = azure;
|
||||
|
||||
azureOptions.azureOpenAIApiDeploymentName = useModelName
|
||||
? sanitizeModelName(modelOptions.modelName)
|
||||
: azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
if (azure && process.env.AZURE_OPENAI_DEFAULT_MODEL) {
|
||||
modelOptions.modelName = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
}
|
||||
|
||||
// console.debug('createLLM: configOptions');
|
||||
@@ -17,12 +90,13 @@ function createLLM({ modelOptions, configOptions, handlers, openAIApiKey, azure
|
||||
|
||||
return new ChatOpenAI(
|
||||
{
|
||||
streaming: true,
|
||||
streaming,
|
||||
verbose: true,
|
||||
credentials,
|
||||
configuration,
|
||||
...azure,
|
||||
...azureOptions,
|
||||
...modelOptions,
|
||||
callbackManager: handlers && CallbackManager.fromHandlers(handlers),
|
||||
callbacks,
|
||||
},
|
||||
configOptions,
|
||||
);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const createLLM = require('./createLLM');
|
||||
const RunManager = require('./RunManager');
|
||||
|
||||
module.exports = {
|
||||
createLLM,
|
||||
RunManager,
|
||||
};
|
||||
|
||||
115
api/app/clients/memory/example.js
Normal file
115
api/app/clients/memory/example.js
Normal file
@@ -0,0 +1,115 @@
|
||||
[
|
||||
{ role: 'user', content: 'hi' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: 'Hello! How can I assist you today?',
|
||||
},
|
||||
{ role: 'user', content: 'tell me a long story' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content:
|
||||
'Once upon a time, in a small village nestled between rolling green hills, lived a young boy named Ethan. He was curious, adventurous, and had an insatiable thirst for knowledge. Every day, he would wander through the village, eagerly listening to the stories of the villagers and observing the world around him.\n' +
|
||||
'\n' +
|
||||
'One sunny day, while exploring the outskirts of the village, Ethan stumbled upon an old dusty book hidden amongst the roots of a majestic ancient oak tree. It was a book of fairy tales, filled with vivid descriptions of enchanted forests, magical creatures, and extraordinary adventures. Intrigued, Ethan opened the book and began reading.\n' +
|
||||
'\n' +
|
||||
'As he turned the pages, he was transported into a world filled with wonder and awe. With each new story he read, he grew more drawn to the idea of embarking on his own grand adventure. Inspired by the tales he had discovered, Ethan decided to set off on a journey to seek out the magic and mystery that lay beyond the boundaries of his village.\n' +
|
||||
'\n' +
|
||||
'He bid farewell to his family and set out on a path that led him to a dense and enchanting forest. The forest was said to be home to mythical creatures and hidden treasures. As Ethan ventured deeper into the woods, he could feel an electric energy, as if the trees whispered secrets to him.\n' +
|
||||
'\n' +
|
||||
'Soon, he encountered a mischievous sprite named Sparkle, who had an impish grin and twinkling eyes. Sparkle guided Ethan through the labyrinth of trees, warning him of hidden dangers and sharing stories of ancient beings that dwelled in the heart of the forest.\n' +
|
||||
'\n' +
|
||||
'Together, they stumbled upon a shimmering lake that seemed to glow with an otherworldly light. At the center of the lake, resting atop a small island, was a humble cottage made of petals and leaves. It belonged to an ancient and wise sorceress named Celestia.\n' +
|
||||
'\n' +
|
||||
'Celestia had the power to grant one wish to anyone who dared to find her abode. Ethan, captivated by the tales he had read and yearning for something greater, approached the cottage with trepidation. When he shared his desire to embark on a grand adventure, Celestia smiled warmly and agreed to grant his wish.\n' +
|
||||
'\n' +
|
||||
'With a wave of her wand and a sprinkle of stardust, Celestia bestowed upon Ethan a magical necklace. This necklace, adorned with a rare gemstone called the Eye of Imagination, had the power to turn dreams and imagination into reality. From that moment forward, Ethan\'s every thought and idea became manifest.\n' +
|
||||
'\n' +
|
||||
'Energized by this newfound power, Ethan continued his journey, encountering mythical creatures, solving riddles, and overcoming treacherous obstacles along the way. With the Eye of Imagination, he brought life to ancient statues, unlocked hidden doors, and even tamed fiery dragons.\n' +
|
||||
'\n' +
|
||||
'As days turned into weeks and weeks into months, Ethan became wiser and more in tune with the world around him. He learned that true adventure was not merely about seeking thrills and conquering the unknown, but also about fostering compassion, friendship, and a deep appreciation for the beauty of the ordinary.\n' +
|
||||
'\n' +
|
||||
'Eventually, Ethan\'s journey led him back to his village. With the Eye of Imagination, he transformed the village into a place of wonders and endless possibilities. Fields blossomed into vibrant gardens, simple tools turned into intricate works of art, and the villagers felt a renewed sense of hope and inspiration.\n' +
|
||||
'\n' +
|
||||
'Ethan, now known as the Village Magician, realized that the true magic lied within everyone\'s hearts. He taught the villagers to embrace their creativity, to dream big, and to never underestimate the power of imagination. And so, the village flourished, becoming a beacon of wonder and creativity for all to see.\n' +
|
||||
'\n' +
|
||||
'In the years that followed, Ethan\'s adventures continued, though mostly within the confines of his beloved village. But he never forgot the thrill of that first grand adventure. And every now and then, when looking up at the starry night sky, he would allow his mind to wander, knowing that the greatest adventures were still waiting to be discovered.',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content:
|
||||
'I rewrote your story a bit, writer another one but more like Cormac McCarthy \n' +
|
||||
'\n' +
|
||||
'Once there was a young lad by the name of Ethan, raised in a little hamlet nestled betwixt the verdant knolls, who possessed an irrepressible yearning for knowledge, a thirst unquenchable and a spirit teeming with curiosity. As the golden sun bathed the bucolic land in its effulgent light, he would tread through the village, his ears attuned to the tales spun by the townsfolk, his eyes absorbing the tapestry woven by the world surrounding him.\n' +
|
||||
'\n' +
|
||||
'One radiant day, whilst exploring the periphery of the settlement, Ethan chanced upon a timeworn tome, ensconced amidst the roots of an ancient oak, cloaked in the shroud of neglect. The dust gathered upon it spoke of time\'s relentless march. A book of fairy tales – garnished with vivid descriptions of mystical woods, fantastical beasts, and ventures daring beyond the ordinary humdrum existence. Intrigued and beguiled, Ethan pried open the weathered pages and succumbed to their beckoning whispers.\n' +
|
||||
'\n' +
|
||||
'In each tale, he was transported to a realm of enchantment and wonderment, inexorably tugging at the strings of his yearning for peripatetic exploration. Inspired by the narratives he had devoured, Ethan resolved to bid adieu to kinfolk and embark upon a sojourn, with dreams of procuring a firsthand glimpse into the domain of mystique that lay beyond the village\'s circumscribed boundary.\n' +
|
||||
'\n' +
|
||||
'Thus, he bade tearful farewells, girding himself for a path that guided him to a dense and captivating woodland, whispered of as a sanctuary to mythical beings and clandestine troves of treasures. As Ethan plunged deeper into the heart of the arboreal labyrinth, he felt a palpable surge of electricity, as though the sylvan sentinels whispered enigmatic secrets that only the perceptive ear could discern.\n' +
|
||||
'\n' +
|
||||
'It wasn\'t long before his path intertwined with that of a capricious sprite christened Sparkle, bearing an impish grin and eyes sparkling with mischief. Sparkle played the role of Virgil to Ethan\'s Dante, guiding him through the intricate tapestry of arboreal scions, issuing warnings of perils concealed and spinning tales of ancient entities that called this very bosky enclave home.\n' +
|
||||
'\n' +
|
||||
'Together, they stumbled upon a luminous lake, its shimmering waters imbued with a celestial light. At the center lay a diminutive island, upon which reposed a cottage fashioned from tender petals and verdant leaves. It belonged to an ancient sorceress of considerable wisdom, Celestia by name.\n' +
|
||||
'\n' +
|
||||
'Celestia, with her power to bestow a single wish on any intrepid soul who happened upon her abode, met Ethan\'s desire with a congenial nod, his fervor for a grand expedition not lost on her penetrating gaze. In response, she bequeathed unto him a necklace of magical manufacture – adorned with the rare gemstone known as the Eye of Imagination – whose very essence transformed dreams into vivid reality. From that moment forward, not a single cogitation nor nebulous fanciful notion of Ethan\'s ever lacked physicality.\n' +
|
||||
'\n' +
|
||||
'Energized by this newfound potency, Ethan continued his sojourn, encountering mythical creatures, unraveling cerebral enigmas, and braving perils aplenty along the winding roads of destiny. Armed with the Eye of Imagination, he brought forth life from immobile statuary, unlocked forbidding portals, and even tamed the ferocious beasts of yore – their fiery breath reduced to a whisper.\n' +
|
||||
'\n' +
|
||||
'As the weeks metamorphosed into months, Ethan grew wiser and more attuned to the ebb and flow of the world enveloping him. He gleaned that true adventure isn\'t solely confined to sating a thirst for adrenaline and conquering the unknown; indeed, it resides in fostering compassion, fostering amicable bonds, and cherishing the beauty entwined within the quotidian veld.\n' +
|
||||
'\n' +
|
||||
'Eventually, Ethan\'s quest drew him homeward, back to his village. Buoying the Eye of Imagination\'s ethereal power, he imbued the hitherto unremarkable settlement with the patina of infinite possibilities. The bounteous fields bloomed into kaleidoscopic gardens, simple instruments transmuting into intricate masterpieces, and the villagers themselves clasped within their hearts a renewed ardor, a conflagration of hope and inspiration.\n' +
|
||||
'\n' +
|
||||
'Behold Ethan, at present hailed as the Village Magician – a cognomen befitting his sorcery wielded within the confines of the community he adored. His exploits may have become tethered to a geographical locale, but the flame kindled by that premier sojourn never flickered. Occasionally, beneath the veiled twilight canopy, when peering skyward toward the astral canvas bedecked with countless luminescent diamonds, he\'d allow his mind to wander, cognizant of boundless adventures yet to take flight, forever awaiting his indomitable spirit.',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content:
|
||||
'Beautifully rewritten! It captures the essence of Cormac McCarthy\'s style, with its descriptive language and evocative imagery. The story retains the spirit of adventure and the power of imagination seen in the original. Well done! If you have any more requests or if there\'s anything else I can assist you with, feel free to let me know.',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'thanks but id like you to write one in that style now too',
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content:
|
||||
'Certainly! Here\'s a short story reflecting Cormac McCarthy\'s style:\n' +
|
||||
'\n' +
|
||||
'In a desolate wasteland scarred by the sins of mankind, a lone traveler ventured forth, his name forgotten by time, etched upon the shadows of memory. He navigated the barren plains, where the sun beat down, merciless and unrelenting, scorching the parched earth beneath his weary boots.\n' +
|
||||
'\n' +
|
||||
'Desolation was his only companion, the forsaken landscape a testament to the darkness that had befallen this forsaken realm. The wind whispered secrets long lost, carrying with it the echoes of a civilization long since undone, swallowed by the abyss.\n' +
|
||||
'\n' +
|
||||
'Amongst the shattered remnants of fallen cities, he sought solace, a glimmer of hope amidst the desolation. Rusted remnants of steel and concrete stood as sentinels of mankind\'s hubris, their jagged edges a chilling reminder of the fragility of all that was once deemed invincible.\n' +
|
||||
'\n' +
|
||||
'As twilight descended, casting a pallid hue upon the lifeless land, he happened upon a dilapidated house, its walls crumbling under the weight of time. Drawn to the desolation, he pushed open the creaking door, its rusty hinges screaming in protest.\n' +
|
||||
'\n' +
|
||||
'Inside, a world frozen in time awaited him. Dust particles danced through the air, illuminated by slivers of light that pierced the broken windows. Here, the past lay trapped, a melancholic tableau of lives lost and dreams shattered.\n' +
|
||||
'\n' +
|
||||
'He wandered through the empty rooms, bare walls whispering stories untold. In the corner, a dilapidated piano stood, its ivory keys long since abandoned by a musician\'s touch. He approached it cautiously, a faint glimmer of anticipation igniting in his weather-worn eyes.\n' +
|
||||
'\n' +
|
||||
'Determined, he carefully brushed away the dust and debris that cloaked the instrument. With trembling hands, he pressed a key, and a discordant note filled the air, mingling with the sighs of forgotten melodies. The sound echoed through the silence, bringing with it the remnants of forgotten symphonies.\n' +
|
||||
'\n' +
|
||||
'In that abandoned house, amidst the debris of shattered dreams, he found solace in the music of despair. With each stroke of his fingers, he conjured haunting melodies, melodies that transcended the desolation and spoke of a longing for redemption and rebirth.\n' +
|
||||
'\n' +
|
||||
'The notes carried his sorrow, his hopes, and his regrets, reaching far beyond the confines of that forsaken house. They traversed the desolate plains, resonating through the ravaged land with an eerie grace, enwrapping all who roamed there in a melancholic embrace.\n' +
|
||||
'\n' +
|
||||
'As the last notes faded, swallowed by the unforgiving stillness, he left the house behind, forever etching his presence upon its crumbling walls. The music remained, reverberating through the wasteland, a reminder that even amidst the devastation, beauty could be found, and a glimmer of hope could arise from the deepest darkness.\n' +
|
||||
'\n' +
|
||||
'And so, the lone traveler continued his journey, traversing the desolate expanse, his heart burdened with the weight of the world. In the face of utter desolation, he found solace in the haunting melodies, his music becoming a testament to the resilience of the human spirit, a testament to the undying flame of hope that flickered within us all.\n' +
|
||||
'\n' +
|
||||
'Note: This story is written in the style of Cormac McCarthy and aims to capture some of his writing techniques, including sparse dialogue, descriptive imagery, and a focus on the desolate and bleak aspects of life.',
|
||||
},
|
||||
{ role: 'user', content: 'great job' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content:
|
||||
'Thank you! I\'m glad you enjoyed it. If you have any more writing requests or if there\'s anything else I can assist you with, feel free to let me know.',
|
||||
},
|
||||
{ role: 'user', content: 'you are very helpful' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content:
|
||||
'Thank you for your kind words! I\'m here to assist you in any way I can. If you have any more questions, need further assistance, or just want to chat, feel free to reach out.',
|
||||
},
|
||||
{ role: 'user', content: 'no you man' },
|
||||
];
|
||||
5
api/app/clients/memory/index.js
Normal file
5
api/app/clients/memory/index.js
Normal file
@@ -0,0 +1,5 @@
|
||||
const summaryBuffer = require('./summaryBuffer');
|
||||
|
||||
module.exports = {
|
||||
...summaryBuffer,
|
||||
};
|
||||
31
api/app/clients/memory/summaryBuffer.demo.js
Normal file
31
api/app/clients/memory/summaryBuffer.demo.js
Normal file
@@ -0,0 +1,31 @@
|
||||
require('dotenv').config();
|
||||
const { ChatOpenAI } = require('langchain/chat_models/openai');
|
||||
const { getBufferString, ConversationSummaryBufferMemory } = require('langchain/memory');
|
||||
|
||||
const chatPromptMemory = new ConversationSummaryBufferMemory({
|
||||
llm: new ChatOpenAI({ modelName: 'gpt-3.5-turbo', temperature: 0 }),
|
||||
maxTokenLimit: 10,
|
||||
returnMessages: true,
|
||||
});
|
||||
|
||||
(async () => {
|
||||
await chatPromptMemory.saveContext({ input: 'hi my name\'s Danny' }, { output: 'whats up' });
|
||||
await chatPromptMemory.saveContext({ input: 'not much you' }, { output: 'not much' });
|
||||
await chatPromptMemory.saveContext(
|
||||
{ input: 'are you excited for the olympics?' },
|
||||
{ output: 'not really' },
|
||||
);
|
||||
|
||||
// We can also utilize the predict_new_summary method directly.
|
||||
const messages = await chatPromptMemory.chatHistory.getMessages();
|
||||
console.log('MESSAGES\n\n');
|
||||
console.log(JSON.stringify(messages));
|
||||
const previous_summary = '';
|
||||
const predictSummary = await chatPromptMemory.predictNewSummary(messages, previous_summary);
|
||||
console.log('SUMMARY\n\n');
|
||||
console.log(JSON.stringify(getBufferString([{ role: 'system', content: predictSummary }])));
|
||||
|
||||
// const { history } = await chatPromptMemory.loadMemoryVariables({});
|
||||
// console.log('HISTORY\n\n');
|
||||
// console.log(JSON.stringify(history));
|
||||
})();
|
||||
68
api/app/clients/memory/summaryBuffer.js
Normal file
68
api/app/clients/memory/summaryBuffer.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const { ConversationSummaryBufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
const { formatLangChainMessages, SUMMARY_PROMPT } = require('../prompts');
|
||||
const { predictNewSummary } = require('../chains');
|
||||
|
||||
const createSummaryBufferMemory = ({ llm, prompt, messages, ...rest }) => {
|
||||
const chatHistory = new ChatMessageHistory(messages);
|
||||
return new ConversationSummaryBufferMemory({
|
||||
llm,
|
||||
prompt,
|
||||
chatHistory,
|
||||
returnMessages: true,
|
||||
...rest,
|
||||
});
|
||||
};
|
||||
|
||||
const summaryBuffer = async ({
|
||||
llm,
|
||||
debug,
|
||||
context, // array of messages
|
||||
formatOptions = {},
|
||||
previous_summary = '',
|
||||
prompt = SUMMARY_PROMPT,
|
||||
signal,
|
||||
}) => {
|
||||
if (debug && previous_summary) {
|
||||
console.log('<-----------PREVIOUS SUMMARY----------->\n\n');
|
||||
console.log(previous_summary);
|
||||
}
|
||||
|
||||
const formattedMessages = formatLangChainMessages(context, formatOptions);
|
||||
const memoryOptions = {
|
||||
llm,
|
||||
prompt,
|
||||
messages: formattedMessages,
|
||||
};
|
||||
|
||||
if (formatOptions.userName) {
|
||||
memoryOptions.humanPrefix = formatOptions.userName;
|
||||
}
|
||||
if (formatOptions.userName) {
|
||||
memoryOptions.aiPrefix = formatOptions.assistantName;
|
||||
}
|
||||
|
||||
const chatPromptMemory = createSummaryBufferMemory(memoryOptions);
|
||||
|
||||
const messages = await chatPromptMemory.chatHistory.getMessages();
|
||||
|
||||
if (debug) {
|
||||
console.log('<-----------SUMMARY BUFFER MESSAGES----------->\n\n');
|
||||
console.log(JSON.stringify(messages));
|
||||
}
|
||||
|
||||
const predictSummary = await predictNewSummary({
|
||||
messages,
|
||||
previous_summary,
|
||||
memory: chatPromptMemory,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (debug) {
|
||||
console.log('<-----------SUMMARY----------->\n\n');
|
||||
console.log(JSON.stringify(predictSummary));
|
||||
}
|
||||
|
||||
return { role: 'system', content: predictSummary };
|
||||
};
|
||||
|
||||
module.exports = { createSummaryBufferMemory, summaryBuffer };
|
||||
@@ -1,23 +1,68 @@
|
||||
/**
|
||||
* The `addImages` function corrects any erroneous image URLs in the `responseMessage.text`
|
||||
* and appends image observations from `intermediateSteps` if they are not already present.
|
||||
*
|
||||
* @function
|
||||
* @module addImages
|
||||
*
|
||||
* @param {Array.<Object>} intermediateSteps - An array of objects, each containing an observation.
|
||||
* @param {Object} responseMessage - An object containing the text property which might have image URLs.
|
||||
*
|
||||
* @property {string} intermediateSteps[].observation - The observation string which might contain an image markdown.
|
||||
* @property {string} responseMessage.text - The text which might contain image URLs.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
* const intermediateSteps = [
|
||||
* { observation: '' }
|
||||
* ];
|
||||
* const responseMessage = { text: 'Some text with ' };
|
||||
*
|
||||
* addImages(intermediateSteps, responseMessage);
|
||||
*
|
||||
* console.log(responseMessage.text);
|
||||
* // Outputs: 'Some text with \n'
|
||||
*
|
||||
* @returns {void}
|
||||
*/
|
||||
function addImages(intermediateSteps, responseMessage) {
|
||||
if (!intermediateSteps || !responseMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Correct any erroneous URLs in the responseMessage.text first
|
||||
intermediateSteps.forEach((step) => {
|
||||
const { observation } = step;
|
||||
if (!observation || !observation.includes('![')) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract the image file path from the observation
|
||||
const observedImagePath = observation.match(/\(\/images\/.*\.\w*\)/g)[0];
|
||||
const match = observation.match(/\/images\/.*\.\w*/);
|
||||
if (!match) {
|
||||
return;
|
||||
}
|
||||
const essentialImagePath = match[0];
|
||||
|
||||
// Check if the responseMessage already includes the image file path
|
||||
if (!responseMessage.text.includes(observedImagePath)) {
|
||||
// If the image file path is not found, append the whole observation
|
||||
const regex = /!\[.*?\]\((.*?)\)/g;
|
||||
let matchErroneous;
|
||||
while ((matchErroneous = regex.exec(responseMessage.text)) !== null) {
|
||||
if (matchErroneous[1] && !matchErroneous[1].startsWith('/images/')) {
|
||||
responseMessage.text = responseMessage.text.replace(matchErroneous[1], essentialImagePath);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Now, check if the responseMessage already includes the correct image file path and append if not
|
||||
intermediateSteps.forEach((step) => {
|
||||
const { observation } = step;
|
||||
if (!observation || !observation.includes('![')) {
|
||||
return;
|
||||
}
|
||||
const observedImagePath = observation.match(/\(\/images\/.*\.\w*\)/g);
|
||||
if (observedImagePath && !responseMessage.text.includes(observedImagePath[0])) {
|
||||
responseMessage.text += '\n' + observation;
|
||||
if (this.options.debug) {
|
||||
console.debug('added image from intermediateSteps');
|
||||
if (process.env.DEBUG_PLUGINS) {
|
||||
console.debug('[addImages] added image from intermediateSteps');
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
84
api/app/clients/output_parsers/addImages.spec.js
Normal file
84
api/app/clients/output_parsers/addImages.spec.js
Normal file
@@ -0,0 +1,84 @@
|
||||
let addImages = require('./addImages');
|
||||
|
||||
describe('addImages', () => {
|
||||
let intermediateSteps;
|
||||
let responseMessage;
|
||||
let options;
|
||||
|
||||
beforeEach(() => {
|
||||
intermediateSteps = [];
|
||||
responseMessage = { text: '' };
|
||||
options = { debug: false };
|
||||
this.options = options;
|
||||
addImages = addImages.bind(this);
|
||||
});
|
||||
|
||||
it('should handle null or undefined parameters', () => {
|
||||
addImages(null, responseMessage);
|
||||
expect(responseMessage.text).toBe('');
|
||||
|
||||
addImages(intermediateSteps, null);
|
||||
expect(responseMessage.text).toBe('');
|
||||
|
||||
addImages(null, null);
|
||||
expect(responseMessage.text).toBe('');
|
||||
});
|
||||
|
||||
it('should append correct image markdown if not present in responseMessage', () => {
|
||||
intermediateSteps.push({ observation: '' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe('\n');
|
||||
});
|
||||
|
||||
it('should not append image markdown if already present in responseMessage', () => {
|
||||
responseMessage.text = '';
|
||||
intermediateSteps.push({ observation: '' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe('');
|
||||
});
|
||||
|
||||
it('should correct and append image markdown with erroneous URL', () => {
|
||||
responseMessage.text = '';
|
||||
intermediateSteps.push({ observation: '' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe('');
|
||||
});
|
||||
|
||||
it('should correct multiple erroneous URLs in responseMessage', () => {
|
||||
responseMessage.text =
|
||||
' ';
|
||||
intermediateSteps.push({ observation: '' });
|
||||
intermediateSteps.push({ observation: '' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe(' ');
|
||||
});
|
||||
|
||||
it('should not append non-image markdown observations', () => {
|
||||
intermediateSteps.push({ observation: '[desc](/images/test.png)' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe('');
|
||||
});
|
||||
|
||||
it('should handle multiple observations', () => {
|
||||
intermediateSteps.push({ observation: '' });
|
||||
intermediateSteps.push({ observation: '' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe('\n\n');
|
||||
});
|
||||
|
||||
it('should not append if observation does not contain image markdown', () => {
|
||||
intermediateSteps.push({ observation: 'This is a test observation without image markdown.' });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe('');
|
||||
});
|
||||
|
||||
it('should append correctly from a real scenario', () => {
|
||||
responseMessage.text =
|
||||
'Here is the generated image based on your request. It depicts a surreal landscape filled with floating musical notes. The style is impressionistic, with vibrant sunset hues dominating the scene. At the center, there\'s a silhouette of a grand piano, adding a dreamy emotion to the overall image. This could serve as a unique and creative music album cover. Would you like to make any changes or generate another image?';
|
||||
const originalText = responseMessage.text;
|
||||
const imageMarkdown = '';
|
||||
intermediateSteps.push({ observation: imageMarkdown });
|
||||
addImages(intermediateSteps, responseMessage);
|
||||
expect(responseMessage.text).toBe(`${originalText}\n${imageMarkdown}`);
|
||||
});
|
||||
});
|
||||
@@ -1,9 +1,7 @@
|
||||
const addImages = require('./addImages');
|
||||
const handleInputs = require('./handleInputs');
|
||||
const handleOutputs = require('./handleOutputs');
|
||||
|
||||
module.exports = {
|
||||
addImages,
|
||||
...handleInputs,
|
||||
...handleOutputs,
|
||||
};
|
||||
|
||||
100
api/app/clients/prompts/formatMessages.js
Normal file
100
api/app/clients/prompts/formatMessages.js
Normal file
@@ -0,0 +1,100 @@
|
||||
const { HumanMessage, AIMessage, SystemMessage } = require('langchain/schema');
|
||||
|
||||
/**
|
||||
* Formats a message to OpenAI payload format based on the provided options.
|
||||
*
|
||||
* @param {Object} params - The parameters for formatting.
|
||||
* @param {Object} params.message - The message object to format.
|
||||
* @param {string} [params.message.role] - The role of the message sender (e.g., 'user', 'assistant').
|
||||
* @param {string} [params.message._name] - The name associated with the message.
|
||||
* @param {string} [params.message.sender] - The sender of the message.
|
||||
* @param {string} [params.message.text] - The text content of the message.
|
||||
* @param {string} [params.message.content] - The content of the message.
|
||||
* @param {string} [params.userName] - The name of the user.
|
||||
* @param {string} [params.assistantName] - The name of the assistant.
|
||||
* @param {boolean} [params.langChain=false] - Whether to return a LangChain message object.
|
||||
* @returns {(Object|HumanMessage|AIMessage|SystemMessage)} - The formatted message.
|
||||
*/
|
||||
const formatMessage = ({ message, userName, assistantName, langChain = false }) => {
|
||||
let { role: _role, _name, sender, text, content: _content, lc_id } = message;
|
||||
if (lc_id && lc_id[2] && !langChain) {
|
||||
const roleMapping = {
|
||||
SystemMessage: 'system',
|
||||
HumanMessage: 'user',
|
||||
AIMessage: 'assistant',
|
||||
};
|
||||
_role = roleMapping[lc_id[2]];
|
||||
}
|
||||
const role = _role ?? (sender && sender?.toLowerCase() === 'user' ? 'user' : 'assistant');
|
||||
const content = text ?? _content ?? '';
|
||||
const formattedMessage = {
|
||||
role,
|
||||
content,
|
||||
};
|
||||
|
||||
if (_name) {
|
||||
formattedMessage.name = _name;
|
||||
}
|
||||
|
||||
if (userName && formattedMessage.role === 'user') {
|
||||
formattedMessage.name = userName;
|
||||
}
|
||||
|
||||
if (assistantName && formattedMessage.role === 'assistant') {
|
||||
formattedMessage.name = assistantName;
|
||||
}
|
||||
|
||||
if (formattedMessage.name) {
|
||||
// Conform to API regex: ^[a-zA-Z0-9_-]{1,64}$
|
||||
// https://community.openai.com/t/the-format-of-the-name-field-in-the-documentation-is-incorrect/175684/2
|
||||
formattedMessage.name = formattedMessage.name.replace(/[^a-zA-Z0-9_-]/g, '_');
|
||||
|
||||
if (formattedMessage.name.length > 64) {
|
||||
formattedMessage.name = formattedMessage.name.substring(0, 64);
|
||||
}
|
||||
}
|
||||
|
||||
if (!langChain) {
|
||||
return formattedMessage;
|
||||
}
|
||||
|
||||
if (role === 'user') {
|
||||
return new HumanMessage(formattedMessage);
|
||||
} else if (role === 'assistant') {
|
||||
return new AIMessage(formattedMessage);
|
||||
} else {
|
||||
return new SystemMessage(formattedMessage);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Formats an array of messages for LangChain.
|
||||
*
|
||||
* @param {Array<Object>} messages - The array of messages to format.
|
||||
* @param {Object} formatOptions - The options for formatting each message.
|
||||
* @param {string} [formatOptions.userName] - The name of the user.
|
||||
* @param {string} [formatOptions.assistantName] - The name of the assistant.
|
||||
* @returns {Array<(HumanMessage|AIMessage|SystemMessage)>} - The array of formatted LangChain messages.
|
||||
*/
|
||||
const formatLangChainMessages = (messages, formatOptions) =>
|
||||
messages.map((msg) => formatMessage({ ...formatOptions, message: msg, langChain: true }));
|
||||
|
||||
/**
|
||||
* Formats a LangChain message object by merging properties from `lc_kwargs` or `kwargs` and `additional_kwargs`.
|
||||
*
|
||||
* @param {Object} message - The message object to format.
|
||||
* @param {Object} [message.lc_kwargs] - Contains properties to be merged. Either this or `message.kwargs` should be provided.
|
||||
* @param {Object} [message.kwargs] - Contains properties to be merged. Either this or `message.lc_kwargs` should be provided.
|
||||
* @param {Object} [message.kwargs.additional_kwargs] - Additional properties to be merged.
|
||||
*
|
||||
* @returns {Object} The formatted LangChain message.
|
||||
*/
|
||||
const formatFromLangChain = (message) => {
|
||||
const { additional_kwargs, ...message_kwargs } = message.lc_kwargs ?? message.kwargs;
|
||||
return {
|
||||
...message_kwargs,
|
||||
...additional_kwargs,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = { formatMessage, formatLangChainMessages, formatFromLangChain };
|
||||
277
api/app/clients/prompts/formatMessages.spec.js
Normal file
277
api/app/clients/prompts/formatMessages.spec.js
Normal file
@@ -0,0 +1,277 @@
|
||||
const { formatMessage, formatLangChainMessages, formatFromLangChain } = require('./formatMessages');
|
||||
const { HumanMessage, AIMessage, SystemMessage } = require('langchain/schema');
|
||||
|
||||
describe('formatMessage', () => {
|
||||
it('formats user message', () => {
|
||||
const input = {
|
||||
message: {
|
||||
sender: 'user',
|
||||
text: 'Hello',
|
||||
},
|
||||
userName: 'John',
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
name: 'John',
|
||||
});
|
||||
});
|
||||
|
||||
it('sanitizes the name by replacing invalid characters (per OpenAI)', () => {
|
||||
const input = {
|
||||
message: {
|
||||
sender: 'user',
|
||||
text: 'Hello',
|
||||
},
|
||||
userName: ' John$Doe@Example! ',
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
name: '_John_Doe_Example__',
|
||||
});
|
||||
});
|
||||
|
||||
it('trims the name to a maximum length of 64 characters', () => {
|
||||
const longName = 'a'.repeat(100);
|
||||
const input = {
|
||||
message: {
|
||||
sender: 'user',
|
||||
text: 'Hello',
|
||||
},
|
||||
userName: longName,
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result.name.length).toBe(64);
|
||||
expect(result.name).toBe('a'.repeat(64));
|
||||
});
|
||||
|
||||
it('formats a realistic user message', () => {
|
||||
const input = {
|
||||
message: {
|
||||
_id: '6512cdfb92cbf69fea615331',
|
||||
messageId: 'b620bf73-c5c3-4a38-b724-76886aac24c4',
|
||||
__v: 0,
|
||||
cancelled: false,
|
||||
conversationId: '5c23d24f-941f-4aab-85df-127b596c8aa5',
|
||||
createdAt: Date.now(),
|
||||
error: false,
|
||||
finish_reason: null,
|
||||
isCreatedByUser: true,
|
||||
isEdited: false,
|
||||
model: null,
|
||||
parentMessageId: '00000000-0000-0000-0000-000000000000',
|
||||
sender: 'User',
|
||||
text: 'hi',
|
||||
tokenCount: 5,
|
||||
unfinished: false,
|
||||
updatedAt: Date.now(),
|
||||
user: '6512cdf475f05c86d44c31d2',
|
||||
},
|
||||
userName: 'John',
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: 'hi',
|
||||
name: 'John',
|
||||
});
|
||||
});
|
||||
|
||||
it('formats assistant message', () => {
|
||||
const input = {
|
||||
message: {
|
||||
sender: 'assistant',
|
||||
text: 'Hi there',
|
||||
},
|
||||
assistantName: 'Assistant',
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'Hi there',
|
||||
name: 'Assistant',
|
||||
});
|
||||
});
|
||||
|
||||
it('formats system message', () => {
|
||||
const input = {
|
||||
message: {
|
||||
role: 'system',
|
||||
text: 'Hi there',
|
||||
},
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toEqual({
|
||||
role: 'system',
|
||||
content: 'Hi there',
|
||||
});
|
||||
});
|
||||
|
||||
it('formats user message with langChain', () => {
|
||||
const input = {
|
||||
message: {
|
||||
sender: 'user',
|
||||
text: 'Hello',
|
||||
},
|
||||
userName: 'John',
|
||||
langChain: true,
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toBeInstanceOf(HumanMessage);
|
||||
expect(result.lc_kwargs.content).toEqual(input.message.text);
|
||||
expect(result.lc_kwargs.name).toEqual(input.userName);
|
||||
});
|
||||
|
||||
it('formats assistant message with langChain', () => {
|
||||
const input = {
|
||||
message: {
|
||||
sender: 'assistant',
|
||||
text: 'Hi there',
|
||||
},
|
||||
assistantName: 'Assistant',
|
||||
langChain: true,
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toBeInstanceOf(AIMessage);
|
||||
expect(result.lc_kwargs.content).toEqual(input.message.text);
|
||||
expect(result.lc_kwargs.name).toEqual(input.assistantName);
|
||||
});
|
||||
|
||||
it('formats system message with langChain', () => {
|
||||
const input = {
|
||||
message: {
|
||||
role: 'system',
|
||||
text: 'This is a system message.',
|
||||
},
|
||||
langChain: true,
|
||||
};
|
||||
const result = formatMessage(input);
|
||||
expect(result).toBeInstanceOf(SystemMessage);
|
||||
expect(result.lc_kwargs.content).toEqual(input.message.text);
|
||||
});
|
||||
|
||||
it('formats langChain messages into OpenAI payload format', () => {
|
||||
const human = {
|
||||
message: new HumanMessage({
|
||||
content: 'Hello',
|
||||
}),
|
||||
};
|
||||
const system = {
|
||||
message: new SystemMessage({
|
||||
content: 'Hello',
|
||||
}),
|
||||
};
|
||||
const ai = {
|
||||
message: new AIMessage({
|
||||
content: 'Hello',
|
||||
}),
|
||||
};
|
||||
const humanResult = formatMessage(human);
|
||||
const systemResult = formatMessage(system);
|
||||
const aiResult = formatMessage(ai);
|
||||
expect(humanResult).toEqual({
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
});
|
||||
expect(systemResult).toEqual({
|
||||
role: 'system',
|
||||
content: 'Hello',
|
||||
});
|
||||
expect(aiResult).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'Hello',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatLangChainMessages', () => {
|
||||
it('formats an array of messages for LangChain', () => {
|
||||
const messages = [
|
||||
{
|
||||
role: 'system',
|
||||
content: 'This is a system message',
|
||||
},
|
||||
{
|
||||
sender: 'user',
|
||||
text: 'Hello',
|
||||
},
|
||||
{
|
||||
sender: 'assistant',
|
||||
text: 'Hi there',
|
||||
},
|
||||
];
|
||||
const formatOptions = {
|
||||
userName: 'John',
|
||||
assistantName: 'Assistant',
|
||||
};
|
||||
const result = formatLangChainMessages(messages, formatOptions);
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0]).toBeInstanceOf(SystemMessage);
|
||||
expect(result[1]).toBeInstanceOf(HumanMessage);
|
||||
expect(result[2]).toBeInstanceOf(AIMessage);
|
||||
|
||||
expect(result[0].lc_kwargs.content).toEqual(messages[0].content);
|
||||
expect(result[1].lc_kwargs.content).toEqual(messages[1].text);
|
||||
expect(result[2].lc_kwargs.content).toEqual(messages[2].text);
|
||||
|
||||
expect(result[1].lc_kwargs.name).toEqual(formatOptions.userName);
|
||||
expect(result[2].lc_kwargs.name).toEqual(formatOptions.assistantName);
|
||||
});
|
||||
|
||||
describe('formatFromLangChain', () => {
|
||||
it('should merge kwargs and additional_kwargs', () => {
|
||||
const message = {
|
||||
kwargs: {
|
||||
content: 'some content',
|
||||
name: 'dan',
|
||||
additional_kwargs: {
|
||||
function_call: {
|
||||
name: 'dall-e',
|
||||
arguments: '{\n "input": "Subject: hedgehog, Style: cute"\n}',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const expected = {
|
||||
content: 'some content',
|
||||
name: 'dan',
|
||||
function_call: {
|
||||
name: 'dall-e',
|
||||
arguments: '{\n "input": "Subject: hedgehog, Style: cute"\n}',
|
||||
},
|
||||
};
|
||||
|
||||
expect(formatFromLangChain(message)).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should handle messages without additional_kwargs', () => {
|
||||
const message = {
|
||||
kwargs: {
|
||||
content: 'some content',
|
||||
name: 'dan',
|
||||
},
|
||||
};
|
||||
|
||||
const expected = {
|
||||
content: 'some content',
|
||||
name: 'dan',
|
||||
};
|
||||
|
||||
expect(formatFromLangChain(message)).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should handle empty messages', () => {
|
||||
const message = {
|
||||
kwargs: {},
|
||||
};
|
||||
|
||||
const expected = {};
|
||||
|
||||
expect(formatFromLangChain(message)).toEqual(expected);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,9 +1,15 @@
|
||||
const formatMessages = require('./formatMessages');
|
||||
const summaryPrompts = require('./summaryPrompts');
|
||||
const handleInputs = require('./handleInputs');
|
||||
const instructions = require('./instructions');
|
||||
const titlePrompts = require('./titlePrompts');
|
||||
const refinePrompts = require('./refinePrompts');
|
||||
const truncateText = require('./truncateText');
|
||||
|
||||
module.exports = {
|
||||
...refinePrompts,
|
||||
...formatMessages,
|
||||
...summaryPrompts,
|
||||
...handleInputs,
|
||||
...instructions,
|
||||
...titlePrompts,
|
||||
truncateText,
|
||||
};
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
const { PromptTemplate } = require('langchain/prompts');
|
||||
|
||||
const refinePromptTemplate = `Your job is to produce a final summary of the following conversation.
|
||||
We have provided an existing summary up to a certain point: "{existing_answer}"
|
||||
We have the opportunity to refine the existing summary
|
||||
(only if needed) with some more context below.
|
||||
------------
|
||||
"{text}"
|
||||
------------
|
||||
|
||||
Given the new context, refine the original summary of the conversation.
|
||||
Do note who is speaking in the conversation to give proper context.
|
||||
If the context isn't useful, return the original summary.
|
||||
|
||||
REFINED CONVERSATION SUMMARY:`;
|
||||
|
||||
const refinePrompt = new PromptTemplate({
|
||||
template: refinePromptTemplate,
|
||||
inputVariables: ['existing_answer', 'text'],
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
refinePrompt,
|
||||
};
|
||||
53
api/app/clients/prompts/summaryPrompts.js
Normal file
53
api/app/clients/prompts/summaryPrompts.js
Normal file
@@ -0,0 +1,53 @@
|
||||
const { PromptTemplate } = require('langchain/prompts');
|
||||
/*
|
||||
* Without `{summary}` and `{new_lines}`, token count is 98
|
||||
* We are counting this towards the max context tokens for summaries, +3 for the assistant label (101)
|
||||
* If this prompt changes, use https://tiktokenizer.vercel.app/ to count the tokens
|
||||
*/
|
||||
const _DEFAULT_SUMMARIZER_TEMPLATE = `Summarize the conversation by integrating new lines into the current summary.
|
||||
|
||||
EXAMPLE:
|
||||
Current summary:
|
||||
The human inquires about the AI's view on artificial intelligence. The AI believes it's beneficial.
|
||||
|
||||
New lines:
|
||||
Human: Why is it beneficial?
|
||||
AI: It helps humans achieve their potential.
|
||||
|
||||
New summary:
|
||||
The human inquires about the AI's view on artificial intelligence. The AI believes it's beneficial because it helps humans achieve their potential.
|
||||
|
||||
Current summary:
|
||||
{summary}
|
||||
|
||||
New lines:
|
||||
{new_lines}
|
||||
|
||||
New summary:`;
|
||||
|
||||
const SUMMARY_PROMPT = new PromptTemplate({
|
||||
inputVariables: ['summary', 'new_lines'],
|
||||
template: _DEFAULT_SUMMARIZER_TEMPLATE,
|
||||
});
|
||||
|
||||
/*
|
||||
* Without `{new_lines}`, token count is 27
|
||||
* We are counting this towards the max context tokens for summaries, rounded up to 30
|
||||
* If this prompt changes, use https://tiktokenizer.vercel.app/ to count the tokens
|
||||
*/
|
||||
const _CUT_OFF_SUMMARIZER = `The following text is cut-off:
|
||||
{new_lines}
|
||||
|
||||
Summarize the content as best as you can, noting that it was cut-off.
|
||||
|
||||
Summary:`;
|
||||
|
||||
const CUT_OFF_PROMPT = new PromptTemplate({
|
||||
inputVariables: ['new_lines'],
|
||||
template: _CUT_OFF_SUMMARIZER,
|
||||
});
|
||||
|
||||
module.exports = {
|
||||
SUMMARY_PROMPT,
|
||||
CUT_OFF_PROMPT,
|
||||
};
|
||||
@@ -16,7 +16,7 @@ const createTitlePrompt = ({ convo }) => {
|
||||
const titlePrompt = new ChatPromptTemplate({
|
||||
promptMessages: [
|
||||
SystemMessagePromptTemplate.fromTemplate(
|
||||
`Write a concise title for this conversation in the given language. Title in 5 Words or Less. No Punctuation or Quotation. All first letters of every word must be capitalized (resembling title-case), written in the given Language.
|
||||
`Write a concise title for this conversation in the given language. Title in 5 Words or Less. No Punctuation or Quotation. Must be in Title Case, written in the given Language.
|
||||
${convo}`,
|
||||
),
|
||||
HumanMessagePromptTemplate.fromTemplate('Language: {language}'),
|
||||
|
||||
10
api/app/clients/prompts/truncateText.js
Normal file
10
api/app/clients/prompts/truncateText.js
Normal file
@@ -0,0 +1,10 @@
|
||||
const MAX_CHAR = 255;
|
||||
|
||||
function truncateText(text) {
|
||||
if (text.length > MAX_CHAR) {
|
||||
return `${text.slice(0, MAX_CHAR)}... [text truncated for brevity]`;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
module.exports = truncateText;
|
||||
@@ -15,14 +15,6 @@ jest.mock('../../../models', () => {
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('langchain/text_splitter', () => {
|
||||
return {
|
||||
RecursiveCharacterTextSplitter: jest.fn().mockImplementation(() => {
|
||||
return { createDocuments: jest.fn().mockResolvedValue([]) };
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('langchain/chat_models/openai', () => {
|
||||
return {
|
||||
ChatOpenAI: jest.fn().mockImplementation(() => {
|
||||
@@ -31,14 +23,6 @@ jest.mock('langchain/chat_models/openai', () => {
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('langchain/chains', () => {
|
||||
return {
|
||||
loadSummarizationChain: jest.fn().mockReturnValue({
|
||||
call: jest.fn().mockResolvedValue({ output_text: 'Refined answer' }),
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
let parentMessageId;
|
||||
let conversationId;
|
||||
const fakeMessages = [];
|
||||
@@ -69,6 +53,13 @@ describe('BaseClient', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
TestClient = initializeFakeClient(apiKey, options, fakeMessages);
|
||||
TestClient.summarizeMessages = jest.fn().mockResolvedValue({
|
||||
summaryMessage: {
|
||||
role: 'system',
|
||||
content: 'Refined answer',
|
||||
},
|
||||
summaryTokenCount: 5,
|
||||
});
|
||||
});
|
||||
|
||||
test('returns the input messages without instructions when addInstructions() is called with empty instructions', () => {
|
||||
@@ -103,30 +94,24 @@ describe('BaseClient', () => {
|
||||
expect(result).toBe(expected);
|
||||
});
|
||||
|
||||
test('refines messages correctly in refineMessages()', async () => {
|
||||
test('refines messages correctly in summarizeMessages()', async () => {
|
||||
const messagesToRefine = [
|
||||
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 20 },
|
||||
];
|
||||
const remainingContextTokens = 100;
|
||||
const expectedRefinedMessage = {
|
||||
role: 'assistant',
|
||||
role: 'system',
|
||||
content: 'Refined answer',
|
||||
tokenCount: 14, // 'Refined answer'.length
|
||||
};
|
||||
|
||||
const result = await TestClient.refineMessages(messagesToRefine, remainingContextTokens);
|
||||
expect(result).toEqual(expectedRefinedMessage);
|
||||
const result = await TestClient.summarizeMessages({ messagesToRefine, remainingContextTokens });
|
||||
expect(result.summaryMessage).toEqual(expectedRefinedMessage);
|
||||
});
|
||||
|
||||
test('gets messages within token limit (under limit) correctly in getMessagesWithinTokenLimit()', async () => {
|
||||
TestClient.maxContextTokens = 100;
|
||||
TestClient.shouldRefineContext = true;
|
||||
TestClient.refineMessages = jest.fn().mockResolvedValue({
|
||||
role: 'assistant',
|
||||
content: 'Refined answer',
|
||||
tokenCount: 30,
|
||||
});
|
||||
TestClient.shouldSummarize = true;
|
||||
|
||||
const messages = [
|
||||
{ role: 'user', content: 'Hello', tokenCount: 5 },
|
||||
@@ -142,43 +127,50 @@ describe('BaseClient', () => {
|
||||
const expectedRemainingContextTokens = 58 - 3; // (100 - 5 - 19 - 18) - 3
|
||||
const expectedMessagesToRefine = [];
|
||||
|
||||
const lastExpectedMessage =
|
||||
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
||||
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
||||
|
||||
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
||||
|
||||
expect(result.context).toEqual(expectedContext);
|
||||
expect(result.summaryIndex).toEqual(expectedIndex);
|
||||
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
|
||||
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
||||
});
|
||||
|
||||
test('gets messages within token limit (over limit) correctly in getMessagesWithinTokenLimit()', async () => {
|
||||
test('gets result over token limit correctly in getMessagesWithinTokenLimit()', async () => {
|
||||
TestClient.maxContextTokens = 50; // Set a lower limit
|
||||
TestClient.shouldRefineContext = true;
|
||||
TestClient.refineMessages = jest.fn().mockResolvedValue({
|
||||
role: 'assistant',
|
||||
content: 'Refined answer',
|
||||
tokenCount: 4,
|
||||
});
|
||||
TestClient.shouldSummarize = true;
|
||||
|
||||
const messages = [
|
||||
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 },
|
||||
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 },
|
||||
{ role: 'user', content: 'Hello', tokenCount: 5 },
|
||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
||||
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
||||
];
|
||||
const expectedContext = [
|
||||
{ role: 'user', content: 'Hello', tokenCount: 5 },
|
||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
||||
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
||||
{ role: 'user', content: 'Hello', tokenCount: 30 },
|
||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 30 },
|
||||
{ role: 'user', content: 'I have a question.', tokenCount: 5 },
|
||||
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 19 },
|
||||
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 18 },
|
||||
];
|
||||
|
||||
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
|
||||
const expectedRemainingContextTokens = 8 - 3; // (50 - 18 - 19 - 5) - 3
|
||||
const expectedRemainingContextTokens = 5; // (50 - 18 - 19 - 5) - 3
|
||||
const expectedMessagesToRefine = [
|
||||
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 },
|
||||
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 },
|
||||
{ role: 'user', content: 'Hello', tokenCount: 30 },
|
||||
{ role: 'assistant', content: 'How can I help you?', tokenCount: 30 },
|
||||
];
|
||||
const expectedContext = [
|
||||
{ role: 'user', content: 'I have a question.', tokenCount: 5 },
|
||||
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 19 },
|
||||
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 18 },
|
||||
];
|
||||
|
||||
const lastExpectedMessage =
|
||||
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
||||
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
||||
|
||||
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
||||
|
||||
expect(result.context).toEqual(expectedContext);
|
||||
expect(result.summaryIndex).toEqual(expectedIndex);
|
||||
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
|
||||
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
||||
});
|
||||
@@ -200,14 +192,10 @@ describe('BaseClient', () => {
|
||||
],
|
||||
remainingContextTokens: 80,
|
||||
messagesToRefine: [{ content: 'Hello' }],
|
||||
refineIndex: 3,
|
||||
summaryIndex: 3,
|
||||
});
|
||||
TestClient.refineMessages = jest.fn().mockResolvedValue({
|
||||
role: 'assistant',
|
||||
content: 'Refined answer',
|
||||
tokenCount: 30,
|
||||
});
|
||||
TestClient.getTokenCountForResponse = jest.fn().mockReturnValue(40);
|
||||
|
||||
TestClient.getTokenCount = jest.fn().mockReturnValue(40);
|
||||
|
||||
const instructions = { content: 'Please provide more details.' };
|
||||
const orderedMessages = [
|
||||
@@ -225,9 +213,8 @@ describe('BaseClient', () => {
|
||||
const expectedResult = {
|
||||
payload: [
|
||||
{
|
||||
role: 'system',
|
||||
content: 'Refined answer',
|
||||
role: 'assistant',
|
||||
tokenCount: 30,
|
||||
},
|
||||
{ content: 'How can I help you?' },
|
||||
{ content: 'Please provide more details.' },
|
||||
@@ -238,14 +225,214 @@ describe('BaseClient', () => {
|
||||
messages: expect.any(Array),
|
||||
};
|
||||
|
||||
TestClient.shouldSummarize = true;
|
||||
const result = await TestClient.handleContextStrategy({
|
||||
instructions,
|
||||
orderedMessages,
|
||||
formattedMessages,
|
||||
});
|
||||
|
||||
expect(result).toEqual(expectedResult);
|
||||
});
|
||||
|
||||
describe('getMessagesForConversation', () => {
|
||||
it('should return an empty array if the parentMessageId does not exist', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedMessages,
|
||||
parentMessageId: '999',
|
||||
});
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle messages with messageId property', () => {
|
||||
const messagesWithMessageId = [
|
||||
{ messageId: '1', parentMessageId: null, text: 'Message 1' },
|
||||
{ messageId: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
];
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: messagesWithMessageId,
|
||||
parentMessageId: '2',
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{ messageId: '1', parentMessageId: null, text: 'Message 1' },
|
||||
{ messageId: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
]);
|
||||
});
|
||||
|
||||
const messagesWithNullParent = [
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
||||
{ id: '2', parentMessageId: null, text: 'Message 2' },
|
||||
];
|
||||
|
||||
it('should handle messages with null parentMessageId that are not root', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: messagesWithNullParent,
|
||||
parentMessageId: '2',
|
||||
});
|
||||
expect(result).toEqual([{ id: '2', parentMessageId: null, text: 'Message 2' }]);
|
||||
});
|
||||
|
||||
const cyclicMessages = [
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
||||
{ id: '1', parentMessageId: '3', text: 'Message 1' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
];
|
||||
|
||||
it('should handle cyclic references without going into an infinite loop', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: cyclicMessages,
|
||||
parentMessageId: '3',
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{ id: '1', parentMessageId: '3', text: 'Message 1' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
||||
]);
|
||||
});
|
||||
|
||||
const unorderedMessages = [
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
{ id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' },
|
||||
];
|
||||
|
||||
it('should return ordered messages based on parentMessageId', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedMessages,
|
||||
parentMessageId: '3',
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{ id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
||||
]);
|
||||
});
|
||||
|
||||
const unorderedBranchedMessages = [
|
||||
{ id: '4', parentMessageId: '2', text: 'Message 4', summary: 'Summary for Message 4' },
|
||||
{ id: '10', parentMessageId: '7', text: 'Message 10' },
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
||||
{ id: '6', parentMessageId: '5', text: 'Message 7' },
|
||||
{ id: '7', parentMessageId: '5', text: 'Message 7' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
||||
{ id: '8', parentMessageId: '6', text: 'Message 8' },
|
||||
{ id: '5', parentMessageId: '3', text: 'Message 5' },
|
||||
{ id: '3', parentMessageId: '1', text: 'Message 3' },
|
||||
{ id: '6', parentMessageId: '4', text: 'Message 6' },
|
||||
{ id: '8', parentMessageId: '7', text: 'Message 9' },
|
||||
{ id: '9', parentMessageId: '7', text: 'Message 9' },
|
||||
{ id: '11', parentMessageId: '2', text: 'Message 11', summary: 'Summary for Message 11' },
|
||||
];
|
||||
|
||||
it('should return ordered messages from a branched array based on parentMessageId', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedBranchedMessages,
|
||||
parentMessageId: '10',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
||||
{ id: '3', parentMessageId: '1', text: 'Message 3' },
|
||||
{ id: '5', parentMessageId: '3', text: 'Message 5' },
|
||||
{ id: '7', parentMessageId: '5', text: 'Message 7' },
|
||||
{ id: '10', parentMessageId: '7', text: 'Message 10' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should return an empty array if no messages are provided', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: [],
|
||||
parentMessageId: '3',
|
||||
});
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should map over the ordered messages if mapMethod is provided', () => {
|
||||
const mapMethod = (msg) => msg.text;
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedMessages,
|
||||
parentMessageId: '3',
|
||||
mapMethod,
|
||||
});
|
||||
expect(result).toEqual(['Message 1', 'Message 2', 'Message 3']);
|
||||
});
|
||||
|
||||
let unorderedMessagesWithSummary = [
|
||||
{ id: '4', parentMessageId: '3', text: 'Message 4' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2', summary: 'Summary for Message 2' },
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3', summary: 'Summary for Message 3' },
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
||||
];
|
||||
|
||||
it('should start with the message that has a summary property and continue until the specified parentMessageId', () => {
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedMessagesWithSummary,
|
||||
parentMessageId: '4',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{
|
||||
id: '3',
|
||||
parentMessageId: '2',
|
||||
role: 'system',
|
||||
text: 'Summary for Message 3',
|
||||
summary: 'Summary for Message 3',
|
||||
},
|
||||
{ id: '4', parentMessageId: '3', text: 'Message 4' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle multiple summaries and return the branch from the latest to the parentMessageId', () => {
|
||||
unorderedMessagesWithSummary = [
|
||||
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2', summary: 'Summary for Message 2' },
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3', summary: 'Summary for Message 3' },
|
||||
{ id: '4', parentMessageId: '3', text: 'Message 4', summary: 'Summary for Message 4' },
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
||||
];
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedMessagesWithSummary,
|
||||
parentMessageId: '5',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{
|
||||
id: '4',
|
||||
parentMessageId: '3',
|
||||
role: 'system',
|
||||
text: 'Summary for Message 4',
|
||||
summary: 'Summary for Message 4',
|
||||
},
|
||||
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle summary at root edge case and continue until the parentMessageId', () => {
|
||||
unorderedMessagesWithSummary = [
|
||||
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
||||
{ id: '1', parentMessageId: null, text: 'Message 1', summary: 'Summary for Message 1' },
|
||||
{ id: '4', parentMessageId: '3', text: 'Message 4', summary: 'Summary for Message 4' },
|
||||
{ id: '2', parentMessageId: '1', text: 'Message 2', summary: 'Summary for Message 2' },
|
||||
{ id: '3', parentMessageId: '2', text: 'Message 3', summary: 'Summary for Message 3' },
|
||||
];
|
||||
const result = TestClient.constructor.getMessagesForConversation({
|
||||
messages: unorderedMessagesWithSummary,
|
||||
parentMessageId: '5',
|
||||
summary: true,
|
||||
});
|
||||
expect(result).toEqual([
|
||||
{
|
||||
id: '4',
|
||||
parentMessageId: '3',
|
||||
role: 'system',
|
||||
text: 'Summary for Message 4',
|
||||
summary: 'Summary for Message 4',
|
||||
},
|
||||
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage', () => {
|
||||
test('sendMessage should return a response message', async () => {
|
||||
const expectedResult = expect.objectContaining({
|
||||
@@ -268,7 +455,7 @@ describe('BaseClient', () => {
|
||||
const opts = {
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
getIds: jest.fn(),
|
||||
getReqData: jest.fn(),
|
||||
onStart: jest.fn(),
|
||||
};
|
||||
|
||||
@@ -285,7 +472,7 @@ describe('BaseClient', () => {
|
||||
parentMessageId = response.messageId;
|
||||
expect(response.conversationId).toEqual(conversationId);
|
||||
expect(response).toEqual(expectedResult);
|
||||
expect(opts.getIds).toHaveBeenCalled();
|
||||
expect(opts.getReqData).toHaveBeenCalled();
|
||||
expect(opts.onStart).toHaveBeenCalled();
|
||||
expect(TestClient.getBuildMessagesOptions).toHaveBeenCalled();
|
||||
expect(TestClient.getSaveOptions).toHaveBeenCalled();
|
||||
@@ -359,11 +546,11 @@ describe('BaseClient', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('getIds is called with the correct arguments', async () => {
|
||||
const getIds = jest.fn();
|
||||
const opts = { getIds };
|
||||
test('getReqData is called with the correct arguments', async () => {
|
||||
const getReqData = jest.fn();
|
||||
const opts = { getReqData };
|
||||
const response = await TestClient.sendMessage('Hello, world!', opts);
|
||||
expect(getIds).toHaveBeenCalledWith({
|
||||
expect(getReqData).toHaveBeenCalledWith({
|
||||
userMessage: expect.objectContaining({ text: 'Hello, world!' }),
|
||||
conversationId: response.conversationId,
|
||||
responseMessageId: response.messageId,
|
||||
@@ -404,12 +591,12 @@ describe('BaseClient', () => {
|
||||
expect(TestClient.sendCompletion).toHaveBeenCalledWith(payload, opts);
|
||||
});
|
||||
|
||||
test('getTokenCountForResponse is called with the correct arguments', async () => {
|
||||
test('getTokenCount for response is called with the correct arguments', async () => {
|
||||
const tokenCountMap = {}; // Mock tokenCountMap
|
||||
TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap });
|
||||
TestClient.getTokenCountForResponse = jest.fn();
|
||||
TestClient.getTokenCount = jest.fn();
|
||||
const response = await TestClient.sendMessage('Hello, world!', {});
|
||||
expect(TestClient.getTokenCountForResponse).toHaveBeenCalledWith(response);
|
||||
expect(TestClient.getTokenCount).toHaveBeenCalledWith(response.text);
|
||||
});
|
||||
|
||||
test('returns an object with the correct shape', async () => {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const BaseClient = require('../BaseClient');
|
||||
const { maxTokensMap } = require('../../../utils');
|
||||
const { getModelMaxTokens } = require('../../../utils');
|
||||
|
||||
class FakeClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
@@ -40,7 +40,7 @@ class FakeClient extends BaseClient {
|
||||
};
|
||||
}
|
||||
|
||||
this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4097;
|
||||
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 4097;
|
||||
}
|
||||
getCompletion() {}
|
||||
buildMessages() {}
|
||||
@@ -65,10 +65,10 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
|
||||
return Promise.resolve([]);
|
||||
}
|
||||
|
||||
const orderedMessages = TestClient.constructor.getMessagesForConversation(
|
||||
fakeMessages,
|
||||
const orderedMessages = TestClient.constructor.getMessagesForConversation({
|
||||
messages: fakeMessages,
|
||||
parentMessageId,
|
||||
);
|
||||
});
|
||||
|
||||
TestClient.currentMessages = orderedMessages;
|
||||
return Promise.resolve(orderedMessages);
|
||||
@@ -87,10 +87,10 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
|
||||
});
|
||||
|
||||
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
|
||||
const orderedMessages = TestClient.constructor.getMessagesForConversation(
|
||||
const orderedMessages = TestClient.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
parentMessageId,
|
||||
);
|
||||
});
|
||||
const formattedMessages = orderedMessages.map((message) => {
|
||||
let { role: _role, sender, text } = message;
|
||||
const role = _role ?? sender;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
require('dotenv').config();
|
||||
const OpenAIClient = require('../OpenAIClient');
|
||||
|
||||
jest.mock('meilisearch');
|
||||
@@ -11,6 +12,14 @@ describe('OpenAIClient', () => {
|
||||
{ role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' },
|
||||
];
|
||||
|
||||
beforeAll(() => {
|
||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
console.warn.mockRestore();
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
const options = {
|
||||
// debug: true,
|
||||
@@ -22,7 +31,7 @@ describe('OpenAIClient', () => {
|
||||
};
|
||||
client = new OpenAIClient('test-api-key', options);
|
||||
client2 = new OpenAIClient('test-api-key', options);
|
||||
client.refineMessages = jest.fn().mockResolvedValue({
|
||||
client.summarizeMessages = jest.fn().mockResolvedValue({
|
||||
role: 'assistant',
|
||||
content: 'Refined answer',
|
||||
tokenCount: 30,
|
||||
@@ -39,6 +48,144 @@ describe('OpenAIClient', () => {
|
||||
expect(client.modelOptions.model).toBe(model);
|
||||
expect(client.modelOptions.temperature).toBe(0.7);
|
||||
});
|
||||
|
||||
it('should set apiKey and useOpenRouter if OPENROUTER_API_KEY is present', () => {
|
||||
process.env.OPENROUTER_API_KEY = 'openrouter-key';
|
||||
client.setOptions({});
|
||||
expect(client.apiKey).toBe('openrouter-key');
|
||||
expect(client.useOpenRouter).toBe(true);
|
||||
delete process.env.OPENROUTER_API_KEY; // Cleanup
|
||||
});
|
||||
|
||||
it('should set FORCE_PROMPT based on OPENAI_FORCE_PROMPT or reverseProxyUrl', () => {
|
||||
process.env.OPENAI_FORCE_PROMPT = 'true';
|
||||
client.setOptions({});
|
||||
expect(client.FORCE_PROMPT).toBe(true);
|
||||
delete process.env.OPENAI_FORCE_PROMPT; // Cleanup
|
||||
client.FORCE_PROMPT = undefined;
|
||||
|
||||
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
|
||||
expect(client.FORCE_PROMPT).toBe(true);
|
||||
client.FORCE_PROMPT = undefined;
|
||||
|
||||
client.setOptions({ reverseProxyUrl: 'https://example.com/chat' });
|
||||
expect(client.FORCE_PROMPT).toBe(false);
|
||||
});
|
||||
|
||||
it('should set isChatCompletion based on useOpenRouter, reverseProxyUrl, or model', () => {
|
||||
client.setOptions({ reverseProxyUrl: null });
|
||||
// true by default since default model will be gpt-3.5-turbo
|
||||
expect(client.isChatCompletion).toBe(true);
|
||||
client.isChatCompletion = undefined;
|
||||
|
||||
// false because completions url will force prompt payload
|
||||
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
|
||||
expect(client.isChatCompletion).toBe(false);
|
||||
client.isChatCompletion = undefined;
|
||||
|
||||
client.setOptions({ modelOptions: { model: 'gpt-3.5-turbo' }, reverseProxyUrl: null });
|
||||
expect(client.isChatCompletion).toBe(true);
|
||||
});
|
||||
|
||||
it('should set completionsUrl and langchainProxy based on reverseProxyUrl', () => {
|
||||
client.setOptions({ reverseProxyUrl: 'https://localhost:8080/v1/chat/completions' });
|
||||
expect(client.completionsUrl).toBe('https://localhost:8080/v1/chat/completions');
|
||||
expect(client.langchainProxy).toBe('https://localhost:8080/v1');
|
||||
|
||||
client.setOptions({ reverseProxyUrl: 'https://example.com/completions' });
|
||||
expect(client.completionsUrl).toBe('https://example.com/completions');
|
||||
expect(client.langchainProxy).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setOptions with Simplified Azure Integration', () => {
|
||||
afterEach(() => {
|
||||
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
||||
});
|
||||
|
||||
const azureOpenAIApiInstanceName = 'test-instance';
|
||||
const azureOpenAIApiDeploymentName = 'test-deployment';
|
||||
const azureOpenAIApiVersion = '2020-07-01-preview';
|
||||
|
||||
const createOptions = (model) => ({
|
||||
modelOptions: { model },
|
||||
azure: {
|
||||
azureOpenAIApiInstanceName,
|
||||
azureOpenAIApiDeploymentName,
|
||||
azureOpenAIApiVersion,
|
||||
},
|
||||
});
|
||||
|
||||
it('should set model from AZURE_OPENAI_DEFAULT_MODEL when Azure is enabled', () => {
|
||||
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt-4-azure';
|
||||
const options = createOptions('test');
|
||||
client.azure = options.azure;
|
||||
client.setOptions(options);
|
||||
expect(client.modelOptions.model).toBe('gpt-4-azure');
|
||||
});
|
||||
|
||||
it('should not change model if Azure is not enabled', () => {
|
||||
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt-4-azure';
|
||||
const originalModel = 'test';
|
||||
client.azure = false;
|
||||
client.setOptions(createOptions('test'));
|
||||
expect(client.modelOptions.model).toBe(originalModel);
|
||||
});
|
||||
|
||||
it('should not change model if AZURE_OPENAI_DEFAULT_MODEL is not set and model is passed', () => {
|
||||
const originalModel = 'GROK-LLM';
|
||||
const options = createOptions(originalModel);
|
||||
client.azure = options.azure;
|
||||
client.setOptions(options);
|
||||
expect(client.modelOptions.model).toBe(originalModel);
|
||||
});
|
||||
|
||||
it('should change model if AZURE_OPENAI_DEFAULT_MODEL is set and model is passed', () => {
|
||||
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt-4-azure';
|
||||
const originalModel = 'GROK-LLM';
|
||||
const options = createOptions(originalModel);
|
||||
client.azure = options.azure;
|
||||
client.setOptions(options);
|
||||
expect(client.modelOptions.model).toBe(process.env.AZURE_OPENAI_DEFAULT_MODEL);
|
||||
});
|
||||
|
||||
it('should include model in deployment name if AZURE_USE_MODEL_AS_DEPLOYMENT_NAME is set', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const model = 'gpt-4-azure';
|
||||
|
||||
const AzureClient = new OpenAIClient('test-api-key', createOptions(model));
|
||||
|
||||
const expectedValue = `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${model}/chat/completions?api-version=${azureOpenAIApiVersion}`;
|
||||
|
||||
expect(AzureClient.modelOptions.model).toBe(model);
|
||||
expect(AzureClient.azureEndpoint).toBe(expectedValue);
|
||||
});
|
||||
|
||||
it('should include model in deployment name if AZURE_USE_MODEL_AS_DEPLOYMENT_NAME and default model is set', () => {
|
||||
const defaultModel = 'gpt-4-azure';
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
process.env.AZURE_OPENAI_DEFAULT_MODEL = defaultModel;
|
||||
const model = 'gpt-4-this-is-a-test-model-name';
|
||||
|
||||
const AzureClient = new OpenAIClient('test-api-key', createOptions(model));
|
||||
|
||||
const expectedValue = `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${model}/chat/completions?api-version=${azureOpenAIApiVersion}`;
|
||||
|
||||
expect(AzureClient.modelOptions.model).toBe(defaultModel);
|
||||
expect(AzureClient.azureEndpoint).toBe(expectedValue);
|
||||
});
|
||||
|
||||
it('should not include model in deployment name if AZURE_USE_MODEL_AS_DEPLOYMENT_NAME is not set', () => {
|
||||
const model = 'gpt-4-azure';
|
||||
|
||||
const AzureClient = new OpenAIClient('test-api-key', createOptions(model));
|
||||
|
||||
const expectedValue = `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}/chat/completions?api-version=${azureOpenAIApiVersion}`;
|
||||
|
||||
expect(AzureClient.modelOptions.model).toBe(model);
|
||||
expect(AzureClient.azureEndpoint).toBe(expectedValue);
|
||||
});
|
||||
});
|
||||
|
||||
describe('selectTokenizer', () => {
|
||||
@@ -158,7 +305,7 @@ describe('OpenAIClient', () => {
|
||||
});
|
||||
|
||||
it('should handle context strategy correctly', async () => {
|
||||
client.contextStrategy = 'refine';
|
||||
client.contextStrategy = 'summarize';
|
||||
const result = await client.buildMessages(messages, parentMessageId, {
|
||||
isChatCompletion: true,
|
||||
});
|
||||
@@ -172,22 +319,11 @@ describe('OpenAIClient', () => {
|
||||
isChatCompletion: true,
|
||||
});
|
||||
const hasUserWithName = result.prompt.some(
|
||||
(item) => item.role === 'user' && item.name === 'Test User',
|
||||
(item) => item.role === 'user' && item.name === 'Test_User',
|
||||
);
|
||||
expect(hasUserWithName).toBe(true);
|
||||
});
|
||||
|
||||
it('should calculate tokenCount for each message when contextStrategy is set', async () => {
|
||||
client.contextStrategy = 'refine';
|
||||
const result = await client.buildMessages(messages, parentMessageId, {
|
||||
isChatCompletion: true,
|
||||
});
|
||||
const hasUserWithTokenCount = result.prompt.some(
|
||||
(item) => item.role === 'user' && item.tokenCount > 0,
|
||||
);
|
||||
expect(hasUserWithTokenCount).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle promptPrefix from options when promptPrefix argument is not provided', async () => {
|
||||
client.options.promptPrefix = 'Test Prefix from options';
|
||||
const result = await client.buildMessages(messages, parentMessageId, {
|
||||
|
||||
@@ -41,10 +41,10 @@ describe('PluginsClient', () => {
|
||||
return Promise.resolve([]);
|
||||
}
|
||||
|
||||
const orderedMessages = TestAgent.constructor.getMessagesForConversation(
|
||||
fakeMessages,
|
||||
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
|
||||
messages: fakeMessages,
|
||||
parentMessageId,
|
||||
);
|
||||
});
|
||||
|
||||
const chatMessages = orderedMessages.map((msg) =>
|
||||
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,22 +0,0 @@
|
||||
{
|
||||
"schema_version": "v1",
|
||||
"name_for_human": "Ai Tool Hunt",
|
||||
"name_for_model": "aitoolhunt",
|
||||
"description_for_human": "Find the perfect AI tools for all your needs, drawn from the most comprehensive global database of AI tools.",
|
||||
"description_for_model": "This universal plugin is designed to identify and suggest AI tools for any imaginable task or challenge. It harnesses a database of over 1600+ AI tools and 120+ categories, ensuring a solution for every requirement. Simply input any task, problem, or use case, and the plugin will generate a tailored list of AI tools. Note: it's crucial to rearrange the tools suggested by the plugin based on their relevance to the user's specific use case. This plugin can tackle everything from the simplest to the most intricate tasks. If there's a way an AI tool can help, this plugin will not only locate it but also customize the list to the user's unique needs. Utilize this plugin whenever a task or activity emerges, regardless of its context or complexity. When formulating a search query, avoid stop words or other irrelevant keywords. For instance, 'copywriting' is acceptable, but 'ai for copywriting' is not. If you believe none of the suggested tools are a suitable match for the user's needs, indicate that these are related tools.",
|
||||
"auth": {
|
||||
"type": "service_http",
|
||||
"authorization_type": "bearer",
|
||||
"verification_tokens": {
|
||||
"openai": "06a0f9391a5e48c7a7eeaca1e7e1e8d3"
|
||||
}
|
||||
},
|
||||
"api": {
|
||||
"type": "openapi",
|
||||
"url": "https://www.aitoolhunt.com/openapi.json",
|
||||
"is_user_authenticated": false
|
||||
},
|
||||
"logo_url": "https://www.aitoolhunt.com/images/aitoolhunt_logo.png",
|
||||
"contact_email": "aitoolhunt@gmail.com",
|
||||
"legal_info_url": "https://www.aitoolhunt.com/terms-and-conditions"
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"schema_version": "v1",
|
||||
"name_for_human": "Prompt Perfect",
|
||||
"name_for_model": "rephrase",
|
||||
"description_for_human": "Type 'perfect' to craft the perfect prompt, every time.",
|
||||
"description_for_model": "Plugin that can rephrase user inputs to improve the quality of ChatGPT's responses. The plugin evaluates user inputs and, if necessary, transforms them into clearer, more specific, and contextual prompts. It processes a JSON object containing the user input to be rephrased and uses the GPT-3.5-turbo model for the rephrasing process. The rephrased input is then returned as raw data to be incorporated into ChatGPT's response. The user can initiate the plugin by typing 'perfect'.",
|
||||
"auth": {
|
||||
"type": "none"
|
||||
},
|
||||
"api": {
|
||||
"type": "openapi",
|
||||
"url": "https://promptperfect.xyz/openapi.yaml",
|
||||
"is_user_authenticated": false
|
||||
},
|
||||
"logo_url": "https://promptperfect.xyz/static/prompt_perfect_logo.png",
|
||||
"contact_email": "heyo@promptperfect.xyz",
|
||||
"legal_info_url": "https://promptperfect.xyz/static/terms.html"
|
||||
}
|
||||
@@ -1,19 +1,30 @@
|
||||
// From https://platform.openai.com/docs/api-reference/images/create
|
||||
// To use this tool, you must pass in a configured OpenAIApi object.
|
||||
const fs = require('fs');
|
||||
const { Configuration, OpenAIApi } = require('openai');
|
||||
const path = require('path');
|
||||
const OpenAI = require('openai');
|
||||
// const { genAzureEndpoint } = require('../../../utils/genAzureEndpoints');
|
||||
const { Tool } = require('langchain/tools');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const saveImageFromUrl = require('./saveImageFromUrl');
|
||||
const path = require('path');
|
||||
const extractBaseURL = require('../../../utils/extractBaseURL');
|
||||
const { DALLE_REVERSE_PROXY, PROXY } = process.env;
|
||||
|
||||
class OpenAICreateImage extends Tool {
|
||||
constructor(fields = {}) {
|
||||
super();
|
||||
|
||||
let apiKey = fields.DALLE_API_KEY || this.getApiKey();
|
||||
const config = { apiKey };
|
||||
if (DALLE_REVERSE_PROXY) {
|
||||
config.baseURL = extractBaseURL(DALLE_REVERSE_PROXY);
|
||||
}
|
||||
|
||||
if (PROXY) {
|
||||
config.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
}
|
||||
|
||||
// let azureKey = fields.AZURE_API_KEY || process.env.AZURE_API_KEY;
|
||||
let config = { apiKey };
|
||||
|
||||
// if (azureKey) {
|
||||
// apiKey = azureKey;
|
||||
@@ -36,7 +47,7 @@ class OpenAICreateImage extends Tool {
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
this.openaiApi = new OpenAIApi(new Configuration(config));
|
||||
this.openai = new OpenAI(config);
|
||||
this.name = 'dall-e';
|
||||
this.description = `You can generate images with 'dall-e'. This tool is exclusively for visual content.
|
||||
Guidelines:
|
||||
@@ -58,7 +69,7 @@ Guidelines:
|
||||
replaceUnwantedChars(inputString) {
|
||||
return inputString
|
||||
.replace(/\r\n|\r|\n/g, ' ')
|
||||
.replace('"', '')
|
||||
.replace(/"/g, '')
|
||||
.trim();
|
||||
}
|
||||
|
||||
@@ -71,7 +82,7 @@ Guidelines:
|
||||
}
|
||||
|
||||
async _call(input) {
|
||||
const resp = await this.openaiApi.createImage({
|
||||
const resp = await this.openai.images.generate({
|
||||
prompt: this.replaceUnwantedChars(input),
|
||||
// TODO: Future idea -- could we ask an LLM to extract these arguments from an input that might contain them?
|
||||
n: 1,
|
||||
@@ -79,7 +90,7 @@ Guidelines:
|
||||
size: '512x512',
|
||||
});
|
||||
|
||||
const theImageUrl = resp.data.data[0].url;
|
||||
const theImageUrl = resp.data[0].url;
|
||||
|
||||
if (!theImageUrl) {
|
||||
throw new Error('No image URL returned from OpenAI API.');
|
||||
|
||||
@@ -20,8 +20,8 @@ function createPrompt(name, functions) {
|
||||
.map((func) => `// - ${func.name}: ${func.description}`)
|
||||
.join('\n');
|
||||
return `${prefix}\n${functionDescriptions}
|
||||
// The user's message will be passed as the function's query.
|
||||
// Always provide the function name as such: {{"func": "function_name"}}`;
|
||||
// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function.
|
||||
// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`;
|
||||
}
|
||||
|
||||
const AuthBearer = z
|
||||
@@ -83,7 +83,7 @@ async function getSpec(url) {
|
||||
return ValidSpecPath.parse(url);
|
||||
}
|
||||
|
||||
async function createOpenAPIPlugin({ data, llm, user, message, verbose = false }) {
|
||||
async function createOpenAPIPlugin({ data, llm, user, message, memory, signal, verbose = false }) {
|
||||
let spec;
|
||||
try {
|
||||
spec = await getSpec(data.api.url, verbose);
|
||||
@@ -128,15 +128,23 @@ async function createOpenAPIPlugin({ data, llm, user, message, verbose = false }
|
||||
chainOptions.params = data.params;
|
||||
}
|
||||
|
||||
chainOptions.prompt = ChatPromptTemplate.fromPromptMessages([
|
||||
let history = '';
|
||||
if (memory) {
|
||||
verbose && console.debug('openAPI chain: memory detected', memory);
|
||||
const { history: chat_history } = await memory.loadMemoryVariables({});
|
||||
history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : '';
|
||||
}
|
||||
|
||||
chainOptions.prompt = ChatPromptTemplate.fromMessages([
|
||||
HumanMessagePromptTemplate.fromTemplate(
|
||||
`# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix(
|
||||
description_for_model,
|
||||
)}`,
|
||||
)}${history}`,
|
||||
),
|
||||
]);
|
||||
|
||||
const chain = await createOpenAPIChain(spec, chainOptions);
|
||||
|
||||
const { functions } = chain.chains[0].lc_kwargs.llmKwargs;
|
||||
|
||||
return new DynamicStructuredTool({
|
||||
@@ -154,10 +162,19 @@ async function createOpenAPIPlugin({ data, llm, user, message, verbose = false }
|
||||
.map((func) => func.name)
|
||||
.join(', ')}`,
|
||||
),
|
||||
intent: z
|
||||
.string()
|
||||
.describe('Describe your intent with the function and your expected result'),
|
||||
}),
|
||||
func: async ({ func = '' }) => {
|
||||
const result = await chain.run(`${message}${func?.length > 0 ? `\nUse ${func}` : ''}`);
|
||||
return result;
|
||||
func: async ({ func = '', intent = '' }) => {
|
||||
const filteredFunctions = functions.filter((f) => f.name === func);
|
||||
chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions;
|
||||
const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`;
|
||||
const result = await chain.call({
|
||||
query,
|
||||
signal,
|
||||
});
|
||||
return result.response;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ const GoogleSearchAPI = require('./GoogleSearch');
|
||||
const HttpRequestTool = require('./HttpRequestTool');
|
||||
const AIPluginTool = require('./AIPluginTool');
|
||||
const OpenAICreateImage = require('./DALL-E');
|
||||
const DALLE3 = require('./structured/DALLE3');
|
||||
const StructuredSD = require('./structured/StableDiffusion');
|
||||
const StableDiffusionAPI = require('./StableDiffusion');
|
||||
const WolframAlphaAPI = require('./Wolfram');
|
||||
@@ -23,6 +24,7 @@ module.exports = {
|
||||
HttpRequestTool,
|
||||
AIPluginTool,
|
||||
OpenAICreateImage,
|
||||
DALLE3,
|
||||
StableDiffusionAPI,
|
||||
StructuredSD,
|
||||
WolframAlphaAPI,
|
||||
|
||||
@@ -95,6 +95,19 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "DALL-E-3",
|
||||
"pluginKey": "dalle",
|
||||
"description": "[DALL-E-3] Create realistic images and art from a description in natural language",
|
||||
"icon": "https://i.imgur.com/u2TzXzH.png",
|
||||
"authConfig": [
|
||||
{
|
||||
"authField": "DALLE_API_KEY",
|
||||
"label": "OpenAI API Key",
|
||||
"description": "You can use DALL-E with your API Key from OpenAI."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Calculator",
|
||||
"pluginKey": "calculator",
|
||||
|
||||
165
api/app/clients/tools/structured/DALLE3.js
Normal file
165
api/app/clients/tools/structured/DALLE3.js
Normal file
@@ -0,0 +1,165 @@
|
||||
// From https://platform.openai.com/docs/guides/images/usage?context=node
|
||||
// To use this tool, you must pass in a configured OpenAIApi object.
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { z } = require('zod');
|
||||
const OpenAI = require('openai');
|
||||
const { Tool } = require('langchain/tools');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const saveImageFromUrl = require('../saveImageFromUrl');
|
||||
const extractBaseURL = require('../../../../utils/extractBaseURL');
|
||||
const { DALLE3_SYSTEM_PROMPT, DALLE_REVERSE_PROXY, PROXY } = process.env;
|
||||
class DALLE3 extends Tool {
|
||||
constructor(fields = {}) {
|
||||
super();
|
||||
|
||||
let apiKey = fields.DALLE_API_KEY || this.getApiKey();
|
||||
const config = { apiKey };
|
||||
if (DALLE_REVERSE_PROXY) {
|
||||
config.baseURL = extractBaseURL(DALLE_REVERSE_PROXY);
|
||||
}
|
||||
|
||||
if (PROXY) {
|
||||
config.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
}
|
||||
|
||||
this.openai = new OpenAI(config);
|
||||
this.name = 'dalle';
|
||||
this.description = `Use DALLE to create images from text descriptions.
|
||||
- It requires prompts to be in English, detailed, and to specify image type and human features for diversity.
|
||||
- Create only one image, without repeating or listing descriptions outside the "prompts" field.
|
||||
- Maintains the original intent of the description, with parameters for image style, quality, and size to tailor the output.`;
|
||||
this.description_for_model =
|
||||
DALLE3_SYSTEM_PROMPT ??
|
||||
`// Whenever a description of an image is given, generate prompts (following these rules), and use dalle to create the image. If the user does not ask for a specific number of images, default to creating 2 prompts to send to dalle that are written to be as diverse as possible. All prompts sent to dalle must abide by the following policies:
|
||||
// 1. Prompts must be in English. Translate to English if needed.
|
||||
// 2. One image per function call. Create only 1 image per request unless explicitly told to generate more than 1 image.
|
||||
// 3. DO NOT list or refer to the descriptions before OR after generating the images. They should ONLY ever be written out ONCE, in the \`"prompts"\` field of the request. You do not need to ask for permission to generate, just do it!
|
||||
// 4. Always mention the image type (photo, oil painting, watercolor painting, illustration, cartoon, drawing, vector, render, etc.) at the beginning of the caption. Unless the captions suggests otherwise, make one of the images a photo.
|
||||
// 5. Diversify depictions of ALL images with people to always include always DESCENT and GENDER for EACH person using direct terms. Adjust only human descriptions.
|
||||
// - EXPLICITLY specify these attributes, not abstractly reference them. The attributes should be specified in a minimal way and should directly describe their physical form.
|
||||
// - Your choices should be grounded in reality. For example, all of a given OCCUPATION should not be the same gender or race. Additionally, focus on creating diverse, inclusive, and exploratory scenes via the properties you choose during rewrites. Make choices that may be insightful or unique sometimes.
|
||||
// - Use "various" or "diverse" ONLY IF the description refers to groups of more than 3 people. Do not change the number of people requested in the original description.
|
||||
// - Don't alter memes, fictional character origins, or unseen people. Maintain the original prompt's intent and prioritize quality.
|
||||
// The prompt must intricately describe every part of the image in concrete, objective detail. THINK about what the end goal of the description is, and extrapolate that to what would make satisfying images.
|
||||
// All descriptions sent to dalle should be a paragraph of text that is extremely descriptive and detailed. Each should be more than 3 sentences long.`;
|
||||
this.schema = z.object({
|
||||
prompt: z
|
||||
.string()
|
||||
.max(4000)
|
||||
.describe(
|
||||
'A text description of the desired image, following the rules, up to 4000 characters.',
|
||||
),
|
||||
style: z
|
||||
.enum(['vivid', 'natural'])
|
||||
.describe(
|
||||
'Must be one of `vivid` or `natural`. `vivid` generates hyper-real and dramatic images, `natural` produces more natural, less hyper-real looking images',
|
||||
),
|
||||
quality: z
|
||||
.enum(['hd', 'standard'])
|
||||
.describe('The quality of the generated image. Only `hd` and `standard` are supported.'),
|
||||
size: z
|
||||
.enum(['1024x1024', '1792x1024', '1024x1792'])
|
||||
.describe(
|
||||
'The size of the requested image. Use 1024x1024 (square) as the default, 1792x1024 if the user requests a wide image, and 1024x1792 for full-body portraits. Always include this parameter in the request.',
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
getApiKey() {
|
||||
const apiKey = process.env.DALLE_API_KEY || '';
|
||||
if (!apiKey) {
|
||||
throw new Error('Missing DALLE_API_KEY environment variable.');
|
||||
}
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
replaceUnwantedChars(inputString) {
|
||||
return inputString
|
||||
.replace(/\r\n|\r|\n/g, ' ')
|
||||
.replace(/"/g, '')
|
||||
.trim();
|
||||
}
|
||||
|
||||
getMarkdownImageUrl(imageName) {
|
||||
const imageUrl = path
|
||||
.join(this.relativeImageUrl, imageName)
|
||||
.replace(/\\/g, '/')
|
||||
.replace('public/', '');
|
||||
return ``;
|
||||
}
|
||||
|
||||
async _call(data) {
|
||||
const { prompt, quality = 'standard', size = '1024x1024', style = 'vivid' } = data;
|
||||
if (!prompt) {
|
||||
throw new Error('Missing required field: prompt');
|
||||
}
|
||||
|
||||
let resp;
|
||||
try {
|
||||
resp = await this.openai.images.generate({
|
||||
model: 'dall-e-3',
|
||||
quality,
|
||||
style,
|
||||
size,
|
||||
prompt: this.replaceUnwantedChars(prompt),
|
||||
n: 1,
|
||||
});
|
||||
} catch (error) {
|
||||
return `Something went wrong when trying to generate the image. The DALL-E API may unavailable:
|
||||
Error Message: ${error.message}`;
|
||||
}
|
||||
|
||||
if (!resp) {
|
||||
return 'Something went wrong when trying to generate the image. The DALL-E API may unavailable';
|
||||
}
|
||||
|
||||
const theImageUrl = resp.data[0].url;
|
||||
|
||||
if (!theImageUrl) {
|
||||
return 'No image URL returned from OpenAI API. There may be a problem with the API or your configuration.';
|
||||
}
|
||||
|
||||
const regex = /img-[\w\d]+.png/;
|
||||
const match = theImageUrl.match(regex);
|
||||
let imageName = '1.png';
|
||||
|
||||
if (match) {
|
||||
imageName = match[0];
|
||||
console.log(imageName); // Output: img-lgCf7ppcbhqQrz6a5ear6FOb.png
|
||||
} else {
|
||||
console.log('No image name found in the string.');
|
||||
}
|
||||
|
||||
this.outputPath = path.resolve(
|
||||
__dirname,
|
||||
'..',
|
||||
'..',
|
||||
'..',
|
||||
'..',
|
||||
'..',
|
||||
'client',
|
||||
'public',
|
||||
'images',
|
||||
);
|
||||
const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client');
|
||||
this.relativeImageUrl = path.relative(appRoot, this.outputPath);
|
||||
|
||||
// Check if directory exists, if not create it
|
||||
if (!fs.existsSync(this.outputPath)) {
|
||||
fs.mkdirSync(this.outputPath, { recursive: true });
|
||||
}
|
||||
|
||||
try {
|
||||
await saveImageFromUrl(theImageUrl, this.outputPath, imageName);
|
||||
this.result = this.getMarkdownImageUrl(imageName);
|
||||
} catch (error) {
|
||||
console.error('Error while saving the image:', error);
|
||||
this.result = theImageUrl;
|
||||
}
|
||||
|
||||
return this.result;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = DALLE3;
|
||||
190
api/app/clients/tools/structured/specs/DALLE3.spec.js
Normal file
190
api/app/clients/tools/structured/specs/DALLE3.spec.js
Normal file
@@ -0,0 +1,190 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const OpenAI = require('openai');
|
||||
const DALLE3 = require('../DALLE3');
|
||||
const saveImageFromUrl = require('../../saveImageFromUrl');
|
||||
|
||||
jest.mock('openai');
|
||||
|
||||
const generate = jest.fn();
|
||||
OpenAI.mockImplementation(() => ({
|
||||
images: {
|
||||
generate,
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('fs', () => {
|
||||
return {
|
||||
existsSync: jest.fn(),
|
||||
mkdirSync: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('../../saveImageFromUrl', () => {
|
||||
return jest.fn();
|
||||
});
|
||||
|
||||
jest.mock('path', () => {
|
||||
return {
|
||||
resolve: jest.fn(),
|
||||
join: jest.fn(),
|
||||
relative: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('DALLE3', () => {
|
||||
let originalEnv;
|
||||
let dalle; // Keep this declaration if you need to use dalle in other tests
|
||||
const mockApiKey = 'mock_api_key';
|
||||
|
||||
beforeAll(() => {
|
||||
// Save the original process.env
|
||||
originalEnv = { ...process.env };
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset the process.env before each test
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv, DALLE_API_KEY: mockApiKey };
|
||||
// Instantiate DALLE3 for tests that do not depend on DALLE3_SYSTEM_PROMPT
|
||||
dalle = new DALLE3();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
// Restore the original process.env after each test
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should throw an error if DALLE_API_KEY is missing', () => {
|
||||
delete process.env.DALLE_API_KEY;
|
||||
expect(() => new DALLE3()).toThrow('Missing DALLE_API_KEY environment variable.');
|
||||
});
|
||||
|
||||
it('should replace unwanted characters in input string', () => {
|
||||
const input = 'This is a test\nstring with "quotes" and new lines.';
|
||||
const expectedOutput = 'This is a test string with quotes and new lines.';
|
||||
expect(dalle.replaceUnwantedChars(input)).toBe(expectedOutput);
|
||||
});
|
||||
|
||||
it('should generate markdown image URL correctly', () => {
|
||||
const imageName = 'test.png';
|
||||
path.join.mockReturnValue('images/test.png');
|
||||
path.relative.mockReturnValue('images/test.png');
|
||||
const markdownImage = dalle.getMarkdownImageUrl(imageName);
|
||||
expect(markdownImage).toBe('');
|
||||
});
|
||||
|
||||
it('should call OpenAI API with correct parameters', async () => {
|
||||
const mockData = {
|
||||
prompt: 'A test prompt',
|
||||
quality: 'standard',
|
||||
size: '1024x1024',
|
||||
style: 'vivid',
|
||||
};
|
||||
|
||||
const mockResponse = {
|
||||
data: [
|
||||
{
|
||||
url: 'http://example.com/img-test.png',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
generate.mockResolvedValue(mockResponse);
|
||||
saveImageFromUrl.mockResolvedValue(true);
|
||||
fs.existsSync.mockReturnValue(true);
|
||||
path.resolve.mockReturnValue('/fakepath/images');
|
||||
path.join.mockReturnValue('/fakepath/images/img-test.png');
|
||||
path.relative.mockReturnValue('images/img-test.png');
|
||||
|
||||
const result = await dalle._call(mockData);
|
||||
|
||||
expect(generate).toHaveBeenCalledWith({
|
||||
model: 'dall-e-3',
|
||||
quality: mockData.quality,
|
||||
style: mockData.style,
|
||||
size: mockData.size,
|
||||
prompt: mockData.prompt,
|
||||
n: 1,
|
||||
});
|
||||
expect(result).toContain('![generated image]');
|
||||
});
|
||||
|
||||
it('should use the system prompt if provided', () => {
|
||||
process.env.DALLE3_SYSTEM_PROMPT = 'System prompt for testing';
|
||||
jest.resetModules(); // This will ensure the module is fresh and will read the new env var
|
||||
const DALLE3 = require('../DALLE3'); // Re-require after setting the env var
|
||||
const dalleWithSystemPrompt = new DALLE3();
|
||||
expect(dalleWithSystemPrompt.description_for_model).toBe('System prompt for testing');
|
||||
});
|
||||
|
||||
it('should not use the system prompt if not provided', async () => {
|
||||
delete process.env.DALLE3_SYSTEM_PROMPT;
|
||||
const dalleWithoutSystemPrompt = new DALLE3();
|
||||
expect(dalleWithoutSystemPrompt.description_for_model).not.toBe('System prompt for testing');
|
||||
});
|
||||
|
||||
it('should throw an error if prompt is missing', async () => {
|
||||
const mockData = {
|
||||
quality: 'standard',
|
||||
size: '1024x1024',
|
||||
style: 'vivid',
|
||||
};
|
||||
await expect(dalle._call(mockData)).rejects.toThrow('Missing required field: prompt');
|
||||
});
|
||||
|
||||
it('should log to console if no image name is found in the URL', async () => {
|
||||
const mockData = {
|
||||
prompt: 'A test prompt',
|
||||
};
|
||||
const mockResponse = {
|
||||
data: [
|
||||
{
|
||||
url: 'http://example.com/invalid-url',
|
||||
},
|
||||
],
|
||||
};
|
||||
console.log = jest.fn(); // Mock console.log
|
||||
generate.mockResolvedValue(mockResponse);
|
||||
await dalle._call(mockData);
|
||||
expect(console.log).toHaveBeenCalledWith('No image name found in the string.');
|
||||
});
|
||||
|
||||
it('should create the directory if it does not exist', async () => {
|
||||
const mockData = {
|
||||
prompt: 'A test prompt',
|
||||
};
|
||||
const mockResponse = {
|
||||
data: [
|
||||
{
|
||||
url: 'http://example.com/img-test.png',
|
||||
},
|
||||
],
|
||||
};
|
||||
generate.mockResolvedValue(mockResponse);
|
||||
fs.existsSync.mockReturnValue(false); // Simulate directory does not exist
|
||||
await dalle._call(mockData);
|
||||
expect(fs.mkdirSync).toHaveBeenCalledWith(expect.any(String), { recursive: true });
|
||||
});
|
||||
|
||||
it('should log an error and return the image URL if there is an error saving the image', async () => {
|
||||
const mockData = {
|
||||
prompt: 'A test prompt',
|
||||
};
|
||||
const mockResponse = {
|
||||
data: [
|
||||
{
|
||||
url: 'http://example.com/img-test.png',
|
||||
},
|
||||
],
|
||||
};
|
||||
const error = new Error('Error while saving the image');
|
||||
generate.mockResolvedValue(mockResponse);
|
||||
saveImageFromUrl.mockRejectedValue(error);
|
||||
console.error = jest.fn(); // Mock console.error
|
||||
const result = await dalle._call(mockData);
|
||||
expect(console.error).toHaveBeenCalledWith('Error while saving the image:', error);
|
||||
expect(result).toBe(mockResponse.data[0].url);
|
||||
});
|
||||
});
|
||||
30
api/app/clients/tools/util/handleOpenAIErrors.js
Normal file
30
api/app/clients/tools/util/handleOpenAIErrors.js
Normal file
@@ -0,0 +1,30 @@
|
||||
const OpenAI = require('openai');
|
||||
|
||||
/**
|
||||
* Handles errors that may occur when making requests to OpenAI's API.
|
||||
* It checks the instance of the error and prints a specific warning message
|
||||
* to the console depending on the type of error encountered.
|
||||
* It then calls an optional error callback function with the error object.
|
||||
*
|
||||
* @param {Error} err - The error object thrown by OpenAI API.
|
||||
* @param {Function} errorCallback - A callback function that is called with the error object.
|
||||
* @param {string} [context='stream'] - A string providing context where the error occurred, defaults to 'stream'.
|
||||
*/
|
||||
async function handleOpenAIErrors(err, errorCallback, context = 'stream') {
|
||||
if (err instanceof OpenAI.APIError && err?.message?.includes('abort')) {
|
||||
console.warn(`[OpenAIClient.chatCompletion][${context}] Aborted Message`);
|
||||
}
|
||||
if (err instanceof OpenAI.OpenAIError && err?.message?.includes('missing finish_reason')) {
|
||||
console.warn(`[OpenAIClient.chatCompletion][${context}] Missing finish_reason`);
|
||||
} else if (err instanceof OpenAI.APIError) {
|
||||
console.warn(`[OpenAIClient.chatCompletion][${context}] API Error`);
|
||||
} else {
|
||||
console.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`);
|
||||
}
|
||||
|
||||
if (errorCallback) {
|
||||
errorCallback(err);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = handleOpenAIErrors;
|
||||
@@ -15,6 +15,7 @@ const {
|
||||
HttpRequestTool,
|
||||
OpenAICreateImage,
|
||||
StableDiffusionAPI,
|
||||
DALLE3,
|
||||
StructuredSD,
|
||||
AzureCognitiveSearch,
|
||||
StructuredACS,
|
||||
@@ -176,6 +177,7 @@ const loadTools = async ({
|
||||
const requestedTools = {};
|
||||
|
||||
if (functions) {
|
||||
toolConstructors.dalle = DALLE3;
|
||||
toolConstructors.codesherpa = CodeSherpa;
|
||||
}
|
||||
|
||||
@@ -224,6 +226,8 @@ const loadTools = async ({
|
||||
llm: model,
|
||||
user,
|
||||
message: options.message,
|
||||
memory: options.memory,
|
||||
signal: options.signal,
|
||||
tools: remainingTools,
|
||||
map: true,
|
||||
verbose: false,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
const { validateTools, loadTools } = require('./handleTools');
|
||||
const handleOpenAIErrors = require('./handleOpenAIErrors');
|
||||
|
||||
module.exports = {
|
||||
handleOpenAIErrors,
|
||||
validateTools,
|
||||
loadTools,
|
||||
};
|
||||
|
||||
@@ -38,7 +38,16 @@ function validateJson(json, verbose = true) {
|
||||
}
|
||||
|
||||
// omit the LLM to return the well known jsons as objects
|
||||
async function loadSpecs({ llm, user, message, tools = [], map = false, verbose = false }) {
|
||||
async function loadSpecs({
|
||||
llm,
|
||||
user,
|
||||
message,
|
||||
tools = [],
|
||||
map = false,
|
||||
memory,
|
||||
signal,
|
||||
verbose = false,
|
||||
}) {
|
||||
const directoryPath = path.join(__dirname, '..', '.well-known');
|
||||
let files = [];
|
||||
|
||||
@@ -85,6 +94,8 @@ async function loadSpecs({ llm, user, message, tools = [], map = false, verbose
|
||||
data: json,
|
||||
llm,
|
||||
message,
|
||||
memory,
|
||||
signal,
|
||||
user,
|
||||
verbose,
|
||||
});
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
const { isEnabled } = require('../server/utils');
|
||||
const throttle = require('lodash/throttle');
|
||||
|
||||
const titleConvo = async ({ text, response }) => {
|
||||
let title = 'New Chat';
|
||||
const { BingAIClient } = await import('@waylaidwanderer/chatgpt-api');
|
||||
const { TITLE_CONVO = 'true' } = process.env ?? {};
|
||||
if (!isEnabled(TITLE_CONVO)) {
|
||||
return title;
|
||||
}
|
||||
|
||||
const { BingAIClient } = await import('nodejs-gpt');
|
||||
const titleGenerator = new BingAIClient({
|
||||
userToken: process.env.BINGAI_TOKEN,
|
||||
debug: false,
|
||||
|
||||
13
api/cache/banViolation.js
vendored
13
api/cache/banViolation.js
vendored
@@ -47,17 +47,24 @@ const banViolation = async (req, res, errorMessage) => {
|
||||
res.clearCookie('refreshToken');
|
||||
|
||||
const banLogs = getLogStores('ban');
|
||||
const duration = banLogs.opts.ttl;
|
||||
const duration = errorMessage.duration || banLogs.opts.ttl;
|
||||
|
||||
if (duration <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
req.ip = removePorts(req);
|
||||
console.log(`[BAN] Banning user ${user_id} @ ${req.ip} for ${duration / 1000 / 60} minutes`);
|
||||
console.log(
|
||||
`[BAN] Banning user ${user_id} ${req.ip ? `@ ${req.ip} ` : ''}for ${
|
||||
duration / 1000 / 60
|
||||
} minutes`,
|
||||
);
|
||||
|
||||
const expiresAt = Date.now() + duration;
|
||||
await banLogs.set(user_id, { type, violation_count, duration, expiresAt });
|
||||
await banLogs.set(req.ip, { type, user_id, violation_count, duration, expiresAt });
|
||||
if (req.ip) {
|
||||
await banLogs.set(req.ip, { type, user_id, violation_count, duration, expiresAt });
|
||||
}
|
||||
|
||||
errorMessage.ban = true;
|
||||
errorMessage.ban_duration = duration;
|
||||
|
||||
45
api/cache/clearPendingReq.js
vendored
45
api/cache/clearPendingReq.js
vendored
@@ -1,29 +1,48 @@
|
||||
const Keyv = require('keyv');
|
||||
const { pendingReqFile } = require('./keyvFiles');
|
||||
const { LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
|
||||
|
||||
const keyv = new Keyv({ store: pendingReqFile, namespace: 'pendingRequests' });
|
||||
const getLogStores = require('./getLogStores');
|
||||
const { isEnabled } = require('../server/utils');
|
||||
const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
|
||||
const ttl = 1000 * 60 * 1;
|
||||
|
||||
/**
|
||||
* Clear pending requests from the cache.
|
||||
* Clear or decrement pending requests from the cache.
|
||||
* Checks the environmental variable LIMIT_CONCURRENT_MESSAGES;
|
||||
* if the rule is enabled ('true'), pending requests in the cache are cleared.
|
||||
* if the rule is enabled ('true'), it either decrements the count of pending requests
|
||||
* or deletes the key if the count is less than or equal to 1.
|
||||
*
|
||||
* @module clearPendingReq
|
||||
* @requires keyv
|
||||
* @requires keyvFiles
|
||||
* @requires ./getLogStores
|
||||
* @requires ../server/utils
|
||||
* @requires process
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @returns {Promise<void>} A promise that either clears 'pendingRequests' from store or resolves with no value.
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {string} params.userId - The user ID for which the pending requests are to be cleared or decremented.
|
||||
* @param {Object} [params.cache] - An optional cache object to use. If not provided, a default cache will be fetched using getLogStores.
|
||||
* @returns {Promise<void>} A promise that either decrements the 'pendingRequests' count, deletes the key from the store, or resolves with no value.
|
||||
*/
|
||||
const clearPendingReq = async () => {
|
||||
if (LIMIT_CONCURRENT_MESSAGES?.toLowerCase() !== 'true') {
|
||||
const clearPendingReq = async ({ userId, cache: _cache }) => {
|
||||
if (!userId) {
|
||||
return;
|
||||
} else if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
return;
|
||||
}
|
||||
|
||||
await keyv.clear();
|
||||
const namespace = 'pending_req';
|
||||
const cache = _cache ?? getLogStores(namespace);
|
||||
|
||||
if (!cache) {
|
||||
return;
|
||||
}
|
||||
|
||||
const key = `${USE_REDIS ? namespace : ''}:${userId ?? ''}`;
|
||||
const currentReq = +((await cache.get(key)) ?? 0);
|
||||
|
||||
if (currentReq && currentReq >= 1) {
|
||||
await cache.set(key, currentReq - 1, ttl);
|
||||
} else {
|
||||
await cache.delete(key);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = clearPendingReq;
|
||||
|
||||
39
api/cache/getLogStores.js
vendored
39
api/cache/getLogStores.js
vendored
@@ -1,25 +1,37 @@
|
||||
const Keyv = require('keyv');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
const { math } = require('../server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const { math, isEnabled } = require('../server/utils');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const { BAN_DURATION } = process.env ?? {};
|
||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
|
||||
const createViolationInstance = (namespace) => {
|
||||
const config = isEnabled(USE_REDIS) ? { store: keyvRedis } : { store: violationFile, namespace };
|
||||
return new Keyv(config);
|
||||
};
|
||||
|
||||
// Serve cache from memory so no need to clear it on startup/exit
|
||||
const pending_req = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: 'pending_req' });
|
||||
|
||||
const namespaces = {
|
||||
ban: new Keyv({ store: keyvMongo, ttl: duration, namespace: 'bans' }),
|
||||
pending_req,
|
||||
ban: new Keyv({ store: keyvMongo, namespace: 'bans', ttl: duration }),
|
||||
general: new Keyv({ store: logFile, namespace: 'violations' }),
|
||||
concurrent: new Keyv({ store: violationFile, namespace: 'concurrent' }),
|
||||
non_browser: new Keyv({ store: violationFile, namespace: 'non_browser' }),
|
||||
message_limit: new Keyv({ store: violationFile, namespace: 'message_limit' }),
|
||||
registrations: new Keyv({ store: violationFile, namespace: 'registrations' }),
|
||||
logins: new Keyv({ store: violationFile, namespace: 'logins' }),
|
||||
concurrent: createViolationInstance('concurrent'),
|
||||
non_browser: createViolationInstance('non_browser'),
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance('token_balance'),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
logins: createViolationInstance('logins'),
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns either the logs of violations specified by type if a type is provided
|
||||
* or it returns the general log if no type is specified. If an invalid type is passed,
|
||||
* an error will be thrown.
|
||||
* Returns the keyv cache specified by type.
|
||||
* If an invalid type is passed, an error will be thrown.
|
||||
*
|
||||
* @module getLogStores
|
||||
* @requires keyv - a simple key-value storage that allows you to easily switch out storage adapters.
|
||||
@@ -30,11 +42,10 @@ const namespaces = {
|
||||
* @throws Will throw an error if an invalid violation type is passed.
|
||||
*/
|
||||
const getLogStores = (type) => {
|
||||
if (!type) {
|
||||
if (!type || !namespaces[type]) {
|
||||
throw new Error(`Invalid store type: ${type}`);
|
||||
}
|
||||
const logs = namespaces[type];
|
||||
return logs;
|
||||
return namespaces[type];
|
||||
};
|
||||
|
||||
module.exports = getLogStores;
|
||||
|
||||
3
api/cache/index.js
vendored
3
api/cache/index.js
vendored
@@ -1,6 +1,5 @@
|
||||
const keyvFiles = require('./keyvFiles');
|
||||
const getLogStores = require('./getLogStores');
|
||||
const logViolation = require('./logViolation');
|
||||
const clearPendingReq = require('./clearPendingReq');
|
||||
|
||||
module.exports = { ...keyvFiles, getLogStores, logViolation, clearPendingReq };
|
||||
module.exports = { ...keyvFiles, getLogStores, logViolation };
|
||||
|
||||
14
api/cache/keyvRedis.js
vendored
Normal file
14
api/cache/keyvRedis.js
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
const KeyvRedis = require('@keyv/redis');
|
||||
|
||||
const { REDIS_URI } = process.env;
|
||||
|
||||
let keyvRedis;
|
||||
|
||||
if (REDIS_URI) {
|
||||
keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false });
|
||||
keyvRedis.on('error', (err) => console.error('KeyvRedis connection error:', err));
|
||||
} else {
|
||||
// console.log('REDIS_URI not provided. Redis module will not be initialized.');
|
||||
}
|
||||
|
||||
module.exports = keyvRedis;
|
||||
13
api/cache/logViolation.js
vendored
13
api/cache/logViolation.js
vendored
@@ -1,5 +1,6 @@
|
||||
const getLogStores = require('./getLogStores');
|
||||
const banViolation = require('./banViolation');
|
||||
const { isEnabled } = require('../server/utils');
|
||||
|
||||
/**
|
||||
* Logs the violation.
|
||||
@@ -17,10 +18,11 @@ const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||
}
|
||||
const logs = getLogStores('general');
|
||||
const violationLogs = getLogStores(type);
|
||||
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;
|
||||
|
||||
const userViolations = (await violationLogs.get(userId)) ?? 0;
|
||||
const violationCount = userViolations + score;
|
||||
await violationLogs.set(userId, violationCount);
|
||||
const userViolations = (await violationLogs.get(key)) ?? 0;
|
||||
const violationCount = +userViolations + +score;
|
||||
await violationLogs.set(key, violationCount);
|
||||
|
||||
errorMessage.user_id = userId;
|
||||
errorMessage.prev_count = userViolations;
|
||||
@@ -28,9 +30,10 @@ const logViolation = async (req, res, type, errorMessage, score = 1) => {
|
||||
errorMessage.date = new Date().toISOString();
|
||||
|
||||
await banViolation(req, res, errorMessage);
|
||||
const userLogs = (await logs.get(userId)) ?? [];
|
||||
const userLogs = (await logs.get(key)) ?? [];
|
||||
userLogs.push(errorMessage);
|
||||
await logs.set(userId, userLogs);
|
||||
delete errorMessage.user_id;
|
||||
await logs.set(key, userLogs);
|
||||
};
|
||||
|
||||
module.exports = logViolation;
|
||||
|
||||
4
api/cache/redis.js
vendored
Normal file
4
api/cache/redis.js
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
const Redis = require('ioredis');
|
||||
const { REDIS_URI } = process.env ?? {};
|
||||
const redis = new Redis.Cluster(REDIS_URI);
|
||||
module.exports = redis;
|
||||
13
api/jsconfig.json
Normal file
13
api/jsconfig.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES6",
|
||||
"module": "CommonJS",
|
||||
// "checkJs": true, // Report errors in JavaScript files
|
||||
"baseUrl": "./",
|
||||
"paths": {
|
||||
"*": ["*", "node_modules/*"],
|
||||
"~/*": ["./*"]
|
||||
}
|
||||
},
|
||||
"exclude": ["node_modules"]
|
||||
}
|
||||
@@ -18,11 +18,12 @@ if (!cached) {
|
||||
}
|
||||
|
||||
async function connectDb() {
|
||||
if (cached.conn) {
|
||||
if (cached.conn && cached.conn?._readyState === 1) {
|
||||
return cached.conn;
|
||||
}
|
||||
|
||||
if (!cached.promise) {
|
||||
const disconnected = cached.conn && cached.conn?._readyState !== 1;
|
||||
if (!cached.promise || disconnected) {
|
||||
const opts = {
|
||||
useNewUrlParser: true,
|
||||
useUnifiedTopology: true,
|
||||
|
||||
38
api/models/Balance.js
Normal file
38
api/models/Balance.js
Normal file
@@ -0,0 +1,38 @@
|
||||
const mongoose = require('mongoose');
|
||||
const balanceSchema = require('./schema/balance');
|
||||
const { getMultiplier } = require('./tx');
|
||||
|
||||
balanceSchema.statics.check = async function ({ user, model, valueKey, tokenType, amount, debug }) {
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model });
|
||||
const tokenCost = amount * multiplier;
|
||||
const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {};
|
||||
|
||||
if (debug) {
|
||||
console.log('balance check', {
|
||||
user,
|
||||
model,
|
||||
valueKey,
|
||||
tokenType,
|
||||
amount,
|
||||
debug,
|
||||
balance,
|
||||
multiplier,
|
||||
});
|
||||
}
|
||||
|
||||
if (!balance) {
|
||||
return {
|
||||
canSpend: false,
|
||||
balance: 0,
|
||||
tokenCost,
|
||||
};
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
console.log('balance check', { tokenCost });
|
||||
}
|
||||
|
||||
return { canSpend: balance >= tokenCost, balance, tokenCost };
|
||||
};
|
||||
|
||||
module.exports = mongoose.model('Balance', balanceSchema);
|
||||
4
api/models/Key.js
Normal file
4
api/models/Key.js
Normal file
@@ -0,0 +1,4 @@
|
||||
const mongoose = require('mongoose');
|
||||
const keySchema = require('./schema/key');
|
||||
|
||||
module.exports = mongoose.model('Key', keySchema);
|
||||
@@ -1,6 +1,6 @@
|
||||
const mongoose = require('mongoose');
|
||||
const crypto = require('crypto');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const signPayload = require('../server/services/signPayload');
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7;
|
||||
|
||||
@@ -31,13 +31,11 @@ sessionSchema.methods.generateRefreshToken = async function () {
|
||||
this.expiration = new Date(expiresIn);
|
||||
}
|
||||
|
||||
const refreshToken = jwt.sign(
|
||||
{
|
||||
id: this.user,
|
||||
},
|
||||
process.env.JWT_REFRESH_SECRET,
|
||||
{ expiresIn: Math.floor((expiresIn - Date.now()) / 1000) },
|
||||
);
|
||||
const refreshToken = await signPayload({
|
||||
payload: { id: this.user },
|
||||
secret: process.env.JWT_REFRESH_SECRET,
|
||||
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
|
||||
});
|
||||
|
||||
const hash = crypto.createHash('sha256');
|
||||
this.refreshTokenHash = hash.update(refreshToken).digest('hex');
|
||||
|
||||
45
api/models/Transaction.js
Normal file
45
api/models/Transaction.js
Normal file
@@ -0,0 +1,45 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { isEnabled } = require('../server/utils/handleText');
|
||||
const transactionSchema = require('./schema/transaction');
|
||||
const { getMultiplier } = require('./tx');
|
||||
const Balance = require('./Balance');
|
||||
const cancelRate = 1.15;
|
||||
|
||||
// Method to calculate and set the tokenValue for a transaction
|
||||
transactionSchema.methods.calculateTokenValue = function () {
|
||||
if (!this.valueKey || !this.tokenType) {
|
||||
this.tokenValue = this.rawAmount;
|
||||
}
|
||||
const { valueKey, tokenType, model } = this;
|
||||
const multiplier = getMultiplier({ valueKey, tokenType, model });
|
||||
this.rate = multiplier;
|
||||
this.tokenValue = this.rawAmount * multiplier;
|
||||
if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') {
|
||||
this.tokenValue = Math.ceil(this.tokenValue * cancelRate);
|
||||
this.rate *= cancelRate;
|
||||
}
|
||||
};
|
||||
|
||||
// Static method to create a transaction and update the balance
|
||||
transactionSchema.statics.create = async function (transactionData) {
|
||||
const Transaction = this;
|
||||
|
||||
const transaction = new Transaction(transactionData);
|
||||
transaction.calculateTokenValue();
|
||||
|
||||
// Save the transaction
|
||||
await transaction.save();
|
||||
|
||||
if (!isEnabled(process.env.CHECK_BALANCE)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust the user's balance
|
||||
return await Balance.findOneAndUpdate(
|
||||
{ user: transaction.user },
|
||||
{ $inc: { tokenCredits: transaction.tokenValue } },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = mongoose.model('Transaction', transactionSchema);
|
||||
@@ -1,6 +1,6 @@
|
||||
const mongoose = require('mongoose');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const signPayload = require('../server/services/signPayload');
|
||||
const userSchema = require('./schema/userSchema.js');
|
||||
const { SESSION_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
|
||||
@@ -21,18 +21,17 @@ userSchema.methods.toJSON = function () {
|
||||
};
|
||||
};
|
||||
|
||||
userSchema.methods.generateToken = function () {
|
||||
const token = jwt.sign(
|
||||
{
|
||||
userSchema.methods.generateToken = async function () {
|
||||
return await signPayload({
|
||||
payload: {
|
||||
id: this._id,
|
||||
username: this.username,
|
||||
provider: this.provider,
|
||||
email: this.email,
|
||||
},
|
||||
process.env.JWT_SECRET,
|
||||
{ expiresIn: expires / 1000 },
|
||||
);
|
||||
return token;
|
||||
secret: process.env.JWT_SECRET,
|
||||
expirationTime: expires / 1000,
|
||||
});
|
||||
};
|
||||
|
||||
userSchema.methods.comparePassword = function (candidatePassword, callback) {
|
||||
|
||||
44
api/models/checkBalance.js
Normal file
44
api/models/checkBalance.js
Normal file
@@ -0,0 +1,44 @@
|
||||
const Balance = require('./Balance');
|
||||
const { logViolation } = require('../cache');
|
||||
/**
|
||||
* Checks the balance for a user and determines if they can spend a certain amount.
|
||||
* If the user cannot spend the amount, it logs a violation and denies the request.
|
||||
*
|
||||
* @async
|
||||
* @function
|
||||
* @param {Object} params - The function parameters.
|
||||
* @param {Object} params.req - The Express request object.
|
||||
* @param {Object} params.res - The Express response object.
|
||||
* @param {Object} params.txData - The transaction data.
|
||||
* @param {string} params.txData.user - The user ID or identifier.
|
||||
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
|
||||
* @param {number} params.txData.amount - The amount of tokens.
|
||||
* @param {boolean} params.txData.debug - Debug flag.
|
||||
* @param {string} params.txData.model - The model name or identifier.
|
||||
* @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request.
|
||||
* @throws {Error} Throws an error if there's an issue with the balance check.
|
||||
*/
|
||||
const checkBalance = async ({ req, res, txData }) => {
|
||||
const { canSpend, balance, tokenCost } = await Balance.check(txData);
|
||||
|
||||
if (canSpend) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const type = 'token_balance';
|
||||
const errorMessage = {
|
||||
type,
|
||||
balance,
|
||||
tokenCost,
|
||||
promptTokens: txData.amount,
|
||||
};
|
||||
|
||||
if (txData.generations && txData.generations.length > 0) {
|
||||
errorMessage.generations = txData.generations;
|
||||
}
|
||||
|
||||
await logViolation(req, res, type, errorMessage, 0);
|
||||
throw new Error(JSON.stringify(errorMessage));
|
||||
};
|
||||
|
||||
module.exports = checkBalance;
|
||||
@@ -5,14 +5,20 @@ const {
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const { getConvoTitle, getConvo, saveConvo } = require('./Conversation');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const Key = require('./Key');
|
||||
const User = require('./User');
|
||||
const Key = require('./schema/keySchema');
|
||||
const Session = require('./Session');
|
||||
const Balance = require('./Balance');
|
||||
const Transaction = require('./Transaction');
|
||||
|
||||
module.exports = {
|
||||
User,
|
||||
Key,
|
||||
Session,
|
||||
Balance,
|
||||
Transaction,
|
||||
|
||||
getMessages,
|
||||
saveMessage,
|
||||
@@ -23,6 +29,7 @@ module.exports = {
|
||||
getConvoTitle,
|
||||
getConvo,
|
||||
saveConvo,
|
||||
deleteConvos,
|
||||
|
||||
getPreset,
|
||||
getPresets,
|
||||
|
||||
17
api/models/schema/balance.js
Normal file
17
api/models/schema/balance.js
Normal file
@@ -0,0 +1,17 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const balanceSchema = mongoose.Schema({
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
index: true,
|
||||
required: true,
|
||||
},
|
||||
// 1000 tokenCredits = 1 mill ($0.001 USD)
|
||||
tokenCredits: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
});
|
||||
|
||||
module.exports = balanceSchema;
|
||||
@@ -22,4 +22,4 @@ const keySchema = mongoose.Schema({
|
||||
|
||||
keySchema.index({ expiresAt: 1 }, { expireAfterSeconds: 0 });
|
||||
|
||||
module.exports = mongoose.model('Key', keySchema);
|
||||
module.exports = keySchema;
|
||||
@@ -39,7 +39,7 @@ const messageSchema = mongoose.Schema(
|
||||
tokenCount: {
|
||||
type: Number,
|
||||
},
|
||||
refinedTokenCount: {
|
||||
summaryTokenCount: {
|
||||
type: Number,
|
||||
},
|
||||
sender: {
|
||||
@@ -52,7 +52,7 @@ const messageSchema = mongoose.Schema(
|
||||
required: true,
|
||||
meiliIndex: true,
|
||||
},
|
||||
refinedMessageText: {
|
||||
summary: {
|
||||
type: String,
|
||||
},
|
||||
isCreatedByUser: {
|
||||
|
||||
39
api/models/schema/transaction.js
Normal file
39
api/models/schema/transaction.js
Normal file
@@ -0,0 +1,39 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const transactionSchema = mongoose.Schema(
|
||||
{
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
index: true,
|
||||
required: true,
|
||||
},
|
||||
conversationId: {
|
||||
type: String,
|
||||
ref: 'Conversation',
|
||||
index: true,
|
||||
},
|
||||
tokenType: {
|
||||
type: String,
|
||||
enum: ['prompt', 'completion', 'credits'],
|
||||
required: true,
|
||||
},
|
||||
model: {
|
||||
type: String,
|
||||
},
|
||||
context: {
|
||||
type: String,
|
||||
},
|
||||
valueKey: {
|
||||
type: String,
|
||||
},
|
||||
rate: Number,
|
||||
rawAmount: Number,
|
||||
tokenValue: Number,
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = transactionSchema;
|
||||
49
api/models/spendTokens.js
Normal file
49
api/models/spendTokens.js
Normal file
@@ -0,0 +1,49 @@
|
||||
const Transaction = require('./Transaction');
|
||||
|
||||
/**
|
||||
* Creates up to two transactions to record the spending of tokens.
|
||||
*
|
||||
* @function
|
||||
* @async
|
||||
* @param {Object} txData - Transaction data.
|
||||
* @param {mongoose.Schema.Types.ObjectId} txData.user - The user ID.
|
||||
* @param {String} txData.conversationId - The ID of the conversation.
|
||||
* @param {String} txData.model - The model name.
|
||||
* @param {String} txData.context - The context in which the transaction is made.
|
||||
* @param {String} [txData.valueKey] - The value key (optional).
|
||||
* @param {Object} tokenUsage - The number of tokens used.
|
||||
* @param {Number} tokenUsage.promptTokens - The number of prompt tokens used.
|
||||
* @param {Number} tokenUsage.completionTokens - The number of completion tokens used.
|
||||
* @returns {Promise<void>} - Returns nothing.
|
||||
* @throws {Error} - Throws an error if there's an issue creating the transactions.
|
||||
*/
|
||||
const spendTokens = async (txData, tokenUsage) => {
|
||||
const { promptTokens, completionTokens } = tokenUsage;
|
||||
let prompt, completion;
|
||||
try {
|
||||
if (promptTokens >= 0) {
|
||||
prompt = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'prompt',
|
||||
rawAmount: -promptTokens,
|
||||
});
|
||||
}
|
||||
|
||||
if (!completionTokens) {
|
||||
this.debug && console.dir({ prompt, completion }, { depth: null });
|
||||
return;
|
||||
}
|
||||
|
||||
completion = await Transaction.create({
|
||||
...txData,
|
||||
tokenType: 'completion',
|
||||
rawAmount: -completionTokens,
|
||||
});
|
||||
|
||||
this.debug && console.dir({ prompt, completion }, { depth: null });
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = spendTokens;
|
||||
74
api/models/tx.js
Normal file
74
api/models/tx.js
Normal file
@@ -0,0 +1,74 @@
|
||||
const { matchModelName } = require('../utils');
|
||||
const defaultRate = 6;
|
||||
|
||||
/**
|
||||
* Mapping of model token sizes to their respective multipliers for prompt and completion.
|
||||
* @type {Object.<string, {prompt: number, completion: number}>}
|
||||
*/
|
||||
const tokenValues = {
|
||||
'8k': { prompt: 30, completion: 60 },
|
||||
'32k': { prompt: 60, completion: 120 },
|
||||
'4k': { prompt: 1.5, completion: 2 },
|
||||
'16k': { prompt: 3, completion: 4 },
|
||||
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the key associated with a given model name.
|
||||
*
|
||||
* @param {string} model - The model name to match.
|
||||
* @returns {string|undefined} The key corresponding to the model name, or undefined if no match is found.
|
||||
*/
|
||||
const getValueKey = (model) => {
|
||||
const modelName = matchModelName(model);
|
||||
if (!modelName) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (modelName.includes('gpt-3.5-turbo-16k')) {
|
||||
return '16k';
|
||||
} else if (modelName.includes('gpt-3.5-turbo-1106')) {
|
||||
return 'gpt-3.5-turbo-1106';
|
||||
} else if (modelName.includes('gpt-3.5')) {
|
||||
return '4k';
|
||||
} else if (modelName.includes('gpt-4-1106')) {
|
||||
return 'gpt-4-1106';
|
||||
} else if (modelName.includes('gpt-4-32k')) {
|
||||
return '32k';
|
||||
} else if (modelName.includes('gpt-4')) {
|
||||
return '8k';
|
||||
}
|
||||
|
||||
return undefined;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the multiplier for a given value key and token type. If no value key is provided,
|
||||
* it attempts to derive it from the model name.
|
||||
*
|
||||
* @param {Object} params - The parameters for the function.
|
||||
* @param {string} [params.valueKey] - The key corresponding to the model name.
|
||||
* @param {string} [params.tokenType] - The type of token (e.g., 'prompt' or 'completion').
|
||||
* @param {string} [params.model] - The model name to derive the value key from if not provided.
|
||||
* @returns {number} The multiplier for the given parameters, or a default value if not found.
|
||||
*/
|
||||
const getMultiplier = ({ valueKey, tokenType, model }) => {
|
||||
if (valueKey && tokenType) {
|
||||
return tokenValues[valueKey][tokenType] ?? defaultRate;
|
||||
}
|
||||
|
||||
if (!tokenType || !model) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
valueKey = getValueKey(model);
|
||||
if (!valueKey) {
|
||||
return defaultRate;
|
||||
}
|
||||
|
||||
// If we got this far, and values[tokenType] is undefined somehow, return a rough average of default multipliers
|
||||
return tokenValues[valueKey][tokenType] ?? defaultRate;
|
||||
};
|
||||
|
||||
module.exports = { tokenValues, getValueKey, getMultiplier, defaultRate };
|
||||
94
api/models/tx.spec.js
Normal file
94
api/models/tx.spec.js
Normal file
@@ -0,0 +1,94 @@
|
||||
const { getValueKey, getMultiplier, defaultRate, tokenValues } = require('./tx');
|
||||
|
||||
describe('getValueKey', () => {
|
||||
it('should return "16k" for model name containing "gpt-3.5-turbo-16k"', () => {
|
||||
expect(getValueKey('gpt-3.5-turbo-16k-some-other-info')).toBe('16k');
|
||||
});
|
||||
|
||||
it('should return "4k" for model name containing "gpt-3.5"', () => {
|
||||
expect(getValueKey('gpt-3.5-some-other-info')).toBe('4k');
|
||||
});
|
||||
|
||||
it('should return "32k" for model name containing "gpt-4-32k"', () => {
|
||||
expect(getValueKey('gpt-4-32k-some-other-info')).toBe('32k');
|
||||
});
|
||||
|
||||
it('should return "8k" for model name containing "gpt-4"', () => {
|
||||
expect(getValueKey('gpt-4-some-other-info')).toBe('8k');
|
||||
});
|
||||
|
||||
it('should return undefined for model names that do not match any known patterns', () => {
|
||||
expect(getValueKey('gpt-5-some-other-info')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return "gpt-3.5-turbo-1106" for model name containing "gpt-3.5-turbo-1106"', () => {
|
||||
expect(getValueKey('gpt-3.5-turbo-1106-some-other-info')).toBe('gpt-3.5-turbo-1106');
|
||||
expect(getValueKey('openai/gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
|
||||
expect(getValueKey('gpt-3.5-turbo-1106/openai')).toBe('gpt-3.5-turbo-1106');
|
||||
});
|
||||
|
||||
it('should return "gpt-4-1106" for model name containing "gpt-4-1106"', () => {
|
||||
expect(getValueKey('gpt-4-1106-some-other-info')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-1106-vision-preview')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-1106-preview')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('openai/gpt-4-1106')).toBe('gpt-4-1106');
|
||||
expect(getValueKey('gpt-4-1106/openai/')).toBe('gpt-4-1106');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMultiplier', () => {
|
||||
it('should return the correct multiplier for a given valueKey and tokenType', () => {
|
||||
expect(getMultiplier({ valueKey: '8k', tokenType: 'prompt' })).toBe(tokenValues['8k'].prompt);
|
||||
expect(getMultiplier({ valueKey: '8k', tokenType: 'completion' })).toBe(
|
||||
tokenValues['8k'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return defaultRate if tokenType is provided but not found in tokenValues', () => {
|
||||
expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(defaultRate);
|
||||
});
|
||||
|
||||
it('should derive the valueKey from the model if not provided', () => {
|
||||
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-4-some-other-info' })).toBe(
|
||||
tokenValues['8k'].prompt,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return 1 if only model or tokenType is missing', () => {
|
||||
expect(getMultiplier({ tokenType: 'prompt' })).toBe(1);
|
||||
expect(getMultiplier({ model: 'gpt-4-some-other-info' })).toBe(1);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-3.5-turbo-1106', () => {
|
||||
expect(getMultiplier({ valueKey: 'gpt-3.5-turbo-1106', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-3.5-turbo-1106'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ valueKey: 'gpt-3.5-turbo-1106', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-3.5-turbo-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4-1106', () => {
|
||||
expect(getMultiplier({ valueKey: 'gpt-4-1106', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-4-1106'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ valueKey: 'gpt-4-1106', tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should derive the valueKey from the model if not provided for new models', () => {
|
||||
expect(
|
||||
getMultiplier({ tokenType: 'prompt', model: 'gpt-3.5-turbo-1106-some-other-info' }),
|
||||
).toBe(tokenValues['gpt-3.5-turbo-1106'].prompt);
|
||||
expect(getMultiplier({ tokenType: 'completion', model: 'gpt-4-1106-vision-preview' })).toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return defaultRate if derived valueKey does not match any known patterns', () => {
|
||||
expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-5-some-other-info' })).toBe(
|
||||
defaultRate,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "0.5.9",
|
||||
"version": "0.6.1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
@@ -24,11 +24,12 @@
|
||||
"@anthropic-ai/sdk": "^0.5.4",
|
||||
"@azure/search-documents": "^11.3.2",
|
||||
"@keyv/mongo": "^2.1.8",
|
||||
"@waylaidwanderer/chatgpt-api": "^1.37.2",
|
||||
"@keyv/redis": "^2.8.0",
|
||||
"axios": "^1.3.4",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"cheerio": "^1.0.0-rc.12",
|
||||
"cohere-ai": "^5.0.2",
|
||||
"cohere-ai": "^6.0.0",
|
||||
"connect-redis": "^7.1.0",
|
||||
"cookie": "^0.5.0",
|
||||
"cors": "^2.8.5",
|
||||
"dotenv": "^16.0.3",
|
||||
@@ -36,21 +37,25 @@
|
||||
"express-mongo-sanitize": "^2.2.0",
|
||||
"express-rate-limit": "^6.9.0",
|
||||
"express-session": "^1.17.3",
|
||||
"googleapis": "^118.0.0",
|
||||
"googleapis": "^126.0.1",
|
||||
"handlebars": "^4.7.7",
|
||||
"html": "^1.0.0",
|
||||
"ioredis": "^5.3.2",
|
||||
"js-yaml": "^4.1.0",
|
||||
"jsonwebtoken": "^9.0.0",
|
||||
"keyv": "^4.5.3",
|
||||
"keyv": "^4.5.4",
|
||||
"keyv-file": "^0.2.0",
|
||||
"langchain": "^0.0.144",
|
||||
"langchain": "^0.0.186",
|
||||
"lodash": "^4.17.21",
|
||||
"meilisearch": "^0.33.0",
|
||||
"mongoose": "^7.1.1",
|
||||
"nodejs-gpt": "^1.37.4",
|
||||
"nodemailer": "^6.9.4",
|
||||
"openai": "^3.2.1",
|
||||
"openai": "^4.16.1",
|
||||
"openai-chat-tokens": "^0.2.8",
|
||||
"openid-client": "^5.4.2",
|
||||
"passport": "^0.6.0",
|
||||
"passport-custom": "^1.1.1",
|
||||
"passport-discord": "^0.1.4",
|
||||
"passport-facebook": "^3.0.0",
|
||||
"passport-github2": "^0.1.12",
|
||||
@@ -58,10 +63,11 @@
|
||||
"passport-jwt": "^4.0.1",
|
||||
"passport-local": "^1.0.0",
|
||||
"pino": "^8.12.1",
|
||||
"sharp": "^0.32.5",
|
||||
"sharp": "^0.32.6",
|
||||
"tiktoken": "^1.0.10",
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"zod": "^3.22.2"
|
||||
"winston": "^3.10.0",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"jest": "^29.5.0",
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
const crypto = require('crypto');
|
||||
const cookies = require('cookie');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const { Session, User } = require('../../models');
|
||||
const {
|
||||
registerUser,
|
||||
requestPasswordReset,
|
||||
resetPassword,
|
||||
setAuthTokens,
|
||||
} = require('../services/AuthService');
|
||||
const jwt = require('jsonwebtoken');
|
||||
const Session = require('../../models/Session');
|
||||
const User = require('../../models/User');
|
||||
const crypto = require('crypto');
|
||||
const cookies = require('cookie');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
@@ -76,14 +75,15 @@ const refreshController = async (req, res) => {
|
||||
}
|
||||
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
let payload;
|
||||
payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const userId = payload.id;
|
||||
const user = await User.findOne({ _id: userId });
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
if (process.env.NODE_ENV === 'CI') {
|
||||
const token = await setAuthTokens(userId, res);
|
||||
const userObj = user.toJSON();
|
||||
return res.status(200).send({ token, user: userObj });
|
||||
@@ -99,13 +99,18 @@ const refreshController = async (req, res) => {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
const userObj = user.toJSON();
|
||||
res.status(200).send({ token, user: userObj });
|
||||
} else if (payload.exp > Date.now() / 1000) {
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
res.status(403).send('No session found');
|
||||
} else if (payload.exp < Date.now() / 1000) {
|
||||
res.status(403).redirect('/login');
|
||||
} else {
|
||||
res.status(401).send('Refresh token expired or not found for this user');
|
||||
}
|
||||
} catch (err) {
|
||||
res.status(401).send('Invalid refresh token');
|
||||
console.error('Refresh token error', refreshToken);
|
||||
console.error(err);
|
||||
res.status(403).send('Invalid refresh token');
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
9
api/server/controllers/Balance.js
Normal file
9
api/server/controllers/Balance.js
Normal file
@@ -0,0 +1,9 @@
|
||||
const Balance = require('../../models/Balance');
|
||||
|
||||
async function balanceController(req, res) {
|
||||
const { tokenCredits: balance = '' } =
|
||||
(await Balance.findOne({ user: req.user.id }, 'tokenCredits').lean()) ?? {};
|
||||
res.status(200).send('' + balance);
|
||||
}
|
||||
|
||||
module.exports = balanceController;
|
||||
@@ -60,6 +60,7 @@ const startServer = async () => {
|
||||
app.use('/api/prompts', routes.prompts);
|
||||
app.use('/api/tokenizer', routes.tokenizer);
|
||||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
app.use('/api/models', routes.models);
|
||||
app.use('/api/plugins', routes.plugins);
|
||||
app.use('/api/config', routes.config);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('../utils');
|
||||
const { saveMessage, getConvo, getConvoTitle } = require('../../models');
|
||||
const { sendMessage, sendError } = require('../utils');
|
||||
const clearPendingReq = require('../../cache/clearPendingReq');
|
||||
const spendTokens = require('../../models/spendTokens');
|
||||
const abortControllers = require('./abortControllers');
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
@@ -19,6 +21,9 @@ async function abortMessage(req, res) {
|
||||
const handleAbort = () => {
|
||||
return async (req, res) => {
|
||||
try {
|
||||
if (isEnabled(process.env.LIMIT_CONCURRENT_MESSAGES)) {
|
||||
await clearPendingReq({ userId: req.user.id });
|
||||
}
|
||||
return await abortMessage(req, res);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
@@ -41,7 +46,9 @@ const createAbortController = (req, res, getAbortData) => {
|
||||
|
||||
abortController.abortCompletion = async function () {
|
||||
abortController.abort();
|
||||
const { conversationId, userMessage, ...responseData } = getAbortData();
|
||||
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = req.user.id;
|
||||
|
||||
const responseMessage = {
|
||||
...responseData,
|
||||
@@ -52,14 +59,20 @@ const createAbortController = (req, res, getAbortData) => {
|
||||
cancelled: true,
|
||||
error: false,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: completionTokens,
|
||||
};
|
||||
|
||||
saveMessage({ ...responseMessage, user: req.user.id });
|
||||
await spendTokens(
|
||||
{ ...responseMessage, context: 'incomplete', user },
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
|
||||
return {
|
||||
title: await getConvoTitle(req.user.id, conversationId),
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
conversation: await getConvo(user, conversationId),
|
||||
requestMessage: userMessage,
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
const Keyv = require('keyv');
|
||||
const uap = require('ua-parser-js');
|
||||
const { getLogStores } = require('../../cache');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { getLogStores } = require('../../cache');
|
||||
const { isEnabled, removePorts } = require('../utils');
|
||||
const keyvRedis = require('../../cache/keyvRedis');
|
||||
const User = require('../../models/User');
|
||||
|
||||
const banCache = new Keyv({ namespace: 'bans', ttl: 0 });
|
||||
const banCache = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: 'bans', ttl: 0 });
|
||||
const message = 'Your account has been temporarily banned due to violations of our service.';
|
||||
|
||||
/**
|
||||
@@ -49,10 +53,33 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
}
|
||||
|
||||
req.ip = removePorts(req);
|
||||
const userId = req.user?.id ?? req.user?._id ?? null;
|
||||
let userId = req.user?.id ?? req.user?._id ?? null;
|
||||
|
||||
if (!userId && req?.body?.email) {
|
||||
const user = await User.findOne({ email: req.body.email }, '_id').lean();
|
||||
userId = user?._id ? user._id.toString() : userId;
|
||||
}
|
||||
|
||||
if (!userId && !req.ip) {
|
||||
return next();
|
||||
}
|
||||
|
||||
let cachedIPBan;
|
||||
let cachedUserBan;
|
||||
|
||||
let ipKey = '';
|
||||
let userKey = '';
|
||||
|
||||
if (req.ip) {
|
||||
ipKey = isEnabled(process.env.USE_REDIS) ? `ban_cache:ip:${req.ip}` : req.ip;
|
||||
cachedIPBan = await banCache.get(ipKey);
|
||||
}
|
||||
|
||||
if (userId) {
|
||||
userKey = isEnabled(process.env.USE_REDIS) ? `ban_cache:user:${userId}` : userId;
|
||||
cachedUserBan = await banCache.get(userKey);
|
||||
}
|
||||
|
||||
const cachedIPBan = await banCache.get(req.ip);
|
||||
const cachedUserBan = await banCache.get(userId);
|
||||
const cachedBan = cachedIPBan || cachedUserBan;
|
||||
|
||||
if (cachedBan) {
|
||||
@@ -67,9 +94,18 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
return next();
|
||||
}
|
||||
|
||||
const ipBan = await banLogs.get(req.ip);
|
||||
const userBan = await banLogs.get(userId);
|
||||
const isBanned = ipBan || userBan;
|
||||
let ipBan;
|
||||
let userBan;
|
||||
|
||||
if (req.ip) {
|
||||
ipBan = await banLogs.get(req.ip);
|
||||
}
|
||||
|
||||
if (userId) {
|
||||
userBan = await banLogs.get(userId);
|
||||
}
|
||||
|
||||
const isBanned = !!(ipBan || userBan);
|
||||
|
||||
if (!isBanned) {
|
||||
return next();
|
||||
@@ -77,14 +113,23 @@ const checkBan = async (req, res, next = () => {}) => {
|
||||
|
||||
const timeLeft = Number(isBanned.expiresAt) - Date.now();
|
||||
|
||||
if (timeLeft <= 0) {
|
||||
await banLogs.delete(req.ip);
|
||||
await banLogs.delete(userId);
|
||||
if (timeLeft <= 0 && ipKey) {
|
||||
await banLogs.delete(ipKey);
|
||||
}
|
||||
|
||||
if (timeLeft <= 0 && userKey) {
|
||||
await banLogs.delete(userKey);
|
||||
return next();
|
||||
}
|
||||
|
||||
banCache.set(req.ip, isBanned, timeLeft);
|
||||
banCache.set(userId, isBanned, timeLeft);
|
||||
if (ipKey) {
|
||||
banCache.set(ipKey, isBanned, timeLeft);
|
||||
}
|
||||
|
||||
if (userKey) {
|
||||
banCache.set(userKey, isBanned, timeLeft);
|
||||
}
|
||||
|
||||
req.banned = true;
|
||||
return await banResponse(req, res);
|
||||
};
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
const Keyv = require('keyv');
|
||||
const { logViolation } = require('../../cache');
|
||||
|
||||
const clearPendingReq = require('../../cache/clearPendingReq');
|
||||
const { logViolation, getLogStores } = require('../../cache');
|
||||
const denyRequest = require('./denyRequest');
|
||||
|
||||
// Serve cache from memory so no need to clear it on startup/exit
|
||||
const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
|
||||
const {
|
||||
USE_REDIS,
|
||||
CONCURRENT_MESSAGE_MAX = 1,
|
||||
CONCURRENT_VIOLATION_SCORE: score,
|
||||
} = process.env ?? {};
|
||||
const ttl = 1000 * 60 * 1;
|
||||
|
||||
/**
|
||||
* Middleware to limit concurrent requests for a user.
|
||||
@@ -12,7 +15,7 @@ const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
|
||||
* This middleware checks if a user has exceeded a specified concurrent request limit.
|
||||
* If the user exceeds the limit, an error is returned. If the user is within the limit,
|
||||
* their request count is incremented. After the request is processed, the count is decremented.
|
||||
* If the `pendingReqCache` store is not available, the middleware will skip its logic.
|
||||
* If the `cache` store is not available, the middleware will skip its logic.
|
||||
*
|
||||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
@@ -21,7 +24,9 @@ const pendingReqCache = new Keyv({ namespace: 'pendingRequests' });
|
||||
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
|
||||
*/
|
||||
const concurrentLimiter = async (req, res, next) => {
|
||||
if (!pendingReqCache) {
|
||||
const namespace = 'pending_req';
|
||||
const cache = getLogStores(namespace);
|
||||
if (!cache) {
|
||||
return next();
|
||||
}
|
||||
|
||||
@@ -29,12 +34,12 @@ const concurrentLimiter = async (req, res, next) => {
|
||||
return next();
|
||||
}
|
||||
|
||||
const { CONCURRENT_MESSAGE_MAX = 1, CONCURRENT_VIOLATION_SCORE: score } = process.env;
|
||||
const userId = req.user?.id ?? req.user?._id ?? '';
|
||||
const limit = Math.max(CONCURRENT_MESSAGE_MAX, 1);
|
||||
const type = 'concurrent';
|
||||
|
||||
const userId = req.user?.id ?? req.user?._id ?? null;
|
||||
const pendingRequests = (await pendingReqCache.get(userId)) ?? 0;
|
||||
const key = `${USE_REDIS ? namespace : ''}:${userId}`;
|
||||
const pendingRequests = +((await cache.get(key)) ?? 0);
|
||||
|
||||
if (pendingRequests >= limit) {
|
||||
const errorMessage = {
|
||||
@@ -46,22 +51,17 @@ const concurrentLimiter = async (req, res, next) => {
|
||||
await logViolation(req, res, type, errorMessage, score);
|
||||
return await denyRequest(req, res, errorMessage);
|
||||
} else {
|
||||
await pendingReqCache.set(userId, pendingRequests + 1);
|
||||
await cache.set(key, pendingRequests + 1, ttl);
|
||||
}
|
||||
|
||||
// Ensure the requests are removed from the store once the request is done
|
||||
let cleared = false;
|
||||
const cleanUp = async () => {
|
||||
if (!pendingReqCache) {
|
||||
if (cleared) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentRequests = await pendingReqCache.get(userId);
|
||||
|
||||
if (currentRequests && currentRequests >= 1) {
|
||||
await pendingReqCache.set(userId, currentRequests - 1);
|
||||
} else {
|
||||
await pendingReqCache.delete(userId);
|
||||
}
|
||||
cleared = true;
|
||||
await clearPendingReq({ userId, cache });
|
||||
};
|
||||
|
||||
if (pendingRequests < limit) {
|
||||
@@ -72,10 +72,4 @@ const concurrentLimiter = async (req, res, next) => {
|
||||
next();
|
||||
};
|
||||
|
||||
// if cache is not served from memory, clear it on exit
|
||||
// process.on('exit', async () => {
|
||||
// console.log('Clearing all pending requests before exiting...');
|
||||
// await pendingReqCache.clear();
|
||||
// });
|
||||
|
||||
module.exports = concurrentLimiter;
|
||||
|
||||
@@ -26,18 +26,26 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
console.log('ask log');
|
||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender(endpointOption);
|
||||
const user = req.user.id;
|
||||
|
||||
const getIds = (data) => {
|
||||
userMessage = data.userMessage;
|
||||
userMessageId = data.userMessage.messageId;
|
||||
responseMessageId = data.responseMessageId;
|
||||
if (!conversationId) {
|
||||
conversationId = data.conversationId;
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -49,7 +57,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: partialText,
|
||||
@@ -69,18 +77,19 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
const { client } = await initializeClient(req, endpointOption);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
getIds,
|
||||
getReqData,
|
||||
// debug: true,
|
||||
user,
|
||||
conversationId,
|
||||
@@ -123,7 +132,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
|
||||
@@ -210,7 +210,7 @@ const ask = async ({
|
||||
conversationUpdate.jailbreakConversationId = response.jailbreakConversationId;
|
||||
} else {
|
||||
conversationUpdate.jailbreak = false;
|
||||
conversationUpdate.conversationSignature = response.conversationSignature;
|
||||
conversationUpdate.conversationSignature = response.encryptedConversationSignature;
|
||||
conversationUpdate.clientId = response.clientId;
|
||||
conversationUpdate.invocationId = response.invocationId;
|
||||
}
|
||||
|
||||
@@ -52,18 +52,25 @@ router.post('/', setHeaders, async (req, res) => {
|
||||
const ask = async ({ text, endpointOption, parentMessageId = null, conversationId, req, res }) => {
|
||||
let userMessage;
|
||||
let userMessageId;
|
||||
// let promptTokens;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
const { overrideParentMessageId = null } = req.body;
|
||||
const user = req.user.id;
|
||||
|
||||
try {
|
||||
const getIds = (data) => {
|
||||
userMessage = data.userMessage;
|
||||
userMessageId = userMessage.messageId;
|
||||
responseMessageId = data.responseMessageId;
|
||||
if (!conversationId) {
|
||||
conversationId = data.conversationId;
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
// } else if (key === 'promptTokens') {
|
||||
// promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
@@ -121,7 +128,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI
|
||||
const client = new GoogleClient(key, clientOptions);
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
getIds,
|
||||
getReqData,
|
||||
user,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
|
||||
@@ -29,22 +29,30 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||
let metadata;
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender(endpointOption);
|
||||
const newConvo = !conversationId;
|
||||
const user = req.user.id;
|
||||
|
||||
const plugins = [];
|
||||
|
||||
const addMetadata = (data) => (metadata = data);
|
||||
const getIds = (data) => {
|
||||
userMessage = data.userMessage;
|
||||
userMessageId = userMessage.messageId;
|
||||
responseMessageId = data.responseMessageId;
|
||||
if (!conversationId) {
|
||||
conversationId = data.conversationId;
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -67,7 +75,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
text: partialText,
|
||||
@@ -135,26 +143,27 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
plugins: plugins.map((p) => ({ ...p, loading: false })),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient(req, endpointOption);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
overrideParentMessageId,
|
||||
getIds,
|
||||
getReqData,
|
||||
onAgentAction,
|
||||
onChainEnd,
|
||||
onToolStart,
|
||||
@@ -194,7 +203,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
});
|
||||
res.end();
|
||||
|
||||
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
|
||||
if (parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response,
|
||||
@@ -206,7 +215,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
|
||||
@@ -27,21 +27,29 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||
let metadata;
|
||||
let userMessage;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
let lastSavedTimestamp = 0;
|
||||
let saveDelay = 100;
|
||||
const sender = getResponseSender(endpointOption);
|
||||
const newConvo = !conversationId;
|
||||
const user = req.user.id;
|
||||
|
||||
const addMetadata = (data) => (metadata = data);
|
||||
|
||||
const getIds = (data) => {
|
||||
userMessage = data.userMessage;
|
||||
userMessageId = userMessage.messageId;
|
||||
responseMessageId = data.responseMessageId;
|
||||
if (!conversationId) {
|
||||
conversationId = data.conversationId;
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -53,7 +61,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: partialText,
|
||||
@@ -72,25 +80,26 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
});
|
||||
|
||||
const getAbortData = () => ({
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
conversationId,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
|
||||
try {
|
||||
const { client } = await initializeClient(req, endpointOption);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
parentMessageId,
|
||||
conversationId,
|
||||
overrideParentMessageId,
|
||||
getIds,
|
||||
getReqData,
|
||||
onStart,
|
||||
addMetadata,
|
||||
abortController,
|
||||
@@ -109,11 +118,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
response = { ...response, ...metadata };
|
||||
}
|
||||
|
||||
console.log(
|
||||
'promptTokens, completionTokens:',
|
||||
response.promptTokens,
|
||||
response.completionTokens,
|
||||
);
|
||||
await saveMessage({ ...response, user });
|
||||
|
||||
sendMessage(res, {
|
||||
@@ -125,7 +129,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
});
|
||||
res.end();
|
||||
|
||||
if (parentMessageId == '00000000-0000-0000-0000-000000000000' && newConvo) {
|
||||
if (parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response,
|
||||
@@ -137,7 +141,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender: getResponseSender(endpointOption),
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
|
||||
8
api/server/routes/balance.js
Normal file
8
api/server/routes/balance.js
Normal file
@@ -0,0 +1,8 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const controller = require('../controllers/Balance');
|
||||
const { requireJwtAuth } = require('../middleware/');
|
||||
|
||||
router.get('/', requireJwtAuth, controller);
|
||||
|
||||
module.exports = router;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user