Compare commits
90 Commits
v0.7.3
...
re-add-dow
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c27e8566fb | ||
|
|
4ffdefc2a8 | ||
|
|
2d5f704695 | ||
|
|
3fd25920d4 | ||
|
|
d03f8285db | ||
|
|
e565e0faab | ||
|
|
d4d56281e3 | ||
|
|
14eff23b57 | ||
|
|
18fd8f1416 | ||
|
|
ba9cb71245 | ||
|
|
e5dfa06e6c | ||
|
|
0bd59c0efe | ||
|
|
344297021f | ||
|
|
620973436c | ||
|
|
422d1a2c91 | ||
|
|
2ad097647c | ||
|
|
9e7615f832 | ||
|
|
ee4dd1b2e9 | ||
|
|
f6125ccd59 | ||
|
|
1acd47a0f6 | ||
|
|
511f1336da | ||
|
|
5d40d0a37a | ||
|
|
1c282d1517 | ||
|
|
73dbf3eb20 | ||
|
|
237a0de8b6 | ||
|
|
d5782ac66c | ||
|
|
d5d188eebf | ||
|
|
25ea3b8e98 | ||
|
|
b1ec67ea42 | ||
|
|
f1bb5fa4c5 | ||
|
|
44d8596872 | ||
|
|
32d84c85ea | ||
|
|
0a1d38e318 | ||
|
|
5ef71a7a36 | ||
|
|
326069d7a6 | ||
|
|
785430daf5 | ||
|
|
03fe361917 | ||
|
|
b34a4ddac1 | ||
|
|
7d5b03dd98 | ||
|
|
f959ee302c | ||
|
|
cd00df69bb | ||
|
|
a05e2c1dcc | ||
|
|
87bdbda10a | ||
|
|
605a8ae8c9 | ||
|
|
a724635998 | ||
|
|
6c306a662c | ||
|
|
55f8d9910e | ||
|
|
7edb54889b | ||
|
|
71d9e841b1 | ||
|
|
e76777d298 | ||
|
|
1edbfdbce2 | ||
|
|
1aad315de6 | ||
|
|
5d985746cb | ||
|
|
04654014b2 | ||
|
|
456793772b | ||
|
|
a87d4e0b75 | ||
|
|
a2fd975cd5 | ||
|
|
83619de158 | ||
|
|
b8f2bee3fc | ||
|
|
81292bb4dd | ||
|
|
ed5ee1f86f | ||
|
|
791b0139bc | ||
|
|
156c52e293 | ||
|
|
eef894e608 | ||
|
|
e2867eecc9 | ||
|
|
dd563e0796 | ||
|
|
c99cf1b4b1 | ||
|
|
b5081bfe86 | ||
|
|
aac01df80c | ||
|
|
24467dd626 | ||
|
|
b2b469bd3d | ||
|
|
cec2e57ee9 | ||
|
|
a8c874267f | ||
|
|
a53312bbd4 | ||
|
|
ab74685476 | ||
|
|
015215b790 | ||
|
|
4e4de88faa | ||
|
|
3172381bad | ||
|
|
54b1095239 | ||
|
|
0424f8fe55 | ||
|
|
4319c62e66 | ||
|
|
d3a0b862db | ||
|
|
5d8793c5d1 | ||
|
|
54db67449a | ||
|
|
0cd3c83328 | ||
|
|
d839e4661c | ||
|
|
302b28fc9b | ||
|
|
dad25bd297 | ||
|
|
a338decf90 | ||
|
|
2cf5228021 |
15
.env.example
15
.env.example
@@ -80,7 +80,7 @@ PROXY=
|
||||
#============#
|
||||
|
||||
ANTHROPIC_API_KEY=user_provided
|
||||
# ANTHROPIC_MODELS=claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||
# ANTHROPIC_MODELS=claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k
|
||||
# ANTHROPIC_REVERSE_PROXY=
|
||||
|
||||
#============#
|
||||
@@ -123,6 +123,8 @@ GOOGLE_KEY=user_provided
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||
|
||||
# GOOGLE_TITLE_MODEL=gemini-pro
|
||||
|
||||
# Google Gemini Safety Settings
|
||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
||||
# To use this restricted HarmBlockThreshold setting, you will need to either:
|
||||
@@ -142,7 +144,7 @@ GOOGLE_KEY=user_provided
|
||||
#============#
|
||||
|
||||
OPENAI_API_KEY=user_provided
|
||||
# OPENAI_MODELS=gpt-4o,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k
|
||||
# OPENAI_MODELS=gpt-4o,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k
|
||||
|
||||
DEBUG_OPENAI=false
|
||||
|
||||
@@ -164,7 +166,7 @@ DEBUG_OPENAI=false
|
||||
|
||||
ASSISTANTS_API_KEY=user_provided
|
||||
# ASSISTANTS_BASE_URL=
|
||||
# ASSISTANTS_MODELS=gpt-4o,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
|
||||
# ASSISTANTS_MODELS=gpt-4o,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
|
||||
|
||||
#==========================#
|
||||
# Azure Assistants API #
|
||||
@@ -186,7 +188,7 @@ ASSISTANTS_API_KEY=user_provided
|
||||
# Plugins #
|
||||
#============#
|
||||
|
||||
# PLUGIN_MODELS=gpt-4o,gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613
|
||||
# PLUGIN_MODELS=gpt-4o,gpt-4o-mini,gpt-4,gpt-4-turbo-preview,gpt-4-0125-preview,gpt-4-1106-preview,gpt-4-0613,gpt-3.5-turbo,gpt-3.5-turbo-0125,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613
|
||||
|
||||
DEBUG_PLUGINS=true
|
||||
|
||||
@@ -372,6 +374,11 @@ LDAP_BIND_CREDENTIALS=
|
||||
LDAP_USER_SEARCH_BASE=
|
||||
LDAP_SEARCH_FILTER=mail={{username}}
|
||||
LDAP_CA_CERT_PATH=
|
||||
# LDAP_TLS_REJECT_UNAUTHORIZED=
|
||||
# LDAP_LOGIN_USES_USERNAME=true
|
||||
# LDAP_ID=
|
||||
# LDAP_USERNAME=
|
||||
# LDAP_FULL_NAME=
|
||||
|
||||
#========================#
|
||||
# Email Password Reset #
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -11,6 +11,7 @@ logs
|
||||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
.git
|
||||
|
||||
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||
lib-cov
|
||||
@@ -45,6 +46,7 @@ api/node_modules/
|
||||
client/node_modules/
|
||||
bower_components/
|
||||
*.d.ts
|
||||
!vite-env.d.ts
|
||||
|
||||
# Floobits
|
||||
.floo
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://railway.app/template/b5k2mn?referralCode=myKrVZ">
|
||||
<a href="https://railway.app/template/b5k2mn?referralCode=HI9hWz">
|
||||
<img src="https://railway.app/button.svg" alt="Deploy on Railway" height="30">
|
||||
</a>
|
||||
<a href="https://zeabur.com/templates/0X2ZY8">
|
||||
@@ -50,7 +50,7 @@
|
||||
- 🔄 Edit, Resubmit, and Continue Messages with Conversation branching
|
||||
- 🌿 Fork Messages & Conversations for Advanced Context control
|
||||
- 💬 Multimodal Chat:
|
||||
- Upload and analyze images with Claude 3, GPT-4 (including `gpt-4o`), and Gemini Vision 📸
|
||||
- Upload and analyze images with Claude 3, GPT-4 (including `gpt-4o` and `gpt-4o-mini`), and Gemini Vision 📸
|
||||
- Chat with Files using Custom Endpoints, OpenAI, Azure, Anthropic, & Google. 🗃️
|
||||
- Advanced Agents with Files, Code Interpreter, Tools, and API Actions 🔦
|
||||
- Available through the [OpenAI Assistants API](https://platform.openai.com/docs/assistants/overview) 🌤️
|
||||
@@ -81,7 +81,7 @@ LibreChat brings together the future of assistant AIs with the revolutionary tec
|
||||
|
||||
With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform.
|
||||
|
||||
[](https://www.youtube.com/watch?v=YLVUW5UP9N0)
|
||||
[](https://www.youtube.com/watch?v=bSVHEbVPNl4)
|
||||
Click on the thumbnail to open the video☝️
|
||||
|
||||
---
|
||||
|
||||
@@ -2,8 +2,10 @@ const Anthropic = require('@anthropic-ai/sdk');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
getResponseSender,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
anthropicSettings,
|
||||
getResponseSender,
|
||||
validateVisionModel,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
@@ -16,6 +18,7 @@ const {
|
||||
} = require('./prompts');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -29,6 +32,8 @@ function delayBeforeRetry(attempts, baseDelay = 1000) {
|
||||
return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts));
|
||||
}
|
||||
|
||||
const { legacy } = anthropicSettings;
|
||||
|
||||
class AnthropicClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
super(apiKey, options);
|
||||
@@ -61,15 +66,20 @@ class AnthropicClient extends BaseClient {
|
||||
const modelOptions = this.options.modelOptions || {};
|
||||
this.modelOptions = {
|
||||
...modelOptions,
|
||||
// set some good defaults (check for undefined in some cases because they may be 0)
|
||||
model: modelOptions.model || 'claude-1',
|
||||
temperature: typeof modelOptions.temperature === 'undefined' ? 1 : modelOptions.temperature, // 0 - 1, 1 is default
|
||||
topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7
|
||||
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
|
||||
stop: modelOptions.stop, // no stop method for now
|
||||
model: modelOptions.model || anthropicSettings.model.default,
|
||||
};
|
||||
|
||||
this.isClaude3 = this.modelOptions.model.includes('claude-3');
|
||||
this.isLegacyOutput = !this.modelOptions.model.includes('claude-3-5-sonnet');
|
||||
|
||||
if (
|
||||
this.isLegacyOutput &&
|
||||
this.modelOptions.maxOutputTokens &&
|
||||
this.modelOptions.maxOutputTokens > legacy.maxOutputTokens.default
|
||||
) {
|
||||
this.modelOptions.maxOutputTokens = legacy.maxOutputTokens.default;
|
||||
}
|
||||
|
||||
this.useMessages = this.isClaude3 || !!this.options.attachments;
|
||||
|
||||
this.defaultVisionModel = this.options.visionModel ?? 'claude-3-sonnet-20240229';
|
||||
@@ -119,10 +129,11 @@ class AnthropicClient extends BaseClient {
|
||||
|
||||
/**
|
||||
* Get the initialized Anthropic client.
|
||||
* @param {Partial<Anthropic.ClientOptions>} requestOptions - The options for the client.
|
||||
* @returns {Anthropic} The Anthropic client instance.
|
||||
*/
|
||||
getClient() {
|
||||
/** @type {Anthropic.default.RequestOptions} */
|
||||
getClient(requestOptions) {
|
||||
/** @type {Anthropic.ClientOptions} */
|
||||
const options = {
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
@@ -136,6 +147,12 @@ class AnthropicClient extends BaseClient {
|
||||
options.baseURL = this.options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
if (requestOptions?.model && requestOptions.model.includes('claude-3-5-sonnet')) {
|
||||
options.defaultHeaders = {
|
||||
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
|
||||
};
|
||||
}
|
||||
|
||||
return new Anthropic(options);
|
||||
}
|
||||
|
||||
@@ -556,8 +573,6 @@ class AnthropicClient extends BaseClient {
|
||||
}
|
||||
|
||||
logger.debug('modelOptions', { modelOptions });
|
||||
|
||||
const client = this.getClient();
|
||||
const metadata = {
|
||||
user_id: this.user,
|
||||
};
|
||||
@@ -585,7 +600,7 @@ class AnthropicClient extends BaseClient {
|
||||
|
||||
if (this.useMessages) {
|
||||
requestOptions.messages = payload;
|
||||
requestOptions.max_tokens = maxOutputTokens || 1500;
|
||||
requestOptions.max_tokens = maxOutputTokens || legacy.maxOutputTokens.default;
|
||||
} else {
|
||||
requestOptions.prompt = payload;
|
||||
requestOptions.max_tokens_to_sample = maxOutputTokens || 1500;
|
||||
@@ -605,12 +620,14 @@ class AnthropicClient extends BaseClient {
|
||||
};
|
||||
|
||||
const maxRetries = 3;
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
async function processResponse() {
|
||||
let attempts = 0;
|
||||
|
||||
while (attempts < maxRetries) {
|
||||
let response;
|
||||
try {
|
||||
const client = this.getClient(requestOptions);
|
||||
response = await this.createResponse(client, requestOptions);
|
||||
|
||||
signal.addEventListener('abort', () => {
|
||||
@@ -627,6 +644,8 @@ class AnthropicClient extends BaseClient {
|
||||
} else if (completion.completion) {
|
||||
handleChunk(completion.completion);
|
||||
}
|
||||
|
||||
await sleep(streamRate);
|
||||
}
|
||||
|
||||
// Successful processing, exit loop
|
||||
@@ -737,7 +756,11 @@ class AnthropicClient extends BaseClient {
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await this.createResponse(this.getClient(), requestOptions, true);
|
||||
const response = await this.createResponse(
|
||||
this.getClient(requestOptions),
|
||||
requestOptions,
|
||||
true,
|
||||
);
|
||||
let promptTokens = response?.usage?.input_tokens;
|
||||
let completionTokens = response?.usage?.output_tokens;
|
||||
if (!promptTokens) {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
const crypto = require('crypto');
|
||||
const fetch = require('node-fetch');
|
||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { supportsBalanceCheck, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const TextStream = require('./TextStream');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -19,6 +20,14 @@ class BaseClient {
|
||||
day: 'numeric',
|
||||
});
|
||||
this.fetch = this.fetch.bind(this);
|
||||
/** @type {boolean} */
|
||||
this.skipSaveConvo = false;
|
||||
/** @type {boolean} */
|
||||
this.skipSaveUserMessage = false;
|
||||
/** @type {ClientDatabaseSavePromise} */
|
||||
this.userMessagePromise;
|
||||
/** @type {ClientDatabaseSavePromise} */
|
||||
this.responsePromise;
|
||||
}
|
||||
|
||||
setOptions() {
|
||||
@@ -84,19 +93,45 @@ class BaseClient {
|
||||
await stream.processTextStream(onProgress);
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {[string|undefined, string|undefined]}
|
||||
*/
|
||||
processOverideIds() {
|
||||
/** @type {Record<string, string | undefined>} */
|
||||
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
|
||||
if (overrideConvoId) {
|
||||
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
|
||||
overrideConvoId = conversationId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveConvo = true;
|
||||
}
|
||||
}
|
||||
if (overrideUserMessageId) {
|
||||
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
|
||||
overrideUserMessageId = userMessageId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveUserMessage = true;
|
||||
}
|
||||
}
|
||||
|
||||
return [overrideConvoId, overrideUserMessageId];
|
||||
}
|
||||
|
||||
async setMessageOptions(opts = {}) {
|
||||
if (opts && opts.replaceOptions) {
|
||||
this.setOptions(opts);
|
||||
}
|
||||
|
||||
const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
|
||||
const { isEdited, isContinued } = opts;
|
||||
const user = opts.user ?? null;
|
||||
this.user = user;
|
||||
const saveOptions = this.getSaveOptions();
|
||||
this.abortController = opts.abortController ?? new AbortController();
|
||||
const conversationId = opts.conversationId ?? crypto.randomUUID();
|
||||
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
const userMessageId =
|
||||
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||
let head = isEdited ? responseMessageId : parentMessageId;
|
||||
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||
@@ -160,7 +195,7 @@ class BaseClient {
|
||||
}
|
||||
|
||||
if (typeof opts?.onStart === 'function') {
|
||||
opts.onStart(userMessage);
|
||||
opts.onStart(userMessage, responseMessageId);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -450,8 +485,13 @@ class BaseClient {
|
||||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
if (!isEdited) {
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (!isEdited && !this.skipSaveUserMessage) {
|
||||
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise: this.userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -500,15 +540,23 @@ class BaseClient {
|
||||
const completionTokens = this.getTokenCount(completion);
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens });
|
||||
}
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
if (this.userMessagePromise) {
|
||||
await this.userMessagePromise;
|
||||
}
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
delete responseMessage.tokenCount;
|
||||
return responseMessage;
|
||||
}
|
||||
|
||||
async getConversation(conversationId, user = null) {
|
||||
return await getConvo(user, conversationId);
|
||||
}
|
||||
|
||||
async loadHistory(conversationId, parentMessageId = null) {
|
||||
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
|
||||
|
||||
@@ -563,22 +611,41 @@ class BaseClient {
|
||||
* @param {string | null} user
|
||||
*/
|
||||
async saveMessageToDatabase(message, endpointOptions, user = null) {
|
||||
await saveMessage({
|
||||
...message,
|
||||
endpoint: this.options.endpoint,
|
||||
unfinished: false,
|
||||
user,
|
||||
});
|
||||
await saveConvo(user, {
|
||||
conversationId: message.conversationId,
|
||||
endpoint: this.options.endpoint,
|
||||
endpointType: this.options.endpointType,
|
||||
...endpointOptions,
|
||||
});
|
||||
if (this.user && user !== this.user) {
|
||||
throw new Error('User mismatch.');
|
||||
}
|
||||
|
||||
const savedMessage = await saveMessage(
|
||||
this.options.req,
|
||||
{
|
||||
...message,
|
||||
endpoint: this.options.endpoint,
|
||||
unfinished: false,
|
||||
user,
|
||||
},
|
||||
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' },
|
||||
);
|
||||
|
||||
if (this.skipSaveConvo) {
|
||||
return { message: savedMessage };
|
||||
}
|
||||
|
||||
const conversation = await saveConvo(
|
||||
this.options.req,
|
||||
{
|
||||
conversationId: message.conversationId,
|
||||
endpoint: this.options.endpoint,
|
||||
endpointType: this.options.endpointType,
|
||||
...endpointOptions,
|
||||
},
|
||||
{ context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo' },
|
||||
);
|
||||
|
||||
return { message: savedMessage, conversation };
|
||||
}
|
||||
|
||||
async updateMessageInDatabase(message) {
|
||||
await updateMessage(message);
|
||||
await updateMessage(this.options.req, message);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -13,13 +13,20 @@ const {
|
||||
endpointSettings,
|
||||
EModelEndpoint,
|
||||
VisionModes,
|
||||
Constants,
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const { formatMessage, createContextHandlers } = require('./prompts');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const {
|
||||
formatMessage,
|
||||
createContextHandlers,
|
||||
titleInstruction,
|
||||
truncateText,
|
||||
} = require('./prompts');
|
||||
const BaseClient = require('./BaseClient');
|
||||
|
||||
const loc = 'us-central1';
|
||||
const publisher = 'google';
|
||||
@@ -591,12 +598,16 @@ class GoogleClient extends BaseClient {
|
||||
createLLM(clientOptions) {
|
||||
const model = clientOptions.modelName ?? clientOptions.model;
|
||||
if (this.project_id && this.isTextModel) {
|
||||
logger.debug('Creating Google VertexAI client');
|
||||
return new GoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id && this.isChatModel) {
|
||||
logger.debug('Creating Chat Google VertexAI client');
|
||||
return new ChatGoogleVertexAI(clientOptions);
|
||||
} else if (this.project_id) {
|
||||
logger.debug('Creating VertexAI client');
|
||||
return new ChatVertexAI(clientOptions);
|
||||
} else if (model.includes('1.5')) {
|
||||
logger.debug('Creating GenAI client');
|
||||
return new GenAI(this.apiKey).getGenerativeModel(
|
||||
{
|
||||
...clientOptions,
|
||||
@@ -606,12 +617,14 @@ class GoogleClient extends BaseClient {
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug('Creating Chat Google Generative AI client');
|
||||
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
|
||||
}
|
||||
|
||||
async getCompletion(_payload, options = {}) {
|
||||
const { onProgress, abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { onProgress, abortController } = options;
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let examples;
|
||||
@@ -691,6 +704,7 @@ class GoogleClient extends BaseClient {
|
||||
delay,
|
||||
});
|
||||
reply += chunkText;
|
||||
await sleep(streamRate);
|
||||
}
|
||||
return reply;
|
||||
}
|
||||
@@ -702,10 +716,17 @@ class GoogleClient extends BaseClient {
|
||||
safetySettings: safetySettings,
|
||||
});
|
||||
|
||||
let delay = this.isGenerativeModel ? 12 : 8;
|
||||
if (modelName.includes('flash')) {
|
||||
delay = 5;
|
||||
let delay = this.options.streamRate || 8;
|
||||
|
||||
if (!this.options.streamRate) {
|
||||
if (this.isGenerativeModel) {
|
||||
delay = 12;
|
||||
}
|
||||
if (modelName.includes('flash')) {
|
||||
delay = 5;
|
||||
}
|
||||
}
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const chunkText = chunk?.content ?? chunk;
|
||||
await this.generateTextStream(chunkText, onProgress, {
|
||||
@@ -717,6 +738,123 @@ class GoogleClient extends BaseClient {
|
||||
return reply;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
|
||||
*/
|
||||
async titleChatCompletion(_payload, options = {}) {
|
||||
const { abortController } = options;
|
||||
const { parameters, instances } = _payload;
|
||||
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};
|
||||
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
|
||||
logger.debug('Initialized title client options');
|
||||
|
||||
if (this.project_id) {
|
||||
clientOptions['authOptions'] = {
|
||||
credentials: {
|
||||
...this.serviceKey,
|
||||
},
|
||||
projectId: this.project_id,
|
||||
};
|
||||
}
|
||||
|
||||
if (!parameters) {
|
||||
clientOptions = { ...clientOptions, ...this.modelOptions };
|
||||
}
|
||||
|
||||
if (this.isGenerativeModel && !this.project_id) {
|
||||
clientOptions.modelName = clientOptions.model;
|
||||
delete clientOptions.model;
|
||||
}
|
||||
|
||||
const model = this.createLLM(clientOptions);
|
||||
|
||||
let reply = '';
|
||||
const messages = this.isTextModel ? _payload.trim() : _messages;
|
||||
|
||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||
if (modelName?.includes('1.5') && !this.project_id) {
|
||||
logger.debug('Identified titling model as 1.5 version');
|
||||
/** @type {GenerativeModel} */
|
||||
const client = model;
|
||||
const requestOptions = {
|
||||
contents: _payload,
|
||||
};
|
||||
|
||||
if (this.options?.promptPrefix?.length) {
|
||||
requestOptions.systemInstruction = {
|
||||
parts: [
|
||||
{
|
||||
text: this.options.promptPrefix,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const safetySettings = _payload.safetySettings;
|
||||
requestOptions.safetySettings = safetySettings;
|
||||
|
||||
const result = await client.generateContent(requestOptions);
|
||||
|
||||
reply = result.response?.text();
|
||||
|
||||
return reply;
|
||||
} else {
|
||||
logger.debug('Beginning titling');
|
||||
const safetySettings = _payload.safetySettings;
|
||||
|
||||
const titleResponse = await model.invoke(messages, {
|
||||
signal: abortController.signal,
|
||||
timeout: 7000,
|
||||
safetySettings: safetySettings,
|
||||
});
|
||||
|
||||
reply = titleResponse.content;
|
||||
|
||||
return reply;
|
||||
}
|
||||
}
|
||||
|
||||
async titleConvo({ text, responseText = '' }) {
|
||||
let title = 'New Chat';
|
||||
const convo = `||>User:
|
||||
"${truncateText(text)}"
|
||||
||>Response:
|
||||
"${JSON.stringify(truncateText(responseText))}"`;
|
||||
|
||||
let { prompt: payload } = await this.buildMessages([
|
||||
{
|
||||
text: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
|
||||
||>Title:`,
|
||||
isCreatedByUser: true,
|
||||
author: this.userLabel,
|
||||
},
|
||||
]);
|
||||
|
||||
if (this.isVisionModel) {
|
||||
logger.warn(
|
||||
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
|
||||
);
|
||||
|
||||
payload.parameters = { ...payload.parameters, model: settings.model.default };
|
||||
}
|
||||
|
||||
try {
|
||||
title = await this.titleChatCompletion(payload, {
|
||||
abortController: new AbortController(),
|
||||
onProgress: () => {},
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error('[GoogleClient] There was an issue generating the title', e);
|
||||
}
|
||||
logger.debug(`Title response: ${title}`);
|
||||
return title;
|
||||
}
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
const { z } = require('zod');
|
||||
const axios = require('axios');
|
||||
const { Ollama } = require('ollama');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { deriveBaseURL } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const ollamaPayloadSchema = z.object({
|
||||
@@ -40,6 +42,7 @@ const getValidBase64 = (imageUrl) => {
|
||||
class OllamaClient {
|
||||
constructor(options = {}) {
|
||||
const host = deriveBaseURL(options.baseURL ?? 'http://localhost:11434');
|
||||
this.streamRate = options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
/** @type {Ollama} */
|
||||
this.client = new Ollama({ host });
|
||||
}
|
||||
@@ -136,6 +139,8 @@ class OllamaClient {
|
||||
stream.controller.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
await sleep(this.streamRate);
|
||||
}
|
||||
}
|
||||
// TODO: regular completion
|
||||
|
||||
@@ -27,7 +27,6 @@ const {
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { updateTokenWebsocket } = require('~/server/services/Files/Audio');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
@@ -595,7 +594,6 @@ class OpenAIClient extends BaseClient {
|
||||
payload,
|
||||
(progressMessage) => {
|
||||
if (progressMessage === '[DONE]') {
|
||||
updateTokenWebsocket('[DONE]');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1184,8 +1182,10 @@ ${convo}
|
||||
});
|
||||
}
|
||||
|
||||
const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE;
|
||||
|
||||
if (this.message_file_map && this.isOllama) {
|
||||
const ollamaClient = new OllamaClient({ baseURL });
|
||||
const ollamaClient = new OllamaClient({ baseURL, streamRate });
|
||||
return await ollamaClient.chatCompletion({
|
||||
payload: modelOptions,
|
||||
onProgress,
|
||||
@@ -1223,8 +1223,6 @@ ${convo}
|
||||
}
|
||||
});
|
||||
|
||||
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const token = chunk.choices[0]?.delta?.content || '';
|
||||
intermediateReply += token;
|
||||
@@ -1234,9 +1232,7 @@ ${convo}
|
||||
break;
|
||||
}
|
||||
|
||||
if (this.azure) {
|
||||
await sleep(azureDelay);
|
||||
}
|
||||
await sleep(streamRate);
|
||||
}
|
||||
|
||||
if (!UnexpectedRoleError) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const OpenAIClient = require('./OpenAIClient');
|
||||
const { CallbackManager } = require('langchain/callbacks');
|
||||
const { CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
|
||||
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
|
||||
@@ -11,6 +12,7 @@ const { SelfReflectionTool } = require('./tools');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
const { loadTools } = require('./tools/util');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
class PluginsClient extends OpenAIClient {
|
||||
@@ -220,6 +222,13 @@ class PluginsClient extends OpenAIClient {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TMessage} responseMessage
|
||||
* @param {Partial<TMessage>} saveOptions
|
||||
* @param {string} user
|
||||
* @returns
|
||||
*/
|
||||
async handleResponseMessage(responseMessage, saveOptions, user) {
|
||||
const { output, errorMessage, ...result } = this.result;
|
||||
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
|
||||
@@ -238,12 +247,32 @@ class PluginsClient extends OpenAIClient {
|
||||
await this.recordTokenUsage(responseMessage);
|
||||
}
|
||||
|
||||
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessage.messageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
delete responseMessage.tokenCount;
|
||||
return { ...responseMessage, ...result };
|
||||
}
|
||||
|
||||
async sendMessage(message, opts = {}) {
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
} else {
|
||||
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
|
||||
this.options.tools = tools;
|
||||
}
|
||||
|
||||
// If a message is edited, no tools can be used.
|
||||
const completionMode = this.options.tools.length === 0 || opts.isEdited;
|
||||
if (completionMode) {
|
||||
@@ -301,7 +330,15 @@ class PluginsClient extends OpenAIClient {
|
||||
if (payload) {
|
||||
this.currentMessages = payload;
|
||||
}
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
|
||||
if (!this.skipSaveUserMessage) {
|
||||
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
userMessagePromise: this.userMessagePromise,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
await checkBalance({
|
||||
|
||||
@@ -1,44 +1,3 @@
|
||||
/*
|
||||
module.exports = `You are ChatGPT, a Large Language model with useful tools.
|
||||
|
||||
Talk to the human and provide meaningful answers when questions are asked.
|
||||
|
||||
Use the tools when you need them, but use your own knowledge if you are confident of the answer. Keep answers short and concise.
|
||||
|
||||
A tool is not usually needed for creative requests, so do your best to answer them without tools.
|
||||
|
||||
Avoid repeating identical answers if it appears before. Only fulfill the human's requests, do not create extra steps beyond what the human has asked for.
|
||||
|
||||
Your input for 'Action' should be the name of tool used only.
|
||||
|
||||
Be honest. If you can't answer something, or a tool is not appropriate, say you don't know or answer to the best of your ability.
|
||||
|
||||
Attempt to fulfill the human's requests in as few actions as possible`;
|
||||
*/
|
||||
|
||||
// module.exports = `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
||||
|
||||
// Engage with the Human conversationally, providing concise and meaningful answers to questions. Utilize built-in tools when necessary, except for creative requests, where relying on your own knowledge is preferred. Aim for variety and avoid repetitive answers.
|
||||
|
||||
// For your 'Action' input, state the name of the tool used only, and honor user requests without adding extra steps. Always be honest; if you cannot provide an appropriate answer or tool, admit that or do your best.
|
||||
|
||||
// Strive to meet the user's needs efficiently with minimal actions.`;
|
||||
|
||||
// import {
|
||||
// BasePromptTemplate,
|
||||
// BaseStringPromptTemplate,
|
||||
// SerializedBasePromptTemplate,
|
||||
// renderTemplate,
|
||||
// } from "langchain/prompts";
|
||||
|
||||
// prefix: `You are ChatGPT, a highly knowledgeable and versatile large language model.
|
||||
// Your objective is to help users by understanding their intent and choosing the best action. Prioritize direct, specific responses. Use concise, varied answers and rely on your knowledge for creative tasks. Utilize tools when needed, and structure results for machine compatibility.
|
||||
// prefix: `Objective: to comprehend human intentions based on user input and available tools. Goal: identify the best action to directly address the human's query. In your subsequent steps, you will utilize the chosen action. You may select multiple actions and list them in a meaningful order. Prioritize actions that directly relate to the user's query over general ones. Ensure that the generated thought is highly specific and explicit to best match the user's expectations. Construct the result in a manner that an online open-API would most likely expect. Provide concise and meaningful answers to human queries. Utilize tools when necessary. Relying on your own knowledge is preferred for creative requests. Aim for variety and avoid repetitive answers.
|
||||
|
||||
// # Available Actions & Tools:
|
||||
// N/A: no suitable action, use your own knowledge.`,
|
||||
// suffix: `Remember, all your responses MUST adhere to the described format and only respond if the format is followed. Output exactly with the requested format, avoiding any other text as this will be parsed by a machine. Following 'Action:', provide only one of the actions listed above. If a tool is not necessary, deduce this quickly and finish your response. Honor the human's requests without adding extra steps. Carry out tasks in the sequence written by the human. Always be honest; if you cannot provide an appropriate answer or tool, do your best with your own knowledge. Strive to meet the user's needs efficiently with minimal actions.`;
|
||||
|
||||
module.exports = {
|
||||
'gpt3-v1': {
|
||||
prefix: `Objective: Understand human intentions using user input and available tools. Goal: Identify the most suitable actions to directly address user queries.
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
const AnthropicClient = require('../AnthropicClient');
|
||||
const { anthropicSettings } = require('librechat-data-provider');
|
||||
const AnthropicClient = require('~/app/clients/AnthropicClient');
|
||||
|
||||
const HUMAN_PROMPT = '\n\nHuman:';
|
||||
const AI_PROMPT = '\n\nAssistant:';
|
||||
|
||||
@@ -22,7 +24,7 @@ describe('AnthropicClient', () => {
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model,
|
||||
temperature: 0.7,
|
||||
temperature: anthropicSettings.temperature.default,
|
||||
},
|
||||
};
|
||||
client = new AnthropicClient('test-api-key');
|
||||
@@ -33,7 +35,42 @@ describe('AnthropicClient', () => {
|
||||
it('should set the options correctly', () => {
|
||||
expect(client.apiKey).toBe('test-api-key');
|
||||
expect(client.modelOptions.model).toBe(model);
|
||||
expect(client.modelOptions.temperature).toBe(0.7);
|
||||
expect(client.modelOptions.temperature).toBe(anthropicSettings.temperature.default);
|
||||
});
|
||||
|
||||
it('should set legacy maxOutputTokens for non-Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-2',
|
||||
maxOutputTokens: anthropicSettings.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
it('should not set maxOutputTokens if not provided', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-3',
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not set legacy maxOutputTokens for Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -136,4 +173,57 @@ describe('AnthropicClient', () => {
|
||||
expect(prompt).toContain('You are Claude-2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getClient', () => {
|
||||
it('should set legacy maxOutputTokens for non-Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-2',
|
||||
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not set legacy maxOutputTokens for Claude-3 models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
maxOutputTokens: anthropicSettings.legacy.maxOutputTokens.default,
|
||||
},
|
||||
});
|
||||
expect(client.modelOptions.maxOutputTokens).toBe(
|
||||
anthropicSettings.legacy.maxOutputTokens.default,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add beta header for claude-3-5-sonnet model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
model: 'claude-3-5-sonnet-20240307',
|
||||
};
|
||||
client.setOptions({ modelOptions });
|
||||
const anthropicClient = client.getClient(modelOptions);
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'max-tokens-3-5-sonnet-2024-07-15',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not add beta header for other models', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
client.setOptions({
|
||||
modelOptions: {
|
||||
model: 'claude-2',
|
||||
},
|
||||
});
|
||||
const anthropicClient = client.getClient();
|
||||
expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const { initializeFakeClient } = require('./FakeClient');
|
||||
|
||||
jest.mock('../../../lib/db/connectDb');
|
||||
jest.mock('~/lib/db/connectDb');
|
||||
jest.mock('~/models', () => ({
|
||||
User: jest.fn(),
|
||||
Key: jest.fn(),
|
||||
@@ -576,7 +576,11 @@ describe('BaseClient', () => {
|
||||
const onStart = jest.fn();
|
||||
const opts = { onStart };
|
||||
await TestClient.sendMessage('Hello, world!', opts);
|
||||
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
|
||||
|
||||
expect(onStart).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ text: 'Hello, world!' }),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
||||
@@ -627,5 +631,32 @@ describe('BaseClient', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('userMessagePromise is awaited before saving response message', async () => {
|
||||
// Mock the saveMessageToDatabase method
|
||||
TestClient.saveMessageToDatabase = jest.fn().mockImplementation(() => {
|
||||
return new Promise((resolve) => setTimeout(resolve, 100)); // Simulate a delay
|
||||
});
|
||||
|
||||
// Send a message
|
||||
const messagePromise = TestClient.sendMessage('Hello, world!');
|
||||
|
||||
// Wait a short time to ensure the user message save has started
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Check that saveMessageToDatabase has been called once (for the user message)
|
||||
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Wait for the message to be fully processed
|
||||
await messagePromise;
|
||||
|
||||
// Check that saveMessageToDatabase has been called twice (once for user message, once for response)
|
||||
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Check the order of calls
|
||||
const calls = TestClient.saveMessageToDatabase.mock.calls;
|
||||
expect(calls[0][0].isCreatedByUser).toBe(true); // First call should be for user message
|
||||
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -194,6 +194,7 @@ describe('PluginsClient', () => {
|
||||
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Azure OpenAI tests specific to Plugins', () => {
|
||||
// TODO: add more tests for Azure OpenAI integration with Plugins
|
||||
// let client;
|
||||
@@ -220,4 +221,94 @@ describe('PluginsClient', () => {
|
||||
spy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage with filtered tools', () => {
|
||||
let TestAgent;
|
||||
const apiKey = 'fake-api-key';
|
||||
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
|
||||
|
||||
beforeEach(() => {
|
||||
TestAgent = new PluginsClient(apiKey, {
|
||||
tools: mockTools,
|
||||
modelOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
temperature: 0,
|
||||
max_tokens: 2,
|
||||
},
|
||||
agentOptions: {
|
||||
model: 'gpt-3.5-turbo',
|
||||
},
|
||||
});
|
||||
|
||||
TestAgent.options.req = {
|
||||
app: {
|
||||
locals: {},
|
||||
},
|
||||
};
|
||||
|
||||
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
|
||||
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
|
||||
|
||||
if (includedTools.length > 0) {
|
||||
const tools = TestAgent.options.tools.filter((plugin) =>
|
||||
includedTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
} else {
|
||||
const tools = TestAgent.options.tools.filter(
|
||||
(plugin) => !filteredTools.includes(plugin.name),
|
||||
);
|
||||
TestAgent.options.tools = tools;
|
||||
}
|
||||
|
||||
return {
|
||||
text: 'Mocked response',
|
||||
tools: TestAgent.options.tools,
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
test('should filter out tools when filteredTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should only include specified tools when includedTools is provided', async () => {
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
expect.objectContaining({ name: 'tool4' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should prioritize includedTools over filteredTools', async () => {
|
||||
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
|
||||
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(2);
|
||||
expect(response.tools).toEqual(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'tool1' }),
|
||||
expect.objectContaining({ name: 'tool2' }),
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
test('should not modify tools when no filters are provided', async () => {
|
||||
const response = await TestAgent.sendMessage('Test message');
|
||||
expect(response.tools).toHaveLength(4);
|
||||
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
30
api/cache/getLogStores.js
vendored
30
api/cache/getLogStores.js
vendored
@@ -1,13 +1,11 @@
|
||||
const Keyv = require('keyv');
|
||||
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
|
||||
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
|
||||
const { logFile, violationFile } = require('./keyvFiles');
|
||||
const { math, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||
const THIRTY_MINUTES = 1800000;
|
||||
const TEN_MINUTES = 600000;
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
|
||||
@@ -25,17 +23,25 @@ const config = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
|
||||
const roles = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const messages = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES });
|
||||
|
||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const genTitle = isEnabled(USE_REDIS) // ttl: 2 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: 120000 })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: 120000 });
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
@@ -43,9 +49,10 @@ const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
|
||||
const abortKeys = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: 600000 });
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const namespaces = {
|
||||
[CacheKeys.ROLES]: roles,
|
||||
[CacheKeys.CONFIG_STORE]: config,
|
||||
pending_req,
|
||||
[ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }),
|
||||
@@ -76,6 +83,7 @@ const namespaces = {
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -109,6 +109,14 @@ const condenseArray = (item) => {
|
||||
* @returns {string} - The formatted log message.
|
||||
*/
|
||||
const debugTraverse = winston.format.printf(({ level, message, timestamp, ...metadata }) => {
|
||||
if (!message) {
|
||||
return `${timestamp} ${level}`;
|
||||
}
|
||||
|
||||
if (!message?.trim || typeof message !== 'string') {
|
||||
return `${timestamp} ${level}: ${JSON.stringify(message)}`;
|
||||
}
|
||||
|
||||
let msg = `${timestamp} ${level}: ${truncateLongStrings(message?.trim(), 150)}`;
|
||||
try {
|
||||
if (level !== 'debug') {
|
||||
|
||||
61
api/models/Categories.js
Normal file
61
api/models/Categories.js
Normal file
@@ -0,0 +1,61 @@
|
||||
const { logger } = require('~/config');
|
||||
// const { Categories } = require('./schema/categories');
|
||||
const options = [
|
||||
{
|
||||
label: '',
|
||||
value: '',
|
||||
},
|
||||
{
|
||||
label: 'idea',
|
||||
value: 'idea',
|
||||
},
|
||||
{
|
||||
label: 'travel',
|
||||
value: 'travel',
|
||||
},
|
||||
{
|
||||
label: 'teach_or_explain',
|
||||
value: 'teach_or_explain',
|
||||
},
|
||||
{
|
||||
label: 'write',
|
||||
value: 'write',
|
||||
},
|
||||
{
|
||||
label: 'shop',
|
||||
value: 'shop',
|
||||
},
|
||||
{
|
||||
label: 'code',
|
||||
value: 'code',
|
||||
},
|
||||
{
|
||||
label: 'misc',
|
||||
value: 'misc',
|
||||
},
|
||||
{
|
||||
label: 'roleplay',
|
||||
value: 'roleplay',
|
||||
},
|
||||
{
|
||||
label: 'finance',
|
||||
value: 'finance',
|
||||
},
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
/**
|
||||
* Retrieves the categories asynchronously.
|
||||
* @returns {Promise<TGetCategoriesResponse>} An array of category objects.
|
||||
* @throws {Error} If there is an error retrieving the categories.
|
||||
*/
|
||||
getCategories: async () => {
|
||||
try {
|
||||
// const categories = await Categories.find();
|
||||
return options;
|
||||
} catch (error) {
|
||||
logger.error('Error getting categories', error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -19,20 +19,39 @@ const getConvo = async (user, conversationId) => {
|
||||
|
||||
module.exports = {
|
||||
Conversation,
|
||||
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
||||
/**
|
||||
* Saves a conversation to the database.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} conversationId - The conversation's ID.
|
||||
* @param {Object} metadata - Additional metadata to log for operation.
|
||||
* @returns {Promise<TConversation>} The conversation object.
|
||||
*/
|
||||
saveConvo: async (req, { conversationId, newConversationId, ...convo }, metadata) => {
|
||||
try {
|
||||
if (metadata && metadata?.context) {
|
||||
logger.debug(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update = { ...convo, messages, user };
|
||||
const update = { ...convo, messages, user: req.user.id };
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
}
|
||||
|
||||
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: true,
|
||||
});
|
||||
const conversation = await Conversation.findOneAndUpdate(
|
||||
{ conversationId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
upsert: true,
|
||||
},
|
||||
);
|
||||
|
||||
return conversation.toObject();
|
||||
} catch (error) {
|
||||
logger.error('[saveConvo] Error saving conversation', error);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`[saveConvo] ${metadata.context}`);
|
||||
}
|
||||
return { message: 'Error saving conversation' };
|
||||
}
|
||||
},
|
||||
@@ -54,13 +73,16 @@ module.exports = {
|
||||
throw new Error('Failed to save conversations in bulk.');
|
||||
}
|
||||
},
|
||||
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false) => {
|
||||
getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false, tags) => {
|
||||
const query = { user };
|
||||
if (isArchived) {
|
||||
query.isArchived = true;
|
||||
} else {
|
||||
query.$or = [{ isArchived: false }, { isArchived: { $exists: false } }];
|
||||
}
|
||||
if (Array.isArray(tags) && tags.length > 0) {
|
||||
query.tags = { $in: tags };
|
||||
}
|
||||
try {
|
||||
const totalConvos = (await Conversation.countDocuments(query)) || 1;
|
||||
const totalPages = Math.ceil(totalConvos / pageSize);
|
||||
|
||||
268
api/models/ConversationTag.js
Normal file
268
api/models/ConversationTag.js
Normal file
@@ -0,0 +1,268 @@
|
||||
//const crypto = require('crypto');
|
||||
|
||||
const logger = require('~/config/winston');
|
||||
const Conversation = require('./schema/convoSchema');
|
||||
const ConversationTag = require('./schema/conversationTagSchema');
|
||||
|
||||
const SAVED_TAG = 'Saved';
|
||||
|
||||
const updateTagsForConversation = async (user, conversationId, tags) => {
|
||||
try {
|
||||
const conversation = await Conversation.findOne({ user, conversationId });
|
||||
if (!conversation) {
|
||||
return { message: 'Conversation not found' };
|
||||
}
|
||||
|
||||
const addedTags = tags.tags.filter((tag) => !conversation.tags.includes(tag));
|
||||
const removedTags = conversation.tags.filter((tag) => !tags.tags.includes(tag));
|
||||
for (const tag of addedTags) {
|
||||
await ConversationTag.updateOne({ tag, user }, { $inc: { count: 1 } }, { upsert: true });
|
||||
}
|
||||
for (const tag of removedTags) {
|
||||
await ConversationTag.updateOne({ tag, user }, { $inc: { count: -1 } });
|
||||
}
|
||||
conversation.tags = tags.tags;
|
||||
await conversation.save({ timestamps: { updatedAt: false } });
|
||||
return conversation.tags;
|
||||
} catch (error) {
|
||||
logger.error('[updateTagsToConversation] Error updating tags', error);
|
||||
return { message: 'Error updating tags' };
|
||||
}
|
||||
};
|
||||
|
||||
const createConversationTag = async (user, data) => {
|
||||
try {
|
||||
const cTag = await ConversationTag.findOne({ user, tag: data.tag });
|
||||
if (cTag) {
|
||||
return cTag;
|
||||
}
|
||||
|
||||
const addToConversation = data.addToConversation && data.conversationId;
|
||||
const newTag = await ConversationTag.create({
|
||||
user,
|
||||
tag: data.tag,
|
||||
count: 0,
|
||||
description: data.description,
|
||||
position: 1,
|
||||
});
|
||||
|
||||
await ConversationTag.updateMany(
|
||||
{ user, position: { $gte: 1 }, _id: { $ne: newTag._id } },
|
||||
{ $inc: { position: 1 } },
|
||||
);
|
||||
|
||||
if (addToConversation) {
|
||||
const conversation = await Conversation.findOne({
|
||||
user,
|
||||
conversationId: data.conversationId,
|
||||
});
|
||||
if (conversation) {
|
||||
const tags = [...(conversation.tags || []), data.tag];
|
||||
await updateTagsForConversation(user, data.conversationId, { tags });
|
||||
} else {
|
||||
logger.warn('[updateTagsForConversation] Conversation not found', data.conversationId);
|
||||
}
|
||||
}
|
||||
|
||||
return await ConversationTag.findOne({ user, tag: data.tag });
|
||||
} catch (error) {
|
||||
logger.error('[createConversationTag] Error updating conversation tag', error);
|
||||
return { message: 'Error updating conversation tag' };
|
||||
}
|
||||
};
|
||||
|
||||
const replaceOrRemoveTagInConversations = async (user, oldtag, newtag) => {
|
||||
try {
|
||||
const conversations = await Conversation.find({ user, tags: { $in: [oldtag] } });
|
||||
for (const conversation of conversations) {
|
||||
if (newtag && newtag !== '') {
|
||||
conversation.tags = conversation.tags.map((tag) => (tag === oldtag ? newtag : tag));
|
||||
} else {
|
||||
conversation.tags = conversation.tags.filter((tag) => tag !== oldtag);
|
||||
}
|
||||
await conversation.save({ timestamps: { updatedAt: false } });
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[replaceOrRemoveTagInConversations] Error updating conversation tags', error);
|
||||
return { message: 'Error updating conversation tags' };
|
||||
}
|
||||
};
|
||||
|
||||
const updateTagPosition = async (user, tag, newPosition) => {
|
||||
try {
|
||||
const cTag = await ConversationTag.findOne({ user, tag });
|
||||
if (!cTag) {
|
||||
return { message: 'Tag not found' };
|
||||
}
|
||||
|
||||
const oldPosition = cTag.position;
|
||||
|
||||
if (newPosition === oldPosition) {
|
||||
return cTag;
|
||||
}
|
||||
|
||||
const updateOperations = [];
|
||||
|
||||
if (newPosition > oldPosition) {
|
||||
// Move other tags up
|
||||
updateOperations.push({
|
||||
updateMany: {
|
||||
filter: {
|
||||
user,
|
||||
position: { $gt: oldPosition, $lte: newPosition },
|
||||
tag: { $ne: SAVED_TAG },
|
||||
},
|
||||
update: { $inc: { position: -1 } },
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// Move other tags down
|
||||
updateOperations.push({
|
||||
updateMany: {
|
||||
filter: {
|
||||
user,
|
||||
position: { $gte: newPosition, $lt: oldPosition },
|
||||
tag: { $ne: SAVED_TAG },
|
||||
},
|
||||
update: { $inc: { position: 1 } },
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Update the target tag's position
|
||||
updateOperations.push({
|
||||
updateOne: {
|
||||
filter: { _id: cTag._id },
|
||||
update: { $set: { position: newPosition } },
|
||||
},
|
||||
});
|
||||
|
||||
await ConversationTag.bulkWrite(updateOperations);
|
||||
|
||||
return await ConversationTag.findById(cTag._id);
|
||||
} catch (error) {
|
||||
logger.error('[updateTagPosition] Error updating tag position', error);
|
||||
return { message: 'Error updating tag position' };
|
||||
}
|
||||
};
|
||||
module.exports = {
|
||||
SAVED_TAG,
|
||||
ConversationTag,
|
||||
getConversationTags: async (user) => {
|
||||
try {
|
||||
const cTags = await ConversationTag.find({ user }).sort({ position: 1 }).lean();
|
||||
cTags.sort((a, b) => (a.tag === SAVED_TAG ? -1 : b.tag === SAVED_TAG ? 1 : 0));
|
||||
|
||||
return cTags;
|
||||
} catch (error) {
|
||||
logger.error('[getShare] Error getting share link', error);
|
||||
return { message: 'Error getting share link' };
|
||||
}
|
||||
},
|
||||
|
||||
createConversationTag,
|
||||
updateConversationTag: async (user, tag, data) => {
|
||||
try {
|
||||
const cTag = await ConversationTag.findOne({ user, tag });
|
||||
if (!cTag) {
|
||||
return createConversationTag(user, data);
|
||||
}
|
||||
|
||||
if (cTag.tag !== data.tag || cTag.description !== data.description) {
|
||||
cTag.tag = data.tag;
|
||||
cTag.description = data.description === undefined ? cTag.description : data.description;
|
||||
await cTag.save();
|
||||
}
|
||||
|
||||
if (data.position !== undefined && cTag.position !== data.position) {
|
||||
await updateTagPosition(user, tag, data.position);
|
||||
}
|
||||
|
||||
// update conversation tags properties
|
||||
replaceOrRemoveTagInConversations(user, tag, data.tag);
|
||||
return await ConversationTag.findOne({ user, tag: data.tag });
|
||||
} catch (error) {
|
||||
logger.error('[updateConversationTag] Error updating conversation tag', error);
|
||||
return { message: 'Error updating conversation tag' };
|
||||
}
|
||||
},
|
||||
|
||||
deleteConversationTag: async (user, tag) => {
|
||||
try {
|
||||
const currentTag = await ConversationTag.findOne({ user, tag });
|
||||
if (!currentTag) {
|
||||
return;
|
||||
}
|
||||
|
||||
await currentTag.deleteOne({ user, tag });
|
||||
|
||||
await replaceOrRemoveTagInConversations(user, tag, null);
|
||||
return currentTag;
|
||||
} catch (error) {
|
||||
logger.error('[deleteConversationTag] Error deleting conversation tag', error);
|
||||
return { message: 'Error deleting conversation tag' };
|
||||
}
|
||||
},
|
||||
|
||||
updateTagsForConversation,
|
||||
rebuildConversationTags: async (user) => {
|
||||
try {
|
||||
const conversations = await Conversation.find({ user }).select('tags');
|
||||
const tagCountMap = {};
|
||||
|
||||
// Count the occurrences of each tag
|
||||
conversations.forEach((conversation) => {
|
||||
conversation.tags.forEach((tag) => {
|
||||
if (tagCountMap[tag]) {
|
||||
tagCountMap[tag]++;
|
||||
} else {
|
||||
tagCountMap[tag] = 1;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const tags = await ConversationTag.find({ user }).sort({ position: -1 });
|
||||
|
||||
// Update existing tags and add new tags
|
||||
for (const [tag, count] of Object.entries(tagCountMap)) {
|
||||
const existingTag = tags.find((t) => t.tag === tag);
|
||||
if (existingTag) {
|
||||
existingTag.count = count;
|
||||
await existingTag.save();
|
||||
} else {
|
||||
const newTag = new ConversationTag({ user, tag, count });
|
||||
tags.push(newTag);
|
||||
await newTag.save();
|
||||
}
|
||||
}
|
||||
|
||||
// Set count to 0 for tags that are not in the grouped tags
|
||||
for (const tag of tags) {
|
||||
if (!tagCountMap[tag.tag]) {
|
||||
tag.count = 0;
|
||||
await tag.save();
|
||||
}
|
||||
}
|
||||
|
||||
// Sort tags by position in descending order
|
||||
tags.sort((a, b) => a.position - b.position);
|
||||
|
||||
// Move the tag with name "saved" to the first position
|
||||
const savedTagIndex = tags.findIndex((tag) => tag.tag === SAVED_TAG);
|
||||
if (savedTagIndex !== -1) {
|
||||
const [savedTag] = tags.splice(savedTagIndex, 1);
|
||||
tags.unshift(savedTag);
|
||||
}
|
||||
|
||||
// Reassign positions starting from 0
|
||||
tags.forEach((tag, index) => {
|
||||
tag.position = index;
|
||||
tag.save();
|
||||
});
|
||||
return tags;
|
||||
} catch (error) {
|
||||
logger.error('[rearrangeTags] Error rearranging tags', error);
|
||||
return { message: 'Error rearranging tags' };
|
||||
}
|
||||
},
|
||||
};
|
||||
@@ -1,209 +1,342 @@
|
||||
const { z } = require('zod');
|
||||
const Message = require('./schema/messageSchema');
|
||||
const logger = require('~/config/winston');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const idSchema = z.string().uuid();
|
||||
|
||||
/**
|
||||
* Saves a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function saveMessage
|
||||
* @param {Express.Request} req - The request object containing user information.
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.iconURL - The URL of the sender's icon.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.newMessageId - The new unique identifier for the message (if applicable).
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {string} params.sender - The identifier of the sender.
|
||||
* @param {string} params.text - The text content of the message.
|
||||
* @param {boolean} params.isCreatedByUser - Indicates if the message was created by the user.
|
||||
* @param {string} [params.error] - Any error associated with the message.
|
||||
* @param {boolean} [params.unfinished] - Indicates if the message is unfinished.
|
||||
* @param {Object[]} [params.files] - An array of files associated with the message.
|
||||
* @param {boolean} [params.isEdited] - Indicates if the message was edited.
|
||||
* @param {string} [params.finish_reason] - Reason for finishing the message.
|
||||
* @param {number} [params.tokenCount] - The number of tokens in the message.
|
||||
* @param {string} [params.plugin] - Plugin associated with the message.
|
||||
* @param {string[]} [params.plugins] - An array of plugins associated with the message.
|
||||
* @param {string} [params.model] - The model used to generate the message.
|
||||
* @param {Object} [metadata] - Additional metadata for this operation
|
||||
* @param {string} [metadata.context] - The context of the operation
|
||||
* @returns {Promise<TMessage>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function saveMessage(req, params, metadata) {
|
||||
try {
|
||||
if (!req || !req.user || !req.user.id) {
|
||||
throw new Error('User not authenticated');
|
||||
}
|
||||
|
||||
const {
|
||||
text,
|
||||
error,
|
||||
model,
|
||||
files,
|
||||
plugin,
|
||||
sender,
|
||||
plugins,
|
||||
iconURL,
|
||||
endpoint,
|
||||
isEdited,
|
||||
messageId,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
newMessageId,
|
||||
finish_reason,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
isCreatedByUser,
|
||||
} = params;
|
||||
|
||||
const validConvoId = idSchema.safeParse(conversationId);
|
||||
if (!validConvoId.success) {
|
||||
logger.warn(`Invalid conversation ID: ${conversationId}`);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`saveMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
|
||||
logger.info(`---Invalid conversation ID Params:
|
||||
|
||||
${JSON.stringify(params, null, 2)}
|
||||
|
||||
`);
|
||||
return;
|
||||
}
|
||||
|
||||
const update = {
|
||||
user: req.user.id,
|
||||
iconURL,
|
||||
endpoint,
|
||||
messageId: newMessageId || messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
error,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
};
|
||||
|
||||
if (files) {
|
||||
update.files = files;
|
||||
}
|
||||
|
||||
const message = await Message.findOneAndUpdate({ messageId, user: req.user.id }, update, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
|
||||
return message.toObject();
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`saveMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves multiple messages in the database in bulk.
|
||||
*
|
||||
* @async
|
||||
* @function bulkSaveMessages
|
||||
* @param {Object[]} messages - An array of message objects to save.
|
||||
* @returns {Promise<Object>} The result of the bulk write operation.
|
||||
* @throws {Error} If there is an error in saving messages in bulk.
|
||||
*/
|
||||
async function bulkSaveMessages(messages) {
|
||||
try {
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async function recordMessage({
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest
|
||||
}) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error recording message:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the text of a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessageText
|
||||
* @param {Object} params - The update data object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.text - The new text content of the message.
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} If there is an error in updating the message text.
|
||||
*/
|
||||
async function updateMessageText(req, { messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId, user: req.user.id }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a message.
|
||||
*
|
||||
* @async
|
||||
* @function updateMessage
|
||||
* @param {Object} message - The message object containing update data.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} message.messageId - The unique identifier for the message.
|
||||
* @param {string} [message.text] - The new text content of the message.
|
||||
* @param {Object[]} [message.files] - The files associated with the message.
|
||||
* @param {boolean} [message.isCreatedByUser] - Indicates if the message was created by the user.
|
||||
* @param {string} [message.sender] - The identifier of the sender.
|
||||
* @param {number} [message.tokenCount] - The number of tokens in the message.
|
||||
* @param {Object} [metadata] - The operation metadata
|
||||
* @param {string} [metadata.context] - The operation metadata
|
||||
* @returns {Promise<TMessage>} The updated message document.
|
||||
* @throws {Error} If there is an error in updating the message or if the message is not found.
|
||||
*/
|
||||
async function updateMessage(req, message, metadata) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
update.isEdited = true;
|
||||
const updatedMessage = await Message.findOneAndUpdate(
|
||||
{ messageId, user: req.user.id },
|
||||
update,
|
||||
{
|
||||
new: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found or user not authorized.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
isEdited: true,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
if (metadata && metadata?.context) {
|
||||
logger.info(`---\`updateMessage\` context: ${metadata.context}`);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages in a conversation since a specific message.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessagesSince
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {Object} req - The request object.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @returns {Promise<Number>} The number of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessagesSince(req, { messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId, user: req.user.id }).lean();
|
||||
|
||||
if (message) {
|
||||
const query = Message.find({ conversationId, user: req.user.id });
|
||||
return await query.deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
return undefined;
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @async
|
||||
* @function getMessages
|
||||
* @param {Record<string, unknown>} filter - The filter criteria.
|
||||
* @param {string | undefined} [select] - The fields to select.
|
||||
* @returns {Promise<TMessage[]>} The messages that match the filter criteria.
|
||||
* @throws {Error} If there is an error in retrieving messages.
|
||||
*/
|
||||
async function getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes messages from the database.
|
||||
*
|
||||
* @async
|
||||
* @function deleteMessages
|
||||
* @param {Object} filter - The filter criteria to find messages to delete.
|
||||
* @returns {Promise<Number>} The number of deleted messages.
|
||||
* @throws {Error} If there is an error in deleting messages.
|
||||
*/
|
||||
async function deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
Message,
|
||||
|
||||
async saveMessage({
|
||||
user,
|
||||
endpoint,
|
||||
iconURL,
|
||||
messageId,
|
||||
newMessageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
error,
|
||||
unfinished,
|
||||
files,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
}) {
|
||||
try {
|
||||
const validConvoId = idSchema.safeParse(conversationId);
|
||||
if (!validConvoId.success) {
|
||||
return;
|
||||
}
|
||||
|
||||
const update = {
|
||||
user,
|
||||
iconURL,
|
||||
endpoint,
|
||||
messageId: newMessageId || messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
isEdited,
|
||||
finish_reason,
|
||||
error,
|
||||
unfinished,
|
||||
tokenCount,
|
||||
plugin,
|
||||
plugins,
|
||||
model,
|
||||
};
|
||||
|
||||
if (files) {
|
||||
update.files = files;
|
||||
}
|
||||
// may also need to update the conversation here
|
||||
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
|
||||
|
||||
return {
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender,
|
||||
text,
|
||||
isCreatedByUser,
|
||||
tokenCount,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
|
||||
async bulkSaveMessages(messages) {
|
||||
try {
|
||||
const bulkOps = messages.map((message) => ({
|
||||
updateOne: {
|
||||
filter: { messageId: message.messageId },
|
||||
update: message,
|
||||
upsert: true,
|
||||
},
|
||||
}));
|
||||
|
||||
const result = await Message.bulkWrite(bulkOps);
|
||||
return result;
|
||||
} catch (err) {
|
||||
logger.error('Error saving messages in bulk:', err);
|
||||
throw new Error('Failed to save messages in bulk.');
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Records a message in the database.
|
||||
*
|
||||
* @async
|
||||
* @function recordMessage
|
||||
* @param {Object} params - The message data object.
|
||||
* @param {string} params.user - The identifier of the user.
|
||||
* @param {string} params.endpoint - The endpoint where the message originated.
|
||||
* @param {string} params.messageId - The unique identifier for the message.
|
||||
* @param {string} params.conversationId - The identifier of the conversation.
|
||||
* @param {string} [params.parentMessageId] - The identifier of the parent message, if any.
|
||||
* @param {Partial<TMessage>} rest - Any additional properties from the TMessage typedef not explicitly listed.
|
||||
* @returns {Promise<Object>} The updated or newly inserted message document.
|
||||
* @throws {Error} If there is an error in saving the message.
|
||||
*/
|
||||
async recordMessage({ user, endpoint, messageId, conversationId, parentMessageId, ...rest }) {
|
||||
try {
|
||||
// No parsing of convoId as may use threadId
|
||||
const message = {
|
||||
user,
|
||||
endpoint,
|
||||
messageId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
...rest,
|
||||
};
|
||||
|
||||
return await Message.findOneAndUpdate({ user, messageId }, message, {
|
||||
upsert: true,
|
||||
new: true,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error('Error saving message:', err);
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
async updateMessageText({ messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw new Error('Failed to update message text.');
|
||||
}
|
||||
},
|
||||
async updateMessage(message) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
update.isEdited = true;
|
||||
const updatedMessage = await Message.findOneAndUpdate({ messageId }, update, {
|
||||
new: true,
|
||||
});
|
||||
|
||||
if (!updatedMessage) {
|
||||
throw new Error('Message not found.');
|
||||
}
|
||||
|
||||
return {
|
||||
messageId: updatedMessage.messageId,
|
||||
conversationId: updatedMessage.conversationId,
|
||||
parentMessageId: updatedMessage.parentMessageId,
|
||||
sender: updatedMessage.sender,
|
||||
text: updatedMessage.text,
|
||||
isCreatedByUser: updatedMessage.isCreatedByUser,
|
||||
tokenCount: updatedMessage.tokenCount,
|
||||
isEdited: true,
|
||||
};
|
||||
} catch (err) {
|
||||
logger.error('Error updating message:', err);
|
||||
throw new Error('Failed to update message.');
|
||||
}
|
||||
},
|
||||
async deleteMessagesSince({ messageId, conversationId }) {
|
||||
try {
|
||||
const message = await Message.findOne({ messageId }).lean();
|
||||
|
||||
if (message) {
|
||||
return await Message.find({ conversationId }).deleteMany({
|
||||
createdAt: { $gt: message.createdAt },
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw new Error('Failed to delete messages.');
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @param {Record<string, unknown>} filter
|
||||
* @param {string | undefined} [select]
|
||||
* @returns
|
||||
*/
|
||||
async getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
throw new Error('Failed to get messages.');
|
||||
}
|
||||
},
|
||||
|
||||
async deleteMessages(filter) {
|
||||
try {
|
||||
return await Message.deleteMany(filter);
|
||||
} catch (err) {
|
||||
logger.error('Error deleting messages:', err);
|
||||
throw new Error('Failed to delete messages.');
|
||||
}
|
||||
},
|
||||
saveMessage,
|
||||
bulkSaveMessages,
|
||||
recordMessage,
|
||||
updateMessageText,
|
||||
updateMessage,
|
||||
deleteMessagesSince,
|
||||
getMessages,
|
||||
deleteMessages,
|
||||
};
|
||||
|
||||
239
api/models/Message.spec.js
Normal file
239
api/models/Message.spec.js
Normal file
@@ -0,0 +1,239 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
|
||||
jest.mock('mongoose');
|
||||
|
||||
const mockFindQuery = {
|
||||
select: jest.fn().mockReturnThis(),
|
||||
sort: jest.fn().mockReturnThis(),
|
||||
lean: jest.fn().mockReturnThis(),
|
||||
deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }),
|
||||
};
|
||||
|
||||
const mockSchema = {
|
||||
findOneAndUpdate: jest.fn(),
|
||||
updateOne: jest.fn(),
|
||||
findOne: jest.fn(() => ({
|
||||
lean: jest.fn(),
|
||||
})),
|
||||
find: jest.fn(() => mockFindQuery),
|
||||
deleteMany: jest.fn(),
|
||||
};
|
||||
|
||||
mongoose.model.mockReturnValue(mockSchema);
|
||||
|
||||
jest.mock('~/models/schema/messageSchema', () => mockSchema);
|
||||
|
||||
jest.mock('~/config/winston', () => ({
|
||||
error: jest.fn(),
|
||||
}));
|
||||
|
||||
const {
|
||||
saveMessage,
|
||||
getMessages,
|
||||
updateMessage,
|
||||
deleteMessages,
|
||||
updateMessageText,
|
||||
deleteMessagesSince,
|
||||
} = require('~/models/Message');
|
||||
|
||||
describe('Message Operations', () => {
|
||||
let mockReq;
|
||||
let mockMessage;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockReq = {
|
||||
user: { id: 'user123' },
|
||||
};
|
||||
|
||||
mockMessage = {
|
||||
messageId: 'msg123',
|
||||
conversationId: uuidv4(),
|
||||
text: 'Hello, world!',
|
||||
user: 'user123',
|
||||
};
|
||||
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue({
|
||||
toObject: () => mockMessage,
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveMessage', () => {
|
||||
it('should save a message for an authenticated user', async () => {
|
||||
const result = await saveMessage(mockReq, mockMessage);
|
||||
expect(result).toEqual(mockMessage);
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: 'msg123', user: 'user123' },
|
||||
expect.objectContaining({ user: 'user123' }),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error for unauthenticated user', async () => {
|
||||
mockReq.user = null;
|
||||
await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated');
|
||||
});
|
||||
|
||||
it('should throw an error for invalid conversation ID', async () => {
|
||||
mockMessage.conversationId = 'invalid-id';
|
||||
await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessageText', () => {
|
||||
it('should update message text for the authenticated user', async () => {
|
||||
await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' });
|
||||
expect(mockSchema.updateOne).toHaveBeenCalledWith(
|
||||
{ messageId: 'msg123', user: 'user123' },
|
||||
{ text: 'Updated text' },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateMessage', () => {
|
||||
it('should update a message for the authenticated user', async () => {
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage);
|
||||
const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' });
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({
|
||||
messageId: 'msg123',
|
||||
text: 'Hello, world!',
|
||||
isEdited: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if message is not found', async () => {
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(null);
|
||||
await expect(
|
||||
updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessagesSince', () => {
|
||||
it('should delete messages only for the authenticated user', async () => {
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() });
|
||||
mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 });
|
||||
const result = await deleteMessagesSince(mockReq, {
|
||||
messageId: 'msg123',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' });
|
||||
expect(mockSchema.find).not.toHaveBeenCalled();
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if no message is found', async () => {
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce(null);
|
||||
const result = await deleteMessagesSince(mockReq, {
|
||||
messageId: 'nonexistent',
|
||||
conversationId: 'convo123',
|
||||
});
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMessages', () => {
|
||||
it('should retrieve messages with the correct filter', async () => {
|
||||
const filter = { conversationId: 'convo123' };
|
||||
await getMessages(filter);
|
||||
expect(mockSchema.find).toHaveBeenCalledWith(filter);
|
||||
expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 });
|
||||
expect(mockFindQuery.lean).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMessages', () => {
|
||||
it('should delete messages with the correct filter', async () => {
|
||||
await deleteMessages({ user: 'user123' });
|
||||
expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('Conversation Hijacking Prevention', () => {
|
||||
it('should not allow editing a message in another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
mockSchema.findOneAndUpdate.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
updateMessage(attackerReq, {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
text: 'Hacked message',
|
||||
}),
|
||||
).rejects.toThrow('Message not found or user not authorized.');
|
||||
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: victimMessageId, user: 'attacker123' },
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not allow deleting messages from another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
const victimMessageId = 'victim-msg-123';
|
||||
|
||||
mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user
|
||||
const result = await deleteMessagesSince(attackerReq, {
|
||||
messageId: victimMessageId,
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(mockSchema.findOne).toHaveBeenCalledWith({
|
||||
messageId: victimMessageId,
|
||||
user: 'attacker123',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not allow inserting a new message into another user\'s conversation', async () => {
|
||||
const attackerReq = { user: { id: 'attacker123' } };
|
||||
const victimConversationId = uuidv4(); // Use a valid UUID
|
||||
|
||||
await expect(
|
||||
saveMessage(attackerReq, {
|
||||
conversationId: victimConversationId,
|
||||
text: 'Inserted malicious message',
|
||||
messageId: 'new-msg-123',
|
||||
}),
|
||||
).resolves.not.toThrow(); // It should not throw an error
|
||||
|
||||
// Check that the message was saved with the attacker's user ID
|
||||
expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith(
|
||||
{ messageId: 'new-msg-123', user: 'attacker123' },
|
||||
expect.objectContaining({
|
||||
user: 'attacker123',
|
||||
conversationId: victimConversationId,
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow retrieving messages from any conversation', async () => {
|
||||
const victimConversationId = 'victim-convo-123';
|
||||
|
||||
await getMessages({ conversationId: victimConversationId });
|
||||
|
||||
expect(mockSchema.find).toHaveBeenCalledWith({
|
||||
conversationId: victimConversationId,
|
||||
});
|
||||
|
||||
mockSchema.find.mockReturnValueOnce({
|
||||
select: jest.fn().mockReturnThis(),
|
||||
sort: jest.fn().mockReturnThis(),
|
||||
lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]),
|
||||
});
|
||||
|
||||
const result = await getMessages({ conversationId: victimConversationId });
|
||||
expect(result).toEqual([{ text: 'Test message' }]);
|
||||
});
|
||||
});
|
||||
});
|
||||
90
api/models/Project.js
Normal file
90
api/models/Project.js
Normal file
@@ -0,0 +1,90 @@
|
||||
const { model } = require('mongoose');
|
||||
const projectSchema = require('~/models/schema/projectSchema');
|
||||
|
||||
const Project = model('Project', projectSchema);
|
||||
|
||||
/**
|
||||
* Retrieve a project by ID and convert the found project document to a plain object.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to find and return as a plain object.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document, or `null` if no project is found.
|
||||
*/
|
||||
const getProjectById = async function (projectId, fieldsToSelect = null) {
|
||||
const query = Project.findById(projectId);
|
||||
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
return await query.lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve a project by name and convert the found project document to a plain object.
|
||||
* If the project with the given name doesn't exist and the name is "instance", create it and return the lean version.
|
||||
*
|
||||
* @param {string} projectName - The name of the project to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<MongoProject>} A plain object representing the project document.
|
||||
*/
|
||||
const getProjectByName = async function (projectName, fieldsToSelect = null) {
|
||||
const query = { name: projectName };
|
||||
const update = { $setOnInsert: { name: projectName } };
|
||||
const options = {
|
||||
new: true,
|
||||
upsert: projectName === 'instance',
|
||||
lean: true,
|
||||
select: fieldsToSelect,
|
||||
};
|
||||
|
||||
return await Project.findOneAndUpdate(query, update, options);
|
||||
};
|
||||
|
||||
/**
|
||||
* Add an array of prompt group IDs to a project's promptGroupIds array, ensuring uniqueness.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const addGroupIdsToProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $addToSet: { promptGroupIds: { $each: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove an array of prompt group IDs from a project's promptGroupIds array.
|
||||
*
|
||||
* @param {string} projectId - The ID of the project to update.
|
||||
* @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project.
|
||||
* @returns {Promise<MongoProject>} The updated project document.
|
||||
*/
|
||||
const removeGroupIdsFromProject = async function (projectId, promptGroupIds) {
|
||||
return await Project.findByIdAndUpdate(
|
||||
projectId,
|
||||
{ $pull: { promptGroupIds: { $in: promptGroupIds } } },
|
||||
{ new: true },
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Remove a prompt group ID from all projects.
|
||||
*
|
||||
* @param {string} promptGroupId - The ID of the prompt group to remove from projects.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const removeGroupFromAllProjects = async (promptGroupId) => {
|
||||
await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getProjectById,
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
};
|
||||
@@ -1,52 +1,528 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { SystemRoles, SystemCategories } = require('librechat-data-provider');
|
||||
const {
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
removeGroupFromAllProjects,
|
||||
} = require('./Project');
|
||||
const { Prompt, PromptGroup } = require('./schema/promptSchema');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const promptSchema = mongoose.Schema(
|
||||
{
|
||||
title: {
|
||||
type: String,
|
||||
required: true,
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
* @param {Object} query
|
||||
* @param {number} skip
|
||||
* @param {number} limit
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createGroupPipeline = (query, skip, limit) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{ $skip: skip },
|
||||
{ $limit: limit },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
// 'productionPrompt._id': 1,
|
||||
// 'productionPrompt.type': 1,
|
||||
},
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
];
|
||||
};
|
||||
|
||||
const Prompt = mongoose.models.Prompt || mongoose.model('Prompt', promptSchema);
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get all prompt groups
|
||||
* @param {Object} query
|
||||
* @param {Partial<MongoPromptGroup>} $project
|
||||
* @returns {[Object]} - The pipeline for the aggregation
|
||||
*/
|
||||
const createAllGroupsPipeline = (
|
||||
query,
|
||||
$project = {
|
||||
name: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
command: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
) => {
|
||||
return [
|
||||
{ $match: query },
|
||||
{ $sort: { createdAt: -1 } },
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project,
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all prompt groups with filters
|
||||
* @param {Object} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getAllPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { name, ...query } = filter;
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(name, 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||
if (project && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const promptGroupsPipeline = createAllGroupsPipeline(combinedQuery);
|
||||
return await PromptGroup.aggregate(promptGroupsPipeline).exec();
|
||||
} catch (error) {
|
||||
console.error('Error getting all prompt groups', error);
|
||||
return { message: 'Error getting all prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {Object} req
|
||||
* @param {TPromptGroupsWithFilterRequest} filter
|
||||
* @returns {Promise<PromptGroupListResponse>}
|
||||
*/
|
||||
const getPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { pageNumber = 1, pageSize = 10, name, ...query } = filter;
|
||||
|
||||
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
query.name = new RegExp(name, 'i');
|
||||
}
|
||||
if (!query.category) {
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.MY_PROMPTS) {
|
||||
searchShared = false;
|
||||
delete query.category;
|
||||
} else if (query.category === SystemCategories.NO_CATEGORY) {
|
||||
query.category = '';
|
||||
} else if (query.category === SystemCategories.SHARED_PROMPTS) {
|
||||
searchSharedOnly = true;
|
||||
delete query.category;
|
||||
}
|
||||
|
||||
let combinedQuery = query;
|
||||
|
||||
if (searchShared) {
|
||||
// const projects = req.user.projects || []; // TODO: handle multiple projects
|
||||
const project = await getProjectByName('instance', 'promptGroupIds');
|
||||
if (project && project.promptGroupIds.length > 0) {
|
||||
const projectQuery = { _id: { $in: project.promptGroupIds }, ...query };
|
||||
delete projectQuery.author;
|
||||
combinedQuery = searchSharedOnly ? projectQuery : { $or: [projectQuery, query] };
|
||||
}
|
||||
}
|
||||
|
||||
const skip = (validatedPageNumber - 1) * validatedPageSize;
|
||||
const limit = validatedPageSize;
|
||||
|
||||
const promptGroupsPipeline = createGroupPipeline(combinedQuery, skip, limit);
|
||||
const totalPromptGroupsPipeline = [{ $match: combinedQuery }, { $count: 'total' }];
|
||||
|
||||
const [promptGroupsResults, totalPromptGroupsResults] = await Promise.all([
|
||||
PromptGroup.aggregate(promptGroupsPipeline).exec(),
|
||||
PromptGroup.aggregate(totalPromptGroupsPipeline).exec(),
|
||||
]);
|
||||
|
||||
const promptGroups = promptGroupsResults;
|
||||
const totalPromptGroups =
|
||||
totalPromptGroupsResults.length > 0 ? totalPromptGroupsResults[0].total : 0;
|
||||
|
||||
return {
|
||||
promptGroups,
|
||||
pageNumber: validatedPageNumber.toString(),
|
||||
pageSize: validatedPageSize.toString(),
|
||||
pages: Math.ceil(totalPromptGroups / validatedPageSize).toString(),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
savePrompt: async ({ title, prompt }) => {
|
||||
getPromptGroups,
|
||||
getAllPromptGroups,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
createPromptGroup: async (saveData) => {
|
||||
try {
|
||||
await Prompt.create({
|
||||
title,
|
||||
prompt,
|
||||
});
|
||||
return { title, prompt };
|
||||
const { prompt, group, author, authorName } = saveData;
|
||||
|
||||
let newPromptGroup = await PromptGroup.findOneAndUpdate(
|
||||
{ ...group, author, authorName, productionId: null },
|
||||
{ $setOnInsert: { ...group, author, authorName, productionId: null } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
const newPrompt = await Prompt.findOneAndUpdate(
|
||||
{ ...prompt, author, groupId: newPromptGroup._id },
|
||||
{ $setOnInsert: { ...prompt, author, groupId: newPromptGroup._id } },
|
||||
{ new: true, upsert: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
newPromptGroup = await PromptGroup.findByIdAndUpdate(
|
||||
newPromptGroup._id,
|
||||
{ productionId: newPrompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.select('-__v')
|
||||
.exec();
|
||||
|
||||
return {
|
||||
prompt: newPrompt,
|
||||
group: {
|
||||
...newPromptGroup,
|
||||
productionPrompt: { prompt: newPrompt.prompt },
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt group', error);
|
||||
throw new Error('Error saving prompt group');
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Save a prompt
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
* @returns {Promise<TCreatePromptResponse>}
|
||||
*/
|
||||
savePrompt: async (saveData) => {
|
||||
try {
|
||||
const { prompt, author } = saveData;
|
||||
const newPromptData = {
|
||||
...prompt,
|
||||
author,
|
||||
};
|
||||
|
||||
/** @type {TPrompt} */
|
||||
let newPrompt;
|
||||
try {
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
} catch (error) {
|
||||
if (error?.message?.includes('groupId_1_version_1')) {
|
||||
await Prompt.db.collection('prompts').dropIndex('groupId_1_version_1');
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
newPrompt = await Prompt.create(newPromptData);
|
||||
}
|
||||
|
||||
return { prompt: newPrompt };
|
||||
} catch (error) {
|
||||
logger.error('Error saving prompt', error);
|
||||
return { prompt: 'Error saving prompt' };
|
||||
return { message: 'Error saving prompt' };
|
||||
}
|
||||
},
|
||||
getPrompts: async (filter) => {
|
||||
try {
|
||||
return await Prompt.find(filter).lean();
|
||||
return await Prompt.find(filter).sort({ createdAt: -1 }).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompts', error);
|
||||
return { prompt: 'Error getting prompts' };
|
||||
return { message: 'Error getting prompts' };
|
||||
}
|
||||
},
|
||||
deletePrompts: async (filter) => {
|
||||
getPrompt: async (filter) => {
|
||||
try {
|
||||
return await Prompt.deleteMany(filter);
|
||||
if (filter.groupId) {
|
||||
filter.groupId = new ObjectId(filter.groupId);
|
||||
}
|
||||
return await Prompt.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompts', error);
|
||||
return { prompt: 'Error deleting prompts' };
|
||||
logger.error('Error getting prompt', error);
|
||||
return { message: 'Error getting prompt' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Get prompt groups with filters
|
||||
* @param {TGetRandomPromptsRequest} filter
|
||||
* @returns {Promise<TGetRandomPromptsResponse>}
|
||||
*/
|
||||
getRandomPromptGroups: async (filter) => {
|
||||
try {
|
||||
const result = await PromptGroup.aggregate([
|
||||
{
|
||||
$match: {
|
||||
category: { $ne: '' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$group: {
|
||||
_id: '$category',
|
||||
promptGroup: { $first: '$$ROOT' },
|
||||
},
|
||||
},
|
||||
{
|
||||
$replaceRoot: { newRoot: '$promptGroup' },
|
||||
},
|
||||
{
|
||||
$sample: { size: +filter.limit + +filter.skip },
|
||||
},
|
||||
{
|
||||
$skip: +filter.skip,
|
||||
},
|
||||
{
|
||||
$limit: +filter.limit,
|
||||
},
|
||||
]);
|
||||
return { prompts: result };
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroupsWithPrompts: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter)
|
||||
.populate({
|
||||
path: 'prompts',
|
||||
select: '-_id -__v -user',
|
||||
})
|
||||
.select('-_id -__v -user')
|
||||
.lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt groups', error);
|
||||
return { message: 'Error getting prompt groups' };
|
||||
}
|
||||
},
|
||||
getPromptGroup: async (filter) => {
|
||||
try {
|
||||
return await PromptGroup.findOne(filter).lean();
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
return { message: 'Error getting prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Deletes a prompt and its corresponding prompt group if it is the last prompt in the group.
|
||||
*
|
||||
* @param {Object} options - The options for deleting the prompt.
|
||||
* @param {ObjectId|string} options.promptId - The ID of the prompt to delete.
|
||||
* @param {ObjectId|string} options.groupId - The ID of the prompt's group.
|
||||
* @param {ObjectId|string} options.author - The ID of the prompt's author.
|
||||
* @param {string} options.role - The role of the prompt's author.
|
||||
* @return {Promise<TDeletePromptResponse>} An object containing the result of the deletion.
|
||||
* If the prompt was deleted successfully, the object will have a property 'prompt' with the value 'Prompt deleted successfully'.
|
||||
* If the prompt group was deleted successfully, the object will have a property 'promptGroup' with the message 'Prompt group deleted successfully' and id of the deleted group.
|
||||
* If there was an error deleting the prompt, the object will have a property 'message' with the value 'Error deleting prompt'.
|
||||
*/
|
||||
deletePrompt: async ({ promptId, groupId, author, role }) => {
|
||||
const query = { _id: promptId, groupId, author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const { deletedCount } = await Prompt.deleteOne(query);
|
||||
if (deletedCount === 0) {
|
||||
throw new Error('Failed to delete the prompt');
|
||||
}
|
||||
|
||||
const remainingPrompts = await Prompt.find({ groupId })
|
||||
.select('_id')
|
||||
.sort({ createdAt: 1 })
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
|
||||
return {
|
||||
prompt: 'Prompt deleted successfully',
|
||||
promptGroup: {
|
||||
message: 'Prompt group deleted successfully',
|
||||
id: groupId,
|
||||
},
|
||||
};
|
||||
} else {
|
||||
const promptGroup = await PromptGroup.findById(groupId).lean();
|
||||
if (promptGroup.productionId.toString() === promptId.toString()) {
|
||||
await PromptGroup.updateOne(
|
||||
{ _id: groupId },
|
||||
{ productionId: remainingPrompts[remainingPrompts.length - 1]._id },
|
||||
);
|
||||
}
|
||||
|
||||
return { prompt: 'Prompt deleted successfully' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Update prompt group
|
||||
* @param {Partial<MongoPromptGroup>} filter - Filter to find prompt group
|
||||
* @param {Partial<MongoPromptGroup>} data - Data to update
|
||||
* @returns {Promise<TUpdatePromptGroupResponse>}
|
||||
*/
|
||||
updatePromptGroup: async (filter, data) => {
|
||||
try {
|
||||
const updateOps = {};
|
||||
if (data.removeProjectIds) {
|
||||
for (const projectId of data.removeProjectIds) {
|
||||
await removeGroupIdsFromProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$pull = { projectIds: { $in: data.removeProjectIds } };
|
||||
delete data.removeProjectIds;
|
||||
}
|
||||
|
||||
if (data.projectIds) {
|
||||
for (const projectId of data.projectIds) {
|
||||
await addGroupIdsToProject(projectId, [filter._id]);
|
||||
}
|
||||
|
||||
updateOps.$addToSet = { projectIds: { $each: data.projectIds } };
|
||||
delete data.projectIds;
|
||||
}
|
||||
|
||||
const updateData = { ...data, ...updateOps };
|
||||
const updatedDoc = await PromptGroup.findOneAndUpdate(filter, updateData, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
|
||||
if (!updatedDoc) {
|
||||
throw new Error('Prompt group not found');
|
||||
}
|
||||
|
||||
return updatedDoc;
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt group', error);
|
||||
return { message: 'Error updating prompt group' };
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Function to make a prompt production based on its ID.
|
||||
* @param {String} promptId - The ID of the prompt to make production.
|
||||
* @returns {Object} The result of the production operation.
|
||||
*/
|
||||
makePromptProduction: async (promptId) => {
|
||||
try {
|
||||
const prompt = await Prompt.findById(promptId).lean();
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('Prompt not found');
|
||||
}
|
||||
|
||||
await PromptGroup.findByIdAndUpdate(
|
||||
prompt.groupId,
|
||||
{ productionId: prompt._id },
|
||||
{ new: true },
|
||||
)
|
||||
.lean()
|
||||
.exec();
|
||||
|
||||
return {
|
||||
message: 'Prompt production made successfully',
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('Error making prompt production', error);
|
||||
return { message: 'Error making prompt production' };
|
||||
}
|
||||
},
|
||||
updatePromptLabels: async (_id, labels) => {
|
||||
try {
|
||||
const response = await Prompt.updateOne({ _id }, { $set: { labels } });
|
||||
if (response.matchedCount === 0) {
|
||||
return { message: 'Prompt not found' };
|
||||
}
|
||||
return { message: 'Prompt labels updated successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error updating prompt labels', error);
|
||||
return { message: 'Error updating prompt labels' };
|
||||
}
|
||||
},
|
||||
deletePromptGroup: async (_id) => {
|
||||
try {
|
||||
const response = await PromptGroup.deleteOne({ _id });
|
||||
|
||||
if (response.deletedCount === 0) {
|
||||
return { promptGroup: 'Prompt group not found' };
|
||||
}
|
||||
|
||||
await Prompt.deleteMany({ groupId: new ObjectId(_id) });
|
||||
await removeGroupFromAllProjects(_id);
|
||||
return { promptGroup: 'Prompt group deleted successfully' };
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompt group', error);
|
||||
return { message: 'Error deleting prompt group' };
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
86
api/models/Role.js
Normal file
86
api/models/Role.js
Normal file
@@ -0,0 +1,86 @@
|
||||
const { SystemRoles, CacheKeys, roleDefaults } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const Role = require('~/models/schema/roleSchema');
|
||||
|
||||
/**
|
||||
* Retrieve a role by name and convert the found role document to a plain object.
|
||||
* If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const cachedRole = await cache.get(roleName);
|
||||
if (cachedRole) {
|
||||
return cachedRole;
|
||||
}
|
||||
let query = Role.findOne({ name: roleName });
|
||||
if (fieldsToSelect) {
|
||||
query = query.select(fieldsToSelect);
|
||||
}
|
||||
let role = await query.lean().exec();
|
||||
|
||||
if (!role && SystemRoles[roleName]) {
|
||||
role = roleDefaults[roleName];
|
||||
role = await new Role(role).save();
|
||||
await cache.set(roleName, role);
|
||||
return role.toObject();
|
||||
}
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to retrieve or create role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Update role values by name.
|
||||
*
|
||||
* @param {string} roleName - The name of the role to update.
|
||||
* @param {Partial<TRole>} updates - The fields to update.
|
||||
* @returns {Promise<TRole>} Updated role document.
|
||||
*/
|
||||
const updateRoleByName = async function (roleName, updates) {
|
||||
try {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
const role = await Role.findOneAndUpdate(
|
||||
{ name: roleName },
|
||||
{ $set: updates },
|
||||
{ new: true, lean: true },
|
||||
)
|
||||
.select('-__v')
|
||||
.lean()
|
||||
.exec();
|
||||
await cache.set(roleName, role);
|
||||
return role;
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to update role: ${error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize default roles in the system.
|
||||
* Creates the default roles (ADMIN, USER) if they don't exist in the database.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const initializeRoles = async function () {
|
||||
const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER];
|
||||
|
||||
for (const roleName of defaultRoles) {
|
||||
let role = await Role.findOne({ name: roleName }).select('name').lean();
|
||||
if (!role) {
|
||||
role = new Role(roleDefaults[roleName]);
|
||||
await role.save();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getRoleByName,
|
||||
initializeRoles,
|
||||
updateRoleByName,
|
||||
};
|
||||
@@ -22,7 +22,7 @@ module.exports = {
|
||||
return share;
|
||||
} catch (error) {
|
||||
logger.error('[getShare] Error getting share link', error);
|
||||
return { message: 'Error getting share link' };
|
||||
throw new Error('Error getting share link');
|
||||
}
|
||||
},
|
||||
|
||||
@@ -41,17 +41,17 @@ module.exports = {
|
||||
return { sharedLinks: shares, pages: totalPages, pageNumber, pageSize };
|
||||
} catch (error) {
|
||||
logger.error('[getShareByPage] Error getting shares', error);
|
||||
return { message: 'Error getting shares' };
|
||||
throw new Error('Error getting shares');
|
||||
}
|
||||
},
|
||||
|
||||
createSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (share) {
|
||||
return share;
|
||||
}
|
||||
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (share) {
|
||||
return share;
|
||||
}
|
||||
|
||||
const shareId = crypto.randomUUID();
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, shareId, messages, user };
|
||||
@@ -60,31 +60,42 @@ module.exports = {
|
||||
upsert: true,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[saveShareMessage] Error saving conversation', error);
|
||||
return { message: 'Error saving conversation' };
|
||||
logger.error('[createSharedLink] Error creating shared link', error);
|
||||
throw new Error('Error creating shared link');
|
||||
}
|
||||
},
|
||||
|
||||
updateSharedLink: async (user, { conversationId, ...shareData }) => {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
try {
|
||||
const share = await SharedLink.findOne({ conversationId }).select('-_id -__v -user').lean();
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
}
|
||||
|
||||
// update messages to the latest
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, messages, user };
|
||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[updateSharedLink] Error updating shared link', error);
|
||||
throw new Error('Error updating shared link');
|
||||
}
|
||||
// update messages to the latest
|
||||
const messages = await getMessages({ conversationId });
|
||||
const update = { ...shareData, messages, user };
|
||||
return await SharedLink.findOneAndUpdate({ conversationId: conversationId, user }, update, {
|
||||
new: true,
|
||||
upsert: false,
|
||||
});
|
||||
},
|
||||
|
||||
deleteSharedLink: async (user, { shareId }) => {
|
||||
const share = await SharedLink.findOne({ shareId, user });
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
try {
|
||||
const share = await SharedLink.findOne({ shareId, user });
|
||||
if (!share) {
|
||||
return { message: 'Share not found' };
|
||||
}
|
||||
return await SharedLink.findOneAndDelete({ shareId, user });
|
||||
} catch (error) {
|
||||
logger.error('[deleteSharedLink] Error deleting shared link', error);
|
||||
throw new Error('Error deleting shared link');
|
||||
}
|
||||
return await SharedLink.findOneAndDelete({ shareId, user });
|
||||
},
|
||||
/**
|
||||
* Deletes all shared links for a specific user.
|
||||
@@ -100,7 +111,7 @@ module.exports = {
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllSharedLinks] Error deleting shared links', error);
|
||||
return { message: 'Error deleting shared links' };
|
||||
throw new Error('Error deleting shared links');
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
19
api/models/schema/categories.js
Normal file
19
api/models/schema/categories.js
Normal file
@@ -0,0 +1,19 @@
|
||||
const mongoose = require('mongoose');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
const categoriesSchema = new Schema({
|
||||
label: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
value: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
},
|
||||
});
|
||||
|
||||
const categories = mongoose.model('categories', categoriesSchema);
|
||||
|
||||
module.exports = { Categories: categories };
|
||||
31
api/models/schema/conversationTagSchema.js
Normal file
31
api/models/schema/conversationTagSchema.js
Normal file
@@ -0,0 +1,31 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const conversationTagSchema = mongoose.Schema(
|
||||
{
|
||||
tag: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
user: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
description: {
|
||||
type: String,
|
||||
index: true,
|
||||
},
|
||||
count: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
position: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
conversationTagSchema.index({ tag: 1, user: 1 }, { unique: true });
|
||||
|
||||
module.exports = mongoose.model('ConversationTag', conversationTagSchema);
|
||||
@@ -42,6 +42,11 @@ const convoSchema = mongoose.Schema(
|
||||
invocationId: {
|
||||
type: Number,
|
||||
},
|
||||
tags: {
|
||||
type: [String],
|
||||
default: [],
|
||||
meiliIndex: true,
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
||||
@@ -103,6 +103,10 @@ const conversationPreset = {
|
||||
spec: {
|
||||
type: String,
|
||||
},
|
||||
tags: {
|
||||
type: [String],
|
||||
default: [],
|
||||
},
|
||||
tools: { type: [{ type: String }], default: undefined },
|
||||
maxContextTokens: {
|
||||
type: Number,
|
||||
|
||||
@@ -129,6 +129,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) {
|
||||
}
|
||||
|
||||
messageSchema.index({ createdAt: 1 });
|
||||
messageSchema.index({ messageId: 1, user: 1 }, { unique: true });
|
||||
|
||||
const Message = mongoose.models.Message || mongoose.model('Message', messageSchema);
|
||||
|
||||
|
||||
30
api/models/schema/projectSchema.js
Normal file
30
api/models/schema/projectSchema.js
Normal file
@@ -0,0 +1,30 @@
|
||||
const { Schema } = require('mongoose');
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoProject
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the project
|
||||
* @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project
|
||||
* @property {Date} [createdAt] - Date when the project was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const projectSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
promptGroupIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'PromptGroup',
|
||||
default: [],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = projectSchema;
|
||||
118
api/models/schema/promptSchema.js
Normal file
118
api/models/schema/promptSchema.js
Normal file
@@ -0,0 +1,118 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const Schema = mongoose.Schema;
|
||||
|
||||
/**
|
||||
* @typedef {Object} MongoPromptGroup
|
||||
* @property {ObjectId} [_id] - MongoDB Document ID
|
||||
* @property {string} name - The name of the prompt group
|
||||
* @property {ObjectId} author - The author of the prompt group
|
||||
* @property {ObjectId} [projectId=null] - The project ID of the prompt group
|
||||
* @property {ObjectId} [productionId=null] - The project ID of the prompt group
|
||||
* @property {string} authorName - The name of the author of the prompt group
|
||||
* @property {number} [numberOfGenerations=0] - Number of generations the prompt group has
|
||||
* @property {string} [oneliner=''] - Oneliner description of the prompt group
|
||||
* @property {string} [category=''] - Category of the prompt group
|
||||
* @property {string} [command] - Command for the prompt group
|
||||
* @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps)
|
||||
* @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps)
|
||||
*/
|
||||
|
||||
const promptGroupSchema = new Schema(
|
||||
{
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
numberOfGenerations: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
oneliner: {
|
||||
type: String,
|
||||
default: '',
|
||||
},
|
||||
category: {
|
||||
type: String,
|
||||
default: '',
|
||||
index: true,
|
||||
},
|
||||
projectIds: {
|
||||
type: [Schema.Types.ObjectId],
|
||||
ref: 'Project',
|
||||
index: true,
|
||||
},
|
||||
productionId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'Prompt',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
authorName: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
command: {
|
||||
type: String,
|
||||
index: true,
|
||||
validate: {
|
||||
validator: function (v) {
|
||||
return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v);
|
||||
},
|
||||
message: (props) =>
|
||||
`${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`,
|
||||
},
|
||||
maxlength: [
|
||||
Constants.COMMANDS_MAX_LENGTH,
|
||||
`Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`,
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema);
|
||||
|
||||
const promptSchema = new Schema(
|
||||
{
|
||||
groupId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'PromptGroup',
|
||||
required: true,
|
||||
index: true,
|
||||
},
|
||||
author: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
prompt: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
type: {
|
||||
type: String,
|
||||
enum: ['text', 'chat'],
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
timestamps: true,
|
||||
},
|
||||
);
|
||||
|
||||
const Prompt = mongoose.model('Prompt', promptSchema);
|
||||
|
||||
promptSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
promptGroupSchema.index({ createdAt: 1, updatedAt: 1 });
|
||||
|
||||
module.exports = { Prompt, PromptGroup };
|
||||
29
api/models/schema/roleSchema.js
Normal file
29
api/models/schema/roleSchema.js
Normal file
@@ -0,0 +1,29 @@
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const roleSchema = new mongoose.Schema({
|
||||
name: {
|
||||
type: String,
|
||||
required: true,
|
||||
unique: true,
|
||||
index: true,
|
||||
},
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: {
|
||||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
[Permissions.USE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
[Permissions.CREATE]: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const Role = mongoose.model('Role', roleSchema);
|
||||
|
||||
module.exports = Role;
|
||||
@@ -12,11 +12,13 @@ const tokenValues = {
|
||||
'4k': { prompt: 1.5, completion: 2 },
|
||||
'16k': { prompt: 3, completion: 4 },
|
||||
'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 },
|
||||
'gpt-4o-mini': { prompt: 0.15, completion: 0.6 },
|
||||
'gpt-4o': { prompt: 5, completion: 15 },
|
||||
'gpt-4-1106': { prompt: 10, completion: 30 },
|
||||
'gpt-3.5-turbo-0125': { prompt: 0.5, completion: 1.5 },
|
||||
'claude-3-opus': { prompt: 15, completion: 75 },
|
||||
'claude-3-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-5-sonnet': { prompt: 3, completion: 15 },
|
||||
'claude-3-haiku': { prompt: 0.25, completion: 1.25 },
|
||||
'claude-2.1': { prompt: 8, completion: 24 },
|
||||
'claude-2': { prompt: 8, completion: 24 },
|
||||
@@ -53,6 +55,8 @@ const getValueKey = (model, endpoint) => {
|
||||
return 'gpt-3.5-turbo-1106';
|
||||
} else if (modelName.includes('gpt-3.5')) {
|
||||
return '4k';
|
||||
} else if (modelName.includes('gpt-4o-mini')) {
|
||||
return 'gpt-4o-mini';
|
||||
} else if (modelName.includes('gpt-4o')) {
|
||||
return 'gpt-4o';
|
||||
} else if (modelName.includes('gpt-4-vision')) {
|
||||
|
||||
@@ -48,6 +48,19 @@ describe('getValueKey', () => {
|
||||
expect(getValueKey('gpt-4o-turbo')).toBe('gpt-4o');
|
||||
expect(getValueKey('gpt-4o-0125')).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
it('should return "gpt-4o-mini" for model type of "gpt-4o-mini"', () => {
|
||||
expect(getValueKey('gpt-4o-mini-2024-07-18')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('openai/gpt-4o-mini')).toBe('gpt-4o-mini');
|
||||
expect(getValueKey('gpt-4o-mini-0718')).toBe('gpt-4o-mini');
|
||||
});
|
||||
|
||||
it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => {
|
||||
expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-turbo')).toBe('claude-3-5-sonnet');
|
||||
expect(getValueKey('claude-3-5-sonnet-0125')).toBe('claude-3-5-sonnet');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getMultiplier', () => {
|
||||
@@ -102,6 +115,19 @@ describe('getMultiplier', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct multiplier for gpt-4o-mini', () => {
|
||||
const valueKey = getValueKey('gpt-4o-mini-2024-07-18');
|
||||
expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(
|
||||
tokenValues['gpt-4o-mini'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe(
|
||||
tokenValues['gpt-4o-mini'].completion,
|
||||
);
|
||||
expect(getMultiplier({ valueKey, tokenType: 'completion' })).not.toBe(
|
||||
tokenValues['gpt-4-1106'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
it('should derive the valueKey from the model if not provided for new models', () => {
|
||||
expect(
|
||||
getMultiplier({ tokenType: 'prompt', model: 'gpt-3.5-turbo-1106-some-other-info' }),
|
||||
|
||||
@@ -56,10 +56,11 @@ const updateUser = async function (userId, updateData) {
|
||||
* Creates a new user, optionally with a TTL of 1 week.
|
||||
* @param {MongoUser} data - The user data to be created, must contain user_id.
|
||||
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`.
|
||||
* @returns {Promise<ObjectId>} A promise that resolves to the created user document ID.
|
||||
* @throws {Error} If a user with the same user_id already exists.
|
||||
*/
|
||||
const createUser = async (data, disableTTL = true) => {
|
||||
const createUser = async (data, disableTTL = true, returnUser = false) => {
|
||||
const userData = {
|
||||
...data,
|
||||
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
|
||||
@@ -69,17 +70,11 @@ const createUser = async (data, disableTTL = true) => {
|
||||
delete userData.expiresAt;
|
||||
}
|
||||
|
||||
try {
|
||||
const user = await User.create(userData);
|
||||
return user._id;
|
||||
} catch (error) {
|
||||
if (error.code === 11000) {
|
||||
// Duplicate key error code
|
||||
throw new Error(`User with \`_id\` ${data._id} already exists.`);
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
const user = await User.create(userData);
|
||||
if (returnUser) {
|
||||
return user.toObject();
|
||||
}
|
||||
return user._id;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@librechat/backend",
|
||||
"version": "0.7.3",
|
||||
"version": "0.7.4-rc1",
|
||||
"description": "",
|
||||
"scripts": {
|
||||
"start": "echo 'please run this from the root directory'",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
@@ -18,6 +19,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
@@ -34,6 +36,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
@@ -48,11 +52,13 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
|
||||
try {
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
onProgress: throttle(
|
||||
({ text: partialText }) => {
|
||||
saveMessage({
|
||||
/*
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
messageCache.set(responseMessageId, {
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
@@ -62,7 +68,10 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
unfinished,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
}, Time.FIVE_MINUTES);
|
||||
*/
|
||||
|
||||
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||
},
|
||||
3000,
|
||||
{ trailing: false },
|
||||
@@ -74,6 +83,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
@@ -81,7 +91,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[AskController] Request closed');
|
||||
@@ -108,7 +118,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
};
|
||||
@@ -121,7 +130,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
|
||||
response.endpoint = endpointOption.endpoint;
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
@@ -141,10 +150,18 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/controllers/AskController.js - response end' },
|
||||
);
|
||||
}
|
||||
|
||||
await saveMessage(userMessage);
|
||||
if (!client.skipSaveUserMessage) {
|
||||
await saveMessage(req, userMessage, {
|
||||
context: 'api/server/controllers/AskController.js - don\'t skip saving user message',
|
||||
});
|
||||
}
|
||||
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { createAbortController, handleAbortError } = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const EditController = async (req, res, next, initializeClient) => {
|
||||
@@ -27,6 +28,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
});
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
@@ -40,6 +42,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
@@ -48,12 +52,14 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
}
|
||||
};
|
||||
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
const { onProgress: progressCallback, getPartialText } = createOnProgress({
|
||||
generation,
|
||||
onProgress: throttle(
|
||||
({ text: partialText }) => {
|
||||
saveMessage({
|
||||
/*
|
||||
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
|
||||
{
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
@@ -64,7 +70,8 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
isEdited: true,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
} */
|
||||
messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
|
||||
},
|
||||
3000,
|
||||
{ trailing: false },
|
||||
@@ -73,6 +80,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
|
||||
const getAbortData = () => ({
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
@@ -81,7 +89,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[EditController] Request closed');
|
||||
@@ -115,12 +123,11 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
});
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
@@ -138,7 +145,11 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveMessage({ ...response, user });
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/controllers/EditController.js - response end' },
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
|
||||
@@ -120,21 +120,22 @@ const chatV1 = async (req, res) => {
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
req,
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
return sendResponse(req, res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(res, messageData, defaultErrorMessage);
|
||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
@@ -221,10 +222,10 @@ const chatV1 = async (req, res) => {
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
||||
return sendResponse(res, messageData, 'The Assistant run failed');
|
||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(res, finalEvent);
|
||||
return sendResponse(req, res, finalEvent);
|
||||
};
|
||||
|
||||
try {
|
||||
@@ -382,6 +383,9 @@ const chatV1 = async (req, res) => {
|
||||
return files;
|
||||
};
|
||||
|
||||
/** @type {Promise<Run>|undefined} */
|
||||
let userMessagePromise;
|
||||
|
||||
const initializeThread = async () => {
|
||||
/** @type {[ undefined | MongoFile[]]}*/
|
||||
const [processedFiles] = await Promise.all([addVisionPrompt(), getRequestFileIds()]);
|
||||
@@ -438,7 +442,7 @@ const chatV1 = async (req, res) => {
|
||||
previousMessages.push(requestMessage);
|
||||
|
||||
/* asynchronous */
|
||||
saveUserMessage({ ...requestMessage, model });
|
||||
userMessagePromise = saveUserMessage(req, { ...requestMessage, model });
|
||||
|
||||
conversation = {
|
||||
conversationId,
|
||||
@@ -582,7 +586,10 @@ const chatV1 = async (req, res) => {
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveAssistantMessage({ ...responseMessage, model });
|
||||
if (userMessagePromise) {
|
||||
await userMessagePromise;
|
||||
}
|
||||
await saveAssistantMessage(req, { ...responseMessage, model });
|
||||
|
||||
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
|
||||
addTitle(req, {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const { v4 } = require('uuid');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
RunStatus,
|
||||
CacheKeys,
|
||||
ContentTypes,
|
||||
ToolCallTypes,
|
||||
EModelEndpoint,
|
||||
ViolationTypes,
|
||||
retrievalMimeTypes,
|
||||
AssistantStreamEvents,
|
||||
} = require('librechat-data-provider');
|
||||
@@ -14,12 +14,12 @@ const {
|
||||
initThread,
|
||||
recordUsage,
|
||||
saveUserMessage,
|
||||
checkMessageGaps,
|
||||
addThreadMetadata,
|
||||
saveAssistantMessage,
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
@@ -44,7 +44,7 @@ const ten_minutes = 1000 * 60 * 10;
|
||||
const chatV2 = async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
|
||||
/** @type {{ files: MongoFile[]}} */
|
||||
/** @type {{files: MongoFile[]}} */
|
||||
const {
|
||||
text,
|
||||
model,
|
||||
@@ -90,139 +90,20 @@ const chatV2 = async (req, res) => {
|
||||
/** @type {Run | undefined} - The completed run, undefined if incomplete */
|
||||
let completedRun;
|
||||
|
||||
const handleError = async (error) => {
|
||||
const defaultErrorMessage =
|
||||
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
|
||||
const messageData = {
|
||||
thread_id,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender: 'System',
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: false,
|
||||
messageId: responseMessageId,
|
||||
endpoint,
|
||||
};
|
||||
const getContext = () => ({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
cacheKey,
|
||||
thread_id,
|
||||
completedRun,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
responseMessageId,
|
||||
});
|
||||
|
||||
if (error.message === 'Run cancelled') {
|
||||
return res.end();
|
||||
} else if (error.message === 'Request closed' && completedRun) {
|
||||
return;
|
||||
} else if (error.message === 'Request closed') {
|
||||
logger.debug('[/assistants/chat/] Request aborted on close');
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === EModelEndpoint.azureAssistants
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
|
||||
try {
|
||||
const status = await cache.get(cacheKey);
|
||||
if (status === 'cancelled') {
|
||||
logger.debug('[/assistants/chat/] Run already cancelled');
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error cancelling run', error);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
user: req.user.id,
|
||||
conversationId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error fetching or processing run', error);
|
||||
}
|
||||
|
||||
let finalEvent;
|
||||
try {
|
||||
const runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId: responseMessageId,
|
||||
});
|
||||
|
||||
const errorContentPart = {
|
||||
text: {
|
||||
value:
|
||||
error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||
},
|
||||
type: ContentTypes.ERROR,
|
||||
};
|
||||
|
||||
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
|
||||
runMessages[runMessages.length - 1].content = [errorContentPart];
|
||||
} else {
|
||||
const contentParts = runMessages[runMessages.length - 1].content;
|
||||
for (let i = 0; i < contentParts.length; i++) {
|
||||
const currentPart = contentParts[i];
|
||||
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
|
||||
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
|
||||
if (
|
||||
toolCall &&
|
||||
toolCall?.function &&
|
||||
!(toolCall?.function?.output || toolCall?.function?.output?.length)
|
||||
) {
|
||||
contentParts[i] = {
|
||||
...currentPart,
|
||||
[ContentTypes.TOOL_CALL]: {
|
||||
...toolCall,
|
||||
function: {
|
||||
...toolCall.function,
|
||||
output: 'error processing tool',
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
runMessages[runMessages.length - 1].content.push(errorContentPart);
|
||||
}
|
||||
|
||||
finalEvent = {
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
runMessages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
||||
return sendResponse(res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(res, finalEvent);
|
||||
};
|
||||
const handleError = createErrorHandler({ req, res, getContext });
|
||||
|
||||
try {
|
||||
res.on('close', async () => {
|
||||
@@ -365,6 +246,9 @@ const chatV2 = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
/** @type {Promise<Run>|undefined} */
|
||||
let userMessagePromise;
|
||||
|
||||
const initializeThread = async () => {
|
||||
await getRequestFileIds();
|
||||
|
||||
@@ -407,7 +291,7 @@ const chatV2 = async (req, res) => {
|
||||
previousMessages.push(requestMessage);
|
||||
|
||||
/* asynchronous */
|
||||
saveUserMessage({ ...requestMessage, model });
|
||||
userMessagePromise = saveUserMessage(req, { ...requestMessage, model });
|
||||
|
||||
conversation = {
|
||||
conversationId,
|
||||
@@ -489,6 +373,11 @@ const chatV2 = async (req, res) => {
|
||||
},
|
||||
};
|
||||
|
||||
/** @type {undefined | TAssistantEndpoint} */
|
||||
const config = req.app.locals[endpoint] ?? {};
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
|
||||
const streamRunManager = new StreamRunManager({
|
||||
req,
|
||||
res,
|
||||
@@ -498,6 +387,7 @@ const chatV2 = async (req, res) => {
|
||||
attachedFileIds,
|
||||
parentMessageId: userMessageId,
|
||||
responseMessage: openai.responseMessage,
|
||||
streamRate: allConfig?.streamRate ?? config.streamRate,
|
||||
// streamOptions: {
|
||||
|
||||
// },
|
||||
@@ -510,6 +400,16 @@ const chatV2 = async (req, res) => {
|
||||
|
||||
response = streamRunManager;
|
||||
response.text = streamRunManager.intermediateText;
|
||||
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
complete: true,
|
||||
text: response.text,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
};
|
||||
|
||||
await processRun();
|
||||
@@ -552,7 +452,10 @@ const chatV2 = async (req, res) => {
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveAssistantMessage({ ...responseMessage, model });
|
||||
if (userMessagePromise) {
|
||||
await userMessagePromise;
|
||||
}
|
||||
await saveAssistantMessage(req, { ...responseMessage, model });
|
||||
|
||||
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
|
||||
addTitle(req, {
|
||||
|
||||
193
api/server/controllers/assistants/errors.js
Normal file
193
api/server/controllers/assistants/errors.js
Normal file
@@ -0,0 +1,193 @@
|
||||
// errorHandler.js
|
||||
const { sendResponse } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerContext
|
||||
* @property {OpenAIClient} openai - The OpenAI client
|
||||
* @property {string} thread_id - The thread ID
|
||||
* @property {string} run_id - The run ID
|
||||
* @property {boolean} completedRun - Whether the run has completed
|
||||
* @property {string} assistant_id - The assistant ID
|
||||
* @property {string} conversationId - The conversation ID
|
||||
* @property {string} parentMessageId - The parent message ID
|
||||
* @property {string} responseMessageId - The response message ID
|
||||
* @property {string} endpoint - The endpoint being used
|
||||
* @property {string} cacheKey - The cache key for the current request
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} ErrorHandlerDependencies
|
||||
* @property {Express.Request} req - The Express request object
|
||||
* @property {Express.Response} res - The Express response object
|
||||
* @property {() => ErrorHandlerContext} getContext - Function to get the current context
|
||||
* @property {string} [originPath] - The origin path for the error handler
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates an error handler function with the given dependencies
|
||||
* @param {ErrorHandlerDependencies} dependencies - The dependencies for the error handler
|
||||
* @returns {(error: Error) => Promise<void>} The error handler function
|
||||
*/
|
||||
const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/chat/' }) => {
|
||||
const cache = getLogStores(CacheKeys.ABORT_KEYS);
|
||||
|
||||
/**
|
||||
* Handles errors that occur during the chat process
|
||||
* @param {Error} error - The error that occurred
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
return async (error) => {
|
||||
const {
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
cacheKey,
|
||||
thread_id,
|
||||
completedRun,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
responseMessageId,
|
||||
} = getContext();
|
||||
|
||||
const defaultErrorMessage =
|
||||
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
|
||||
const messageData = {
|
||||
thread_id,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender: 'System',
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: false,
|
||||
messageId: responseMessageId,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
if (error.message === 'Run cancelled') {
|
||||
return res.end();
|
||||
} else if (error.message === 'Request closed' && completedRun) {
|
||||
return;
|
||||
} else if (error.message === 'Request closed') {
|
||||
logger.debug(`[${originPath}] Request aborted on close`);
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === 'azureAssistants'
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(req, res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
req,
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(req, res, messageData, error.message);
|
||||
} else {
|
||||
logger.error(`[${originPath}]`, error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(req, res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
|
||||
try {
|
||||
const status = await cache.get(cacheKey);
|
||||
if (status === 'cancelled') {
|
||||
logger.debug(`[${originPath}] Run already cancelled`);
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug(`[${originPath}] Cancelled run:`, cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error cancelling run`, error);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
user: req.user.id,
|
||||
conversationId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error fetching or processing run`, error);
|
||||
}
|
||||
|
||||
let finalEvent;
|
||||
try {
|
||||
const runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId: responseMessageId,
|
||||
});
|
||||
|
||||
const errorContentPart = {
|
||||
text: {
|
||||
value:
|
||||
error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||
},
|
||||
type: ContentTypes.ERROR,
|
||||
};
|
||||
|
||||
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
|
||||
runMessages[runMessages.length - 1].content = [errorContentPart];
|
||||
} else {
|
||||
const contentParts = runMessages[runMessages.length - 1].content;
|
||||
for (let i = 0; i < contentParts.length; i++) {
|
||||
const currentPart = contentParts[i];
|
||||
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
|
||||
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
|
||||
if (
|
||||
toolCall &&
|
||||
toolCall?.function &&
|
||||
!(toolCall?.function?.output || toolCall?.function?.output?.length)
|
||||
) {
|
||||
contentParts[i] = {
|
||||
...currentPart,
|
||||
[ContentTypes.TOOL_CALL]: {
|
||||
...toolCall,
|
||||
function: {
|
||||
...toolCall.function,
|
||||
output: 'error processing tool',
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
runMessages[runMessages.length - 1].content.push(errorContentPart);
|
||||
}
|
||||
|
||||
finalEvent = {
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
runMessages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`[${originPath}] Error finalizing error process`, error);
|
||||
return sendResponse(req, res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(req, res, finalEvent);
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = { createErrorHandler };
|
||||
@@ -1,8 +1,9 @@
|
||||
const {
|
||||
EModelEndpoint,
|
||||
CacheKeys,
|
||||
defaultAssistantsVersion,
|
||||
SystemRoles,
|
||||
EModelEndpoint,
|
||||
defaultOrderQuery,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
@@ -227,7 +228,7 @@ const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === 'ADMIN') {
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
return body;
|
||||
|
||||
@@ -61,7 +61,7 @@ const startServer = async () => {
|
||||
passport.use(passportLogin());
|
||||
|
||||
// LDAP Auth
|
||||
if (process.env.LDAP_URL && process.env.LDAP_BIND_DN && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
if (process.env.LDAP_URL && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
passport.use(ldapLogin);
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ const startServer = async () => {
|
||||
app.use('/api/convos', routes.convos);
|
||||
app.use('/api/presets', routes.presets);
|
||||
app.use('/api/prompts', routes.prompts);
|
||||
app.use('/api/categories', routes.categories);
|
||||
app.use('/api/tokenizer', routes.tokenizer);
|
||||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
@@ -91,7 +92,9 @@ const startServer = async () => {
|
||||
app.use('/api/files', await routes.files.initialize());
|
||||
app.use('/images/', validateImageRequest, routes.staticRoute);
|
||||
app.use('/api/share', routes.share);
|
||||
app.use('/api/roles', routes.roles);
|
||||
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use((req, res) => {
|
||||
res.sendFile(path.join(app.locals.paths.dist, 'index.html'));
|
||||
});
|
||||
|
||||
@@ -1,31 +1,39 @@
|
||||
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const abortControllers = require('./abortControllers');
|
||||
const { saveMessage, getConvo } = require('~/models');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
const { abortRun } = require('./abortRun');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, conversationId, endpoint } = req.body;
|
||||
|
||||
if (!abortKey && conversationId) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
let { abortKey, endpoint } = req.body;
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const finalEvent = await abortController.abortCompletion();
|
||||
logger.debug('[abortMessage] Aborted request', { abortKey });
|
||||
logger.debug(
|
||||
`[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` +
|
||||
JSON.stringify({ abortKey }),
|
||||
);
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
@@ -50,12 +58,35 @@ const handleAbort = () => {
|
||||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData) => {
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
const onStart = (userMessage) => {
|
||||
|
||||
abortController.getAbortData = function () {
|
||||
return getAbortData();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
|
||||
if (prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
|
||||
res.on('finish', function () {
|
||||
abortControllers.delete(addedAbortKey);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...endpointOption });
|
||||
|
||||
res.on('finish', function () {
|
||||
@@ -65,7 +96,8 @@ const createAbortController = (req, res, getAbortData) => {
|
||||
|
||||
abortController.abortCompletion = async function () {
|
||||
abortController.abort();
|
||||
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
|
||||
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
|
||||
getAbortData();
|
||||
const completionTokens = await countTokens(responseData?.text ?? '');
|
||||
const user = req.user.id;
|
||||
|
||||
@@ -87,12 +119,26 @@ const createAbortController = (req, res, getAbortData) => {
|
||||
{ promptTokens, completionTokens },
|
||||
);
|
||||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
saveMessage(
|
||||
req,
|
||||
{ ...responseMessage, user },
|
||||
{ context: 'api/server/middleware/abortMiddleware.js' },
|
||||
);
|
||||
|
||||
let conversation;
|
||||
if (userMessagePromise) {
|
||||
const resolved = await userMessagePromise;
|
||||
conversation = resolved?.conversation;
|
||||
}
|
||||
|
||||
if (!conversation) {
|
||||
conversation = await getConvo(req.user.id, conversationId);
|
||||
}
|
||||
|
||||
return {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
@@ -151,7 +197,7 @@ const handleAbortError = async (res, req, error, data) => {
|
||||
}
|
||||
};
|
||||
|
||||
await sendError(res, options, callback);
|
||||
await sendError(req, res, options, callback);
|
||||
};
|
||||
|
||||
if (partialText && partialText.length > 5) {
|
||||
|
||||
@@ -19,7 +19,8 @@ const validateAssistant = async (req, res, next) => {
|
||||
}
|
||||
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
const error = { message: 'validateAssistant: Assistant not supported' };
|
||||
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
/**
|
||||
@@ -11,7 +12,7 @@ const { getAssistant } = require('~/models/Assistant');
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
|
||||
if (req.user.role === 'ADMIN') {
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -16,7 +17,7 @@ const { logger } = require('~/config');
|
||||
const canDeleteAccount = async (req, res, next = () => {}) => {
|
||||
const { user } = req;
|
||||
const { ALLOW_ACCOUNT_DELETION = true } = process.env;
|
||||
if (user?.role === 'ADMIN' || isEnabled(ALLOW_ACCOUNT_DELETION)) {
|
||||
if (user?.role === SystemRoles.ADMIN || isEnabled(ALLOW_ACCOUNT_DELETION)) {
|
||||
return next();
|
||||
} else {
|
||||
logger.error(`[User] [Delete Account] [User cannot delete account] [User: ${user?.id}]`);
|
||||
|
||||
@@ -41,10 +41,14 @@ const denyRequest = async (req, res, errorMessage) => {
|
||||
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;
|
||||
|
||||
if (shouldSaveMessage) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...userMessage, user: req.user.id },
|
||||
{ context: `api/server/middleware/denyRequest.js - ${responseText}` },
|
||||
);
|
||||
}
|
||||
|
||||
return await sendError(res, {
|
||||
return await sendError(req, res, {
|
||||
sender: getResponseSender(req.body),
|
||||
messageId: crypto.randomUUID(),
|
||||
conversationId,
|
||||
|
||||
@@ -18,10 +18,12 @@ const limiters = require('./limiters');
|
||||
const uaParser = require('./uaParser');
|
||||
const checkBan = require('./checkBan');
|
||||
const noIndex = require('./noIndex');
|
||||
const roles = require('./roles');
|
||||
|
||||
module.exports = {
|
||||
...abortMiddleware,
|
||||
...limiters,
|
||||
...roles,
|
||||
noIndex,
|
||||
checkBan,
|
||||
uaParser,
|
||||
|
||||
@@ -13,7 +13,7 @@ const requireLdapAuth = (req, res, next) => {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error: No user',
|
||||
});
|
||||
return res.status(422).send(info);
|
||||
return res.status(404).send(info);
|
||||
}
|
||||
req.user = user;
|
||||
next();
|
||||
|
||||
14
api/server/middleware/roles/checkAdmin.js
Normal file
14
api/server/middleware/roles/checkAdmin.js
Normal file
@@ -0,0 +1,14 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
function checkAdmin(req, res, next) {
|
||||
try {
|
||||
if (req.user.role !== SystemRoles.ADMIN) {
|
||||
return res.status(403).json({ message: 'Forbidden' });
|
||||
}
|
||||
next();
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Internal Server Error' });
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = checkAdmin;
|
||||
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
52
api/server/middleware/roles/generateCheckAccess.js
Normal file
@@ -0,0 +1,52 @@
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
/**
|
||||
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
|
||||
*
|
||||
* @param {PermissionTypes} permissionType - The type of permission to check.
|
||||
* @param {Permissions[]} permissions - The list of specific permissions to check.
|
||||
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
|
||||
* @returns {Function} Express middleware function.
|
||||
*/
|
||||
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { user } = req;
|
||||
if (!user) {
|
||||
return res.status(401).json({ message: 'Authorization required' });
|
||||
}
|
||||
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (role && role[permissionType]) {
|
||||
const hasAnyPermission = permissions.some((permission) => {
|
||||
if (role[permissionType][permission]) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bodyProps[permission] && req.body) {
|
||||
return bodyProps[permission].some((prop) =>
|
||||
Object.prototype.hasOwnProperty.call(req.body, prop),
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
if (hasAnyPermission) {
|
||||
return next();
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
|
||||
} catch (error) {
|
||||
return res.status(500).json({ message: `Server error: ${error.message}` });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = generateCheckAccess;
|
||||
7
api/server/middleware/roles/index.js
Normal file
7
api/server/middleware/roles/index.js
Normal file
@@ -0,0 +1,7 @@
|
||||
const checkAdmin = require('./checkAdmin');
|
||||
const generateCheckAccess = require('./generateCheckAccess');
|
||||
|
||||
module.exports = {
|
||||
checkAdmin,
|
||||
generateCheckAccess,
|
||||
};
|
||||
@@ -31,10 +31,14 @@ function validateImageRequest(req, res, next) {
|
||||
return res.status(403).send('Access Denied');
|
||||
}
|
||||
|
||||
if (req.path.includes(payload.id)) {
|
||||
const fullPath = decodeURIComponent(req.originalUrl);
|
||||
const pathPattern = new RegExp(`^/images/${payload.id}/[^/]+$`);
|
||||
|
||||
if (pathPattern.test(fullPath)) {
|
||||
logger.debug('[validateImageRequest] Image request validated');
|
||||
next();
|
||||
} else {
|
||||
logger.warn('[validateImageRequest] Invalid image path');
|
||||
res.status(403).send('Access Denied');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { getConvo } = require('../../models');
|
||||
const { getConvo } = require('~/models');
|
||||
|
||||
// Middleware to validate conversationId and user relationship
|
||||
const validateMessageReq = async (req, res, next) => {
|
||||
|
||||
@@ -76,7 +76,9 @@ describe.skip('GET /', () => {
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
ldapLoginEnabled: true,
|
||||
ldap: {
|
||||
enabled: true,
|
||||
},
|
||||
serverDomain: 'http://test-server.com',
|
||||
emailLoginEnabled: 'true',
|
||||
registrationEnabled: 'true',
|
||||
|
||||
55
api/server/routes/__tests__/ldap.spec.js
Normal file
55
api/server/routes/__tests__/ldap.spec.js
Normal file
@@ -0,0 +1,55 @@
|
||||
const request = require('supertest');
|
||||
const express = require('express');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
jest.mock('~/server/services/Config/ldap');
|
||||
jest.mock('~/server/utils');
|
||||
|
||||
const app = express();
|
||||
|
||||
// Mock the route handler
|
||||
app.get('/api/config', (req, res) => {
|
||||
const ldapConfig = getLdapConfig();
|
||||
res.json({ ldap: ldapConfig });
|
||||
});
|
||||
|
||||
describe('LDAP Config Tests', () => {
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
it('should return LDAP config with username property when LDAP_LOGIN_USES_USERNAME is enabled', async () => {
|
||||
getLdapConfig.mockReturnValue({ enabled: true, username: true });
|
||||
isEnabled.mockReturnValue(true);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(200);
|
||||
expect(response.body.ldap).toEqual({
|
||||
enabled: true,
|
||||
username: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return LDAP config without username property when LDAP_LOGIN_USES_USERNAME is not enabled', async () => {
|
||||
getLdapConfig.mockReturnValue({ enabled: true });
|
||||
isEnabled.mockReturnValue(false);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(200);
|
||||
expect(response.body.ldap).toEqual({
|
||||
enabled: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should not return LDAP config when LDAP is not enabled', async () => {
|
||||
getLdapConfig.mockReturnValue(undefined);
|
||||
|
||||
const response = await request(app).get('/api/config');
|
||||
|
||||
expect(response.statusCode).toBe(200);
|
||||
expect(response.body.ldap).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -51,8 +51,8 @@ router.post('/', setHeaders, async (req, res) => {
|
||||
});
|
||||
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveConvo(req.user.id, {
|
||||
await saveMessage(req, { ...userMessage, user: req.user.id });
|
||||
await saveConvo(req, {
|
||||
...userMessage,
|
||||
...endpointOption,
|
||||
conversationId,
|
||||
@@ -93,7 +93,7 @@ const ask = async ({
|
||||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
saveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
|
||||
conversationId,
|
||||
@@ -159,7 +159,7 @@ const ask = async ({
|
||||
isCreatedByUser: false,
|
||||
};
|
||||
|
||||
await saveMessage({ ...responseMessage, user });
|
||||
await saveMessage(req, { ...responseMessage, user });
|
||||
responseMessage.messageId = newResponseMessageId;
|
||||
|
||||
// STEP2 update the conversation
|
||||
@@ -183,7 +183,7 @@ const ask = async ({
|
||||
}
|
||||
}
|
||||
|
||||
await saveConvo(user, conversationUpdate);
|
||||
await saveConvo(req, conversationUpdate);
|
||||
conversationId = newConversationId;
|
||||
|
||||
// STEP3 update the user message
|
||||
@@ -192,7 +192,7 @@ const ask = async ({
|
||||
|
||||
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({
|
||||
await saveMessage(req, {
|
||||
...userMessage,
|
||||
user,
|
||||
messageId: userMessageId,
|
||||
@@ -213,7 +213,7 @@ const ask = async ({
|
||||
if (userParentMessageId == Constants.NO_PARENT) {
|
||||
// const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
|
||||
const title = await response.details.title;
|
||||
await saveConvo(user, {
|
||||
await saveConvo(req, {
|
||||
conversationId: conversationId,
|
||||
title,
|
||||
});
|
||||
@@ -229,7 +229,7 @@ const ask = async ({
|
||||
isCreatedByUser: false,
|
||||
text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`,
|
||||
};
|
||||
await saveMessage({ ...errorMessage, user });
|
||||
await saveMessage(req, { ...errorMessage, user });
|
||||
handleError(res, errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -70,8 +70,8 @@ router.post('/', setHeaders, async (req, res) => {
|
||||
});
|
||||
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({ ...userMessage, user: req.user.id });
|
||||
await saveConvo(req.user.id, {
|
||||
await saveMessage(req, { ...userMessage, user: req.user.id });
|
||||
await saveConvo(req, {
|
||||
...userMessage,
|
||||
...endpointOption,
|
||||
conversationId,
|
||||
@@ -118,7 +118,7 @@ const ask = async ({
|
||||
const currentTimestamp = Date.now();
|
||||
if (currentTimestamp - lastSavedTimestamp > 500) {
|
||||
lastSavedTimestamp = currentTimestamp;
|
||||
saveMessage({
|
||||
saveMessage(req, {
|
||||
messageId: responseMessageId,
|
||||
sender: model,
|
||||
conversationId,
|
||||
@@ -197,7 +197,7 @@ const ask = async ({
|
||||
isCreatedByUser: false,
|
||||
};
|
||||
|
||||
await saveMessage({ ...responseMessage, user });
|
||||
await saveMessage(req, { ...responseMessage, user });
|
||||
responseMessage.messageId = newResponseMessageId;
|
||||
|
||||
let conversationUpdate = {
|
||||
@@ -216,12 +216,12 @@ const ask = async ({
|
||||
conversationUpdate.invocationId = response.invocationId;
|
||||
}
|
||||
|
||||
await saveConvo(user, conversationUpdate);
|
||||
await saveConvo(req, conversationUpdate);
|
||||
userMessage.messageId = newUserMessageId;
|
||||
|
||||
// If response has parentMessageId, the fake userMessage.messageId should be updated to the real one.
|
||||
if (!overrideParentMessageId) {
|
||||
await saveMessage({
|
||||
await saveMessage(req, {
|
||||
...userMessage,
|
||||
user,
|
||||
messageId: userMessageId,
|
||||
@@ -245,7 +245,7 @@ const ask = async ({
|
||||
response: responseMessage,
|
||||
});
|
||||
|
||||
await saveConvo(user, {
|
||||
await saveConvo(req, {
|
||||
conversationId: conversationId,
|
||||
title,
|
||||
});
|
||||
@@ -266,7 +266,7 @@ const ask = async ({
|
||||
isCreatedByUser: false,
|
||||
};
|
||||
|
||||
saveMessage({ ...responseMessage, user });
|
||||
saveMessage(req, { ...responseMessage, user });
|
||||
|
||||
return {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
@@ -288,7 +288,7 @@ const ask = async ({
|
||||
model,
|
||||
isCreatedByUser: false,
|
||||
};
|
||||
await saveMessage({ ...errorMessage, user });
|
||||
await saveMessage(req, { ...errorMessage, user });
|
||||
handleError(res, errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/google');
|
||||
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
@@ -20,7 +20,7 @@ router.post(
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient);
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
const express = require('express');
|
||||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||
const { saveMessage, updateMessage } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
@@ -41,6 +42,7 @@ router.post(
|
||||
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
@@ -58,6 +60,8 @@ router.post(
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
@@ -68,7 +72,15 @@ router.post(
|
||||
}
|
||||
};
|
||||
|
||||
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
const throttledCacheSet = throttle(
|
||||
(text) => {
|
||||
messageCache.set(responseMessageId, text, Time.FIVE_MINUTES);
|
||||
},
|
||||
3000,
|
||||
{ trailing: false },
|
||||
);
|
||||
|
||||
let streaming = null;
|
||||
let timer = null;
|
||||
|
||||
@@ -82,18 +94,7 @@ router.post(
|
||||
clearTimeout(timer);
|
||||
}
|
||||
|
||||
throttledSaveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
text: partialText,
|
||||
model: endpointOption.modelOptions.model,
|
||||
unfinished: true,
|
||||
error: false,
|
||||
plugins,
|
||||
user,
|
||||
});
|
||||
throttledCacheSet(partialText);
|
||||
|
||||
streaming = new Promise((resolve) => {
|
||||
timer = setTimeout(() => {
|
||||
@@ -148,18 +149,10 @@ router.post(
|
||||
}
|
||||
};
|
||||
|
||||
const onChainEnd = () => {
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
@@ -167,12 +160,27 @@ router.post(
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onChainEnd = () => {
|
||||
if (!client.skipSaveUserMessage) {
|
||||
saveMessage(
|
||||
req,
|
||||
{ ...userMessage, user },
|
||||
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
|
||||
);
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
conversationId,
|
||||
@@ -189,7 +197,6 @@ router.post(
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
plugins,
|
||||
},
|
||||
@@ -202,13 +209,14 @@ router.post(
|
||||
|
||||
logger.debug('[/ask/gptPlugins]', response);
|
||||
|
||||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
||||
await saveMessage({ ...response, user });
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation.title,
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
@@ -221,6 +229,15 @@ router.post(
|
||||
client,
|
||||
});
|
||||
}
|
||||
|
||||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
||||
if (response.plugins?.length > 0) {
|
||||
await updateMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/routes/ask/gptPlugins.js - save plugins used' },
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
|
||||
@@ -21,8 +21,7 @@ const {
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const ldapAuth =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
//Local
|
||||
router.post('/logout', requireJwtAuth, logoutController);
|
||||
router.post(
|
||||
|
||||
15
api/server/routes/categories.js
Normal file
15
api/server/routes/categories.js
Normal file
@@ -0,0 +1,15 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { getCategories } = require('~/models/Categories');
|
||||
|
||||
router.get('/', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const categories = await getCategories();
|
||||
res.status(200).send(categories);
|
||||
} catch (error) {
|
||||
res.status(500).send({ message: 'Failed to retrieve categories', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,6 +1,9 @@
|
||||
const express = require('express');
|
||||
const { defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider');
|
||||
const { getLdapConfig } = require('~/server/services/Config/ldap');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -17,13 +20,22 @@ const publicSharedLinksEnabled =
|
||||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
|
||||
|
||||
router.get('/', async function (req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG);
|
||||
if (cachedStartupConfig) {
|
||||
res.send(cachedStartupConfig);
|
||||
return;
|
||||
}
|
||||
|
||||
const isBirthday = () => {
|
||||
const today = new Date();
|
||||
return today.getMonth() === 1 && today.getDate() === 11;
|
||||
};
|
||||
|
||||
const ldapLoginEnabled =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
const instanceProject = await getProjectByName('instance', '_id');
|
||||
|
||||
const ldap = getLdapConfig();
|
||||
|
||||
try {
|
||||
/** @type {TStartupConfig} */
|
||||
const payload = {
|
||||
@@ -41,10 +53,9 @@ router.get('/', async function (req, res) {
|
||||
!!process.env.OPENID_SESSION_SECRET,
|
||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||
ldapLoginEnabled,
|
||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||
emailLoginEnabled,
|
||||
registrationEnabled: !ldapLoginEnabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
||||
emailEnabled:
|
||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||
@@ -63,12 +74,18 @@ router.get('/', async function (req, res) {
|
||||
sharedLinksEnabled,
|
||||
publicSharedLinksEnabled,
|
||||
analyticsGtmId: process.env.ANALYTICS_GTM_ID,
|
||||
instanceProjectId: instanceProject._id.toString(),
|
||||
};
|
||||
|
||||
if (ldap) {
|
||||
payload.ldap = ldap;
|
||||
}
|
||||
|
||||
if (typeof process.env.CUSTOM_FOOTER === 'string') {
|
||||
payload.customFooter = process.env.CUSTOM_FOOTER;
|
||||
}
|
||||
|
||||
await cache.set(CacheKeys.STARTUP_CONFIG, payload);
|
||||
return res.status(200).send(payload);
|
||||
} catch (err) {
|
||||
logger.error('Error in startup config', err);
|
||||
|
||||
@@ -8,6 +8,7 @@ const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { forkConversation } = require('~/server/utils/import/fork');
|
||||
const { importConversations } = require('~/server/utils/import');
|
||||
const { createImportLimiters } = require('~/server/middleware');
|
||||
const { updateTagsForConversation } = require('~/models/ConversationTag');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
@@ -30,8 +31,13 @@ router.get('/', async (req, res) => {
|
||||
return res.status(400).json({ error: 'Invalid page size' });
|
||||
}
|
||||
const isArchived = req.query.isArchived === 'true';
|
||||
const tags = req.query.tags
|
||||
? Array.isArray(req.query.tags)
|
||||
? req.query.tags
|
||||
: [req.query.tags]
|
||||
: undefined;
|
||||
|
||||
res.status(200).send(await getConvosByPage(req.user.id, pageNumber, pageSize, isArchived));
|
||||
res.status(200).send(await getConvosByPage(req.user.id, pageNumber, pageSize, isArchived, tags));
|
||||
});
|
||||
|
||||
router.get('/:conversationId', async (req, res) => {
|
||||
@@ -104,7 +110,7 @@ router.post('/update', async (req, res) => {
|
||||
const update = req.body.arg;
|
||||
|
||||
try {
|
||||
const dbResponse = await saveConvo(req.user.id, update);
|
||||
const dbResponse = await saveConvo(req, update, { context: 'POST /api/convos/update' });
|
||||
res.status(201).json(dbResponse);
|
||||
} catch (error) {
|
||||
logger.error('Error updating conversation', error);
|
||||
@@ -167,4 +173,9 @@ router.post('/fork', async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
router.put('/tags/:conversationId', async (req, res) => {
|
||||
const tag = await updateTagsForConversation(req.user.id, req.params.conversationId, req.body);
|
||||
res.status(200).json(tag);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
const express = require('express');
|
||||
const throttle = require('lodash/throttle');
|
||||
const { getResponseSender } = require('librechat-data-provider');
|
||||
const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
handleAbortError,
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
moderateText,
|
||||
validateModel,
|
||||
handleAbortError,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
createAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
|
||||
const { saveMessage, updateMessage } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { validateTools } = require('~/app');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -49,6 +50,7 @@ router.post(
|
||||
});
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
@@ -68,6 +70,8 @@ router.post(
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
@@ -76,7 +80,15 @@ router.post(
|
||||
}
|
||||
};
|
||||
|
||||
const throttledSaveMessage = throttle(saveMessage, 3000, { trailing: false });
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
const throttledCacheSet = throttle(
|
||||
(text) => {
|
||||
messageCache.set(responseMessageId, text, Time.FIVE_MINUTES);
|
||||
},
|
||||
3000,
|
||||
{ trailing: false },
|
||||
);
|
||||
|
||||
const {
|
||||
onProgress: progressCallback,
|
||||
sendIntermediateMessage,
|
||||
@@ -87,42 +99,19 @@ router.post(
|
||||
if (plugin.loading === true) {
|
||||
plugin.loading = false;
|
||||
}
|
||||
|
||||
throttledSaveMessage({
|
||||
messageId: responseMessageId,
|
||||
sender,
|
||||
conversationId,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
text: partialText,
|
||||
model: endpointOption.modelOptions.model,
|
||||
unfinished: true,
|
||||
isEdited: true,
|
||||
error: false,
|
||||
user,
|
||||
});
|
||||
throttledCacheSet(partialText);
|
||||
},
|
||||
});
|
||||
|
||||
const onAgentAction = (action, start = false) => {
|
||||
const formattedAction = formatAction(action);
|
||||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
const onChainEnd = (data) => {
|
||||
let { intermediateSteps: steps } = data;
|
||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||
plugin.loading = false;
|
||||
saveMessage({ ...userMessage, user });
|
||||
saveMessage(
|
||||
req,
|
||||
{ ...userMessage, user },
|
||||
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
|
||||
);
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
@@ -134,6 +123,7 @@ router.post(
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
@@ -141,12 +131,31 @@ router.post(
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onAgentAction = (action, start = false) => {
|
||||
const formattedAction = formatAction(action);
|
||||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start && !client.skipSaveUserMessage) {
|
||||
saveMessage(
|
||||
req,
|
||||
{ ...userMessage, user },
|
||||
{ context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' },
|
||||
);
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
generation,
|
||||
@@ -164,7 +173,6 @@ router.post(
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
plugin,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
@@ -176,17 +184,26 @@ router.post(
|
||||
}
|
||||
|
||||
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
|
||||
response.plugin = { ...plugin, loading: false };
|
||||
await saveMessage({ ...response, user });
|
||||
|
||||
const { conversation = {} } = await client.responsePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
sendMessage(res, {
|
||||
title: await getConvoTitle(user, conversationId),
|
||||
title: conversation.title,
|
||||
final: true,
|
||||
conversation: await getConvo(user, conversationId),
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
|
||||
response.plugin = { ...plugin, loading: false };
|
||||
await updateMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/routes/edit/gptPlugins.js' },
|
||||
);
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
const express = require('express');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
createFileLimiters,
|
||||
createTTSLimiters,
|
||||
createSTTLimiters,
|
||||
} = require('~/server/middleware');
|
||||
const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware');
|
||||
const { createMulterInstance } = require('./multer');
|
||||
|
||||
const files = require('./files');
|
||||
const images = require('./images');
|
||||
const avatar = require('./avatar');
|
||||
const stt = require('./stt');
|
||||
const tts = require('./tts');
|
||||
const speech = require('./speech');
|
||||
|
||||
const initialize = async () => {
|
||||
const router = express.Router();
|
||||
@@ -21,11 +13,8 @@ const initialize = async () => {
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
/* Important: stt/tts routes must be added before the upload limiters */
|
||||
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||
/* Important: speech route must be added before the upload limiters */
|
||||
router.use('/speech', speech);
|
||||
|
||||
const upload = await createMulterInstance();
|
||||
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
|
||||
|
||||
10
api/server/routes/files/speech/customConfigSpeech.js
Normal file
10
api/server/routes/files/speech/customConfigSpeech.js
Normal file
@@ -0,0 +1,10 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
|
||||
const { getCustomConfigSpeech } = require('~/server/services/Files/Audio');
|
||||
|
||||
router.get('/get', async (req, res) => {
|
||||
await getCustomConfigSpeech(req, res);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
17
api/server/routes/files/speech/index.js
Normal file
17
api/server/routes/files/speech/index.js
Normal file
@@ -0,0 +1,17 @@
|
||||
const express = require('express');
|
||||
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware');
|
||||
|
||||
const stt = require('./stt');
|
||||
const tts = require('./tts');
|
||||
const customConfigSpeech = require('./customConfigSpeech');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||
|
||||
router.use('/config', customConfigSpeech);
|
||||
|
||||
module.exports = router;
|
||||
@@ -19,6 +19,9 @@ const assistants = require('./assistants');
|
||||
const files = require('./files');
|
||||
const staticRoute = require('./static');
|
||||
const share = require('./share');
|
||||
const categories = require('./categories');
|
||||
const roles = require('./roles');
|
||||
const tags = require('./tags');
|
||||
|
||||
module.exports = {
|
||||
search,
|
||||
@@ -42,4 +45,7 @@ module.exports = {
|
||||
files,
|
||||
staticRoute,
|
||||
share,
|
||||
categories,
|
||||
roles,
|
||||
tags,
|
||||
};
|
||||
|
||||
@@ -1,49 +1,79 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const {
|
||||
getMessages,
|
||||
updateMessage,
|
||||
saveConvo,
|
||||
saveMessage,
|
||||
deleteMessages,
|
||||
} = require('../../models');
|
||||
const { countTokens } = require('../utils');
|
||||
const { requireJwtAuth, validateMessageReq } = require('../middleware/');
|
||||
const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models');
|
||||
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
|
||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
|
||||
try {
|
||||
const { conversationId } = req.params;
|
||||
const messages = await getMessages({ conversationId }, '-_id -__v -user');
|
||||
res.status(200).json(messages);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching messages:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
// CREATE
|
||||
router.post('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
const message = req.body;
|
||||
const savedMessage = await saveMessage({ ...message, user: req.user.id });
|
||||
await saveConvo(req.user.id, savedMessage);
|
||||
res.status(201).send(savedMessage);
|
||||
try {
|
||||
const message = req.body;
|
||||
const savedMessage = await saveMessage(
|
||||
req,
|
||||
{ ...message, user: req.user.id },
|
||||
{ context: 'POST /api/messages/:conversationId' },
|
||||
);
|
||||
if (!savedMessage) {
|
||||
return res.status(400).json({ error: 'Message not saved' });
|
||||
}
|
||||
await saveConvo(req, savedMessage, { context: 'POST /api/messages/:conversationId' });
|
||||
res.status(201).json(savedMessage);
|
||||
} catch (error) {
|
||||
logger.error('Error saving message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
// READ
|
||||
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
const { conversationId, messageId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
|
||||
try {
|
||||
const { conversationId, messageId } = req.params;
|
||||
const message = await getMessages({ conversationId, messageId }, '-_id -__v -user');
|
||||
if (!message) {
|
||||
return res.status(404).json({ error: 'Message not found' });
|
||||
}
|
||||
res.status(200).json(message);
|
||||
} catch (error) {
|
||||
logger.error('Error fetching message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
// UPDATE
|
||||
router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
const { messageId, model } = req.params;
|
||||
const { text } = req.body;
|
||||
const tokenCount = await countTokens(text, model);
|
||||
res.status(201).json(await updateMessage({ messageId, text, tokenCount }));
|
||||
try {
|
||||
const { messageId, model } = req.params;
|
||||
const { text } = req.body;
|
||||
const tokenCount = await countTokens(text, model);
|
||||
const result = await updateMessage(req, { messageId, text, tokenCount });
|
||||
res.status(200).json(result);
|
||||
} catch (error) {
|
||||
logger.error('Error updating message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
// DELETE
|
||||
router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
const { messageId } = req.params;
|
||||
await deleteMessages({ messageId });
|
||||
res.status(204).send();
|
||||
try {
|
||||
const { messageId } = req.params;
|
||||
await deleteMessages({ messageId });
|
||||
res.status(204).send();
|
||||
} catch (error) {
|
||||
logger.error('Error deleting message:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,14 +1,235 @@
|
||||
const express = require('express');
|
||||
const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
|
||||
const {
|
||||
getPrompt,
|
||||
getPrompts,
|
||||
savePrompt,
|
||||
deletePrompt,
|
||||
getPromptGroup,
|
||||
getPromptGroups,
|
||||
updatePromptGroup,
|
||||
deletePromptGroup,
|
||||
createPromptGroup,
|
||||
getAllPromptGroups,
|
||||
// updatePromptLabels,
|
||||
makePromptProduction,
|
||||
} = require('~/models/Prompt');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const { getPrompts } = require('../../models/Prompt');
|
||||
|
||||
const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
|
||||
const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
const checkGlobalPromptShare = generateCheckAccess(
|
||||
PermissionTypes.PROMPTS,
|
||||
[Permissions.USE, Permissions.CREATE],
|
||||
{
|
||||
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
|
||||
},
|
||||
);
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkPromptAccess);
|
||||
|
||||
/**
|
||||
* Route to get single prompt group by its ID
|
||||
* GET /groups/:groupId
|
||||
*/
|
||||
router.get('/groups/:groupId', async (req, res) => {
|
||||
let groupId = req.params.groupId;
|
||||
const author = req.user.id;
|
||||
|
||||
const query = {
|
||||
_id: groupId,
|
||||
$or: [{ projectIds: { $exists: true, $ne: [], $not: { $size: 0 } } }, { author }],
|
||||
};
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.$or;
|
||||
}
|
||||
|
||||
try {
|
||||
const group = await getPromptGroup(query);
|
||||
|
||||
if (!group) {
|
||||
return res.status(404).send({ message: 'Prompt group not found' });
|
||||
}
|
||||
|
||||
res.status(200).send(group);
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
res.status(500).send({ message: 'Error getting prompt group' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to fetch all prompt groups
|
||||
* GET /groups
|
||||
*/
|
||||
router.get('/all', async (req, res) => {
|
||||
try {
|
||||
const groups = await getAllPromptGroups(req, {
|
||||
author: req.user._id,
|
||||
});
|
||||
res.status(200).send(groups);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to fetch paginated prompt groups with filters
|
||||
* GET /groups
|
||||
*/
|
||||
router.get('/groups', async (req, res) => {
|
||||
try {
|
||||
const filter = req.query;
|
||||
/* Note: The aggregation requires an ObjectId */
|
||||
filter.author = req.user._id;
|
||||
const groups = await getPromptGroups(req, filter);
|
||||
res.status(200).send(groups);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Updates or creates a prompt + promptGroup
|
||||
* @param {object} req
|
||||
* @param {TCreatePrompt} req.body
|
||||
* @param {Express.Response} res
|
||||
*/
|
||||
const createPrompt = async (req, res) => {
|
||||
try {
|
||||
const { prompt, group } = req.body;
|
||||
if (!prompt) {
|
||||
return res.status(400).send({ error: 'Prompt is required' });
|
||||
}
|
||||
|
||||
const saveData = {
|
||||
prompt,
|
||||
group,
|
||||
author: req.user.id,
|
||||
authorName: req.user.name,
|
||||
};
|
||||
|
||||
/** @type {TCreatePromptResponse} */
|
||||
let result;
|
||||
if (group && group.name) {
|
||||
result = await createPromptGroup(saveData);
|
||||
} else {
|
||||
result = await savePrompt(saveData);
|
||||
}
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error saving prompt' });
|
||||
}
|
||||
};
|
||||
|
||||
router.post('/', createPrompt);
|
||||
|
||||
/**
|
||||
* Updates a prompt group
|
||||
* @param {object} req
|
||||
* @param {object} req.params - The request parameters
|
||||
* @param {string} req.params.groupId - The group ID
|
||||
* @param {TUpdatePromptGroupPayload} req.body - The request body
|
||||
* @param {Express.Response} res
|
||||
*/
|
||||
const patchPromptGroup = async (req, res) => {
|
||||
try {
|
||||
const { groupId } = req.params;
|
||||
const author = req.user.id;
|
||||
const filter = { _id: groupId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete filter.author;
|
||||
}
|
||||
const promptGroup = await updatePromptGroup(filter, req.body);
|
||||
res.status(200).send(promptGroup);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error updating prompt group' });
|
||||
}
|
||||
};
|
||||
|
||||
router.patch('/groups/:groupId', checkGlobalPromptShare, patchPromptGroup);
|
||||
|
||||
router.patch('/:promptId/tags/production', checkPromptCreate, async (req, res) => {
|
||||
try {
|
||||
const { promptId } = req.params;
|
||||
const result = await makePromptProduction(promptId);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error updating prompt production' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/:promptId', async (req, res) => {
|
||||
const { promptId } = req.params;
|
||||
const author = req.user.id;
|
||||
const query = { _id: promptId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const prompt = await getPrompt(query);
|
||||
res.status(200).send(prompt);
|
||||
});
|
||||
|
||||
router.get('/', async (req, res) => {
|
||||
let filter = {};
|
||||
// const { search } = req.body.arg;
|
||||
// if (!!search) {
|
||||
// filter = { conversationId };
|
||||
// }
|
||||
res.status(200).send(await getPrompts(filter));
|
||||
try {
|
||||
const author = req.user.id;
|
||||
const { groupId } = req.query;
|
||||
const query = { groupId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const prompts = await getPrompts(query);
|
||||
res.status(200).send(prompts);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompts' });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Deletes a prompt
|
||||
*
|
||||
* @param {Express.Request} req - The request object.
|
||||
* @param {TDeletePromptVariables} req.params - The request parameters
|
||||
* @param {import('mongoose').ObjectId} req.params.promptId - The prompt ID
|
||||
* @param {Express.Response} res - The response object.
|
||||
* @return {TDeletePromptResponse} A promise that resolves when the prompt is deleted.
|
||||
*/
|
||||
const deletePromptController = async (req, res) => {
|
||||
try {
|
||||
const { promptId } = req.params;
|
||||
const { groupId } = req.query;
|
||||
const author = req.user.id;
|
||||
const query = { promptId, groupId, author, role: req.user.role };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const result = await deletePrompt(query);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error deleting prompt' });
|
||||
}
|
||||
};
|
||||
|
||||
router.delete('/:promptId', checkPromptCreate, deletePromptController);
|
||||
|
||||
router.delete('/groups/:groupId', checkPromptCreate, async (req, res) => {
|
||||
const { groupId } = req.params;
|
||||
res.status(200).send(await deletePromptGroup(groupId));
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
|
||||
72
api/server/routes/roles.js
Normal file
72
api/server/routes/roles.js
Normal file
@@ -0,0 +1,72 @@
|
||||
const express = require('express');
|
||||
const {
|
||||
promptPermissionsSchema,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
SystemRoles,
|
||||
} = require('librechat-data-provider');
|
||||
const { checkAdmin, requireJwtAuth } = require('~/server/middleware');
|
||||
const { updateRoleByName, getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
/**
|
||||
* GET /api/roles/:roleName
|
||||
* Get a specific role by name
|
||||
*/
|
||||
router.get('/:roleName', async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
|
||||
if (req.user.role !== SystemRoles.ADMIN && !roleDefaults[roleName]) {
|
||||
return res.status(403).send({ message: 'Unauthorized' });
|
||||
}
|
||||
|
||||
try {
|
||||
const role = await getRoleByName(roleName, '-_id -__v');
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
res.status(200).send(role);
|
||||
} catch (error) {
|
||||
return res.status(500).send({ message: 'Failed to retrieve role', error: error.message });
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* PUT /api/roles/:roleName/prompts
|
||||
* Update prompt permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['PROMPTS']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = promptPermissionsSchema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const mergedUpdates = {
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
...role[PermissionTypes.PROMPTS],
|
||||
...parsedUpdates,
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -25,12 +25,16 @@ if (allowSharedLinks) {
|
||||
'/:shareId',
|
||||
allowSharedLinksPublic ? (req, res, next) => next() : requireJwtAuth,
|
||||
async (req, res) => {
|
||||
const share = await getSharedMessages(req.params.shareId);
|
||||
try {
|
||||
const share = await getSharedMessages(req.params.shareId);
|
||||
|
||||
if (share) {
|
||||
res.status(200).json(share);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
if (share) {
|
||||
res.status(200).json(share);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error getting shared messages' });
|
||||
}
|
||||
},
|
||||
);
|
||||
@@ -40,47 +44,63 @@ if (allowSharedLinks) {
|
||||
* Shared links
|
||||
*/
|
||||
router.get('/', requireJwtAuth, async (req, res) => {
|
||||
let pageNumber = req.query.pageNumber || 1;
|
||||
pageNumber = parseInt(pageNumber, 10);
|
||||
try {
|
||||
let pageNumber = req.query.pageNumber || 1;
|
||||
pageNumber = parseInt(pageNumber, 10);
|
||||
|
||||
if (isNaN(pageNumber) || pageNumber < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page number' });
|
||||
if (isNaN(pageNumber) || pageNumber < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page number' });
|
||||
}
|
||||
|
||||
let pageSize = req.query.pageSize || 25;
|
||||
pageSize = parseInt(pageSize, 10);
|
||||
|
||||
if (isNaN(pageSize) || pageSize < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page size' });
|
||||
}
|
||||
const isPublic = req.query.isPublic === 'true';
|
||||
res.status(200).send(await getSharedLinks(req.user.id, pageNumber, pageSize, isPublic));
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error getting shared links' });
|
||||
}
|
||||
|
||||
let pageSize = req.query.pageSize || 25;
|
||||
pageSize = parseInt(pageSize, 10);
|
||||
|
||||
if (isNaN(pageSize) || pageSize < 1) {
|
||||
return res.status(400).json({ error: 'Invalid page size' });
|
||||
}
|
||||
const isPublic = req.query.isPublic === 'true';
|
||||
res.status(200).send(await getSharedLinks(req.user.id, pageNumber, pageSize, isPublic));
|
||||
});
|
||||
|
||||
router.post('/', requireJwtAuth, async (req, res) => {
|
||||
const created = await createSharedLink(req.user.id, req.body);
|
||||
if (created) {
|
||||
res.status(200).json(created);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
try {
|
||||
const created = await createSharedLink(req.user.id, req.body);
|
||||
if (created) {
|
||||
res.status(200).json(created);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error creating shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
router.patch('/', requireJwtAuth, async (req, res) => {
|
||||
const updated = await updateSharedLink(req.user.id, req.body);
|
||||
if (updated) {
|
||||
res.status(200).json(updated);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
try {
|
||||
const updated = await updateSharedLink(req.user.id, req.body);
|
||||
if (updated) {
|
||||
res.status(200).json(updated);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error updating shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
router.delete('/:shareId', requireJwtAuth, async (req, res) => {
|
||||
const deleted = await deleteSharedLink(req.user.id, { shareId: req.params.shareId });
|
||||
if (deleted) {
|
||||
res.status(200).json(deleted);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
try {
|
||||
const deleted = await deleteSharedLink(req.user.id, { shareId: req.params.shareId });
|
||||
if (deleted) {
|
||||
res.status(200).json(deleted);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
} catch (error) {
|
||||
res.status(500).json({ message: 'Error deleting shared link' });
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
44
api/server/routes/tags.js
Normal file
44
api/server/routes/tags.js
Normal file
@@ -0,0 +1,44 @@
|
||||
const express = require('express');
|
||||
|
||||
const {
|
||||
getConversationTags,
|
||||
updateConversationTag,
|
||||
createConversationTag,
|
||||
deleteConversationTag,
|
||||
rebuildConversationTags,
|
||||
} = require('~/models/ConversationTag');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const router = express.Router();
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
router.get('/', async (req, res) => {
|
||||
const tags = await getConversationTags(req.user.id);
|
||||
|
||||
if (tags) {
|
||||
res.status(200).json(tags);
|
||||
} else {
|
||||
res.status(404).end();
|
||||
}
|
||||
});
|
||||
|
||||
router.post('/', async (req, res) => {
|
||||
const tag = await createConversationTag(req.user.id, req.body);
|
||||
res.status(200).json(tag);
|
||||
});
|
||||
|
||||
router.post('/rebuild', async (req, res) => {
|
||||
const tag = await rebuildConversationTags(req.user.id);
|
||||
res.status(200).json(tag);
|
||||
});
|
||||
|
||||
router.put('/:tag', async (req, res) => {
|
||||
const tag = await updateConversationTag(req.user.id, req.params.tag, req.body);
|
||||
res.status(200).json(tag);
|
||||
});
|
||||
|
||||
router.delete('/:tag', async (req, res) => {
|
||||
const tag = await deleteConversationTag(req.user.id, req.params.tag);
|
||||
res.status(200).json(tag);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -7,6 +7,7 @@ const handleRateLimits = require('./Config/handleRateLimits');
|
||||
const { loadDefaultInterface } = require('./start/interface');
|
||||
const { azureConfigSetup } = require('./start/azureOpenAI');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { initializeRoles } = require('~/models/Role');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
/**
|
||||
@@ -16,6 +17,7 @@ const paths = require('~/config/paths');
|
||||
* @param {Express.Application} app - The Express application object.
|
||||
*/
|
||||
const AppService = async (app) => {
|
||||
await initializeRoles();
|
||||
/** @type {TCustomConfig}*/
|
||||
const config = (await loadCustomConfig()) ?? {};
|
||||
const configDefaults = getConfigDefaults();
|
||||
@@ -65,17 +67,18 @@ const AppService = async (app) => {
|
||||
handleRateLimits(config?.rateLimits);
|
||||
|
||||
const endpointLocals = {};
|
||||
const endpoints = config?.endpoints;
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]) {
|
||||
if (endpoints?.[EModelEndpoint.azureOpenAI]) {
|
||||
endpointLocals[EModelEndpoint.azureOpenAI] = azureConfigSetup(config);
|
||||
checkAzureVariables();
|
||||
}
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
if (endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
|
||||
}
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) {
|
||||
if (endpoints?.[EModelEndpoint.azureAssistants]) {
|
||||
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
|
||||
config,
|
||||
EModelEndpoint.azureAssistants,
|
||||
@@ -83,7 +86,7 @@ const AppService = async (app) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.assistants]) {
|
||||
if (endpoints?.[EModelEndpoint.assistants]) {
|
||||
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
|
||||
config,
|
||||
EModelEndpoint.assistants,
|
||||
@@ -91,6 +94,19 @@ const AppService = async (app) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (endpoints?.[EModelEndpoint.openAI]) {
|
||||
endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI];
|
||||
}
|
||||
if (endpoints?.[EModelEndpoint.google]) {
|
||||
endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google];
|
||||
}
|
||||
if (endpoints?.[EModelEndpoint.anthropic]) {
|
||||
endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic];
|
||||
}
|
||||
if (endpoints?.[EModelEndpoint.gptPlugins]) {
|
||||
endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins];
|
||||
}
|
||||
|
||||
app.locals = {
|
||||
...defaultLocals,
|
||||
modelSpecs: config.modelSpecs,
|
||||
|
||||
@@ -21,6 +21,9 @@ jest.mock('./Config/loadCustomConfig', () => {
|
||||
jest.mock('./Files/Firebase/initialize', () => ({
|
||||
initializeFirebase: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Role', () => ({
|
||||
initializeRoles: jest.fn(),
|
||||
}));
|
||||
jest.mock('./ToolService', () => ({
|
||||
loadAndFormatTools: jest.fn().mockReturnValue({
|
||||
ExampleTool: {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const crypto = require('crypto');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const { errorsToString } = require('librechat-data-provider');
|
||||
const { SystemRoles, errorsToString } = require('librechat-data-provider');
|
||||
const {
|
||||
findUser,
|
||||
countUsers,
|
||||
@@ -62,7 +62,9 @@ const sendVerificationEmail = async (user) => {
|
||||
let verifyToken = crypto.randomBytes(32).toString('hex');
|
||||
const hash = bcrypt.hashSync(verifyToken, 10);
|
||||
|
||||
const verificationLink = `${domains.client}/verify?token=${verifyToken}&email=${user.email}`;
|
||||
const verificationLink = `${
|
||||
domains.client
|
||||
}/verify?token=${verifyToken}&email=${encodeURIComponent(user.email)}`;
|
||||
await sendEmail({
|
||||
email: user.email,
|
||||
subject: 'Verify your email',
|
||||
@@ -91,7 +93,7 @@ const sendVerificationEmail = async (user) => {
|
||||
*/
|
||||
const verifyEmail = async (req) => {
|
||||
const { email, token } = req.body;
|
||||
let emailVerificationData = await Token.findOne({ email });
|
||||
let emailVerificationData = await Token.findOne({ email: decodeURIComponent(email) });
|
||||
|
||||
if (!emailVerificationData) {
|
||||
logger.warn(`[verifyEmail] [No email verification data found] [Email: ${email}]`);
|
||||
@@ -119,9 +121,10 @@ const verifyEmail = async (req) => {
|
||||
/**
|
||||
* Register a new user.
|
||||
* @param {MongoUser} user <email, password, name, username>
|
||||
* @param {Partial<MongoUser>} [additionalData={}]
|
||||
* @returns {Promise<{status: number, message: string, user?: MongoUser}>}
|
||||
*/
|
||||
const registerUser = async (user) => {
|
||||
const registerUser = async (user, additionalData = {}) => {
|
||||
const { error } = registerSchema.safeParse(user);
|
||||
if (error) {
|
||||
const errorMessage = errorsToString(error.errors);
|
||||
@@ -169,13 +172,15 @@ const registerUser = async (user) => {
|
||||
username,
|
||||
name,
|
||||
avatar: null,
|
||||
role: isFirstRegisteredUser ? 'ADMIN' : 'USER',
|
||||
role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER,
|
||||
password: bcrypt.hashSync(password, salt),
|
||||
...additionalData,
|
||||
};
|
||||
|
||||
const emailEnabled = checkEmailConfig();
|
||||
newUserId = await createUser(newUserData, false);
|
||||
if (emailEnabled) {
|
||||
const newUser = await createUser(newUserData, false, true);
|
||||
newUserId = newUser._id;
|
||||
if (emailEnabled && !newUser.emailVerified) {
|
||||
await sendVerificationEmail({
|
||||
_id: newUserId,
|
||||
email,
|
||||
@@ -363,7 +368,9 @@ const resendVerificationEmail = async (req) => {
|
||||
let verifyToken = crypto.randomBytes(32).toString('hex');
|
||||
const hash = bcrypt.hashSync(verifyToken, 10);
|
||||
|
||||
const verificationLink = `${domains.client}/verify?token=${verifyToken}&email=${user.email}`;
|
||||
const verificationLink = `${
|
||||
domains.client
|
||||
}/verify?token=${verifyToken}&email=${encodeURIComponent(user.email)}`;
|
||||
|
||||
await sendEmail({
|
||||
email: user.email,
|
||||
|
||||
24
api/server/services/Config/ldap.js
Normal file
24
api/server/services/Config/ldap.js
Normal file
@@ -0,0 +1,24 @@
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
|
||||
/** @returns {TStartupConfig['ldap'] | undefined} */
|
||||
const getLdapConfig = () => {
|
||||
const ldapLoginEnabled = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
|
||||
const ldap = {
|
||||
enabled: ldapLoginEnabled,
|
||||
};
|
||||
const ldapLoginUsesUsername = isEnabled(process.env.LDAP_LOGIN_USES_USERNAME);
|
||||
if (!ldapLoginEnabled) {
|
||||
return ldap;
|
||||
}
|
||||
|
||||
if (ldapLoginUsesUsername) {
|
||||
ldap.username = true;
|
||||
}
|
||||
|
||||
return ldap;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getLdapConfig,
|
||||
};
|
||||
@@ -76,8 +76,28 @@ Please specify a correct \`imageOutputType\` value (case-sensitive).
|
||||
);
|
||||
}
|
||||
if (!result.success) {
|
||||
i === 0 && logger.error(`Invalid custom config file at ${configPath}`, result.error);
|
||||
i === 0 && i++;
|
||||
let errorMessage = `Invalid custom config file at ${configPath}:
|
||||
${JSON.stringify(result.error, null, 2)}`;
|
||||
|
||||
if (i === 0) {
|
||||
logger.error(errorMessage);
|
||||
const speechError = result.error.errors.find(
|
||||
(err) =>
|
||||
err.code === 'unrecognized_keys' &&
|
||||
(err.message?.includes('stt') || err.message?.includes('tts')),
|
||||
);
|
||||
|
||||
if (speechError) {
|
||||
logger.warn(`
|
||||
The Speech-to-text and Text-to-speech configuration format has recently changed.
|
||||
If you're getting this error, please refer to the latest documentation:
|
||||
|
||||
https://www.librechat.ai/docs/configuration/stt_tts`);
|
||||
}
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
return null;
|
||||
} else {
|
||||
logger.info('Custom config file loaded:');
|
||||
|
||||
@@ -23,10 +23,14 @@ const addTitle = async (req, { text, response, client }) => {
|
||||
|
||||
const title = await client.titleConvo({ text, responseText: response?.text });
|
||||
await titleCache.set(key, title, 120000);
|
||||
await saveConvo(req.user.id, {
|
||||
conversationId: response.conversationId,
|
||||
title,
|
||||
});
|
||||
await saveConvo(
|
||||
req,
|
||||
{
|
||||
conversationId: response.conversationId,
|
||||
title,
|
||||
},
|
||||
{ context: 'api/server/services/Endpoints/anthropic/addTitle.js' },
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = addTitle;
|
||||
|
||||
@@ -19,11 +19,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
|
||||
}
|
||||
|
||||
const clientOptions = {};
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
|
||||
|
||||
if (anthropicConfig) {
|
||||
clientOptions.streamRate = anthropicConfig.streamRate;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
if (allConfig) {
|
||||
clientOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
const client = new AnthropicClient(anthropicApiKey, {
|
||||
req,
|
||||
res,
|
||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
...clientOptions,
|
||||
...endpointOption,
|
||||
});
|
||||
|
||||
|
||||
@@ -19,10 +19,14 @@ const addTitle = async (req, { text, responseText, conversationId, client }) =>
|
||||
const title = await client.titleConvo({ text, conversationId, responseText });
|
||||
await titleCache.set(key, title, 120000);
|
||||
|
||||
await saveConvo(req.user.id, {
|
||||
conversationId,
|
||||
title,
|
||||
});
|
||||
await saveConvo(
|
||||
req,
|
||||
{
|
||||
conversationId,
|
||||
title,
|
||||
},
|
||||
{ context: 'api/server/services/Endpoints/assistants/addTitle.js' },
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = addTitle;
|
||||
|
||||
@@ -114,9 +114,16 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
|
||||
directEndpoint: endpointConfig.directEndpoint,
|
||||
titleMessageRole: endpointConfig.titleMessageRole,
|
||||
streamRate: endpointConfig.streamRate,
|
||||
endpointTokenConfig,
|
||||
};
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
if (allConfig) {
|
||||
customOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
const clientOptions = {
|
||||
reverseProxyUrl: baseURL ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
|
||||
62
api/server/services/Endpoints/google/addTitle.js
Normal file
62
api/server/services/Endpoints/google/addTitle.js
Normal file
@@ -0,0 +1,62 @@
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { saveConvo } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
const addTitle = async (req, { text, response, client }) => {
|
||||
const { TITLE_CONVO = 'true' } = process.env ?? {};
|
||||
if (!isEnabled(TITLE_CONVO)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (client.options.titleConvo === false) {
|
||||
return;
|
||||
}
|
||||
|
||||
const DEFAULT_TITLE_MODEL = 'gemini-pro';
|
||||
const { GOOGLE_TITLE_MODEL } = process.env ?? {};
|
||||
|
||||
let model = GOOGLE_TITLE_MODEL ?? DEFAULT_TITLE_MODEL;
|
||||
|
||||
if (GOOGLE_TITLE_MODEL === Constants.CURRENT_MODEL) {
|
||||
model = client.options?.modelOptions.model;
|
||||
|
||||
if (client.isVisionModel) {
|
||||
logger.warn(
|
||||
`current_model was specified for Google title request, but the model ${model} cannot process a text-only conversation. Falling back to ${DEFAULT_TITLE_MODEL}`,
|
||||
);
|
||||
|
||||
model = DEFAULT_TITLE_MODEL;
|
||||
}
|
||||
}
|
||||
|
||||
const titleEndpointOptions = {
|
||||
...client.options,
|
||||
modelOptions: { ...client.options?.modelOptions, model: model },
|
||||
attachments: undefined, // After a response, this is set to an empty array which results in an error during setOptions
|
||||
};
|
||||
|
||||
const { client: titleClient } = await initializeClient({
|
||||
req,
|
||||
res: response,
|
||||
endpointOption: titleEndpointOptions,
|
||||
});
|
||||
|
||||
const titleCache = getLogStores(CacheKeys.GEN_TITLE);
|
||||
const key = `${req.user.id}-${response.conversationId}`;
|
||||
|
||||
const title = await titleClient.titleConvo({ text, responseText: response?.text });
|
||||
await titleCache.set(key, title, 120000);
|
||||
await saveConvo(
|
||||
req,
|
||||
{
|
||||
conversationId: response.conversationId,
|
||||
title,
|
||||
},
|
||||
{ context: 'api/server/services/Endpoints/google/addTitle.js' },
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = addTitle;
|
||||
@@ -1,8 +1,9 @@
|
||||
const addTitle = require('./addTitle');
|
||||
const buildOptions = require('./buildOptions');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
module.exports = {
|
||||
// addTitle, // todo
|
||||
addTitle,
|
||||
buildOptions,
|
||||
initializeClient,
|
||||
};
|
||||
|
||||
@@ -27,11 +27,27 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
|
||||
const clientOptions = {};
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const googleConfig = req.app.locals[EModelEndpoint.google];
|
||||
|
||||
if (googleConfig) {
|
||||
clientOptions.streamRate = googleConfig.streamRate;
|
||||
}
|
||||
|
||||
if (allConfig) {
|
||||
clientOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
const client = new GoogleClient(credentials, {
|
||||
req,
|
||||
res,
|
||||
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
...clientOptions,
|
||||
...endpointOption,
|
||||
});
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ jest.mock('~/server/services/UserService', () => ({
|
||||
getUserKey: jest.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
const app = { locals: {} };
|
||||
|
||||
describe('google/initializeClient', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -23,6 +25,7 @@ describe('google/initializeClient', () => {
|
||||
const req = {
|
||||
body: { key: expiresAt },
|
||||
user: { id: '123' },
|
||||
app,
|
||||
};
|
||||
const res = {};
|
||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||
@@ -44,6 +47,7 @@ describe('google/initializeClient', () => {
|
||||
const req = {
|
||||
body: { key: null },
|
||||
user: { id: '123' },
|
||||
app,
|
||||
};
|
||||
const res = {};
|
||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||
@@ -66,6 +70,7 @@ describe('google/initializeClient', () => {
|
||||
const req = {
|
||||
body: { key: expiresAt },
|
||||
user: { id: '123' },
|
||||
app,
|
||||
};
|
||||
const res = {};
|
||||
const endpointOption = { modelOptions: { model: 'default-model' } };
|
||||
|
||||
@@ -86,6 +86,9 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
clientOptions.titleModel = azureConfig.titleModel;
|
||||
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
||||
|
||||
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
|
||||
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
|
||||
|
||||
const groupName = modelGroupMap[modelName].group;
|
||||
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
|
||||
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||
@@ -98,6 +101,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
apiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins];
|
||||
|
||||
if (!useAzure && pluginsConfig) {
|
||||
clientOptions.streamRate = pluginsConfig.streamRate;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
if (allConfig) {
|
||||
clientOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
throw new Error(`${endpoint} API key not provided. Please provide it again.`);
|
||||
}
|
||||
|
||||
@@ -23,10 +23,14 @@ const addTitle = async (req, { text, response, client }) => {
|
||||
|
||||
const title = await client.titleConvo({ text, responseText: response?.text });
|
||||
await titleCache.set(key, title, 120000);
|
||||
await saveConvo(req.user.id, {
|
||||
conversationId: response.conversationId,
|
||||
title,
|
||||
});
|
||||
await saveConvo(
|
||||
req,
|
||||
{
|
||||
conversationId: response.conversationId,
|
||||
title,
|
||||
},
|
||||
{ context: 'api/server/services/Endpoints/openAI/addTitle.js' },
|
||||
);
|
||||
};
|
||||
|
||||
module.exports = addTitle;
|
||||
|
||||
@@ -76,6 +76,10 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
|
||||
clientOptions.titleConvo = azureConfig.titleConvo;
|
||||
clientOptions.titleModel = azureConfig.titleModel;
|
||||
|
||||
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
|
||||
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
|
||||
|
||||
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
||||
|
||||
const groupName = modelGroupMap[modelName].group;
|
||||
@@ -90,6 +94,19 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
apiKey = clientOptions.azure.azureOpenAIApiKey;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
|
||||
|
||||
if (!isAzureOpenAI && openAIConfig) {
|
||||
clientOptions.streamRate = openAIConfig.streamRate;
|
||||
}
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
if (allConfig) {
|
||||
clientOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
if (userProvidesKey & !apiKey) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
|
||||
58
api/server/services/Files/Audio/getCustomConfigSpeech.js
Normal file
58
api/server/services/Files/Audio/getCustomConfigSpeech.js
Normal file
@@ -0,0 +1,58 @@
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
|
||||
/**
|
||||
* This function retrieves the speechTab settings from the custom configuration
|
||||
* It first fetches the custom configuration
|
||||
* Then, it checks if the custom configuration and the speechTab schema exist
|
||||
* If they do, it sends the speechTab settings as a JSON response
|
||||
* If they don't, it throws an error
|
||||
*
|
||||
* @param {Object} req - The request object
|
||||
* @param {Object} res - The response object
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} - If the custom configuration or the speechTab schema is missing, an error is thrown
|
||||
*/
|
||||
async function getCustomConfigSpeech(req, res) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
const sttExternal = !!customConfig.speech?.stt;
|
||||
const ttsExternal = !!customConfig.speech?.tts;
|
||||
let settings = {
|
||||
sttExternal,
|
||||
ttsExternal,
|
||||
};
|
||||
|
||||
if (!customConfig || !customConfig.speech?.speechTab) {
|
||||
return res.status(200).send(settings);
|
||||
}
|
||||
|
||||
const speechTab = customConfig.speech.speechTab;
|
||||
|
||||
if (speechTab.advancedMode !== undefined) {
|
||||
settings.advancedMode = speechTab.advancedMode;
|
||||
}
|
||||
|
||||
if (speechTab.speechToText) {
|
||||
for (const key in speechTab.speechToText) {
|
||||
if (speechTab.speechToText[key] !== undefined) {
|
||||
settings[key] = speechTab.speechToText[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (speechTab.textToSpeech) {
|
||||
for (const key in speechTab.textToSpeech) {
|
||||
if (speechTab.textToSpeech[key] !== undefined) {
|
||||
settings[key] = speechTab.textToSpeech[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(200).send(settings);
|
||||
} catch (error) {
|
||||
console.error('Failed to get custom config speech settings:', error);
|
||||
res.status(500).send('Internal Server Error');
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = getCustomConfigSpeech;
|
||||
@@ -1,4 +1,4 @@
|
||||
const { logger } = require('~/config');
|
||||
const { TTSProviders } = require('librechat-data-provider');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getProvider } = require('./textToSpeech');
|
||||
|
||||
@@ -16,22 +16,25 @@ async function getVoices(req, res) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
|
||||
if (!customConfig || !customConfig?.tts) {
|
||||
if (!customConfig || !customConfig?.speech?.tts) {
|
||||
throw new Error('Configuration or TTS schema is missing');
|
||||
}
|
||||
|
||||
const ttsSchema = customConfig?.tts;
|
||||
const ttsSchema = customConfig?.speech?.tts;
|
||||
const provider = getProvider(ttsSchema);
|
||||
let voices;
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
case TTSProviders.OPENAI:
|
||||
voices = ttsSchema.openai?.voices;
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
case TTSProviders.AZURE_OPENAI:
|
||||
voices = ttsSchema.azureOpenAI?.voices;
|
||||
break;
|
||||
case TTSProviders.ELEVENLABS:
|
||||
voices = ttsSchema.elevenlabs?.voices;
|
||||
break;
|
||||
case 'localai':
|
||||
case TTSProviders.LOCALAI:
|
||||
voices = ttsSchema.localai?.voices;
|
||||
break;
|
||||
default:
|
||||
@@ -40,8 +43,7 @@ async function getVoices(req, res) {
|
||||
|
||||
res.json(voices);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get voices: ${error.message}`);
|
||||
res.status(500).json({ error: 'Failed to get voices' });
|
||||
res.status(500).json({ error: `Failed to get voices: ${error.message}` });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
const getVoices = require('./getVoices');
|
||||
const getCustomConfigSpeech = require('./getCustomConfigSpeech');
|
||||
const textToSpeech = require('./textToSpeech');
|
||||
const speechToText = require('./speechToText');
|
||||
const { updateTokenWebsocket } = require('./webSocket');
|
||||
|
||||
module.exports = {
|
||||
getVoices,
|
||||
getCustomConfigSpeech,
|
||||
speechToText,
|
||||
...textToSpeech,
|
||||
updateTokenWebsocket,
|
||||
};
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const axios = require('axios');
|
||||
const { Readable } = require('stream');
|
||||
const { logger } = require('~/config');
|
||||
const axios = require('axios');
|
||||
const { extractEnvVariable, STTProviders } = require('librechat-data-provider');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Handle the response from the STT API
|
||||
@@ -24,12 +25,34 @@ async function handleResponse(response) {
|
||||
return response.data.text.trim();
|
||||
}
|
||||
|
||||
function getProvider(sttSchema) {
|
||||
if (sttSchema.openai) {
|
||||
return 'openai';
|
||||
/**
|
||||
* getProviderSchema function
|
||||
* This function takes the customConfig object and returns the name of the provider and its schema
|
||||
* If more than one provider is set or no provider is set, it throws an error
|
||||
*
|
||||
* @param {Object} customConfig - The custom configuration containing the STT schema
|
||||
* @returns {Promise<[string, Object]>} The name of the provider and its schema
|
||||
* @throws {Error} Throws an error if multiple providers are set or no provider is set
|
||||
*/
|
||||
async function getProviderSchema(customConfig) {
|
||||
const sttSchema = customConfig.speech.stt;
|
||||
|
||||
if (!sttSchema) {
|
||||
throw new Error(`No STT schema is set. Did you configure STT in the custom config (librechat.yaml)?
|
||||
|
||||
https://www.librechat.ai/docs/configuration/stt_tts#stt`);
|
||||
}
|
||||
|
||||
throw new Error('Invalid provider');
|
||||
const providers = Object.entries(sttSchema).filter(([, value]) => Object.keys(value).length > 0);
|
||||
|
||||
if (providers.length > 1) {
|
||||
throw new Error('Multiple providers are set. Please set only one provider.');
|
||||
} else if (providers.length === 0) {
|
||||
throw new Error('No provider is set. Please set a provider.');
|
||||
} else {
|
||||
const provider = providers[0][0];
|
||||
return [provider, sttSchema[provider]];
|
||||
}
|
||||
}
|
||||
|
||||
function removeUndefined(obj) {
|
||||
@@ -83,72 +106,63 @@ function openAIProvider(sttSchema, audioReadStream) {
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the Azure API
|
||||
* It uses the provided request and audio stream to create the request
|
||||
* Prepares the necessary data and headers for making a request to the Azure API.
|
||||
* It uses the provided Speech-to-Text (STT) schema and audio file to create the request.
|
||||
*
|
||||
* @param {Object} req - The request object, which should contain the endpoint in its body
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed
|
||||
* @param {Object} sttSchema - The STT schema object, which should contain instanceName, deploymentName, apiVersion, and apiKey.
|
||||
* @param {Buffer} audioBuffer - The audio data to be transcribed
|
||||
* @param {Object} audioFile - The audio file object, which should contain originalname, mimetype, and size.
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it returns an array with three null values and logs the error with logger
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request.
|
||||
* If an error occurs, it logs the error with logger and returns an array with three null values.
|
||||
*/
|
||||
function azureProvider(req, audioReadStream) {
|
||||
function azureOpenAIProvider(sttSchema, audioBuffer, audioFile) {
|
||||
try {
|
||||
const { endpoint } = req.body;
|
||||
const azureConfig = req.app.locals[endpoint];
|
||||
const instanceName = sttSchema?.instanceName;
|
||||
const deploymentName = sttSchema?.deploymentName;
|
||||
const apiVersion = sttSchema?.apiVersion;
|
||||
|
||||
if (!azureConfig) {
|
||||
throw new Error(`No configuration found for endpoint: ${endpoint}`);
|
||||
const url =
|
||||
genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: instanceName,
|
||||
azureOpenAIApiDeploymentName: deploymentName,
|
||||
}) +
|
||||
'/audio/transcriptions?api-version=' +
|
||||
apiVersion;
|
||||
|
||||
const apiKey = sttSchema.apiKey ? extractEnvVariable(sttSchema.apiKey) : '';
|
||||
|
||||
if (audioBuffer.byteLength > 25 * 1024 * 1024) {
|
||||
throw new Error('The audio file size exceeds the limit of 25MB');
|
||||
}
|
||||
const acceptedFormats = ['flac', 'mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'ogg', 'wav', 'webm'];
|
||||
const fileFormat = audioFile.mimetype.split('/')[1];
|
||||
if (!acceptedFormats.includes(fileFormat)) {
|
||||
throw new Error(`The audio file format ${fileFormat} is not accepted`);
|
||||
}
|
||||
|
||||
const { apiKey, instanceName, whisperModel, apiVersion } = Object.entries(
|
||||
azureConfig.groupMap,
|
||||
).reduce((acc, [, value]) => {
|
||||
if (acc) {
|
||||
return acc;
|
||||
}
|
||||
const formData = new FormData();
|
||||
|
||||
const whisperKey = Object.keys(value.models).find((modelKey) =>
|
||||
modelKey.startsWith('whisper'),
|
||||
);
|
||||
const audioBlob = new Blob([audioBuffer], { type: audioFile.mimetype });
|
||||
|
||||
if (whisperKey) {
|
||||
return {
|
||||
apiVersion: value.version,
|
||||
apiKey: value.apiKey,
|
||||
instanceName: value.instanceName,
|
||||
whisperModel: value.models[whisperKey]['deploymentName'],
|
||||
};
|
||||
}
|
||||
formData.append('file', audioBlob, audioFile.originalname);
|
||||
|
||||
return null;
|
||||
}, null);
|
||||
let data = formData;
|
||||
|
||||
if (!apiKey || !instanceName || !whisperModel || !apiVersion) {
|
||||
throw new Error('Required Azure configuration values are missing');
|
||||
}
|
||||
|
||||
const baseURL = `https://${instanceName}.openai.azure.com`;
|
||||
|
||||
const url = `${baseURL}/openai/deployments/${whisperModel}/audio/transcriptions?api-version=${apiVersion}`;
|
||||
|
||||
let data = {
|
||||
file: audioReadStream,
|
||||
filename: 'audio.wav',
|
||||
contentType: 'audio/wav',
|
||||
knownLength: audioReadStream.length,
|
||||
};
|
||||
|
||||
const headers = {
|
||||
...data.getHeaders(),
|
||||
let headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'api-key': apiKey,
|
||||
};
|
||||
|
||||
[headers].forEach(removeUndefined);
|
||||
|
||||
if (apiKey) {
|
||||
headers['api-key'] = apiKey;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while preparing the Azure API STT request: ', error);
|
||||
return [null, null, null];
|
||||
logger.error('An error occurred while preparing the Azure OpenAI API STT request: ', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,16 +190,16 @@ async function speechToText(req, res) {
|
||||
const audioReadStream = Readable.from(audioBuffer);
|
||||
audioReadStream.path = 'audio.wav';
|
||||
|
||||
const provider = getProvider(customConfig.stt);
|
||||
const [provider, sttSchema] = await getProviderSchema(customConfig);
|
||||
|
||||
let [url, data, headers] = [];
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
[url, data, headers] = openAIProvider(customConfig.stt, audioReadStream);
|
||||
case STTProviders.OPENAI:
|
||||
[url, data, headers] = openAIProvider(sttSchema, audioReadStream);
|
||||
break;
|
||||
case 'azure':
|
||||
[url, data, headers] = azureProvider(req, audioReadStream);
|
||||
case STTProviders.AZURE_OPENAI:
|
||||
[url, data, headers] = azureOpenAIProvider(sttSchema, audioBuffer, req.file);
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const WebSocket = require('ws');
|
||||
const { Message } = require('~/models/Message');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* @param {string[]} voiceIds - Array of voice IDs
|
||||
@@ -104,6 +105,8 @@ function createChunkProcessor(messageId) {
|
||||
throw new Error('Message ID is required');
|
||||
}
|
||||
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
|
||||
/**
|
||||
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
||||
*/
|
||||
@@ -116,14 +119,17 @@ function createChunkProcessor(messageId) {
|
||||
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
|
||||
}
|
||||
|
||||
const message = await Message.findOne({ messageId }, 'text unfinished').lean();
|
||||
/** @type { string | { text: string; complete: boolean } } */
|
||||
const message = await messageCache.get(messageId);
|
||||
|
||||
if (!message || !message.text) {
|
||||
if (!message) {
|
||||
notFoundCount++;
|
||||
return [];
|
||||
}
|
||||
|
||||
const { text, unfinished } = message;
|
||||
const text = typeof message === 'string' ? message : message.text;
|
||||
const complete = typeof message === 'string' ? false : message.complete;
|
||||
|
||||
if (text === processedText) {
|
||||
noChangeCount++;
|
||||
}
|
||||
@@ -131,7 +137,7 @@ function createChunkProcessor(messageId) {
|
||||
const remainingText = text.slice(processedText.length);
|
||||
const chunks = [];
|
||||
|
||||
if (unfinished && remainingText.length >= 20) {
|
||||
if (!complete && remainingText.length >= 20) {
|
||||
const separatorIndex = findLastSeparatorIndex(remainingText);
|
||||
if (separatorIndex !== -1) {
|
||||
const chunkText = remainingText.slice(0, separatorIndex + 1);
|
||||
@@ -141,7 +147,7 @@ function createChunkProcessor(messageId) {
|
||||
chunks.push({ text: remainingText, isFinished: false });
|
||||
processedText = text;
|
||||
}
|
||||
} else if (!unfinished && remainingText.trim().length > 0) {
|
||||
} else if (complete && remainingText.trim().length > 0) {
|
||||
chunks.push({ text: remainingText.trim(), isFinished: true });
|
||||
processedText = text;
|
||||
}
|
||||
|
||||
@@ -1,89 +1,145 @@
|
||||
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { Message } = require('~/models/Message');
|
||||
|
||||
jest.mock('~/models/Message', () => ({
|
||||
Message: {
|
||||
findOne: jest.fn().mockReturnValue({
|
||||
lean: jest.fn(),
|
||||
}),
|
||||
},
|
||||
}));
|
||||
jest.mock('keyv');
|
||||
|
||||
const globalCache = {};
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn().mockImplementation(() => {
|
||||
const EventEmitter = require('events');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
|
||||
class KeyvMongo extends EventEmitter {
|
||||
constructor(url = 'mongodb://127.0.0.1:27017', options) {
|
||||
super();
|
||||
this.ttlSupport = false;
|
||||
url = url ?? {};
|
||||
if (typeof url === 'string') {
|
||||
url = { url };
|
||||
}
|
||||
if (url.uri) {
|
||||
url = { url: url.uri, ...url };
|
||||
}
|
||||
this.opts = {
|
||||
url,
|
||||
collection: 'keyv',
|
||||
...url,
|
||||
...options,
|
||||
};
|
||||
}
|
||||
|
||||
get = async (key) => {
|
||||
return new Promise((resolve) => {
|
||||
resolve(globalCache[key] || null);
|
||||
});
|
||||
};
|
||||
|
||||
set = async (key, value) => {
|
||||
return new Promise((resolve) => {
|
||||
globalCache[key] = value;
|
||||
resolve(true);
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
return new KeyvMongo('', {
|
||||
namespace: CacheKeys.MESSAGES,
|
||||
ttl: 0,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('processChunks', () => {
|
||||
let processChunks;
|
||||
let mockMessageCache;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
mockMessageCache = {
|
||||
get: jest.fn(),
|
||||
};
|
||||
require('~/cache/getLogStores').mockReturnValue(mockMessageCache);
|
||||
processChunks = createChunkProcessor('message-id');
|
||||
Message.findOne.mockClear();
|
||||
Message.findOne().lean.mockClear();
|
||||
});
|
||||
|
||||
it('should return an empty array when the message is not found', async () => {
|
||||
Message.findOne().lean.mockResolvedValueOnce(null);
|
||||
mockMessageCache.get.mockResolvedValueOnce(null);
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
expect(mockMessageCache.get).toHaveBeenCalledWith('message-id');
|
||||
});
|
||||
|
||||
it('should return an empty array when the message does not have a text property', async () => {
|
||||
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true });
|
||||
it('should return an error message after MAX_NOT_FOUND_COUNT attempts', async () => {
|
||||
mockMessageCache.get.mockResolvedValue(null);
|
||||
|
||||
for (let i = 0; i < 6; i++) {
|
||||
await processChunks();
|
||||
}
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
expect(result).toBe('Message not found after 6 attempts');
|
||||
});
|
||||
|
||||
it('should return chunks for an unfinished message with separators', async () => {
|
||||
it('should return chunks for an incomplete message with separators', async () => {
|
||||
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
||||
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([
|
||||
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
|
||||
]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return chunks for an unfinished message without separators', async () => {
|
||||
it('should return chunks for an incomplete message without separators', async () => {
|
||||
const messageText = 'This is a long message without separators hello there my friend';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
||||
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: false });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return the remaining text as a chunk for a finished message', async () => {
|
||||
it('should return the remaining text as a chunk for a complete message', async () => {
|
||||
const messageText = 'This is a finished message.';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: true }]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return an empty array for a finished message with no remaining text', async () => {
|
||||
it('should return an empty array for a complete message with no remaining text', async () => {
|
||||
const messageText = 'This is a finished message.';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
|
||||
|
||||
await processChunks();
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
mockMessageCache.get.mockResolvedValueOnce({ text: messageText, complete: true });
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should return an error message after MAX_NO_CHANGE_COUNT attempts with no change', async () => {
|
||||
const messageText = 'This is a message that does not change.';
|
||||
mockMessageCache.get.mockResolvedValue({ text: messageText, complete: false });
|
||||
|
||||
for (let i = 0; i < 11; i++) {
|
||||
await processChunks();
|
||||
}
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toBe('No change in message after 10 attempts');
|
||||
});
|
||||
|
||||
it('should handle string messages as incomplete', async () => {
|
||||
const messageText = 'This is a message as a string.';
|
||||
mockMessageCache.get.mockResolvedValueOnce(messageText);
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const axios = require('axios');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { genAzureEndpoint } = require('~/utils');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
|
||||
/**
|
||||
* getProvider function
|
||||
@@ -91,6 +92,59 @@ function openAIProvider(ttsSchema, input, voice) {
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the necessary parameters for making a request to Azure's OpenAI Text-to-Speech API.
|
||||
*
|
||||
* @param {TCustomConfig['tts']['azureOpenAI']} ttsSchema - The TTS schema containing the AzureOpenAI configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it throws an error with a message indicating that the selected voice is not available
|
||||
*/
|
||||
function azureOpenAIProvider(ttsSchema, input, voice) {
|
||||
const instanceName = ttsSchema?.instanceName;
|
||||
const deploymentName = ttsSchema?.deploymentName;
|
||||
const apiVersion = ttsSchema?.apiVersion;
|
||||
|
||||
const url =
|
||||
genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: instanceName,
|
||||
azureOpenAIApiDeploymentName: deploymentName,
|
||||
}) +
|
||||
'/audio/speech?api-version=' +
|
||||
apiVersion;
|
||||
|
||||
const apiKey = ttsSchema.apiKey ? extractEnvVariable(ttsSchema.apiKey) : '';
|
||||
|
||||
if (
|
||||
ttsSchema?.voices &&
|
||||
ttsSchema.voices.length > 0 &&
|
||||
!ttsSchema.voices.includes(voice) &&
|
||||
!ttsSchema.voices.includes('ALL')
|
||||
) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
let data = {
|
||||
model: ttsSchema?.model,
|
||||
input,
|
||||
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
if (apiKey) {
|
||||
headers['api-key'] = apiKey;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* elevenLabsProvider function
|
||||
* This function prepares the necessary data and headers for making a request to the Eleven Labs TTS
|
||||
@@ -191,8 +245,8 @@ function localAIProvider(ttsSchema, input, voice) {
|
||||
* @returns {Promise<[string, TProviderSchema]>}
|
||||
*/
|
||||
async function getProviderSchema(customConfig) {
|
||||
const provider = getProvider(customConfig.tts);
|
||||
return [provider, customConfig.tts[provider]];
|
||||
const provider = getProvider(customConfig.speech.tts);
|
||||
return [provider, customConfig.speech.tts[provider]];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -225,13 +279,16 @@ async function getVoice(providerSchema, requestVoice) {
|
||||
async function ttsRequest(provider, ttsSchema, { input, voice, stream = true } = { stream: true }) {
|
||||
let [url, data, headers] = [];
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
case TTSProviders.OPENAI:
|
||||
[url, data, headers] = openAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
case TTSProviders.AZURE_OPENAI:
|
||||
[url, data, headers] = azureOpenAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
case TTSProviders.ELEVENLABS:
|
||||
[url, data, headers] = elevenLabsProvider(ttsSchema, input, voice, stream);
|
||||
break;
|
||||
case 'localai':
|
||||
case TTSProviders.LOCALAI:
|
||||
[url, data, headers] = localAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
default:
|
||||
@@ -385,7 +442,7 @@ async function streamAudio(req, res) {
|
||||
break;
|
||||
}
|
||||
} catch (innerError) {
|
||||
logger.error('Error processing update:', update, innerError);
|
||||
logger.error('Error processing audio stream update:', update, innerError);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user