Compare commits
37 Commits
v2-assista
...
v0.7.3-rc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37ae484fbc | ||
|
|
8939d8af37 | ||
|
|
f9a0166352 | ||
|
|
248dfb8b5b | ||
|
|
b8e35002f4 | ||
|
|
8318f26d66 | ||
|
|
08d6bea359 | ||
|
|
a6058c5669 | ||
|
|
e0402b71f0 | ||
|
|
a618266905 | ||
|
|
d5a7806e32 | ||
|
|
e2cb2905e7 | ||
|
|
3f600f0d3f | ||
|
|
c9e7d4ac18 | ||
|
|
40685f6eb4 | ||
|
|
0ee060d730 | ||
|
|
5dc5d875ba | ||
|
|
9f2538fcd9 | ||
|
|
2b7a973a33 | ||
|
|
c704a23749 | ||
|
|
eb5733083e | ||
|
|
b80f38e49e | ||
|
|
4369e75ca7 | ||
|
|
35ba4ba1a4 | ||
|
|
dcd2e3e62d | ||
|
|
514a502b9c | ||
|
|
8e66683577 | ||
|
|
dc1778b11f | ||
|
|
795bb9c568 | ||
|
|
a937650df6 | ||
|
|
6cf1c85363 | ||
|
|
b3e03b75d0 | ||
|
|
9d8fd92dd3 | ||
|
|
f00a8f87f7 | ||
|
|
79840763e7 | ||
|
|
1a452121fa | ||
|
|
af8bcb08d6 |
28
.env.example
28
.env.example
@@ -119,7 +119,7 @@ GOOGLE_KEY=user_provided
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0409,gemini-1.0-pro-vision-001,gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||
|
||||
# Google Gemini Safety Settings
|
||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
||||
@@ -164,6 +164,16 @@ ASSISTANTS_API_KEY=user_provided
|
||||
# ASSISTANTS_BASE_URL=
|
||||
# ASSISTANTS_MODELS=gpt-4o,gpt-3.5-turbo-0125,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-16k,gpt-3.5-turbo,gpt-4,gpt-4-0314,gpt-4-32k-0314,gpt-4-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-1106,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview
|
||||
|
||||
#==========================#
|
||||
# Azure Assistants API #
|
||||
#==========================#
|
||||
|
||||
# Note: You should map your credentials with custom variables according to your Azure OpenAI Configuration
|
||||
# The models for Azure Assistants are also determined by your Azure OpenAI configuration.
|
||||
|
||||
# More info, including how to enable use of Assistants with Azure here:
|
||||
# https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints/azure#using-assistants-with-azure
|
||||
|
||||
#============#
|
||||
# OpenRouter #
|
||||
#============#
|
||||
@@ -247,6 +257,14 @@ MEILI_NO_ANALYTICS=true
|
||||
MEILI_HOST=http://0.0.0.0:7700
|
||||
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
|
||||
|
||||
|
||||
#==================================================#
|
||||
# Speech to Text & Text to Speech #
|
||||
#==================================================#
|
||||
|
||||
STT_API_KEY=
|
||||
TTS_API_KEY=
|
||||
|
||||
#===================================================#
|
||||
# User System #
|
||||
#===================================================#
|
||||
@@ -342,6 +360,14 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
|
||||
OPENID_BUTTON_LABEL=
|
||||
OPENID_IMAGE_URL=
|
||||
|
||||
# LDAP
|
||||
LDAP_URL=
|
||||
LDAP_BIND_DN=
|
||||
LDAP_BIND_CREDENTIALS=
|
||||
LDAP_USER_SEARCH_BASE=
|
||||
LDAP_SEARCH_FILTER=mail={{username}}
|
||||
LDAP_CA_CERT_PATH=
|
||||
|
||||
#========================#
|
||||
# Email Password Reset #
|
||||
#========================#
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const Anthropic = require('@anthropic-ai/sdk');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
getResponseSender,
|
||||
@@ -123,9 +124,14 @@ class AnthropicClient extends BaseClient {
|
||||
getClient() {
|
||||
/** @type {Anthropic.default.RequestOptions} */
|
||||
const options = {
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
};
|
||||
|
||||
if (this.options.proxy) {
|
||||
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
options.baseURL = this.options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const crypto = require('crypto');
|
||||
const fetch = require('node-fetch');
|
||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
@@ -17,6 +18,7 @@ class BaseClient {
|
||||
month: 'long',
|
||||
day: 'numeric',
|
||||
});
|
||||
this.fetch = this.fetch.bind(this);
|
||||
}
|
||||
|
||||
setOptions() {
|
||||
@@ -54,6 +56,22 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes an HTTP request and logs the process.
|
||||
*
|
||||
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
|
||||
* @param {RequestInit} [init] - Optional init options for the request.
|
||||
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
async fetch(_url, init) {
|
||||
let url = _url;
|
||||
if (this.options.directEndpoint) {
|
||||
url = this.options.reverseProxyUrl;
|
||||
}
|
||||
logger.debug(`Making request to ${url}`);
|
||||
return await fetch(url, init);
|
||||
}
|
||||
|
||||
getBuildMessagesOptions() {
|
||||
throw new Error('Subclasses must implement getBuildMessagesOptions');
|
||||
}
|
||||
@@ -373,6 +391,14 @@ class BaseClient {
|
||||
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
||||
await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
const { generation = '' } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
|
||||
@@ -438,9 +438,17 @@ class ChatGPTClient extends BaseClient {
|
||||
|
||||
if (message.eventType === 'text-generation' && message.text) {
|
||||
onTokenProgress(message.text);
|
||||
} else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply += message.text;
|
||||
}
|
||||
/*
|
||||
Cohere API Chinese Unicode character replacement hotfix.
|
||||
Should be un-commented when the following issue is resolved:
|
||||
https://github.com/cohere-ai/cohere-typescript/issues/151
|
||||
|
||||
else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply = message.response.text;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
return reply;
|
||||
|
||||
@@ -27,6 +27,7 @@ const {
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { updateTokenWebsocket } = require('~/server/services/Files/Audio');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
@@ -588,12 +589,13 @@ class OpenAIClient extends BaseClient {
|
||||
let streamResult = null;
|
||||
this.modelOptions.user = this.user;
|
||||
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
|
||||
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
||||
const completionResult = await this.getCompletion(
|
||||
payload,
|
||||
(progressMessage) => {
|
||||
if (progressMessage === '[DONE]') {
|
||||
updateTokenWebsocket('[DONE]');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -756,6 +758,8 @@ class OpenAIClient extends BaseClient {
|
||||
* In case of failure, it will return the default title, "New Chat".
|
||||
*/
|
||||
async titleConvo({ text, conversationId, responseText = '' }) {
|
||||
this.conversationId = conversationId;
|
||||
|
||||
if (this.options.attachments) {
|
||||
delete this.options.attachments;
|
||||
}
|
||||
@@ -825,7 +829,7 @@ class OpenAIClient extends BaseClient {
|
||||
|
||||
const instructionsPayload = [
|
||||
{
|
||||
role: 'system',
|
||||
role: this.options.titleMessageRole ?? 'system',
|
||||
content: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
@@ -838,13 +842,17 @@ ${convo}
|
||||
|
||||
try {
|
||||
let useChatCompletion = true;
|
||||
|
||||
if (this.options.reverseProxyUrl === CohereConstants.API_URL) {
|
||||
useChatCompletion = false;
|
||||
}
|
||||
|
||||
title = (
|
||||
await this.sendPayload(instructionsPayload, { modelOptions, useChatCompletion })
|
||||
).replaceAll('"', '');
|
||||
|
||||
const completionTokens = this.getTokenCount(title);
|
||||
|
||||
this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' });
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
@@ -868,6 +876,7 @@ ${convo}
|
||||
context: 'title',
|
||||
tokenBuffer: 150,
|
||||
});
|
||||
|
||||
title = await runTitleChain({ llm, text, convo, signal: this.abortController.signal });
|
||||
} catch (e) {
|
||||
if (e?.message?.toLowerCase()?.includes('abort')) {
|
||||
@@ -1005,9 +1014,9 @@ ${convo}
|
||||
await spendTokens(
|
||||
{
|
||||
context,
|
||||
user: this.user,
|
||||
model: this.modelOptions.model,
|
||||
conversationId: this.conversationId,
|
||||
user: this.user ?? this.options.req.user?.id,
|
||||
endpointTokenConfig: this.options.endpointTokenConfig,
|
||||
},
|
||||
{ promptTokens, completionTokens },
|
||||
@@ -1099,7 +1108,12 @@ ${convo}
|
||||
}
|
||||
|
||||
if (this.azure || this.options.azure) {
|
||||
// Azure does not accept `model` in the body, so we need to remove it.
|
||||
/* Azure Bug, extremely short default `max_tokens` response */
|
||||
if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
|
||||
modelOptions.max_tokens = 4000;
|
||||
}
|
||||
|
||||
/* Azure does not accept `model` in the body, so we need to remove it. */
|
||||
delete modelOptions.model;
|
||||
|
||||
opts.baseURL = this.langchainProxy
|
||||
@@ -1120,6 +1134,7 @@ ${convo}
|
||||
let chatCompletion;
|
||||
/** @type {OpenAI} */
|
||||
const openai = new OpenAI({
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
...opts,
|
||||
});
|
||||
@@ -1209,6 +1224,7 @@ ${convo}
|
||||
});
|
||||
|
||||
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const token = chunk.choices[0]?.delta?.content || '';
|
||||
intermediateReply += token;
|
||||
|
||||
@@ -250,6 +250,7 @@ class PluginsClient extends OpenAIClient {
|
||||
this.setOptions(opts);
|
||||
return super.sendMessage(message, opts);
|
||||
}
|
||||
|
||||
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
||||
const {
|
||||
user,
|
||||
@@ -264,6 +265,14 @@ class PluginsClient extends OpenAIClient {
|
||||
onToolEnd,
|
||||
} = await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
this.currentMessages.push(userMessage);
|
||||
|
||||
let {
|
||||
|
||||
@@ -28,7 +28,7 @@ ${convo}`,
|
||||
};
|
||||
|
||||
const titleInstruction =
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. Never directly mention the language name or the word "title"';
|
||||
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
|
||||
|
||||
You may call them like this:
|
||||
|
||||
@@ -144,6 +144,7 @@ describe('OpenAIClient', () => {
|
||||
|
||||
const defaultOptions = {
|
||||
// debug: true,
|
||||
req: {},
|
||||
openaiApiKey: 'new-api-key',
|
||||
modelOptions: {
|
||||
model,
|
||||
|
||||
@@ -80,13 +80,18 @@ class StableDiffusionAPI extends StructuredTool {
|
||||
const payload = {
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sampler_index: 'DPM++ 2M Karras',
|
||||
cfg_scale: 4.5,
|
||||
steps: 22,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
};
|
||||
const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
let generationResponse;
|
||||
try {
|
||||
generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
} catch (error) {
|
||||
logger.error('[StableDiffusion] Error while generating image:', error);
|
||||
return 'Error making API request.';
|
||||
}
|
||||
const image = generationResponse.data.images[0];
|
||||
|
||||
/** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */
|
||||
|
||||
8
api/cache/getLogStores.js
vendored
8
api/cache/getLogStores.js
vendored
@@ -7,6 +7,7 @@ const keyvMongo = require('./keyvMongo');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||
const THIRTY_MINUTES = 1800000;
|
||||
const TEN_MINUTES = 600000;
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
|
||||
@@ -24,6 +25,10 @@ const config = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
|
||||
|
||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
||||
@@ -55,6 +60,8 @@ const namespaces = {
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
@@ -64,6 +71,7 @@ const namespaces = {
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -27,26 +27,25 @@ function getMatchingSensitivePatterns(valueStr) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Redacts sensitive information from a console message.
|
||||
*
|
||||
* Redacts sensitive information from a console message and trims it to a specified length if provided.
|
||||
* @param {string} str - The console message to be redacted.
|
||||
* @returns {string} - The redacted console message.
|
||||
* @param {number} [trimLength] - The optional length at which to trim the redacted message.
|
||||
* @returns {string} - The redacted and optionally trimmed console message.
|
||||
*/
|
||||
function redactMessage(str) {
|
||||
function redactMessage(str, trimLength) {
|
||||
if (!str) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const patterns = getMatchingSensitivePatterns(str);
|
||||
|
||||
if (patterns.length === 0) {
|
||||
return str;
|
||||
}
|
||||
|
||||
patterns.forEach((pattern) => {
|
||||
str = str.replace(pattern, '$1[REDACTED]');
|
||||
});
|
||||
|
||||
if (trimLength !== undefined && str.length > trimLength) {
|
||||
return `${str.substring(0, trimLength)}...`;
|
||||
}
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
|
||||
@@ -62,8 +62,24 @@ const deleteAction = async (searchParams, session = null) => {
|
||||
return await Action.findOneAndDelete(searchParams, options).lean();
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAction,
|
||||
getActions,
|
||||
deleteAction,
|
||||
/**
|
||||
* Deletes actions by params, within a transaction session if provided.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the actions to delete.
|
||||
* @param {string} searchParams.action_id - The ID of the action(s) to delete.
|
||||
* @param {string} searchParams.user - The user ID of the action's author.
|
||||
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
||||
* @returns {Promise<Number>} A promise that resolves to the number of deleted action documents.
|
||||
*/
|
||||
const deleteActions = async (searchParams, session = null) => {
|
||||
const options = session ? { session } : {};
|
||||
const result = await Action.deleteMany(searchParams, options);
|
||||
return result.deletedCount;
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getActions,
|
||||
updateAction,
|
||||
deleteAction,
|
||||
deleteActions,
|
||||
};
|
||||
|
||||
@@ -14,7 +14,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
|
||||
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
||||
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
|
||||
*/
|
||||
const updateAssistant = async (searchParams, updateData, session = null) => {
|
||||
const updateAssistantDoc = async (searchParams, updateData, session = null) => {
|
||||
const options = { new: true, upsert: true, session };
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
@@ -39,8 +39,21 @@ const getAssistants = async (searchParams) => {
|
||||
return await Assistant.find(searchParams).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an assistant based on the provided ID.
|
||||
*
|
||||
* @param {Object} searchParams - The search parameters to find the assistant to delete.
|
||||
* @param {string} searchParams.assistant_id - The ID of the assistant to delete.
|
||||
* @param {string} searchParams.user - The user ID of the assistant's author.
|
||||
* @returns {Promise<void>} Resolves when the assistant has been successfully deleted.
|
||||
*/
|
||||
const deleteAssistant = async (searchParams) => {
|
||||
return await Assistant.findOneAndDelete(searchParams);
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAssistant,
|
||||
updateAssistantDoc,
|
||||
deleteAssistant,
|
||||
getAssistants,
|
||||
getAssistant,
|
||||
};
|
||||
|
||||
@@ -21,7 +21,7 @@ module.exports = {
|
||||
Conversation,
|
||||
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
||||
try {
|
||||
const messages = await getMessages({ conversationId });
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update = { ...convo, messages, user };
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
|
||||
@@ -129,6 +129,14 @@ module.exports = {
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
async updateMessageText({ messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw new Error('Failed to update message text.');
|
||||
}
|
||||
},
|
||||
async updateMessage(message) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
@@ -171,8 +179,18 @@ module.exports = {
|
||||
}
|
||||
},
|
||||
|
||||
async getMessages(filter) {
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @param {Record<string, unknown>} filter
|
||||
* @param {string | undefined} [select]
|
||||
* @returns
|
||||
*/
|
||||
async getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
|
||||
@@ -155,7 +155,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
|
||||
function (results, value, key) {
|
||||
return { ...results, [key]: 1 };
|
||||
},
|
||||
{ _id: 1 },
|
||||
{ _id: 1, __v: 1 },
|
||||
),
|
||||
).lean();
|
||||
|
||||
|
||||
@@ -64,6 +64,11 @@ const userSchema = mongoose.Schema(
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
ldapId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
githubId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
|
||||
@@ -40,7 +40,7 @@ const spendTokens = async (txData, tokenUsage) => {
|
||||
});
|
||||
}
|
||||
|
||||
if (!completionTokens) {
|
||||
if (!completionTokens && isNaN(completionTokens)) {
|
||||
logger.debug('[spendTokens] !completionTokens', { prompt, completion });
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"@keyv/redis": "^2.8.1",
|
||||
"@langchain/community": "^0.0.46",
|
||||
"@langchain/google-genai": "^0.0.11",
|
||||
"@langchain/google-vertexai": "^0.0.5",
|
||||
"@langchain/google-vertexai": "^0.0.17",
|
||||
"agenda": "^5.0.0",
|
||||
"axios": "^1.3.4",
|
||||
"bcryptjs": "^2.4.3",
|
||||
@@ -76,7 +76,7 @@
|
||||
"nodejs-gpt": "^1.37.4",
|
||||
"nodemailer": "^6.9.4",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "4.36.0",
|
||||
"openai": "^4.47.1",
|
||||
"openai-chat-tokens": "^0.2.8",
|
||||
"openid-client": "^5.4.2",
|
||||
"passport": "^0.6.0",
|
||||
@@ -86,6 +86,7 @@
|
||||
"passport-github2": "^0.1.12",
|
||||
"passport-google-oauth20": "^2.0.0",
|
||||
"passport-jwt": "^4.0.1",
|
||||
"passport-ldapauth": "^3.0.1",
|
||||
"passport-local": "^1.0.0",
|
||||
"pino": "^8.12.1",
|
||||
"sharp": "^0.32.6",
|
||||
@@ -94,6 +95,7 @@
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^4.7.1",
|
||||
"ws": "^8.17.0",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -105,11 +105,12 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
getReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, messageOptions);
|
||||
|
||||
@@ -112,11 +112,12 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
getReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
});
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
|
||||
@@ -16,10 +16,28 @@ async function endpointController(req, res) {
|
||||
/** @type {TEndpointsConfig} */
|
||||
const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints };
|
||||
if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) {
|
||||
const { disableBuilder, retrievalModels, capabilities, ..._rest } =
|
||||
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
|
||||
req.app.locals[EModelEndpoint.assistants];
|
||||
|
||||
mergedConfig[EModelEndpoint.assistants] = {
|
||||
...mergedConfig[EModelEndpoint.assistants],
|
||||
version,
|
||||
retrievalModels,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
mergedConfig[EModelEndpoint.azureAssistants] &&
|
||||
req.app.locals?.[EModelEndpoint.azureAssistants]
|
||||
) {
|
||||
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
|
||||
req.app.locals[EModelEndpoint.azureAssistants];
|
||||
|
||||
mergedConfig[EModelEndpoint.azureAssistants] = {
|
||||
...mergedConfig[EModelEndpoint.azureAssistants],
|
||||
version,
|
||||
retrievalModels,
|
||||
disableBuilder,
|
||||
capabilities,
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
const { v4 } = require('uuid');
|
||||
const express = require('express');
|
||||
const {
|
||||
Constants,
|
||||
RunStatus,
|
||||
CacheKeys,
|
||||
FileSources,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
ViolationTypes,
|
||||
ImageVisionTool,
|
||||
checkOpenAIStorage,
|
||||
AssistantStreamEvents,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
@@ -21,44 +20,36 @@ const {
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
handleAbortError,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @param {object} req - The request object, containing the request data.
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => {
|
||||
const chatV1 = async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
|
||||
const {
|
||||
text,
|
||||
model,
|
||||
endpoint,
|
||||
files = [],
|
||||
promptPrefix,
|
||||
assistant_id,
|
||||
@@ -69,30 +60,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[EModelEndpoint.assistants];
|
||||
|
||||
if (assistantsConfig) {
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
@@ -138,7 +105,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: false,
|
||||
messageId: responseMessageId,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
if (error.message === 'Run cancelled') {
|
||||
@@ -149,7 +116,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
logger.debug('[/assistants/chat/] Request aborted on close');
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
req.app.locals?.[EModelEndpoint.azureOpenAI].assistants
|
||||
endpoint === EModelEndpoint.azureAssistants
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
@@ -205,6 +172,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
const runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId: responseMessageId,
|
||||
@@ -311,8 +279,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
});
|
||||
};
|
||||
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai: _openai, client } = await initializeClient({
|
||||
const { openai: _openai, client } = await getOpenAIClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption: req.body.endpointOption,
|
||||
@@ -320,6 +287,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
@@ -370,10 +338,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
|
||||
/** @type {MongoFile[]} */
|
||||
const attachments = await req.body.endpointOption.attachments;
|
||||
if (
|
||||
attachments &&
|
||||
attachments.every((attachment) => attachment.source === FileSources.openai)
|
||||
) {
|
||||
if (attachments && attachments.every((attachment) => checkOpenAIStorage(attachment.source))) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -431,7 +396,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
|
||||
if (processedFiles) {
|
||||
for (const file of processedFiles) {
|
||||
if (file.source !== FileSources.openai) {
|
||||
if (!checkOpenAIStorage(file.source)) {
|
||||
attachedFileIds.delete(file.file_id);
|
||||
const index = file_ids.indexOf(file.file_id);
|
||||
if (index > -1) {
|
||||
@@ -467,6 +432,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
previousMessages.push(requestMessage);
|
||||
@@ -476,7 +442,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
|
||||
conversation = {
|
||||
conversationId,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint,
|
||||
promptPrefix: promptPrefix,
|
||||
instructions: instructions,
|
||||
assistant_id,
|
||||
@@ -513,7 +479,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
let response;
|
||||
|
||||
const processRun = async (retry = false) => {
|
||||
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
body.model = openai._options.model;
|
||||
openai.attachedFileIds = attachedFileIds;
|
||||
openai.visionPromise = visionPromise;
|
||||
@@ -603,6 +569,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
sendMessage(res, {
|
||||
@@ -655,6 +622,6 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res
|
||||
} catch (error) {
|
||||
await handleError(error);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = router;
|
||||
module.exports = chatV1;
|
||||
597
api/server/controllers/assistants/chatV2.js
Normal file
597
api/server/controllers/assistants/chatV2.js
Normal file
@@ -0,0 +1,597 @@
|
||||
const { v4 } = require('uuid');
|
||||
const {
|
||||
Constants,
|
||||
RunStatus,
|
||||
CacheKeys,
|
||||
ContentTypes,
|
||||
ToolCallTypes,
|
||||
EModelEndpoint,
|
||||
ViolationTypes,
|
||||
retrievalMimeTypes,
|
||||
AssistantStreamEvents,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initThread,
|
||||
recordUsage,
|
||||
saveUserMessage,
|
||||
checkMessageGaps,
|
||||
addThreadMetadata,
|
||||
saveAssistantMessage,
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {Express.Request} req - The request object, containing the request data.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
const chatV2 = async (req, res) => {
|
||||
logger.debug('[/assistants/chat/] req.body', req.body);
|
||||
|
||||
/** @type {{ files: MongoFile[]}} */
|
||||
const {
|
||||
text,
|
||||
model,
|
||||
endpoint,
|
||||
files = [],
|
||||
promptPrefix,
|
||||
assistant_id,
|
||||
instructions,
|
||||
thread_id: _thread_id,
|
||||
messageId: _messageId,
|
||||
conversationId: convoId,
|
||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
let thread_id = _thread_id;
|
||||
/** @type {string|undefined} - the current run id */
|
||||
let run_id;
|
||||
/** @type {string|undefined} - the parent messageId */
|
||||
let parentMessageId = _parentId;
|
||||
/** @type {TMessage[]} */
|
||||
let previousMessages = [];
|
||||
/** @type {import('librechat-data-provider').TConversation | null} */
|
||||
let conversation = null;
|
||||
/** @type {string[]} */
|
||||
let file_ids = [];
|
||||
/** @type {Set<string>} */
|
||||
let attachedFileIds = new Set();
|
||||
/** @type {TMessage | null} */
|
||||
let requestMessage = null;
|
||||
|
||||
const userMessageId = v4();
|
||||
const responseMessageId = v4();
|
||||
|
||||
/** @type {string} - The conversation UUID - created if undefined */
|
||||
const conversationId = convoId ?? v4();
|
||||
|
||||
const cache = getLogStores(CacheKeys.ABORT_KEYS);
|
||||
const cacheKey = `${req.user.id}:${conversationId}`;
|
||||
|
||||
/** @type {Run | undefined} - The completed run, undefined if incomplete */
|
||||
let completedRun;
|
||||
|
||||
const handleError = async (error) => {
|
||||
const defaultErrorMessage =
|
||||
'The Assistant run failed to initialize. Try sending a message in a new conversation.';
|
||||
const messageData = {
|
||||
thread_id,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
sender: 'System',
|
||||
user: req.user.id,
|
||||
shouldSaveMessage: false,
|
||||
messageId: responseMessageId,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
if (error.message === 'Run cancelled') {
|
||||
return res.end();
|
||||
} else if (error.message === 'Request closed' && completedRun) {
|
||||
return;
|
||||
} else if (error.message === 'Request closed') {
|
||||
logger.debug('[/assistants/chat/] Request aborted on close');
|
||||
} else if (/Files.*are invalid/.test(error.message)) {
|
||||
const errorMessage = `Files are invalid, or may not have uploaded yet.${
|
||||
endpoint === EModelEndpoint.azureAssistants
|
||||
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
|
||||
: ''
|
||||
}`;
|
||||
return sendResponse(res, messageData, errorMessage);
|
||||
} else if (error?.message?.includes('string too long')) {
|
||||
return sendResponse(
|
||||
res,
|
||||
messageData,
|
||||
'Message too long. The Assistants API has a limit of 32,768 characters per message. Please shorten it and try again.',
|
||||
);
|
||||
} else if (error?.message?.includes(ViolationTypes.TOKEN_BALANCE)) {
|
||||
return sendResponse(res, messageData, error.message);
|
||||
} else {
|
||||
logger.error('[/assistants/chat/]', error);
|
||||
}
|
||||
|
||||
if (!openai || !thread_id || !run_id) {
|
||||
return sendResponse(res, messageData, defaultErrorMessage);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
|
||||
try {
|
||||
const status = await cache.get(cacheKey);
|
||||
if (status === 'cancelled') {
|
||||
logger.debug('[/assistants/chat/] Run already cancelled');
|
||||
return res.end();
|
||||
}
|
||||
await cache.delete(cacheKey);
|
||||
const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id);
|
||||
logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error cancelling run', error);
|
||||
}
|
||||
|
||||
await sleep(2000);
|
||||
|
||||
let run;
|
||||
try {
|
||||
run = await openai.beta.threads.runs.retrieve(thread_id, run_id);
|
||||
await recordUsage({
|
||||
...run.usage,
|
||||
model: run.model,
|
||||
user: req.user.id,
|
||||
conversationId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error fetching or processing run', error);
|
||||
}
|
||||
|
||||
let finalEvent;
|
||||
try {
|
||||
const runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId: responseMessageId,
|
||||
});
|
||||
|
||||
const errorContentPart = {
|
||||
text: {
|
||||
value:
|
||||
error?.message ?? 'There was an error processing your request. Please try again later.',
|
||||
},
|
||||
type: ContentTypes.ERROR,
|
||||
};
|
||||
|
||||
if (!Array.isArray(runMessages[runMessages.length - 1]?.content)) {
|
||||
runMessages[runMessages.length - 1].content = [errorContentPart];
|
||||
} else {
|
||||
const contentParts = runMessages[runMessages.length - 1].content;
|
||||
for (let i = 0; i < contentParts.length; i++) {
|
||||
const currentPart = contentParts[i];
|
||||
/** @type {CodeToolCall | RetrievalToolCall | FunctionToolCall | undefined} */
|
||||
const toolCall = currentPart?.[ContentTypes.TOOL_CALL];
|
||||
if (
|
||||
toolCall &&
|
||||
toolCall?.function &&
|
||||
!(toolCall?.function?.output || toolCall?.function?.output?.length)
|
||||
) {
|
||||
contentParts[i] = {
|
||||
...currentPart,
|
||||
[ContentTypes.TOOL_CALL]: {
|
||||
...toolCall,
|
||||
function: {
|
||||
...toolCall.function,
|
||||
output: 'error processing tool',
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
runMessages[runMessages.length - 1].content.push(errorContentPart);
|
||||
}
|
||||
|
||||
finalEvent = {
|
||||
final: true,
|
||||
conversation: await getConvo(req.user.id, conversationId),
|
||||
runMessages,
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/chat/] Error finalizing error process', error);
|
||||
return sendResponse(res, messageData, 'The Assistant run failed');
|
||||
}
|
||||
|
||||
return sendResponse(res, finalEvent);
|
||||
};
|
||||
|
||||
try {
|
||||
res.on('close', async () => {
|
||||
if (!completedRun) {
|
||||
await handleError(new Error('Request closed'));
|
||||
}
|
||||
});
|
||||
|
||||
if (convoId && !_thread_id) {
|
||||
completedRun = true;
|
||||
throw new Error('Missing thread_id for existing conversation');
|
||||
}
|
||||
|
||||
if (!assistant_id) {
|
||||
completedRun = true;
|
||||
throw new Error('Missing assistant_id');
|
||||
}
|
||||
|
||||
const checkBalanceBeforeRun = async () => {
|
||||
if (!isEnabled(process.env.CHECK_BALANCE)) {
|
||||
return;
|
||||
}
|
||||
const transactions =
|
||||
(await getTransactions({
|
||||
user: req.user.id,
|
||||
context: 'message',
|
||||
conversationId,
|
||||
})) ?? [];
|
||||
|
||||
const totalPreviousTokens = Math.abs(
|
||||
transactions.reduce((acc, curr) => acc + curr.rawAmount, 0),
|
||||
);
|
||||
|
||||
// TODO: make promptBuffer a config option; buffer for titles, needs buffer for system instructions
|
||||
const promptBuffer = parentMessageId === Constants.NO_PARENT && !_thread_id ? 200 : 0;
|
||||
// 5 is added for labels
|
||||
let promptTokens = (await countTokens(text + (promptPrefix ?? ''))) + 5;
|
||||
promptTokens += totalPreviousTokens + promptBuffer;
|
||||
// Count tokens up to the current context window
|
||||
promptTokens = Math.min(promptTokens, getModelMaxTokens(model));
|
||||
|
||||
await checkBalance({
|
||||
req,
|
||||
res,
|
||||
txData: {
|
||||
model,
|
||||
user: req.user.id,
|
||||
tokenType: 'prompt',
|
||||
amount: promptTokens,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const { openai: _openai, client } = await getOpenAIClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption: req.body.endpointOption,
|
||||
initAppClient: true,
|
||||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
}
|
||||
|
||||
let userMessage = {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text,
|
||||
},
|
||||
],
|
||||
metadata: {
|
||||
messageId: userMessageId,
|
||||
},
|
||||
};
|
||||
|
||||
/** @type {CreateRunBody | undefined} */
|
||||
const body = {
|
||||
assistant_id,
|
||||
model,
|
||||
};
|
||||
|
||||
if (promptPrefix) {
|
||||
body.additional_instructions = promptPrefix;
|
||||
}
|
||||
|
||||
if (instructions) {
|
||||
body.instructions = instructions;
|
||||
}
|
||||
|
||||
const getRequestFileIds = async () => {
|
||||
let thread_file_ids = [];
|
||||
if (convoId) {
|
||||
const convo = await getConvo(req.user.id, convoId);
|
||||
if (convo && convo.file_ids) {
|
||||
thread_file_ids = convo.file_ids;
|
||||
}
|
||||
}
|
||||
|
||||
if (files.length || thread_file_ids.length) {
|
||||
attachedFileIds = new Set([...file_ids, ...thread_file_ids]);
|
||||
|
||||
let attachmentIndex = 0;
|
||||
for (const file of files) {
|
||||
file_ids.push(file.file_id);
|
||||
if (file.type.startsWith('image')) {
|
||||
userMessage.content.push({
|
||||
type: ContentTypes.IMAGE_FILE,
|
||||
[ContentTypes.IMAGE_FILE]: { file_id: file.file_id },
|
||||
});
|
||||
}
|
||||
|
||||
if (!userMessage.attachments) {
|
||||
userMessage.attachments = [];
|
||||
}
|
||||
|
||||
userMessage.attachments.push({
|
||||
file_id: file.file_id,
|
||||
tools: [{ type: ToolCallTypes.CODE_INTERPRETER }],
|
||||
});
|
||||
|
||||
if (file.type.startsWith('image')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const mimeType = file.type;
|
||||
const isSupportedByRetrieval = retrievalMimeTypes.some((regex) => regex.test(mimeType));
|
||||
if (isSupportedByRetrieval) {
|
||||
userMessage.attachments[attachmentIndex].tools.push({
|
||||
type: ToolCallTypes.FILE_SEARCH,
|
||||
});
|
||||
}
|
||||
|
||||
attachmentIndex++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const initializeThread = async () => {
|
||||
await getRequestFileIds();
|
||||
|
||||
// TODO: may allow multiple messages to be created beforehand in a future update
|
||||
const initThreadBody = {
|
||||
messages: [userMessage],
|
||||
metadata: {
|
||||
user: req.user.id,
|
||||
conversationId,
|
||||
},
|
||||
};
|
||||
|
||||
const result = await initThread({ openai, body: initThreadBody, thread_id });
|
||||
thread_id = result.thread_id;
|
||||
|
||||
createOnTextProgress({
|
||||
openai,
|
||||
conversationId,
|
||||
userMessageId,
|
||||
messageId: responseMessageId,
|
||||
thread_id,
|
||||
});
|
||||
|
||||
requestMessage = {
|
||||
user: req.user.id,
|
||||
text,
|
||||
messageId: userMessageId,
|
||||
parentMessageId,
|
||||
// TODO: make sure client sends correct format for `files`, use zod
|
||||
files,
|
||||
file_ids,
|
||||
conversationId,
|
||||
isCreatedByUser: true,
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
previousMessages.push(requestMessage);
|
||||
|
||||
/* asynchronous */
|
||||
saveUserMessage({ ...requestMessage, model });
|
||||
|
||||
conversation = {
|
||||
conversationId,
|
||||
endpoint,
|
||||
promptPrefix: promptPrefix,
|
||||
instructions: instructions,
|
||||
assistant_id,
|
||||
// model,
|
||||
};
|
||||
|
||||
if (file_ids.length) {
|
||||
conversation.file_ids = file_ids;
|
||||
}
|
||||
};
|
||||
|
||||
const promises = [initializeThread(), checkBalanceBeforeRun()];
|
||||
await Promise.all(promises);
|
||||
|
||||
const sendInitialResponse = () => {
|
||||
sendMessage(res, {
|
||||
sync: true,
|
||||
conversationId,
|
||||
// messages: previousMessages,
|
||||
requestMessage,
|
||||
responseMessage: {
|
||||
user: req.user.id,
|
||||
messageId: openai.responseMessage.messageId,
|
||||
parentMessageId: userMessageId,
|
||||
conversationId,
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
/** @type {RunResponse | typeof StreamRunManager | undefined} */
|
||||
let response;
|
||||
|
||||
const processRun = async (retry = false) => {
|
||||
if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
body.model = openai._options.model;
|
||||
openai.attachedFileIds = attachedFileIds;
|
||||
if (retry) {
|
||||
response = await runAssistant({
|
||||
openai,
|
||||
thread_id,
|
||||
run_id,
|
||||
in_progress: openai.in_progress,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
/* NOTE:
|
||||
* By default, a Run will use the model and tools configuration specified in Assistant object,
|
||||
* but you can override most of these when creating the Run for added flexibility:
|
||||
*/
|
||||
const run = await createRun({
|
||||
openai,
|
||||
thread_id,
|
||||
body,
|
||||
});
|
||||
|
||||
run_id = run.id;
|
||||
await cache.set(cacheKey, `${thread_id}:${run_id}`, ten_minutes);
|
||||
sendInitialResponse();
|
||||
|
||||
// todo: retry logic
|
||||
response = await runAssistant({ openai, thread_id, run_id });
|
||||
return;
|
||||
}
|
||||
|
||||
/** @type {{[AssistantStreamEvents.ThreadRunCreated]: (event: ThreadRunCreated) => Promise<void>}} */
|
||||
const handlers = {
|
||||
[AssistantStreamEvents.ThreadRunCreated]: async (event) => {
|
||||
await cache.set(cacheKey, `${thread_id}:${event.data.id}`, ten_minutes);
|
||||
run_id = event.data.id;
|
||||
sendInitialResponse();
|
||||
},
|
||||
};
|
||||
|
||||
const streamRunManager = new StreamRunManager({
|
||||
req,
|
||||
res,
|
||||
openai,
|
||||
handlers,
|
||||
thread_id,
|
||||
attachedFileIds,
|
||||
parentMessageId: userMessageId,
|
||||
responseMessage: openai.responseMessage,
|
||||
// streamOptions: {
|
||||
|
||||
// },
|
||||
});
|
||||
|
||||
await streamRunManager.runAssistant({
|
||||
thread_id,
|
||||
body,
|
||||
});
|
||||
|
||||
response = streamRunManager;
|
||||
response.text = streamRunManager.intermediateText;
|
||||
};
|
||||
|
||||
await processRun();
|
||||
logger.debug('[/assistants/chat/] response', {
|
||||
run: response.run,
|
||||
steps: response.steps,
|
||||
});
|
||||
|
||||
if (response.run.status === RunStatus.CANCELLED) {
|
||||
logger.debug('[/assistants/chat/] Run cancelled, handled by `abortRun`');
|
||||
return res.end();
|
||||
}
|
||||
|
||||
if (response.run.status === RunStatus.IN_PROGRESS) {
|
||||
processRun(true);
|
||||
}
|
||||
|
||||
completedRun = response.run;
|
||||
|
||||
/** @type {ResponseMessage} */
|
||||
const responseMessage = {
|
||||
...(response.responseMessage ?? response.finalMessage),
|
||||
text: response.text,
|
||||
parentMessageId: userMessageId,
|
||||
conversationId,
|
||||
user: req.user.id,
|
||||
assistant_id,
|
||||
thread_id,
|
||||
model: assistant_id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: {
|
||||
parentMessageId,
|
||||
thread_id,
|
||||
},
|
||||
});
|
||||
res.end();
|
||||
|
||||
await saveAssistantMessage({ ...responseMessage, model });
|
||||
|
||||
if (parentMessageId === Constants.NO_PARENT && !_thread_id) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
responseText: response.text,
|
||||
conversationId,
|
||||
client,
|
||||
});
|
||||
}
|
||||
|
||||
await addThreadMetadata({
|
||||
openai,
|
||||
thread_id,
|
||||
messageId: responseMessage.messageId,
|
||||
messages: response.messages,
|
||||
});
|
||||
|
||||
if (!response.run.usage) {
|
||||
await sleep(3000);
|
||||
completedRun = await openai.beta.threads.runs.retrieve(thread_id, response.run.id);
|
||||
if (completedRun.usage) {
|
||||
await recordUsage({
|
||||
...completedRun.usage,
|
||||
user: req.user.id,
|
||||
model: completedRun.model ?? model,
|
||||
conversationId,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
await recordUsage({
|
||||
...response.run.usage,
|
||||
user: req.user.id,
|
||||
model: response.run.model ?? model,
|
||||
conversationId,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
await handleError(error);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = chatV2;
|
||||
269
api/server/controllers/assistants/helpers.js
Normal file
269
api/server/controllers/assistants/helpers.js
Normal file
@@ -0,0 +1,269 @@
|
||||
const {
|
||||
EModelEndpoint,
|
||||
CacheKeys,
|
||||
defaultAssistantsVersion,
|
||||
defaultOrderQuery,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
} = require('~/server/services/Endpoints/azureAssistants');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* @param {Express.Request} req
|
||||
* @param {string} [endpoint]
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
const getCurrentVersion = async (req, endpoint) => {
|
||||
const index = req.baseUrl.lastIndexOf('/v');
|
||||
let version = index !== -1 ? req.baseUrl.substring(index + 1, index + 3) : null;
|
||||
if (!version && req.body.version) {
|
||||
version = `v${req.body.version}`;
|
||||
}
|
||||
if (!version && endpoint) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
version = `v${
|
||||
cachedEndpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint]
|
||||
}`;
|
||||
}
|
||||
if (!version?.startsWith('v') && version.length !== 2) {
|
||||
throw new Error(`[${req.baseUrl}] Invalid version: ${version}`);
|
||||
}
|
||||
return version;
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants based on provided query parameters.
|
||||
*
|
||||
* Initializes the client with the current request and response objects and lists assistants
|
||||
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
||||
*
|
||||
* @deprecated
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const _listAssistants = async ({ req, res, version, query }) => {
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
return openai.beta.assistants.list(query);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches all assistants based on provided query params, until `has_more` is `false`.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAllAssistants = async ({ req, res, version, query }) => {
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
const allAssistants = [];
|
||||
|
||||
let first_id;
|
||||
let last_id;
|
||||
let afterToken = query.after;
|
||||
let hasMore = true;
|
||||
|
||||
while (hasMore) {
|
||||
const response = await openai.beta.assistants.list({
|
||||
...query,
|
||||
after: afterToken,
|
||||
});
|
||||
|
||||
const { body } = response;
|
||||
|
||||
allAssistants.push(...body.data);
|
||||
hasMore = body.has_more;
|
||||
|
||||
if (!first_id) {
|
||||
first_id = body.first_id;
|
||||
}
|
||||
|
||||
if (hasMore) {
|
||||
afterToken = body.last_id;
|
||||
} else {
|
||||
last_id = body.last_id;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
data: allAssistants,
|
||||
body: {
|
||||
data: allAssistants,
|
||||
has_more: false,
|
||||
first_id,
|
||||
last_id,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants for Azure configured groups.
|
||||
*
|
||||
* Iterates through Azure configured assistant groups, initializes the client with the current request and response objects,
|
||||
* lists assistants based on the provided query parameters, and merges their data alongside the model information into a single array.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client and manipulating the request body.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {TAzureConfig} params.azureConfig - The Azure configuration object containing assistantGroups and groupMap.
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<AssistantListResponse>} A promise that resolves to an array of assistant data merged with their respective model information.
|
||||
*/
|
||||
const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, query }) => {
|
||||
/** @type {Array<[string, TAzureModelConfig]>} */
|
||||
const groupModelTuples = [];
|
||||
const promises = [];
|
||||
/** @type {Array<TAzureGroup>} */
|
||||
const groups = [];
|
||||
|
||||
const { groupMap, assistantGroups } = azureConfig;
|
||||
|
||||
for (const groupName of assistantGroups) {
|
||||
const group = groupMap[groupName];
|
||||
groups.push(group);
|
||||
|
||||
const currentModelTuples = Object.entries(group?.models);
|
||||
groupModelTuples.push(currentModelTuples);
|
||||
|
||||
/* The specified model is only necessary to
|
||||
fetch assistants for the shared instance */
|
||||
req.body.model = currentModelTuples[0][0];
|
||||
promises.push(listAllAssistants({ req, res, version, query }));
|
||||
}
|
||||
|
||||
const resolvedQueries = await Promise.all(promises);
|
||||
const data = resolvedQueries.flatMap((res, i) =>
|
||||
res.data.map((assistant) => {
|
||||
const deploymentName = assistant.model;
|
||||
const currentGroup = groups[i];
|
||||
const currentModelTuples = groupModelTuples[i];
|
||||
const firstModel = currentModelTuples[0][0];
|
||||
|
||||
if (currentGroup.deploymentName === deploymentName) {
|
||||
return { ...assistant, model: firstModel };
|
||||
}
|
||||
|
||||
for (const [model, modelConfig] of currentModelTuples) {
|
||||
if (modelConfig.deploymentName === deploymentName) {
|
||||
return { ...assistant, model };
|
||||
}
|
||||
}
|
||||
|
||||
return { ...assistant, model: firstModel };
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
first_id: data[0]?.id,
|
||||
last_id: data[data.length - 1]?.id,
|
||||
object: 'list',
|
||||
has_more: false,
|
||||
data,
|
||||
};
|
||||
};
|
||||
|
||||
async function getOpenAIClient({ req, res, endpointOption, initAppClient, overrideEndpoint }) {
|
||||
let endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
if (!endpoint) {
|
||||
throw new Error(`[${req.baseUrl}] Endpoint is required`);
|
||||
}
|
||||
|
||||
let result;
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
result = await initializeClient({ req, res, version, endpointOption, initAppClient });
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
result = await initAzureClient({ req, res, version, endpointOption, initAppClient });
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
|
||||
* @param {object} params.res - Express Response
|
||||
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
|
||||
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
|
||||
*/
|
||||
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
const {
|
||||
limit = 100,
|
||||
order = 'desc',
|
||||
after,
|
||||
before,
|
||||
endpoint,
|
||||
} = req.query ?? {
|
||||
endpoint: overrideEndpoint,
|
||||
...defaultOrderQuery,
|
||||
};
|
||||
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
const query = { limit, order, after, before };
|
||||
|
||||
/** @type {AssistantListResponse} */
|
||||
let body;
|
||||
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
({ body } = await listAllAssistants({ req, res, version, query }));
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === 'ADMIN') {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
return body;
|
||||
}
|
||||
|
||||
body.data = filterAssistants({
|
||||
userId: req.user.id,
|
||||
assistants: body.data,
|
||||
assistantsConfig: req.app.locals[endpoint],
|
||||
});
|
||||
return body;
|
||||
};
|
||||
|
||||
/**
|
||||
* Filter assistants based on configuration.
|
||||
*
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {string} params.userId - The user ID to filter private assistants.
|
||||
* @param {Assistant[]} params.assistants - The list of assistants to filter.
|
||||
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
|
||||
* @returns {Assistant[]} - The filtered list of assistants.
|
||||
*/
|
||||
function filterAssistants({ assistants, userId, assistantsConfig }) {
|
||||
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
|
||||
if (privateAssistants) {
|
||||
return assistants.filter((assistant) => userId === assistant.metadata?.author);
|
||||
} else if (supportedIds?.length) {
|
||||
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
return assistants;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIClient,
|
||||
fetchAssistants,
|
||||
getCurrentVersion,
|
||||
};
|
||||
@@ -1,34 +1,12 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const { FileContext, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient,
|
||||
listAssistantsForAzure,
|
||||
listAssistants,
|
||||
} = require('~/server/services/Endpoints/assistants');
|
||||
const { FileContext } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||
const { uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { updateAssistant, getAssistants } = require('~/models/Assistant');
|
||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
const actions = require('./actions');
|
||||
const tools = require('./tools');
|
||||
|
||||
const upload = multer();
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
* Assistant actions route.
|
||||
* @route GET|POST /assistants/actions
|
||||
*/
|
||||
router.use('/actions', actions);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route GET /assistants/tools
|
||||
* @returns {TPlugin[]} 200 - application/json
|
||||
*/
|
||||
router.use('/tools', tools);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
@@ -36,12 +14,11 @@ router.use('/tools', tools);
|
||||
* @param {AssistantCreateParams} req.body - The assistant creation parameters.
|
||||
* @returns {Assistant} 201 - success response - application/json
|
||||
*/
|
||||
router.post('/', async (req, res) => {
|
||||
const createAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const { tools = [], ...assistantData } = req.body;
|
||||
const { tools = [], endpoint, ...assistantData } = req.body;
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
@@ -52,18 +29,30 @@ router.post('/', async (req, res) => {
|
||||
})
|
||||
.filter((tool) => tool);
|
||||
|
||||
let azureModelIdentifier = null;
|
||||
if (openai.locals?.azureOptions) {
|
||||
azureModelIdentifier = assistantData.model;
|
||||
assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
assistantData.metadata = {
|
||||
author: req.user.id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error creating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves an assistant.
|
||||
@@ -71,11 +60,10 @@ router.post('/', async (req, res) => {
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/:id', async (req, res) => {
|
||||
const retrieveAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
|
||||
/* NOTE: not actually being used right now */
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
const assistant_id = req.params.id;
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
res.json(assistant);
|
||||
@@ -83,22 +71,24 @@ router.get('/:id', async (req, res) => {
|
||||
logger.error('[/assistants/:id] Error retrieving assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @route PATCH /assistants/:id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @param {AssistantUpdateParams} req.body - The assistant update parameters.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.patch('/:id', async (req, res) => {
|
||||
const patchAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const updateData = req.body;
|
||||
const { endpoint: _e, ...updateData } = req.body;
|
||||
updateData.tools = (updateData.tools ?? [])
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
@@ -119,90 +109,76 @@ router.patch('/:id', async (req, res) => {
|
||||
logger.error('[/assistants/:id] Error updating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes an assistant.
|
||||
* @route DELETE /assistants/:id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:id', async (req, res) => {
|
||||
const deleteAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||
if (deletionStatus?.deleted) {
|
||||
await deleteAssistantActions({ req, assistant_id });
|
||||
}
|
||||
res.json(deletionStatus);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/:id] Error deleting assistant', error);
|
||||
res.status(500).json({ error: 'Error deleting assistant' });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @route GET /assistants
|
||||
* @param {object} req - Express Request
|
||||
* @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting.
|
||||
* @returns {AssistantListResponse} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/', async (req, res) => {
|
||||
const listAssistants = async (req, res) => {
|
||||
try {
|
||||
const { limit = 100, order = 'desc', after, before } = req.query;
|
||||
const query = { limit, order, after, before };
|
||||
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
/** @type {AssistantListResponse} */
|
||||
let body;
|
||||
|
||||
if (azureConfig?.assistants) {
|
||||
body = await listAssistantsForAzure({ req, res, azureConfig, query });
|
||||
} else {
|
||||
({ body } = await listAssistants({ req, res, query }));
|
||||
}
|
||||
|
||||
if (req.app.locals?.[EModelEndpoint.assistants]) {
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals[EModelEndpoint.assistants];
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
if (supportedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
}
|
||||
|
||||
const body = await fetchAssistants({ req, res });
|
||||
res.json(body);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error listing assistants', error);
|
||||
res.status(500).json({ message: 'Error listing assistants' });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a list of the user's assistant documents (metadata saved to database).
|
||||
* @route GET /assistants/documents
|
||||
* @returns {AssistantDocument[]} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/documents', async (req, res) => {
|
||||
const getAssistantDocuments = async (req, res) => {
|
||||
try {
|
||||
res.json(await getAssistants({ user: req.user.id }));
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/documents] Error listing assistant documents', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Uploads and updates an avatar for a specific assistant.
|
||||
* @route POST /avatar/:assistant_id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.assistant_id - The ID of the assistant.
|
||||
* @param {Express.Multer.File} req.file - The avatar image file.
|
||||
* @param {object} req.body - Request body
|
||||
* @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) => {
|
||||
const uploadAssistantAvatar = async (req, res) => {
|
||||
try {
|
||||
const { assistant_id } = req.params;
|
||||
if (!assistant_id) {
|
||||
@@ -210,8 +186,8 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
|
||||
}
|
||||
|
||||
let { metadata: _metadata = '{}' } = req.body;
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const image = await uploadImageBuffer({
|
||||
req,
|
||||
@@ -246,7 +222,7 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
|
||||
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
avatar: {
|
||||
@@ -266,6 +242,14 @@ router.post('/avatar/:assistant_id', upload.single('file'), async (req, res) =>
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = router;
|
||||
module.exports = {
|
||||
createAssistant,
|
||||
retrieveAssistant,
|
||||
patchAssistant,
|
||||
deleteAssistant,
|
||||
listAssistants,
|
||||
getAssistantDocuments,
|
||||
uploadAssistantAvatar,
|
||||
};
|
||||
213
api/server/controllers/assistants/v2.js
Normal file
213
api/server/controllers/assistants/v2.js
Normal file
@@ -0,0 +1,213 @@
|
||||
const { ToolCallTypes } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route POST /assistants
|
||||
* @param {AssistantCreateParams} req.body - The assistant creation parameters.
|
||||
* @returns {Assistant} 201 - success response - application/json
|
||||
*/
|
||||
const createAssistant = async (req, res) => {
|
||||
try {
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const { tools = [], endpoint, ...assistantData } = req.body;
|
||||
assistantData.tools = tools
|
||||
.map((tool) => {
|
||||
if (typeof tool !== 'string') {
|
||||
return tool;
|
||||
}
|
||||
|
||||
return req.app.locals.availableTools[tool];
|
||||
})
|
||||
.filter((tool) => tool);
|
||||
|
||||
let azureModelIdentifier = null;
|
||||
if (openai.locals?.azureOptions) {
|
||||
azureModelIdentifier = assistantData.model;
|
||||
assistantData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
assistantData.metadata = {
|
||||
author: req.user.id,
|
||||
endpoint,
|
||||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error creating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||
await validateAuthor({ req, openai });
|
||||
const tools = [];
|
||||
|
||||
let hasFileSearch = false;
|
||||
for (const tool of updateData.tools ?? []) {
|
||||
let actualTool = typeof tool === 'string' ? req.app.locals.availableTools[tool] : tool;
|
||||
|
||||
if (!actualTool) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (actualTool.type === ToolCallTypes.FILE_SEARCH) {
|
||||
hasFileSearch = true;
|
||||
}
|
||||
|
||||
if (!actualTool.function) {
|
||||
tools.push(actualTool);
|
||||
continue;
|
||||
}
|
||||
|
||||
const updatedTool = await validateAndUpdateTool({ req, tool: actualTool, assistant_id });
|
||||
if (updatedTool) {
|
||||
tools.push(updatedTool);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFileSearch && !updateData.tool_resources) {
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
updateData.tool_resources = assistant.tool_resources ?? null;
|
||||
}
|
||||
|
||||
if (hasFileSearch && !updateData.tool_resources?.file_search) {
|
||||
updateData.tool_resources = {
|
||||
...(updateData.tool_resources ?? {}),
|
||||
file_search: {
|
||||
vector_store_ids: [],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
updateData.tools = tools;
|
||||
|
||||
if (openai.locals?.azureOptions && updateData.model) {
|
||||
updateData.model = openai.locals.azureOptions.azureOpenAIApiDeploymentName;
|
||||
}
|
||||
|
||||
return await openai.beta.assistants.update(assistant_id, updateData);
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant with the resource file id.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {string} params.tool_resource
|
||||
* @param {string} params.file_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const addResourceFileId = async ({ req, openai, assistant_id, tool_resource, file_id }) => {
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
const { tool_resources = {} } = assistant;
|
||||
if (tool_resources[tool_resource]) {
|
||||
tool_resources[tool_resource].file_ids.push(file_id);
|
||||
} else {
|
||||
tool_resources[tool_resource] = { file_ids: [file_id] };
|
||||
}
|
||||
|
||||
delete assistant.id;
|
||||
return await updateAssistant({
|
||||
req,
|
||||
openai,
|
||||
assistant_id,
|
||||
updateData: { tools: assistant.tools, tool_resources },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a file ID from an assistant's resource.
|
||||
* @param {object} params
|
||||
* @param {Express.Request} params.req
|
||||
* @param {OpenAIClient} params.openai
|
||||
* @param {string} params.assistant_id
|
||||
* @param {string} [params.tool_resource]
|
||||
* @param {string} params.file_id
|
||||
* @param {AssistantUpdateParams} params.updateData
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const deleteResourceFileId = async ({ req, openai, assistant_id, tool_resource, file_id }) => {
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
const { tool_resources = {} } = assistant;
|
||||
|
||||
if (tool_resource && tool_resources[tool_resource]) {
|
||||
const resource = tool_resources[tool_resource];
|
||||
const index = resource.file_ids.indexOf(file_id);
|
||||
if (index !== -1) {
|
||||
resource.file_ids.splice(index, 1);
|
||||
}
|
||||
} else {
|
||||
for (const resourceKey in tool_resources) {
|
||||
const resource = tool_resources[resourceKey];
|
||||
const index = resource.file_ids.indexOf(file_id);
|
||||
if (index !== -1) {
|
||||
resource.file_ids.splice(index, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete assistant.id;
|
||||
return await updateAssistant({
|
||||
req,
|
||||
openai,
|
||||
assistant_id,
|
||||
updateData: { tools: assistant.tools, tool_resources },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @route PATCH /assistants/:id
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.params - Request params
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @param {AssistantUpdateParams} req.body - The assistant update parameters.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
const patchAssistant = async (req, res) => {
|
||||
try {
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
const assistant_id = req.params.id;
|
||||
const { endpoint: _e, ...updateData } = req.body;
|
||||
updateData.tools = updateData.tools ?? [];
|
||||
const updatedAssistant = await updateAssistant({ req, openai, assistant_id, updateData });
|
||||
res.json(updatedAssistant);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants/:id] Error updating assistant', error);
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
patchAssistant,
|
||||
createAssistant,
|
||||
updateAssistant,
|
||||
addResourceFileId,
|
||||
deleteResourceFileId,
|
||||
};
|
||||
@@ -15,7 +15,7 @@ const AppService = require('./services/AppService');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { ldapLogin } = require('~/strategies');
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
|
||||
@@ -60,6 +60,11 @@ const startServer = async () => {
|
||||
passport.use(await jwtLogin());
|
||||
passport.use(passportLogin());
|
||||
|
||||
// LDAP Auth
|
||||
if (process.env.LDAP_URL && process.env.LDAP_BIND_DN && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
passport.use(ldapLogin);
|
||||
}
|
||||
|
||||
if (isEnabled(ALLOW_SOCIAL_LOGIN)) {
|
||||
configureSocialLogins(app);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { isAssistantsEndpoint } = require('librechat-data-provider');
|
||||
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
|
||||
@@ -15,7 +15,7 @@ async function abortMessage(req, res) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||
const { deleteMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
@@ -10,7 +11,7 @@ const three_minutes = 1000 * 60 * 3;
|
||||
|
||||
async function abortRun(req, res) {
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
const { abortKey } = req.body;
|
||||
const { abortKey, endpoint } = req.body;
|
||||
const [conversationId, latestMessageId] = abortKey.split(':');
|
||||
const conversation = await getConvo(req.user.id, conversationId);
|
||||
|
||||
@@ -66,12 +67,19 @@ async function abortRun(req, res) {
|
||||
logger.error('[abortRun] Error fetching or processing run', error);
|
||||
}
|
||||
|
||||
/* TODO: a reconciling strategy between the existing intermediate message would be more optimal than deleting it */
|
||||
await deleteMessages({
|
||||
user: req.user.id,
|
||||
unfinished: true,
|
||||
conversationId,
|
||||
});
|
||||
runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
latestMessageId,
|
||||
thread_id,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
conversationId,
|
||||
latestMessageId,
|
||||
});
|
||||
|
||||
const finalEvent = {
|
||||
|
||||
43
api/server/middleware/assistants/validate.js
Normal file
43
api/server/middleware/assistants/validate.js
Normal file
@@ -0,0 +1,43 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { handleAbortError } = require('~/server/middleware/abortMiddleware');
|
||||
|
||||
/**
|
||||
* Checks if the assistant is supported or excluded
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {object} res - Express Response
|
||||
* @param {function} next - Express next middleware function.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAssistant = async (req, res, next) => {
|
||||
const { endpoint, conversationId, assistant_id, messageId } = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
parentMessageId: messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
parentMessageId: messageId,
|
||||
});
|
||||
}
|
||||
|
||||
return next();
|
||||
};
|
||||
|
||||
module.exports = validateAssistant;
|
||||
42
api/server/middleware/assistants/validateAuthor.js
Normal file
42
api/server/middleware/assistants/validateAuthor.js
Normal file
@@ -0,0 +1,42 @@
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
/**
|
||||
* Checks if the assistant is supported or excluded
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {object} params.req.body - The request payload.
|
||||
* @param {string} params.overrideEndpoint - The override endpoint
|
||||
* @param {string} params.overrideAssistantId - The override assistant ID
|
||||
* @param {OpenAIClient} params.openai - OpenAI API Client
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
|
||||
if (req.user.role === 'ADMIN') {
|
||||
return;
|
||||
}
|
||||
|
||||
const endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const assistant_id =
|
||||
overrideAssistantId ?? req.params.id ?? req.body.assistant_id ?? req.query.assistant_id;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!assistantsConfig.privateAssistants) {
|
||||
return;
|
||||
}
|
||||
|
||||
const assistantDoc = await getAssistant({ assistant_id, user: req.user.id });
|
||||
if (assistantDoc) {
|
||||
return;
|
||||
}
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
if (req.user.id !== assistant?.metadata?.author) {
|
||||
throw new Error(`Assistant ${assistant_id} is not authored by the user.`);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = validateAuthor;
|
||||
@@ -1,5 +1,6 @@
|
||||
const { parseConvo, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
||||
const assistants = require('~/server/services/Endpoints/assistants');
|
||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
@@ -18,6 +19,7 @@ const buildFunction = {
|
||||
[EModelEndpoint.anthropic]: anthropic.buildOptions,
|
||||
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
|
||||
[EModelEndpoint.assistants]: assistants.buildOptions,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
|
||||
};
|
||||
|
||||
async function buildEndpointOption(req, res, next) {
|
||||
|
||||
@@ -2,14 +2,12 @@ const Keyv = require('keyv');
|
||||
const uap = require('ua-parser-js');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, removePorts } = require('../utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const keyvMongo = require('~/cache/keyvMongo');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const User = require('~/models/User');
|
||||
|
||||
const banCache = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const message = 'Your account has been temporarily banned due to violations of our service.';
|
||||
|
||||
/**
|
||||
|
||||
@@ -6,6 +6,7 @@ const setHeaders = require('./setHeaders');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
@@ -29,6 +30,7 @@ module.exports = {
|
||||
setHeaders,
|
||||
loginLimiter,
|
||||
requireJwtAuth,
|
||||
requireLdapAuth,
|
||||
registerLimiter,
|
||||
requireLocalAuth,
|
||||
validateEndpoint,
|
||||
|
||||
22
api/server/middleware/requireLdapAuth.js
Normal file
22
api/server/middleware/requireLdapAuth.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const passport = require('passport');
|
||||
|
||||
const requireLdapAuth = (req, res, next) => {
|
||||
passport.authenticate('ldapauth', (err, user, info) => {
|
||||
if (err) {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error at passport.authenticate',
|
||||
parameters: [{ name: 'error', value: err }],
|
||||
});
|
||||
return next(err);
|
||||
}
|
||||
if (!user) {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error: No user',
|
||||
});
|
||||
return res.status(422).send(info);
|
||||
}
|
||||
req.user = user;
|
||||
next();
|
||||
})(req, res, next);
|
||||
};
|
||||
module.exports = requireLdapAuth;
|
||||
7
api/server/middleware/speech/index.js
Normal file
7
api/server/middleware/speech/index.js
Normal file
@@ -0,0 +1,7 @@
|
||||
const createTTSLimiters = require('./ttsLimiters');
|
||||
const createSTTLimiters = require('./sttLimiters');
|
||||
|
||||
module.exports = {
|
||||
createTTSLimiters,
|
||||
createSTTLimiters,
|
||||
};
|
||||
68
api/server/middleware/speech/sttLimiters.js
Normal file
68
api/server/middleware/speech/sttLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
const sttIpWindowInMinutes = sttIpWindowMs / 60000;
|
||||
|
||||
const sttUserWindowMs = STT_USER_WINDOW * 60 * 1000;
|
||||
const sttUserMax = STT_USER_MAX;
|
||||
const sttUserWindowInMinutes = sttUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
sttIpWindowMs,
|
||||
sttIpMax,
|
||||
sttIpWindowInMinutes,
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.STT_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? sttIpMax : sttUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTLimiters = () => {
|
||||
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
|
||||
|
||||
const sttIpLimiter = rateLimit({
|
||||
windowMs: sttIpWindowMs,
|
||||
max: sttIpMax,
|
||||
handler: createSTTHandler(),
|
||||
});
|
||||
|
||||
const sttUserLimiter = rateLimit({
|
||||
windowMs: sttUserWindowMs,
|
||||
max: sttUserMax,
|
||||
handler: createSTTHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
|
||||
return { sttIpLimiter, sttUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = createSTTLimiters;
|
||||
68
api/server/middleware/speech/ttsLimiters.js
Normal file
68
api/server/middleware/speech/ttsLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
const ttsIpWindowInMinutes = ttsIpWindowMs / 60000;
|
||||
|
||||
const ttsUserWindowMs = TTS_USER_WINDOW * 60 * 1000;
|
||||
const ttsUserMax = TTS_USER_MAX;
|
||||
const ttsUserWindowInMinutes = ttsUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
ttsIpWindowMs,
|
||||
ttsIpMax,
|
||||
ttsIpWindowInMinutes,
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.TTS_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? ttsIpMax : ttsUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSLimiters = () => {
|
||||
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ttsIpLimiter = rateLimit({
|
||||
windowMs: ttsIpWindowMs,
|
||||
max: ttsIpMax,
|
||||
handler: createTTSHandler(),
|
||||
});
|
||||
|
||||
const ttsUserLimiter = rateLimit({
|
||||
windowMs: ttsUserWindowMs,
|
||||
max: ttsUserMax,
|
||||
handler: createTTSHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
|
||||
return { ttsIpLimiter, ttsUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = createTTSLimiters;
|
||||
@@ -25,6 +25,11 @@ afterEach(() => {
|
||||
delete process.env.DOMAIN_SERVER;
|
||||
delete process.env.ALLOW_REGISTRATION;
|
||||
delete process.env.ALLOW_SOCIAL_LOGIN;
|
||||
delete process.env.LDAP_URL;
|
||||
delete process.env.LDAP_BIND_DN;
|
||||
delete process.env.LDAP_BIND_CREDENTIALS;
|
||||
delete process.env.LDAP_USER_SEARCH_BASE;
|
||||
delete process.env.LDAP_SEARCH_FILTER;
|
||||
});
|
||||
|
||||
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
||||
@@ -50,6 +55,11 @@ describe.skip('GET /', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://test-server.com';
|
||||
process.env.ALLOW_REGISTRATION = 'true';
|
||||
process.env.ALLOW_SOCIAL_LOGIN = 'true';
|
||||
process.env.LDAP_URL = 'Test LDAP URL';
|
||||
process.env.LDAP_BIND_DN = 'Test LDAP Bind DN';
|
||||
process.env.LDAP_BIND_CREDENTIALS = 'Test LDAP Bind Credentials';
|
||||
process.env.LDAP_USER_SEARCH_BASE = 'Test LDAP User Search Base';
|
||||
process.env.LDAP_SEARCH_FILTER = 'Test LDAP Search Filter';
|
||||
|
||||
const response = await request(app).get('/');
|
||||
|
||||
@@ -64,6 +74,7 @@ describe.skip('GET /', () => {
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
ldapLoginEnabled: true,
|
||||
serverDomain: 'http://test-server.com',
|
||||
emailLoginEnabled: 'true',
|
||||
registrationEnabled: 'true',
|
||||
|
||||
@@ -106,7 +106,11 @@ router.post(
|
||||
const pluginMap = new Map();
|
||||
const onAgentAction = async (action, runId) => {
|
||||
pluginMap.set(runId, action.tool);
|
||||
sendIntermediateMessage(res, { plugins });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const onToolStart = async (tool, input, runId, parentRunId) => {
|
||||
@@ -124,7 +128,11 @@ router.post(
|
||||
}
|
||||
const extraTokens = ':::plugin:::\n';
|
||||
plugins.push(latestPlugin);
|
||||
sendIntermediateMessage(res, { plugins }, extraTokens);
|
||||
sendIntermediateMessage(
|
||||
res,
|
||||
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
|
||||
extraTokens,
|
||||
);
|
||||
};
|
||||
|
||||
const onToolEnd = async (output, runId) => {
|
||||
@@ -142,7 +150,11 @@ router.post(
|
||||
|
||||
const onChainEnd = () => {
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, { plugins });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
@@ -174,12 +186,13 @@ router.post(
|
||||
onStart,
|
||||
getPartialText,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
plugins,
|
||||
}),
|
||||
},
|
||||
abortController,
|
||||
});
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ const { v4 } = require('uuid');
|
||||
const express = require('express');
|
||||
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
||||
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||
const { updateAssistant, getAssistant } = require('~/models/Assistant');
|
||||
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -45,7 +45,6 @@ router.post('/:assistant_id', async (req, res) => {
|
||||
let metadata = encryptMetadata(_metadata);
|
||||
|
||||
let { domain } = metadata;
|
||||
/* Azure doesn't support periods in function names */
|
||||
domain = await domainParser(req, domain, true);
|
||||
|
||||
if (!domain) {
|
||||
@@ -55,8 +54,7 @@ router.post('/:assistant_id', async (req, res) => {
|
||||
const action_id = _action_id ?? v4();
|
||||
const initialPromises = [];
|
||||
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
initialPromises.push(getAssistant({ assistant_id }));
|
||||
initialPromises.push(openai.beta.assistants.retrieve(assistant_id));
|
||||
@@ -111,7 +109,7 @@ router.post('/:assistant_id', async (req, res) => {
|
||||
let updatedAssistant = await openai.beta.assistants.update(assistant_id, { tools });
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
actions,
|
||||
@@ -157,9 +155,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
|
||||
try {
|
||||
const { assistant_id, action_id, model } = req.params;
|
||||
req.body.model = model;
|
||||
|
||||
/** @type {{ openai: OpenAI }} */
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const initialPromises = [];
|
||||
initialPromises.push(getAssistant({ assistant_id }));
|
||||
@@ -190,7 +186,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
|
||||
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
actions: updatedActions,
|
||||
|
||||
26
api/server/routes/assistants/chatV1.js
Normal file
26
api/server/routes/assistants/chatV1.js
Normal file
@@ -0,0 +1,26 @@
|
||||
const express = require('express');
|
||||
|
||||
const router = express.Router();
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const validateAssistant = require('~/server/middleware/assistants/validate');
|
||||
const chatController = require('~/server/controllers/assistants/chatV1');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
|
||||
|
||||
module.exports = router;
|
||||
26
api/server/routes/assistants/chatV2.js
Normal file
26
api/server/routes/assistants/chatV2.js
Normal file
@@ -0,0 +1,26 @@
|
||||
const express = require('express');
|
||||
|
||||
const router = express.Router();
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const validateAssistant = require('~/server/middleware/assistants/validate');
|
||||
const chatController = require('~/server/controllers/assistants/chatV2');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
|
||||
|
||||
module.exports = router;
|
||||
@@ -7,16 +7,19 @@ const {
|
||||
// concurrentLimiter,
|
||||
// messageIpLimiter,
|
||||
// messageUserLimiter,
|
||||
} = require('../../middleware');
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const assistants = require('./assistants');
|
||||
const chat = require('./chat');
|
||||
const v1 = require('./v1');
|
||||
const chatV1 = require('./chatV1');
|
||||
const v2 = require('./v2');
|
||||
const chatV2 = require('./chatV2');
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
router.use('/', assistants);
|
||||
router.use('/chat', chat);
|
||||
router.use('/v1/', v1);
|
||||
router.use('/v1/chat', chatV1);
|
||||
router.use('/v2/', v2);
|
||||
router.use('/v2/chat', chatV2);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
81
api/server/routes/assistants/v1.js
Normal file
81
api/server/routes/assistants/v1.js
Normal file
@@ -0,0 +1,81 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const controllers = require('~/server/controllers/assistants/v1');
|
||||
const actions = require('./actions');
|
||||
const tools = require('./tools');
|
||||
|
||||
const upload = multer();
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
* Assistant actions route.
|
||||
* @route GET|POST /assistants/actions
|
||||
*/
|
||||
router.use('/actions', actions);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route GET /assistants/tools
|
||||
* @returns {TPlugin[]} 200 - application/json
|
||||
*/
|
||||
router.use('/tools', tools);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route POST /assistants
|
||||
* @param {AssistantCreateParams} req.body - The assistant creation parameters.
|
||||
* @returns {Assistant} 201 - success response - application/json
|
||||
*/
|
||||
router.post('/', controllers.createAssistant);
|
||||
|
||||
/**
|
||||
* Retrieves an assistant.
|
||||
* @route GET /assistants/:id
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/:id', controllers.retrieveAssistant);
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @route PATCH /assistants/:id
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @param {AssistantUpdateParams} req.body - The assistant update parameters.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.patch('/:id', controllers.patchAssistant);
|
||||
|
||||
/**
|
||||
* Deletes an assistant.
|
||||
* @route DELETE /assistants/:id
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:id', controllers.deleteAssistant);
|
||||
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @route GET /assistants
|
||||
* @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting.
|
||||
* @returns {AssistantListResponse} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/', controllers.listAssistants);
|
||||
|
||||
/**
|
||||
* Returns a list of the user's assistant documents (metadata saved to database).
|
||||
* @route GET /assistants/documents
|
||||
* @returns {AssistantDocument[]} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/documents', controllers.getAssistantDocuments);
|
||||
|
||||
/**
|
||||
* Uploads and updates an avatar for a specific assistant.
|
||||
* @route POST /avatar/:assistant_id
|
||||
* @param {string} req.params.assistant_id - The ID of the assistant.
|
||||
* @param {Express.Multer.File} req.file - The avatar image file.
|
||||
* @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/avatar/:assistant_id', upload.single('file'), controllers.uploadAssistantAvatar);
|
||||
|
||||
module.exports = router;
|
||||
82
api/server/routes/assistants/v2.js
Normal file
82
api/server/routes/assistants/v2.js
Normal file
@@ -0,0 +1,82 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const v1 = require('~/server/controllers/assistants/v1');
|
||||
const v2 = require('~/server/controllers/assistants/v2');
|
||||
const actions = require('./actions');
|
||||
const tools = require('./tools');
|
||||
|
||||
const upload = multer();
|
||||
const router = express.Router();
|
||||
|
||||
/**
|
||||
* Assistant actions route.
|
||||
* @route GET|POST /assistants/actions
|
||||
*/
|
||||
router.use('/actions', actions);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route GET /assistants/tools
|
||||
* @returns {TPlugin[]} 200 - application/json
|
||||
*/
|
||||
router.use('/tools', tools);
|
||||
|
||||
/**
|
||||
* Create an assistant.
|
||||
* @route POST /assistants
|
||||
* @param {AssistantCreateParams} req.body - The assistant creation parameters.
|
||||
* @returns {Assistant} 201 - success response - application/json
|
||||
*/
|
||||
router.post('/', v2.createAssistant);
|
||||
|
||||
/**
|
||||
* Retrieves an assistant.
|
||||
* @route GET /assistants/:id
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/:id', v1.retrieveAssistant);
|
||||
|
||||
/**
|
||||
* Modifies an assistant.
|
||||
* @route PATCH /assistants/:id
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @param {AssistantUpdateParams} req.body - The assistant update parameters.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.patch('/:id', v2.patchAssistant);
|
||||
|
||||
/**
|
||||
* Deletes an assistant.
|
||||
* @route DELETE /assistants/:id
|
||||
* @param {string} req.params.id - Assistant identifier.
|
||||
* @returns {Assistant} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:id', v1.deleteAssistant);
|
||||
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @route GET /assistants
|
||||
* @param {AssistantListParams} req.query - The assistant list parameters for pagination and sorting.
|
||||
* @returns {AssistantListResponse} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/', v1.listAssistants);
|
||||
|
||||
/**
|
||||
* Returns a list of the user's assistant documents (metadata saved to database).
|
||||
* @route GET /assistants/documents
|
||||
* @returns {AssistantDocument[]} 200 - success response - application/json
|
||||
*/
|
||||
router.get('/documents', v1.getAssistantDocuments);
|
||||
|
||||
/**
|
||||
* Uploads and updates an avatar for a specific assistant.
|
||||
* @route POST /avatar/:assistant_id
|
||||
* @param {string} req.params.assistant_id - The ID of the assistant.
|
||||
* @param {Express.Multer.File} req.file - The avatar image file.
|
||||
* @param {string} [req.body.metadata] - Optional metadata for the assistant's avatar.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/avatar/:assistant_id', upload.single('file'), v1.uploadAssistantAvatar);
|
||||
|
||||
module.exports = router;
|
||||
@@ -12,15 +12,24 @@ const {
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
requireJwtAuth,
|
||||
requireLdapAuth,
|
||||
requireLocalAuth,
|
||||
validateRegistration,
|
||||
} = require('../middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const ldapAuth =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
//Local
|
||||
router.post('/logout', requireJwtAuth, logoutController);
|
||||
router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController);
|
||||
router.post(
|
||||
'/login',
|
||||
loginLimiter,
|
||||
checkBan,
|
||||
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||
loginController,
|
||||
);
|
||||
router.post('/refresh', refreshController);
|
||||
router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController);
|
||||
router.post('/requestPasswordReset', resetPasswordRequestController);
|
||||
|
||||
@@ -13,6 +13,8 @@ router.get('/', async function (req, res) {
|
||||
return today.getMonth() === 1 && today.getDate() === 11;
|
||||
};
|
||||
|
||||
const ldapLoginEnabled =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
try {
|
||||
/** @type {TStartupConfig} */
|
||||
const payload = {
|
||||
@@ -30,9 +32,10 @@ router.get('/', async function (req, res) {
|
||||
!!process.env.OPENID_SESSION_SECRET,
|
||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||
ldapLoginEnabled,
|
||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||
emailLoginEnabled,
|
||||
registrationEnabled: isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
registrationEnabled: !ldapLoginEnabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
||||
emailEnabled:
|
||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||
|
||||
@@ -110,7 +110,11 @@ router.post(
|
||||
if (!start) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, { plugin });
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
@@ -119,7 +123,11 @@ router.post(
|
||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||
plugin.loading = false;
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, { plugin });
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('CHAIN END', plugin.outputs);
|
||||
};
|
||||
|
||||
@@ -153,12 +161,13 @@ router.post(
|
||||
onChainEnd,
|
||||
onStart,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
plugin,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
abortController,
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const fs = require('fs').promises;
|
||||
const express = require('express');
|
||||
const { isUUID, FileSources } = require('librechat-data-provider');
|
||||
const { isUUID, checkOpenAIStorage } = require('librechat-data-provider');
|
||||
const {
|
||||
filterFile,
|
||||
processFileUpload,
|
||||
@@ -89,7 +89,7 @@ router.get('/download/:userId/:file_id', async (req, res) => {
|
||||
return res.status(403).send('Forbidden');
|
||||
}
|
||||
|
||||
if (file.source === FileSources.openai && !file.model) {
|
||||
if (checkOpenAIStorage(file.source) && !file.model) {
|
||||
logger.warn(`${errorPrefix} has no associated model: ${file_id}`);
|
||||
return res.status(400).send('The model used when creating this file is not available');
|
||||
}
|
||||
@@ -110,7 +110,8 @@ router.get('/download/:userId/:file_id', async (req, res) => {
|
||||
let passThrough;
|
||||
/** @type {ReadableStream | undefined} */
|
||||
let fileStream;
|
||||
if (file.source === FileSources.openai) {
|
||||
|
||||
if (checkOpenAIStorage(file.source)) {
|
||||
req.body = { model: file.model };
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
logger.debug(`Downloading file ${file_id} from OpenAI`);
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
const express = require('express');
|
||||
const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware');
|
||||
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware/speech');
|
||||
const { createMulterInstance } = require('./multer');
|
||||
|
||||
const files = require('./files');
|
||||
const images = require('./images');
|
||||
const avatar = require('./avatar');
|
||||
const stt = require('./stt');
|
||||
const tts = require('./tts');
|
||||
|
||||
const initialize = async () => {
|
||||
const router = express.Router();
|
||||
@@ -12,6 +15,12 @@ const initialize = async () => {
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
/* Important: stt/tts routes must be added before the upload limiters */
|
||||
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||
|
||||
const upload = await createMulterInstance();
|
||||
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
|
||||
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);
|
||||
|
||||
13
api/server/routes/files/stt.js
Normal file
13
api/server/routes/files/stt.js
Normal file
@@ -0,0 +1,13 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const multer = require('multer');
|
||||
const { requireJwtAuth } = require('~/server/middleware/');
|
||||
const { speechToText } = require('~/server/services/Files/Audio');
|
||||
|
||||
const upload = multer();
|
||||
|
||||
router.post('/', requireJwtAuth, upload.single('audio'), async (req, res) => {
|
||||
await speechToText(req, res);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
42
api/server/routes/files/tts.js
Normal file
42
api/server/routes/files/tts.js
Normal file
@@ -0,0 +1,42 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getVoices, streamAudio, textToSpeech } = require('~/server/services/Files/Audio');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const upload = multer();
|
||||
|
||||
router.post('/manual', upload.none(), async (req, res) => {
|
||||
await textToSpeech(req, res);
|
||||
});
|
||||
|
||||
const logDebugMessage = (req, message) =>
|
||||
logger.debug(`[streamAudio] user: ${req?.user?.id ?? 'UNDEFINED_USER'} | ${message}`);
|
||||
|
||||
// TODO: test caching
|
||||
router.post('/', async (req, res) => {
|
||||
try {
|
||||
const audioRunsCache = getLogStores(CacheKeys.AUDIO_RUNS);
|
||||
const audioRun = await audioRunsCache.get(req.body.runId);
|
||||
logDebugMessage(req, 'start stream audio');
|
||||
if (audioRun) {
|
||||
logDebugMessage(req, 'stream audio already running');
|
||||
return res.status(401).json({ error: 'Audio stream already running' });
|
||||
}
|
||||
audioRunsCache.set(req.body.runId, true);
|
||||
await streamAudio(req, res);
|
||||
logDebugMessage(req, 'end stream audio');
|
||||
res.status(200).end();
|
||||
} catch (error) {
|
||||
logger.error(`[streamAudio] user: ${req.user.id} | Failed to stream audio: ${error}`);
|
||||
res.status(500).json({ error: 'Failed to stream audio' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/voices', async (req, res) => {
|
||||
await getVoices(req, res);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -14,7 +14,7 @@ router.use(requireJwtAuth);
|
||||
|
||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId }));
|
||||
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
|
||||
});
|
||||
|
||||
// CREATE
|
||||
@@ -28,7 +28,7 @@ router.post('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
// READ
|
||||
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
const { conversationId, messageId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId, messageId }));
|
||||
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
|
||||
});
|
||||
|
||||
// UPDATE
|
||||
|
||||
@@ -41,29 +41,10 @@ router.get('/', async function (req, res) {
|
||||
return;
|
||||
}
|
||||
|
||||
const messages = (
|
||||
await Message.meiliSearch(
|
||||
q,
|
||||
{
|
||||
attributesToHighlight: ['text'],
|
||||
highlightPreTag: '**',
|
||||
highlightPostTag: '**',
|
||||
},
|
||||
true,
|
||||
)
|
||||
).hits.map((message) => {
|
||||
const { _formatted, ...rest } = message;
|
||||
return {
|
||||
...rest,
|
||||
searchResult: true,
|
||||
text: _formatted.text,
|
||||
};
|
||||
});
|
||||
const messages = (await Message.meiliSearch(q, undefined, true)).hits;
|
||||
const titles = (await Conversation.meiliSearch(q)).hits;
|
||||
|
||||
const sortedHits = reduceHits(messages, titles);
|
||||
// debugging:
|
||||
// logger.debug('user:', user, 'message hits:', messages.length, 'convo hits:', titles.length);
|
||||
// logger.debug('sorted hits:', sortedHits.length);
|
||||
const result = await getConvosQueried(user, sortedHits, pageNumber);
|
||||
|
||||
const activeMessages = [];
|
||||
@@ -86,8 +67,7 @@ router.get('/', async function (req, res) {
|
||||
delete result.cache;
|
||||
}
|
||||
delete result.convoMap;
|
||||
// for debugging
|
||||
// logger.debug(result, messages.length);
|
||||
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error('[/search] Error while searching messages & conversations', error);
|
||||
|
||||
@@ -1,20 +1,59 @@
|
||||
const {
|
||||
AuthTypeEnum,
|
||||
EModelEndpoint,
|
||||
actionDomainSeparator,
|
||||
CacheKeys,
|
||||
Constants,
|
||||
AuthTypeEnum,
|
||||
actionDelimiter,
|
||||
isImageVisionTool,
|
||||
actionDomainSeparator,
|
||||
} = require('librechat-data-provider');
|
||||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { getActions } = require('~/models/Action');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
const { deleteAssistant } = require('~/models/Assistant');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const toolNameRegex = /^[a-zA-Z0-9_-]+$/;
|
||||
|
||||
/**
|
||||
* Validates tool name against regex pattern and updates if necessary.
|
||||
* @param {object} params - The parameters for the function.
|
||||
* @param {object} params.req - Express Request.
|
||||
* @param {FunctionTool} params.tool - The tool object.
|
||||
* @param {string} params.assistant_id - The assistant ID
|
||||
* @returns {object|null} - Updated tool object or null if invalid and not an action.
|
||||
*/
|
||||
const validateAndUpdateTool = async ({ req, tool, assistant_id }) => {
|
||||
let actions;
|
||||
if (isImageVisionTool(tool)) {
|
||||
return null;
|
||||
}
|
||||
if (!toolNameRegex.test(tool.function.name)) {
|
||||
const [functionName, domain] = tool.function.name.split(actionDelimiter);
|
||||
actions = await getActions({ assistant_id, user: req.user.id }, true);
|
||||
const matchingActions = actions.filter((action) => {
|
||||
const metadata = action.metadata;
|
||||
return metadata && metadata.domain === domain;
|
||||
});
|
||||
const action = matchingActions[0];
|
||||
if (!action) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const parsedDomain = await domainParser(req, domain, true);
|
||||
|
||||
if (!parsedDomain) {
|
||||
return null;
|
||||
}
|
||||
|
||||
tool.function.name = `${functionName}${actionDelimiter}${parsedDomain}`;
|
||||
}
|
||||
return tool;
|
||||
};
|
||||
|
||||
/**
|
||||
* Encodes or decodes a domain name to/from base64, or replacing periods with a custom separator.
|
||||
*
|
||||
* Necessary because Azure OpenAI Assistants API doesn't support periods in function
|
||||
* names due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum.
|
||||
* Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum.
|
||||
*
|
||||
* @param {Express.Request} req - The Express Request object.
|
||||
* @param {string} domain - The domain name to encode/decode.
|
||||
@@ -26,10 +65,6 @@ async function domainParser(req, domain, inverse = false) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
return domain;
|
||||
}
|
||||
|
||||
const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS);
|
||||
const cachedDomain = await domainsCache.get(domain);
|
||||
if (inverse && cachedDomain) {
|
||||
@@ -170,10 +205,29 @@ function decryptMetadata(metadata) {
|
||||
return decryptedMetadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes an action and its corresponding assistant.
|
||||
* @param {Object} params - The parameters for the function.
|
||||
* @param {OpenAIClient} params.req - The Express Request object.
|
||||
* @param {string} params.assistant_id - The ID of the assistant.
|
||||
*/
|
||||
const deleteAssistantActions = async ({ req, assistant_id }) => {
|
||||
try {
|
||||
await deleteActions({ assistant_id, user: req.user.id });
|
||||
await deleteAssistant({ assistant_id, user: req.user.id });
|
||||
} catch (error) {
|
||||
const message = 'Trouble deleting Assistant Actions for Assistant ID: ' + assistant_id;
|
||||
logger.error(message, error);
|
||||
throw new Error(message);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
loadActionSets,
|
||||
deleteAssistantActions,
|
||||
validateAndUpdateTool,
|
||||
createActionTool,
|
||||
encryptMetadata,
|
||||
decryptMetadata,
|
||||
loadActionSets,
|
||||
domainParser,
|
||||
};
|
||||
|
||||
@@ -73,12 +73,12 @@ describe('domainParser', () => {
|
||||
const TLD = '.com';
|
||||
|
||||
// Non-azure request
|
||||
it('returns domain as is if not azure', async () => {
|
||||
it('does not return domain as is if not azure', async () => {
|
||||
const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`;
|
||||
const result1 = await domainParser(reqNoAzure, domain, false);
|
||||
const result2 = await domainParser(reqNoAzure, domain, true);
|
||||
expect(result1).toEqual(domain);
|
||||
expect(result2).toEqual(domain);
|
||||
expect(result1).not.toEqual(domain);
|
||||
expect(result2).not.toEqual(domain);
|
||||
});
|
||||
|
||||
// Test for Empty or Null Inputs
|
||||
|
||||
@@ -72,12 +72,21 @@ const AppService = async (app) => {
|
||||
}
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
endpointLocals[EModelEndpoint.assistants] = azureAssistantsDefaults();
|
||||
endpointLocals[EModelEndpoint.azureAssistants] = azureAssistantsDefaults();
|
||||
}
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) {
|
||||
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
|
||||
config,
|
||||
EModelEndpoint.azureAssistants,
|
||||
endpointLocals[EModelEndpoint.azureAssistants],
|
||||
);
|
||||
}
|
||||
|
||||
if (config?.endpoints?.[EModelEndpoint.assistants]) {
|
||||
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
|
||||
config,
|
||||
EModelEndpoint.assistants,
|
||||
endpointLocals[EModelEndpoint.assistants],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -218,6 +218,7 @@ describe('AppService', () => {
|
||||
pollIntervalMs: 5000,
|
||||
timeoutMs: 30000,
|
||||
supportedIds: ['id1', 'id2'],
|
||||
privateAssistants: false,
|
||||
},
|
||||
},
|
||||
}),
|
||||
@@ -232,6 +233,7 @@ describe('AppService', () => {
|
||||
pollIntervalMs: 5000,
|
||||
timeoutMs: 30000,
|
||||
supportedIds: expect.arrayContaining(['id1', 'id2']),
|
||||
privateAssistants: false,
|
||||
}),
|
||||
);
|
||||
});
|
||||
@@ -253,8 +255,8 @@ describe('AppService', () => {
|
||||
process.env.EASTUS_API_KEY = 'eastus-key';
|
||||
|
||||
await AppService(app);
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.assistants);
|
||||
expect(app.locals[EModelEndpoint.assistants].capabilities.length).toEqual(3);
|
||||
expect(app.locals).toHaveProperty(EModelEndpoint.azureAssistants);
|
||||
expect(app.locals[EModelEndpoint.azureAssistants].capabilities.length).toEqual(3);
|
||||
});
|
||||
|
||||
it('should correctly configure Azure OpenAI endpoint based on custom config', async () => {
|
||||
@@ -505,7 +507,31 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Both `supportedIds` and `excludedIds` are defined'),
|
||||
expect.stringContaining(
|
||||
'The \'assistants\' endpoint has both \'supportedIds\' and \'excludedIds\' defined.',
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
it('should log a warning when privateAssistants and supportedIds or excludedIds are provided', async () => {
|
||||
const mockConfig = {
|
||||
endpoints: {
|
||||
assistants: {
|
||||
privateAssistants: true,
|
||||
supportedIds: ['id1'],
|
||||
},
|
||||
},
|
||||
};
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
|
||||
|
||||
const app = { locals: {} };
|
||||
await require('./AppService')(app);
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'The \'assistants\' endpoint has both \'privateAssistants\' and \'supportedIds\' or \'excludedIds\' defined.',
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ async function createOnTextProgress({
|
||||
* @return {Promise<OpenAIAssistantFinish | OpenAIAssistantAction[] | ThreadMessage[] | RequiredActionFunctionToolCall[]>}
|
||||
*/
|
||||
async function getResponse({ openai, run_id, thread_id }) {
|
||||
const run = await waitForRun({ openai, run_id, thread_id, pollIntervalMs: 500 });
|
||||
const run = await waitForRun({ openai, run_id, thread_id, pollIntervalMs: 2000 });
|
||||
|
||||
if (run.status === RunStatus.COMPLETED) {
|
||||
const messages = await openai.beta.threads.messages.list(thread_id, defaultOrderQuery);
|
||||
@@ -393,8 +393,9 @@ async function runAssistant({
|
||||
},
|
||||
});
|
||||
|
||||
const { endpoint = EModelEndpoint.azureAssistants } = openai.req.body;
|
||||
/** @type {TCustomConfig.endpoints.assistants} */
|
||||
const assistantsEndpointConfig = openai.req.app.locals?.[EModelEndpoint.assistants] ?? {};
|
||||
const assistantsEndpointConfig = openai.req.app.locals?.[endpoint] ?? {};
|
||||
const { pollIntervalMs, timeoutMs } = assistantsEndpointConfig;
|
||||
|
||||
const run = await waitForRun({
|
||||
|
||||
@@ -3,6 +3,7 @@ const { isUserProvided, generateConfig } = require('~/server/utils');
|
||||
|
||||
const {
|
||||
OPENAI_API_KEY: openAIApiKey,
|
||||
AZURE_ASSISTANTS_API_KEY: azureAssistantsApiKey,
|
||||
ASSISTANTS_API_KEY: assistantsApiKey,
|
||||
AZURE_API_KEY: azureOpenAIApiKey,
|
||||
ANTHROPIC_API_KEY: anthropicApiKey,
|
||||
@@ -13,6 +14,7 @@ const {
|
||||
OPENAI_REVERSE_PROXY,
|
||||
AZURE_OPENAI_BASEURL,
|
||||
ASSISTANTS_BASE_URL,
|
||||
AZURE_ASSISTANTS_BASE_URL,
|
||||
} = process.env ?? {};
|
||||
|
||||
const useAzurePlugins = !!PLUGINS_USE_AZURE;
|
||||
@@ -28,11 +30,20 @@ module.exports = {
|
||||
useAzurePlugins,
|
||||
userProvidedOpenAI,
|
||||
googleKey,
|
||||
[EModelEndpoint.openAI]: generateConfig(openAIApiKey, OPENAI_REVERSE_PROXY),
|
||||
[EModelEndpoint.assistants]: generateConfig(assistantsApiKey, ASSISTANTS_BASE_URL, true),
|
||||
[EModelEndpoint.azureOpenAI]: generateConfig(azureOpenAIApiKey, AZURE_OPENAI_BASEURL),
|
||||
[EModelEndpoint.chatGPTBrowser]: generateConfig(chatGPTToken),
|
||||
[EModelEndpoint.anthropic]: generateConfig(anthropicApiKey),
|
||||
[EModelEndpoint.bingAI]: generateConfig(bingToken),
|
||||
[EModelEndpoint.anthropic]: generateConfig(anthropicApiKey),
|
||||
[EModelEndpoint.chatGPTBrowser]: generateConfig(chatGPTToken),
|
||||
[EModelEndpoint.openAI]: generateConfig(openAIApiKey, OPENAI_REVERSE_PROXY),
|
||||
[EModelEndpoint.azureOpenAI]: generateConfig(azureOpenAIApiKey, AZURE_OPENAI_BASEURL),
|
||||
[EModelEndpoint.assistants]: generateConfig(
|
||||
assistantsApiKey,
|
||||
ASSISTANTS_BASE_URL,
|
||||
EModelEndpoint.assistants,
|
||||
),
|
||||
[EModelEndpoint.azureAssistants]: generateConfig(
|
||||
azureAssistantsApiKey,
|
||||
AZURE_ASSISTANTS_BASE_URL,
|
||||
EModelEndpoint.azureAssistants,
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
const { RateLimitPrefix } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TCustomConfig['rateLimits'] | undefined} rateLimits
|
||||
@@ -6,24 +8,41 @@ const handleRateLimits = (rateLimits) => {
|
||||
if (!rateLimits) {
|
||||
return;
|
||||
}
|
||||
const { fileUploads, conversationsImport } = rateLimits;
|
||||
if (fileUploads) {
|
||||
process.env.FILE_UPLOAD_IP_MAX = fileUploads.ipMax ?? process.env.FILE_UPLOAD_IP_MAX;
|
||||
process.env.FILE_UPLOAD_IP_WINDOW =
|
||||
fileUploads.ipWindowInMinutes ?? process.env.FILE_UPLOAD_IP_WINDOW;
|
||||
process.env.FILE_UPLOAD_USER_MAX = fileUploads.userMax ?? process.env.FILE_UPLOAD_USER_MAX;
|
||||
process.env.FILE_UPLOAD_USER_WINDOW =
|
||||
fileUploads.userWindowInMinutes ?? process.env.FILE_UPLOAD_USER_WINDOW;
|
||||
}
|
||||
|
||||
if (conversationsImport) {
|
||||
process.env.IMPORT_IP_MAX = conversationsImport.ipMax ?? process.env.IMPORT_IP_MAX;
|
||||
process.env.IMPORT_IP_WINDOW =
|
||||
conversationsImport.ipWindowInMinutes ?? process.env.IMPORT_IP_WINDOW;
|
||||
process.env.IMPORT_USER_MAX = conversationsImport.userMax ?? process.env.IMPORT_USER_MAX;
|
||||
process.env.IMPORT_USER_WINDOW =
|
||||
conversationsImport.userWindowInMinutes ?? process.env.IMPORT_USER_WINDOW;
|
||||
}
|
||||
const rateLimitKeys = {
|
||||
fileUploads: RateLimitPrefix.FILE_UPLOAD,
|
||||
conversationsImport: RateLimitPrefix.IMPORT,
|
||||
tts: RateLimitPrefix.TTS,
|
||||
stt: RateLimitPrefix.STT,
|
||||
};
|
||||
|
||||
Object.entries(rateLimitKeys).forEach(([key, prefix]) => {
|
||||
const rateLimit = rateLimits[key];
|
||||
if (rateLimit) {
|
||||
setRateLimitEnvVars(prefix, rateLimit);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Set environment variables for rate limit configurations
|
||||
*
|
||||
* @param {string} prefix - Prefix for environment variable names
|
||||
* @param {object} rateLimit - Rate limit configuration object
|
||||
*/
|
||||
const setRateLimitEnvVars = (prefix, rateLimit) => {
|
||||
const envVarsMapping = {
|
||||
ipMax: `${prefix}_IP_MAX`,
|
||||
ipWindowInMinutes: `${prefix}_IP_WINDOW`,
|
||||
userMax: `${prefix}_USER_MAX`,
|
||||
userWindowInMinutes: `${prefix}_USER_WINDOW`,
|
||||
};
|
||||
|
||||
Object.entries(envVarsMapping).forEach(([key, envVar]) => {
|
||||
if (rateLimit[key] !== undefined) {
|
||||
process.env[envVar] = rateLimit[key];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = handleRateLimits;
|
||||
|
||||
@@ -53,7 +53,7 @@ async function loadConfigEndpoints(req) {
|
||||
|
||||
if (req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
/** @type {Omit<TConfig, 'order'>} */
|
||||
endpointsConfig[EModelEndpoint.assistants] = {
|
||||
endpointsConfig[EModelEndpoint.azureAssistants] = {
|
||||
userProvide: false,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ async function loadConfigModels(req) {
|
||||
}
|
||||
|
||||
if (azureEndpoint?.assistants && azureConfig.assistantModels) {
|
||||
modelsConfig[EModelEndpoint.assistants] = azureConfig.assistantModels;
|
||||
modelsConfig[EModelEndpoint.azureAssistants] = azureConfig.assistantModels;
|
||||
}
|
||||
|
||||
if (!Array.isArray(endpoints[EModelEndpoint.custom])) {
|
||||
|
||||
@@ -9,13 +9,15 @@ const { config } = require('./EndpointService');
|
||||
*/
|
||||
async function loadDefaultEndpointsConfig(req) {
|
||||
const { google, gptPlugins } = await loadAsyncEndpoints(req);
|
||||
const { openAI, assistants, bingAI, anthropic, azureOpenAI, chatGPTBrowser } = config;
|
||||
const { openAI, assistants, azureAssistants, bingAI, anthropic, azureOpenAI, chatGPTBrowser } =
|
||||
config;
|
||||
|
||||
const enabledEndpoints = getEnabledEndpoints();
|
||||
|
||||
const endpointConfig = {
|
||||
[EModelEndpoint.openAI]: openAI,
|
||||
[EModelEndpoint.assistants]: assistants,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants,
|
||||
[EModelEndpoint.azureOpenAI]: azureOpenAI,
|
||||
[EModelEndpoint.google]: google,
|
||||
[EModelEndpoint.bingAI]: bingAI,
|
||||
|
||||
@@ -25,6 +25,7 @@ async function loadDefaultModels(req) {
|
||||
plugins: true,
|
||||
});
|
||||
const assistants = await getOpenAIModels({ assistants: true });
|
||||
const azureAssistants = await getOpenAIModels({ azureAssistants: true });
|
||||
|
||||
return {
|
||||
[EModelEndpoint.openAI]: openAI,
|
||||
@@ -35,6 +36,7 @@ async function loadDefaultModels(req) {
|
||||
[EModelEndpoint.bingAI]: ['BingAI', 'Sydney'],
|
||||
[EModelEndpoint.chatGPTBrowser]: chatGPTBrowser,
|
||||
[EModelEndpoint.assistants]: assistants,
|
||||
[EModelEndpoint.azureAssistants]: azureAssistants,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -2,95 +2,8 @@ const addTitle = require('./addTitle');
|
||||
const buildOptions = require('./buildOptions');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants based on provided query parameters.
|
||||
*
|
||||
* Initializes the client with the current request and response objects and lists assistants
|
||||
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAssistants = async ({ req, res, query }) => {
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
return openai.beta.assistants.list(query);
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants for Azure configured groups.
|
||||
*
|
||||
* Iterates through Azure configured assistant groups, initializes the client with the current request and response objects,
|
||||
* lists assistants based on the provided query parameters, and merges their data alongside the model information into a single array.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client and manipulating the request body.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {TAzureConfig} params.azureConfig - The Azure configuration object containing assistantGroups and groupMap.
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<AssistantListResponse>} A promise that resolves to an array of assistant data merged with their respective model information.
|
||||
*/
|
||||
const listAssistantsForAzure = async ({ req, res, azureConfig = {}, query }) => {
|
||||
/** @type {Array<[string, TAzureModelConfig]>} */
|
||||
const groupModelTuples = [];
|
||||
const promises = [];
|
||||
/** @type {Array<TAzureGroup>} */
|
||||
const groups = [];
|
||||
|
||||
const { groupMap, assistantGroups } = azureConfig;
|
||||
|
||||
for (const groupName of assistantGroups) {
|
||||
const group = groupMap[groupName];
|
||||
groups.push(group);
|
||||
|
||||
const currentModelTuples = Object.entries(group?.models);
|
||||
groupModelTuples.push(currentModelTuples);
|
||||
|
||||
/* The specified model is only necessary to
|
||||
fetch assistants for the shared instance */
|
||||
req.body.model = currentModelTuples[0][0];
|
||||
promises.push(listAssistants({ req, res, query }));
|
||||
}
|
||||
|
||||
const resolvedQueries = await Promise.all(promises);
|
||||
const data = resolvedQueries.flatMap((res, i) =>
|
||||
res.data.map((assistant) => {
|
||||
const deploymentName = assistant.model;
|
||||
const currentGroup = groups[i];
|
||||
const currentModelTuples = groupModelTuples[i];
|
||||
const firstModel = currentModelTuples[0][0];
|
||||
|
||||
if (currentGroup.deploymentName === deploymentName) {
|
||||
return { ...assistant, model: firstModel };
|
||||
}
|
||||
|
||||
for (const [model, modelConfig] of currentModelTuples) {
|
||||
if (modelConfig.deploymentName === deploymentName) {
|
||||
return { ...assistant, model };
|
||||
}
|
||||
}
|
||||
|
||||
return { ...assistant, model: firstModel };
|
||||
}),
|
||||
);
|
||||
|
||||
return {
|
||||
first_id: data[0]?.id,
|
||||
last_id: data[data.length - 1]?.id,
|
||||
object: 'list',
|
||||
has_more: false,
|
||||
data,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
addTitle,
|
||||
buildOptions,
|
||||
initializeClient,
|
||||
listAssistants,
|
||||
listAssistantsForAzure,
|
||||
};
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { ErrorTypes, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
getUserKeyValues,
|
||||
getUserKeyExpiry,
|
||||
@@ -13,9 +8,8 @@ const {
|
||||
} = require('~/server/services/UserService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { constructAzureURL } = require('~/utils');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, initAppClient = false }) => {
|
||||
const initializeClient = async ({ req, res, endpointOption, version, initAppClient = false }) => {
|
||||
const { PROXY, OPENAI_ORGANIZATION, ASSISTANTS_API_KEY, ASSISTANTS_BASE_URL } = process.env;
|
||||
|
||||
const userProvidesKey = isUserProvided(ASSISTANTS_API_KEY);
|
||||
@@ -34,7 +28,11 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals
|
||||
let apiKey = userProvidesKey ? userValues.apiKey : ASSISTANTS_API_KEY;
|
||||
let baseURL = userProvidesURL ? userValues.baseURL : ASSISTANTS_BASE_URL;
|
||||
|
||||
const opts = {};
|
||||
const opts = {
|
||||
defaultHeaders: {
|
||||
'OpenAI-Beta': `assistants=${version}`,
|
||||
},
|
||||
};
|
||||
|
||||
const clientOptions = {
|
||||
reverseProxyUrl: baseURL ?? null,
|
||||
@@ -44,54 +42,6 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals
|
||||
...endpointOption,
|
||||
};
|
||||
|
||||
/** @type {TAzureConfig | undefined} */
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
/** @type {AzureOptions | undefined} */
|
||||
let azureOptions;
|
||||
|
||||
if (azureConfig && azureConfig.assistants) {
|
||||
const { modelGroupMap, groupMap, assistantModels } = azureConfig;
|
||||
const modelName = req.body.model ?? req.query.model ?? assistantModels[0];
|
||||
const {
|
||||
azureOptions: currentOptions,
|
||||
baseURL: azureBaseURL,
|
||||
headers = {},
|
||||
serverless,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName,
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
|
||||
azureOptions = currentOptions;
|
||||
|
||||
baseURL = constructAzureURL({
|
||||
baseURL: azureBaseURL ?? 'https://${INSTANCE_NAME}.openai.azure.com/openai',
|
||||
azureOptions,
|
||||
});
|
||||
|
||||
apiKey = azureOptions.azureOpenAIApiKey;
|
||||
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
|
||||
opts.defaultHeaders = resolveHeaders({ ...headers, 'api-key': apiKey });
|
||||
opts.model = azureOptions.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (initAppClient) {
|
||||
clientOptions.titleConvo = azureConfig.titleConvo;
|
||||
clientOptions.titleModel = azureConfig.titleModel;
|
||||
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
||||
|
||||
const groupName = modelGroupMap[modelName].group;
|
||||
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
|
||||
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||
clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt;
|
||||
|
||||
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = opts.defaultHeaders;
|
||||
clientOptions.azure = !serverless && azureOptions;
|
||||
}
|
||||
}
|
||||
|
||||
if (userProvidesKey & !apiKey) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
@@ -125,10 +75,6 @@ const initializeClient = async ({ req, res, endpointOption, initAppClient = fals
|
||||
openai.req = req;
|
||||
openai.res = res;
|
||||
|
||||
if (azureOptions) {
|
||||
openai.locals = { ...(openai.locals ?? {}), azureOptions };
|
||||
}
|
||||
|
||||
if (endpointOption && initAppClient) {
|
||||
const client = new OpenAIClient(apiKey, clientOptions);
|
||||
return {
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
const buildOptions = (endpoint, parsedBody) => {
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
const { promptPrefix, assistant_id, iconURL, greeting, spec, ...rest } = parsedBody;
|
||||
const endpointOption = {
|
||||
endpoint,
|
||||
promptPrefix,
|
||||
assistant_id,
|
||||
iconURL,
|
||||
greeting,
|
||||
spec,
|
||||
modelOptions: {
|
||||
...rest,
|
||||
},
|
||||
};
|
||||
|
||||
return endpointOption;
|
||||
};
|
||||
|
||||
module.exports = buildOptions;
|
||||
7
api/server/services/Endpoints/azureAssistants/index.js
Normal file
7
api/server/services/Endpoints/azureAssistants/index.js
Normal file
@@ -0,0 +1,7 @@
|
||||
const buildOptions = require('./buildOptions');
|
||||
const initializeClient = require('./initializeClient');
|
||||
|
||||
module.exports = {
|
||||
buildOptions,
|
||||
initializeClient,
|
||||
};
|
||||
@@ -0,0 +1,195 @@
|
||||
const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getUserKeyValues,
|
||||
getUserKeyExpiry,
|
||||
checkUserKeyExpiry,
|
||||
} = require('~/server/services/UserService');
|
||||
const OpenAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const { constructAzureURL } = require('~/utils');
|
||||
|
||||
class Files {
|
||||
constructor(client) {
|
||||
this._client = client;
|
||||
}
|
||||
/**
|
||||
* Create an assistant file by attaching a
|
||||
* [File](https://platform.openai.com/docs/api-reference/files) to an
|
||||
* [assistant](https://platform.openai.com/docs/api-reference/assistants).
|
||||
*/
|
||||
create(assistantId, body, options) {
|
||||
return this._client.post(`/assistants/${assistantId}/files`, {
|
||||
body,
|
||||
...options,
|
||||
headers: { 'OpenAI-Beta': 'assistants=v1', ...options?.headers },
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves an AssistantFile.
|
||||
*/
|
||||
retrieve(assistantId, fileId, options) {
|
||||
return this._client.get(`/assistants/${assistantId}/files/${fileId}`, {
|
||||
...options,
|
||||
headers: { 'OpenAI-Beta': 'assistants=v1', ...options?.headers },
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete an assistant file.
|
||||
*/
|
||||
del(assistantId, fileId, options) {
|
||||
return this._client.delete(`/assistants/${assistantId}/files/${fileId}`, {
|
||||
...options,
|
||||
headers: { 'OpenAI-Beta': 'assistants=v1', ...options?.headers },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const initializeClient = async ({ req, res, version, endpointOption, initAppClient = false }) => {
|
||||
const { PROXY, OPENAI_ORGANIZATION, AZURE_ASSISTANTS_API_KEY, AZURE_ASSISTANTS_BASE_URL } =
|
||||
process.env;
|
||||
|
||||
const userProvidesKey = isUserProvided(AZURE_ASSISTANTS_API_KEY);
|
||||
const userProvidesURL = isUserProvided(AZURE_ASSISTANTS_BASE_URL);
|
||||
|
||||
let userValues = null;
|
||||
if (userProvidesKey || userProvidesURL) {
|
||||
const expiresAt = await getUserKeyExpiry({
|
||||
userId: req.user.id,
|
||||
name: EModelEndpoint.azureAssistants,
|
||||
});
|
||||
checkUserKeyExpiry(expiresAt, EModelEndpoint.azureAssistants);
|
||||
userValues = await getUserKeyValues({
|
||||
userId: req.user.id,
|
||||
name: EModelEndpoint.azureAssistants,
|
||||
});
|
||||
}
|
||||
|
||||
let apiKey = userProvidesKey ? userValues.apiKey : AZURE_ASSISTANTS_API_KEY;
|
||||
let baseURL = userProvidesURL ? userValues.baseURL : AZURE_ASSISTANTS_BASE_URL;
|
||||
|
||||
const opts = {};
|
||||
|
||||
const clientOptions = {
|
||||
reverseProxyUrl: baseURL ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
req,
|
||||
res,
|
||||
...endpointOption,
|
||||
};
|
||||
|
||||
/** @type {TAzureConfig | undefined} */
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
/** @type {AzureOptions | undefined} */
|
||||
let azureOptions;
|
||||
|
||||
if (azureConfig && azureConfig.assistants) {
|
||||
const { modelGroupMap, groupMap, assistantModels } = azureConfig;
|
||||
const modelName = req.body.model ?? req.query.model ?? assistantModels[0];
|
||||
const {
|
||||
azureOptions: currentOptions,
|
||||
baseURL: azureBaseURL,
|
||||
headers = {},
|
||||
serverless,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName,
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
|
||||
azureOptions = currentOptions;
|
||||
|
||||
baseURL = constructAzureURL({
|
||||
baseURL: azureBaseURL ?? 'https://${INSTANCE_NAME}.openai.azure.com/openai',
|
||||
azureOptions,
|
||||
});
|
||||
|
||||
apiKey = azureOptions.azureOpenAIApiKey;
|
||||
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
|
||||
opts.defaultHeaders = resolveHeaders({
|
||||
...headers,
|
||||
'api-key': apiKey,
|
||||
'OpenAI-Beta': `assistants=${version}`,
|
||||
});
|
||||
opts.model = azureOptions.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (initAppClient) {
|
||||
clientOptions.titleConvo = azureConfig.titleConvo;
|
||||
clientOptions.titleModel = azureConfig.titleModel;
|
||||
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
|
||||
|
||||
const groupName = modelGroupMap[modelName].group;
|
||||
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
|
||||
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
|
||||
clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt;
|
||||
|
||||
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = opts.defaultHeaders;
|
||||
clientOptions.azure = !serverless && azureOptions;
|
||||
}
|
||||
}
|
||||
|
||||
if (userProvidesKey & !apiKey) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
type: ErrorTypes.NO_USER_KEY,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
throw new Error('Assistants API key not provided. Please provide it again.');
|
||||
}
|
||||
|
||||
if (baseURL) {
|
||||
opts.baseURL = baseURL;
|
||||
}
|
||||
|
||||
if (PROXY) {
|
||||
opts.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
}
|
||||
|
||||
if (OPENAI_ORGANIZATION) {
|
||||
opts.organization = OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
const openai = new OpenAI({
|
||||
apiKey,
|
||||
...opts,
|
||||
});
|
||||
|
||||
openai.beta.assistants.files = new Files(openai);
|
||||
|
||||
openai.req = req;
|
||||
openai.res = res;
|
||||
|
||||
if (azureOptions) {
|
||||
openai.locals = { ...(openai.locals ?? {}), azureOptions };
|
||||
}
|
||||
|
||||
if (endpointOption && initAppClient) {
|
||||
const client = new OpenAIClient(apiKey, clientOptions);
|
||||
return {
|
||||
client,
|
||||
openai,
|
||||
openAIApiKey: apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
openai,
|
||||
openAIApiKey: apiKey,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = initializeClient;
|
||||
@@ -0,0 +1,112 @@
|
||||
// const OpenAI = require('openai');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { getUserKey, getUserKeyExpiry, getUserKeyValues } = require('~/server/services/UserService');
|
||||
const initializeClient = require('./initializeClient');
|
||||
// const { OpenAIClient } = require('~/app');
|
||||
|
||||
jest.mock('~/server/services/UserService', () => ({
|
||||
getUserKey: jest.fn(),
|
||||
getUserKeyExpiry: jest.fn(),
|
||||
getUserKeyValues: jest.fn(),
|
||||
checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry,
|
||||
}));
|
||||
|
||||
const today = new Date();
|
||||
const tenDaysFromToday = new Date(today.setDate(today.getDate() + 10));
|
||||
const isoString = tenDaysFromToday.toISOString();
|
||||
|
||||
describe('initializeClient', () => {
|
||||
// Set up environment variables
|
||||
const originalEnvironment = process.env;
|
||||
const app = {
|
||||
locals: {},
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules(); // Clears the cache
|
||||
process.env = { ...originalEnvironment }; // Make a copy
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
process.env = originalEnvironment; // Restore original env vars
|
||||
});
|
||||
|
||||
test('initializes OpenAI client with default API key and URL', async () => {
|
||||
process.env.AZURE_ASSISTANTS_API_KEY = 'default-api-key';
|
||||
process.env.AZURE_ASSISTANTS_BASE_URL = 'https://default.api.url';
|
||||
|
||||
// Assuming 'isUserProvided' to return false for this test case
|
||||
jest.mock('~/server/utils', () => ({
|
||||
isUserProvided: jest.fn().mockReturnValueOnce(false),
|
||||
}));
|
||||
|
||||
const req = { user: { id: 'user123' }, app };
|
||||
const res = {};
|
||||
|
||||
const { openai, openAIApiKey } = await initializeClient({ req, res });
|
||||
expect(openai.apiKey).toBe('default-api-key');
|
||||
expect(openAIApiKey).toBe('default-api-key');
|
||||
expect(openai.baseURL).toBe('https://default.api.url');
|
||||
});
|
||||
|
||||
test('initializes OpenAI client with user-provided API key and URL', async () => {
|
||||
process.env.AZURE_ASSISTANTS_API_KEY = 'user_provided';
|
||||
process.env.AZURE_ASSISTANTS_BASE_URL = 'user_provided';
|
||||
|
||||
getUserKeyValues.mockResolvedValue({ apiKey: 'user-api-key', baseURL: 'https://user.api.url' });
|
||||
getUserKeyExpiry.mockResolvedValue(isoString);
|
||||
|
||||
const req = { user: { id: 'user123' }, app };
|
||||
const res = {};
|
||||
|
||||
const { openai, openAIApiKey } = await initializeClient({ req, res });
|
||||
expect(openAIApiKey).toBe('user-api-key');
|
||||
expect(openai.apiKey).toBe('user-api-key');
|
||||
expect(openai.baseURL).toBe('https://user.api.url');
|
||||
});
|
||||
|
||||
test('throws error for invalid JSON in user-provided values', async () => {
|
||||
process.env.AZURE_ASSISTANTS_API_KEY = 'user_provided';
|
||||
getUserKey.mockResolvedValue('invalid-json');
|
||||
getUserKeyExpiry.mockResolvedValue(isoString);
|
||||
getUserKeyValues.mockImplementation(() => {
|
||||
let userValues = getUserKey();
|
||||
try {
|
||||
userValues = JSON.parse(userValues);
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
type: ErrorTypes.INVALID_USER_KEY,
|
||||
}),
|
||||
);
|
||||
}
|
||||
return userValues;
|
||||
});
|
||||
|
||||
const req = { user: { id: 'user123' } };
|
||||
const res = {};
|
||||
|
||||
await expect(initializeClient({ req, res })).rejects.toThrow(/invalid_user_key/);
|
||||
});
|
||||
|
||||
test('throws error if API key is not provided', async () => {
|
||||
delete process.env.AZURE_ASSISTANTS_API_KEY; // Simulate missing API key
|
||||
|
||||
const req = { user: { id: 'user123' }, app };
|
||||
const res = {};
|
||||
|
||||
await expect(initializeClient({ req, res })).rejects.toThrow(/Assistants API key not/);
|
||||
});
|
||||
|
||||
test('initializes OpenAI client with proxy configuration', async () => {
|
||||
process.env.AZURE_ASSISTANTS_API_KEY = 'test-key';
|
||||
process.env.PROXY = 'http://proxy.server';
|
||||
|
||||
const req = { user: { id: 'user123' }, app };
|
||||
const res = {};
|
||||
|
||||
const { openai } = await initializeClient({ req, res });
|
||||
expect(openai.httpAgent).toBeInstanceOf(HttpsProxyAgent);
|
||||
});
|
||||
});
|
||||
@@ -112,6 +112,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
modelDisplayLabel: endpointConfig.modelDisplayLabel,
|
||||
titleMethod: endpointConfig.titleMethod ?? 'completion',
|
||||
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
|
||||
directEndpoint: endpointConfig.directEndpoint,
|
||||
titleMessageRole: endpointConfig.titleMessageRole,
|
||||
endpointTokenConfig,
|
||||
};
|
||||
|
||||
|
||||
48
api/server/services/Files/Audio/getVoices.js
Normal file
48
api/server/services/Files/Audio/getVoices.js
Normal file
@@ -0,0 +1,48 @@
|
||||
const { logger } = require('~/config');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getProvider } = require('./textToSpeech');
|
||||
|
||||
/**
|
||||
* This function retrieves the available voices for the current TTS provider
|
||||
* It first fetches the TTS configuration and determines the provider
|
||||
* Then, based on the provider, it sends the corresponding voices as a JSON response
|
||||
*
|
||||
* @param {Object} req - The request object
|
||||
* @param {Object} res - The response object
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} - If the provider is not 'openai' or 'elevenlabs', an error is thrown
|
||||
*/
|
||||
async function getVoices(req, res) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
|
||||
if (!customConfig || !customConfig?.tts) {
|
||||
throw new Error('Configuration or TTS schema is missing');
|
||||
}
|
||||
|
||||
const ttsSchema = customConfig?.tts;
|
||||
const provider = getProvider(ttsSchema);
|
||||
let voices;
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
voices = ttsSchema.openai?.voices;
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
voices = ttsSchema.elevenlabs?.voices;
|
||||
break;
|
||||
case 'localai':
|
||||
voices = ttsSchema.localai?.voices;
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
res.json(voices);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get voices: ${error.message}`);
|
||||
res.status(500).json({ error: 'Failed to get voices' });
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = getVoices;
|
||||
11
api/server/services/Files/Audio/index.js
Normal file
11
api/server/services/Files/Audio/index.js
Normal file
@@ -0,0 +1,11 @@
|
||||
const getVoices = require('./getVoices');
|
||||
const textToSpeech = require('./textToSpeech');
|
||||
const speechToText = require('./speechToText');
|
||||
const { updateTokenWebsocket } = require('./webSocket');
|
||||
|
||||
module.exports = {
|
||||
getVoices,
|
||||
speechToText,
|
||||
...textToSpeech,
|
||||
updateTokenWebsocket,
|
||||
};
|
||||
211
api/server/services/Files/Audio/speechToText.js
Normal file
211
api/server/services/Files/Audio/speechToText.js
Normal file
@@ -0,0 +1,211 @@
|
||||
const axios = require('axios');
|
||||
const { Readable } = require('stream');
|
||||
const { logger } = require('~/config');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Handle the response from the STT API
|
||||
* @param {Object} response - The response from the STT API
|
||||
*
|
||||
* @returns {string} The text from the response data
|
||||
*
|
||||
* @throws Will throw an error if the response status is not 200 or the response data is missing
|
||||
*/
|
||||
async function handleResponse(response) {
|
||||
if (response.status !== 200) {
|
||||
throw new Error('Invalid response from the STT API');
|
||||
}
|
||||
|
||||
if (!response.data || !response.data.text) {
|
||||
throw new Error('Missing data in response from the STT API');
|
||||
}
|
||||
|
||||
return response.data.text.trim();
|
||||
}
|
||||
|
||||
function getProvider(sttSchema) {
|
||||
if (sttSchema.openai) {
|
||||
return 'openai';
|
||||
}
|
||||
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
function removeUndefined(obj) {
|
||||
Object.keys(obj).forEach((key) => {
|
||||
if (obj[key] && typeof obj[key] === 'object') {
|
||||
removeUndefined(obj[key]);
|
||||
if (Object.keys(obj[key]).length === 0) {
|
||||
delete obj[key];
|
||||
}
|
||||
} else if (obj[key] === undefined) {
|
||||
delete obj[key];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the OpenAI API
|
||||
* It uses the provided speech-to-text schema and audio stream to create the request
|
||||
*
|
||||
* @param {Object} sttSchema - The speech-to-text schema containing the OpenAI configuration
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it returns an array with three null values and logs the error with logger
|
||||
*/
|
||||
function openAIProvider(sttSchema, audioReadStream) {
|
||||
try {
|
||||
const url = sttSchema.openai?.url || 'https://api.openai.com/v1/audio/transcriptions';
|
||||
const apiKey = sttSchema.openai.apiKey ? extractEnvVariable(sttSchema.openai.apiKey) : '';
|
||||
|
||||
let data = {
|
||||
file: audioReadStream,
|
||||
model: sttSchema.openai.model,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
};
|
||||
|
||||
[headers].forEach(removeUndefined);
|
||||
|
||||
if (apiKey) {
|
||||
headers.Authorization = 'Bearer ' + apiKey;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while preparing the OpenAI API STT request: ', error);
|
||||
return [null, null, null];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the Azure API
|
||||
* It uses the provided request and audio stream to create the request
|
||||
*
|
||||
* @param {Object} req - The request object, which should contain the endpoint in its body
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it returns an array with three null values and logs the error with logger
|
||||
*/
|
||||
function azureProvider(req, audioReadStream) {
|
||||
try {
|
||||
const { endpoint } = req.body;
|
||||
const azureConfig = req.app.locals[endpoint];
|
||||
|
||||
if (!azureConfig) {
|
||||
throw new Error(`No configuration found for endpoint: ${endpoint}`);
|
||||
}
|
||||
|
||||
const { apiKey, instanceName, whisperModel, apiVersion } = Object.entries(
|
||||
azureConfig.groupMap,
|
||||
).reduce((acc, [, value]) => {
|
||||
if (acc) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const whisperKey = Object.keys(value.models).find((modelKey) =>
|
||||
modelKey.startsWith('whisper'),
|
||||
);
|
||||
|
||||
if (whisperKey) {
|
||||
return {
|
||||
apiVersion: value.version,
|
||||
apiKey: value.apiKey,
|
||||
instanceName: value.instanceName,
|
||||
whisperModel: value.models[whisperKey]['deploymentName'],
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}, null);
|
||||
|
||||
if (!apiKey || !instanceName || !whisperModel || !apiVersion) {
|
||||
throw new Error('Required Azure configuration values are missing');
|
||||
}
|
||||
|
||||
const baseURL = `https://${instanceName}.openai.azure.com`;
|
||||
|
||||
const url = `${baseURL}/openai/deployments/${whisperModel}/audio/transcriptions?api-version=${apiVersion}`;
|
||||
|
||||
let data = {
|
||||
file: audioReadStream,
|
||||
filename: 'audio.wav',
|
||||
contentType: 'audio/wav',
|
||||
knownLength: audioReadStream.length,
|
||||
};
|
||||
|
||||
const headers = {
|
||||
...data.getHeaders(),
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'api-key': apiKey,
|
||||
};
|
||||
|
||||
return [url, data, headers];
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while preparing the Azure API STT request: ', error);
|
||||
return [null, null, null];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert speech to text
|
||||
* @param {Object} req - The request object
|
||||
* @param {Object} res - The response object
|
||||
*
|
||||
* @returns {Object} The response object with the text from the STT API
|
||||
*
|
||||
* @throws Will throw an error if an error occurs while processing the audio
|
||||
*/
|
||||
|
||||
async function speechToText(req, res) {
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
return res.status(500).send('Custom config not found');
|
||||
}
|
||||
|
||||
if (!req.file || !req.file.buffer) {
|
||||
return res.status(400).json({ message: 'No audio file provided in the FormData' });
|
||||
}
|
||||
|
||||
const audioBuffer = req.file.buffer;
|
||||
const audioReadStream = Readable.from(audioBuffer);
|
||||
audioReadStream.path = 'audio.wav';
|
||||
|
||||
const provider = getProvider(customConfig.stt);
|
||||
|
||||
let [url, data, headers] = [];
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
[url, data, headers] = openAIProvider(customConfig.stt, audioReadStream);
|
||||
break;
|
||||
case 'azure':
|
||||
[url, data, headers] = azureProvider(req, audioReadStream);
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
if (!Readable.from) {
|
||||
const audioBlob = new Blob([audioBuffer], { type: req.file.mimetype });
|
||||
delete data['file'];
|
||||
data['file'] = audioBlob;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.post(url, data, { headers: headers });
|
||||
const text = await handleResponse(response);
|
||||
|
||||
res.json({ text });
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while processing the audio:', error);
|
||||
res.sendStatus(500);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = speechToText;
|
||||
371
api/server/services/Files/Audio/streamAudio.js
Normal file
371
api/server/services/Files/Audio/streamAudio.js
Normal file
@@ -0,0 +1,371 @@
|
||||
const WebSocket = require('ws');
|
||||
const { Message } = require('~/models/Message');
|
||||
|
||||
/**
|
||||
* @param {string[]} voiceIds - Array of voice IDs
|
||||
* @returns {string}
|
||||
*/
|
||||
function getRandomVoiceId(voiceIds) {
|
||||
const randomIndex = Math.floor(Math.random() * voiceIds.length);
|
||||
return voiceIds[randomIndex];
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef {Object} VoiceSettings
|
||||
* @property {number} similarity_boost
|
||||
* @property {number} stability
|
||||
* @property {boolean} use_speaker_boost
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} GenerateAudioBulk
|
||||
* @property {string} model_id
|
||||
* @property {string} text
|
||||
* @property {VoiceSettings} voice_settings
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} TextToSpeechClient
|
||||
* @property {function(Object): Promise<stream.Readable>} generate
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} AudioChunk
|
||||
* @property {string} audio
|
||||
* @property {boolean} isFinal
|
||||
* @property {Object} alignment
|
||||
* @property {number[]} alignment.char_start_times_ms
|
||||
* @property {number[]} alignment.chars_durations_ms
|
||||
* @property {string[]} alignment.chars
|
||||
* @property {Object} normalizedAlignment
|
||||
* @property {number[]} normalizedAlignment.char_start_times_ms
|
||||
* @property {number[]} normalizedAlignment.chars_durations_ms
|
||||
* @property {string[]} normalizedAlignment.chars
|
||||
*/
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Record<string, unknown | undefined>} parameters
|
||||
* @returns
|
||||
*/
|
||||
function assembleQuery(parameters) {
|
||||
let query = '';
|
||||
let hasQuestionMark = false;
|
||||
|
||||
for (const [key, value] of Object.entries(parameters)) {
|
||||
if (value == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!hasQuestionMark) {
|
||||
query += '?';
|
||||
hasQuestionMark = true;
|
||||
} else {
|
||||
query += '&';
|
||||
}
|
||||
|
||||
query += `${key}=${value}`;
|
||||
}
|
||||
|
||||
return query;
|
||||
}
|
||||
|
||||
const SEPARATORS = ['.', '?', '!', '۔', '。', '‥', ';', '¡', '¿', '\n'];
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} text
|
||||
* @param {string[] | undefined} [separators]
|
||||
* @returns
|
||||
*/
|
||||
function findLastSeparatorIndex(text, separators = SEPARATORS) {
|
||||
let lastIndex = -1;
|
||||
for (const separator of separators) {
|
||||
const index = text.lastIndexOf(separator);
|
||||
if (index > lastIndex) {
|
||||
lastIndex = index;
|
||||
}
|
||||
}
|
||||
return lastIndex;
|
||||
}
|
||||
|
||||
const MAX_NOT_FOUND_COUNT = 6;
|
||||
const MAX_NO_CHANGE_COUNT = 10;
|
||||
|
||||
/**
|
||||
* @param {string} messageId
|
||||
* @returns {() => Promise<{ text: string, isFinished: boolean }[]>}
|
||||
*/
|
||||
function createChunkProcessor(messageId) {
|
||||
let notFoundCount = 0;
|
||||
let noChangeCount = 0;
|
||||
let processedText = '';
|
||||
if (!messageId) {
|
||||
throw new Error('Message ID is required');
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
||||
*/
|
||||
async function processChunks() {
|
||||
if (notFoundCount >= MAX_NOT_FOUND_COUNT) {
|
||||
return `Message not found after ${MAX_NOT_FOUND_COUNT} attempts`;
|
||||
}
|
||||
|
||||
if (noChangeCount >= MAX_NO_CHANGE_COUNT) {
|
||||
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
|
||||
}
|
||||
|
||||
const message = await Message.findOne({ messageId }, 'text unfinished').lean();
|
||||
|
||||
if (!message || !message.text) {
|
||||
notFoundCount++;
|
||||
return [];
|
||||
}
|
||||
|
||||
const { text, unfinished } = message;
|
||||
if (text === processedText) {
|
||||
noChangeCount++;
|
||||
}
|
||||
|
||||
const remainingText = text.slice(processedText.length);
|
||||
const chunks = [];
|
||||
|
||||
if (unfinished && remainingText.length >= 20) {
|
||||
const separatorIndex = findLastSeparatorIndex(remainingText);
|
||||
if (separatorIndex !== -1) {
|
||||
const chunkText = remainingText.slice(0, separatorIndex + 1);
|
||||
chunks.push({ text: chunkText, isFinished: false });
|
||||
processedText += chunkText;
|
||||
} else {
|
||||
chunks.push({ text: remainingText, isFinished: false });
|
||||
processedText = text;
|
||||
}
|
||||
} else if (!unfinished && remainingText.trim().length > 0) {
|
||||
chunks.push({ text: remainingText.trim(), isFinished: true });
|
||||
processedText = text;
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
return processChunks;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} text
|
||||
* @param {number} [chunkSize=4000]
|
||||
* @returns {{ text: string, isFinished: boolean }[]}
|
||||
*/
|
||||
function splitTextIntoChunks(text, chunkSize = 4000) {
|
||||
if (!text) {
|
||||
throw new Error('Text is required');
|
||||
}
|
||||
|
||||
const chunks = [];
|
||||
let startIndex = 0;
|
||||
const textLength = text.length;
|
||||
|
||||
while (startIndex < textLength) {
|
||||
let endIndex = Math.min(startIndex + chunkSize, textLength);
|
||||
let chunkText = text.slice(startIndex, endIndex);
|
||||
|
||||
if (endIndex < textLength) {
|
||||
let lastSeparatorIndex = -1;
|
||||
for (const separator of SEPARATORS) {
|
||||
const index = chunkText.lastIndexOf(separator);
|
||||
if (index !== -1) {
|
||||
lastSeparatorIndex = Math.max(lastSeparatorIndex, index);
|
||||
}
|
||||
}
|
||||
|
||||
if (lastSeparatorIndex !== -1) {
|
||||
endIndex = startIndex + lastSeparatorIndex + 1;
|
||||
chunkText = text.slice(startIndex, endIndex);
|
||||
} else {
|
||||
const nextSeparatorIndex = text.slice(endIndex).search(/\S/);
|
||||
if (nextSeparatorIndex !== -1) {
|
||||
endIndex += nextSeparatorIndex;
|
||||
chunkText = text.slice(startIndex, endIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chunkText = chunkText.trim();
|
||||
if (chunkText) {
|
||||
chunks.push({
|
||||
text: chunkText,
|
||||
isFinished: endIndex >= textLength,
|
||||
});
|
||||
} else if (chunks.length > 0) {
|
||||
chunks[chunks.length - 1].isFinished = true;
|
||||
}
|
||||
|
||||
startIndex = endIndex;
|
||||
while (startIndex < textLength && text[startIndex].trim() === '') {
|
||||
startIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Input stream text to speech
|
||||
* @param {Express.Response} res
|
||||
* @param {AsyncIterable<string>} textStream
|
||||
* @param {(token: string) => Promise<boolean>} callback - Whether to continue the stream or not
|
||||
* @returns {AsyncGenerator<AudioChunk>}
|
||||
*/
|
||||
function inputStreamTextToSpeech(res, textStream, callback) {
|
||||
const model = 'eleven_monolingual_v1';
|
||||
const wsUrl = `wss://api.elevenlabs.io/v1/text-to-speech/${getRandomVoiceId()}/stream-input${assembleQuery(
|
||||
{
|
||||
model_id: model,
|
||||
// flush: true,
|
||||
// optimize_streaming_latency: this.settings.optimizeStreamingLatency,
|
||||
optimize_streaming_latency: 1,
|
||||
// output_format: this.settings.outputFormat,
|
||||
},
|
||||
)}`;
|
||||
const socket = new WebSocket(wsUrl);
|
||||
|
||||
socket.onopen = function () {
|
||||
const streamStart = {
|
||||
text: ' ',
|
||||
voice_settings: {
|
||||
stability: 0.5,
|
||||
similarity_boost: 0.8,
|
||||
},
|
||||
xi_api_key: process.env.ELEVENLABS_API_KEY,
|
||||
// generation_config: { chunk_length_schedule: [50, 90, 120, 150, 200] },
|
||||
};
|
||||
|
||||
socket.send(JSON.stringify(streamStart));
|
||||
|
||||
// send stream until done
|
||||
const streamComplete = new Promise((resolve, reject) => {
|
||||
(async () => {
|
||||
let textBuffer = '';
|
||||
let shouldContinue = true;
|
||||
for await (const textDelta of textStream) {
|
||||
textBuffer += textDelta;
|
||||
|
||||
// using ". " as separator: sending in full sentences improves the quality
|
||||
// of the audio output significantly.
|
||||
const separatorIndex = findLastSeparatorIndex(textBuffer);
|
||||
|
||||
// Callback for textStream (will return false if signal is aborted)
|
||||
shouldContinue = await callback(textDelta);
|
||||
|
||||
if (separatorIndex === -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!shouldContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
const textToProcess = textBuffer.slice(0, separatorIndex);
|
||||
textBuffer = textBuffer.slice(separatorIndex + 1);
|
||||
|
||||
const request = {
|
||||
text: textToProcess,
|
||||
try_trigger_generation: true,
|
||||
};
|
||||
|
||||
socket.send(JSON.stringify(request));
|
||||
}
|
||||
|
||||
// send remaining text:
|
||||
if (shouldContinue && textBuffer.length > 0) {
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
text: `${textBuffer} `, // append space
|
||||
try_trigger_generation: true,
|
||||
}),
|
||||
);
|
||||
}
|
||||
})()
|
||||
.then(resolve)
|
||||
.catch(reject);
|
||||
});
|
||||
|
||||
streamComplete
|
||||
.then(() => {
|
||||
const endStream = {
|
||||
text: '',
|
||||
};
|
||||
|
||||
socket.send(JSON.stringify(endStream));
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error('Error streaming text to speech:', e);
|
||||
throw e;
|
||||
});
|
||||
};
|
||||
|
||||
return (async function* audioStream() {
|
||||
let isDone = false;
|
||||
let chunks = [];
|
||||
let resolve;
|
||||
let waitForMessage = new Promise((r) => (resolve = r));
|
||||
|
||||
socket.onmessage = function (event) {
|
||||
// console.log(event);
|
||||
const audioChunk = JSON.parse(event.data);
|
||||
if (audioChunk.audio && audioChunk.alignment) {
|
||||
res.write(`event: audio\ndata: ${event.data}\n\n`);
|
||||
chunks.push(audioChunk);
|
||||
resolve(null);
|
||||
waitForMessage = new Promise((r) => (resolve = r));
|
||||
} else if (audioChunk.isFinal) {
|
||||
isDone = true;
|
||||
resolve(null);
|
||||
} else if (audioChunk.message) {
|
||||
console.warn('Received Elevenlabs message:', audioChunk.message);
|
||||
resolve(null);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onerror = function (error) {
|
||||
console.error('WebSocket error:', error);
|
||||
// throw error;
|
||||
};
|
||||
|
||||
socket.onclose = function () {
|
||||
isDone = true;
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
while (!isDone) {
|
||||
await waitForMessage;
|
||||
yield* chunks;
|
||||
chunks = [];
|
||||
}
|
||||
|
||||
res.write('event: end\ndata: \n\n');
|
||||
})();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {AsyncIterable<string>} llmStream
|
||||
*/
|
||||
async function* llmMessageSource(llmStream) {
|
||||
for await (const chunk of llmStream) {
|
||||
const message = chunk.choices[0].delta.content;
|
||||
if (message) {
|
||||
yield message;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
inputStreamTextToSpeech,
|
||||
findLastSeparatorIndex,
|
||||
createChunkProcessor,
|
||||
splitTextIntoChunks,
|
||||
llmMessageSource,
|
||||
getRandomVoiceId,
|
||||
};
|
||||
137
api/server/services/Files/Audio/streamAudio.spec.js
Normal file
137
api/server/services/Files/Audio/streamAudio.spec.js
Normal file
@@ -0,0 +1,137 @@
|
||||
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { Message } = require('~/models/Message');
|
||||
|
||||
jest.mock('~/models/Message', () => ({
|
||||
Message: {
|
||||
findOne: jest.fn().mockReturnValue({
|
||||
lean: jest.fn(),
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('processChunks', () => {
|
||||
let processChunks;
|
||||
|
||||
beforeEach(() => {
|
||||
processChunks = createChunkProcessor('message-id');
|
||||
Message.findOne.mockClear();
|
||||
Message.findOne().lean.mockClear();
|
||||
});
|
||||
|
||||
it('should return an empty array when the message is not found', async () => {
|
||||
Message.findOne().lean.mockResolvedValueOnce(null);
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return an empty array when the message does not have a text property', async () => {
|
||||
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return chunks for an unfinished message with separators', async () => {
|
||||
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([
|
||||
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
|
||||
]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return chunks for an unfinished message without separators', async () => {
|
||||
const messageText = 'This is a long message without separators hello there my friend';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return the remaining text as a chunk for a finished message', async () => {
|
||||
const messageText = 'This is a finished message.';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: true }]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return an empty array for a finished message with no remaining text', async () => {
|
||||
const messageText = 'This is a finished message.';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
|
||||
await processChunks();
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('splitTextIntoChunks', () => {
|
||||
test('splits text into chunks of specified size with default separators', () => {
|
||||
const text = 'This is a test. This is only a test! Make sure it works properly? Okay.';
|
||||
const chunkSize = 20;
|
||||
const expectedChunks = [
|
||||
{ text: 'This is a test.', isFinished: false },
|
||||
{ text: 'This is only a test!', isFinished: false },
|
||||
{ text: 'Make sure it works p', isFinished: false },
|
||||
{ text: 'roperly? Okay.', isFinished: true },
|
||||
];
|
||||
|
||||
const result = splitTextIntoChunks(text, chunkSize);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('splits text into chunks with default size', () => {
|
||||
const text = 'A'.repeat(8000) + '. The end.';
|
||||
const expectedChunks = [
|
||||
{ text: 'A'.repeat(4000), isFinished: false },
|
||||
{ text: 'A'.repeat(4000), isFinished: false },
|
||||
{ text: '. The end.', isFinished: true },
|
||||
];
|
||||
|
||||
const result = splitTextIntoChunks(text);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('returns a single chunk if text length is less than chunk size', () => {
|
||||
const text = 'Short text.';
|
||||
const expectedChunks = [{ text: 'Short text.', isFinished: true }];
|
||||
|
||||
const result = splitTextIntoChunks(text, 4000);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('handles text with no separators correctly', () => {
|
||||
const text = 'ThisTextHasNoSeparatorsAndIsVeryLong'.repeat(100);
|
||||
const chunkSize = 4000;
|
||||
const expectedChunks = [{ text: text, isFinished: true }];
|
||||
|
||||
const result = splitTextIntoChunks(text, chunkSize);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('throws an error when text is empty', () => {
|
||||
expect(() => splitTextIntoChunks('')).toThrow('Text is required');
|
||||
});
|
||||
});
|
||||
416
api/server/services/Files/Audio/textToSpeech.js
Normal file
416
api/server/services/Files/Audio/textToSpeech.js
Normal file
@@ -0,0 +1,416 @@
|
||||
const axios = require('axios');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* getProvider function
|
||||
* This function takes the ttsSchema object and returns the name of the provider
|
||||
* If more than one provider is set or no provider is set, it throws an error
|
||||
*
|
||||
* @param {Object} ttsSchema - The TTS schema containing the provider configuration
|
||||
* @returns {string} The name of the provider
|
||||
* @throws {Error} Throws an error if multiple providers are set or no provider is set
|
||||
*/
|
||||
function getProvider(ttsSchema) {
|
||||
if (!ttsSchema) {
|
||||
throw new Error(`No TTS schema is set. Did you configure TTS in the custom config (librechat.yaml)?
|
||||
|
||||
https://www.librechat.ai/docs/configuration/stt_tts#tts`);
|
||||
}
|
||||
const providers = Object.entries(ttsSchema).filter(([, value]) => Object.keys(value).length > 0);
|
||||
|
||||
if (providers.length > 1) {
|
||||
throw new Error('Multiple providers are set. Please set only one provider.');
|
||||
} else if (providers.length === 0) {
|
||||
throw new Error('No provider is set. Please set a provider.');
|
||||
} else {
|
||||
return providers[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* removeUndefined function
|
||||
* This function takes an object and removes all keys with undefined values
|
||||
* It also removes keys with empty objects as values
|
||||
*
|
||||
* @param {Object} obj - The object to be cleaned
|
||||
* @returns {void} This function does not return a value. It modifies the input object directly
|
||||
*/
|
||||
function removeUndefined(obj) {
|
||||
Object.keys(obj).forEach((key) => {
|
||||
if (obj[key] && typeof obj[key] === 'object') {
|
||||
removeUndefined(obj[key]);
|
||||
if (Object.keys(obj[key]).length === 0) {
|
||||
delete obj[key];
|
||||
}
|
||||
} else if (obj[key] === undefined) {
|
||||
delete obj[key];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the OpenAI TTS
|
||||
* It uses the provided TTS schema, input text, and voice to create the request
|
||||
*
|
||||
* @param {TCustomConfig['tts']['openai']} ttsSchema - The TTS schema containing the OpenAI configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it throws an error with a message indicating that the selected voice is not available
|
||||
*/
|
||||
function openAIProvider(ttsSchema, input, voice) {
|
||||
const url = ttsSchema?.url || 'https://api.openai.com/v1/audio/speech';
|
||||
|
||||
if (
|
||||
ttsSchema?.voices &&
|
||||
ttsSchema.voices.length > 0 &&
|
||||
!ttsSchema.voices.includes(voice) &&
|
||||
!ttsSchema.voices.includes('ALL')
|
||||
) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
let data = {
|
||||
input,
|
||||
model: ttsSchema?.model,
|
||||
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
|
||||
backend: ttsSchema?.backend,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey),
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* elevenLabsProvider function
|
||||
* This function prepares the necessary data and headers for making a request to the Eleven Labs TTS
|
||||
* It uses the provided TTS schema, input text, and voice to create the request
|
||||
*
|
||||
* @param {TCustomConfig['tts']['elevenLabs']} ttsSchema - The TTS schema containing the Eleven Labs configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
* @param {boolean} stream - Whether to stream the audio or not
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* @throws {Error} Throws an error if the selected voice is not available
|
||||
*/
|
||||
function elevenLabsProvider(ttsSchema, input, voice, stream) {
|
||||
let url =
|
||||
ttsSchema?.url ||
|
||||
`https://api.elevenlabs.io/v1/text-to-speech/{voice_id}${stream ? '/stream' : ''}`;
|
||||
|
||||
if (!ttsSchema?.voices.includes(voice) && !ttsSchema?.voices.includes('ALL')) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
url = url.replace('{voice_id}', voice);
|
||||
|
||||
let data = {
|
||||
model_id: ttsSchema?.model,
|
||||
text: input,
|
||||
// voice_id: voice,
|
||||
voice_settings: {
|
||||
similarity_boost: ttsSchema?.voice_settings?.similarity_boost,
|
||||
stability: ttsSchema?.voice_settings?.stability,
|
||||
style: ttsSchema?.voice_settings?.style,
|
||||
use_speaker_boost: ttsSchema?.voice_settings?.use_speaker_boost || undefined,
|
||||
},
|
||||
pronunciation_dictionary_locators: ttsSchema?.pronunciation_dictionary_locators,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'xi-api-key': extractEnvVariable(ttsSchema?.apiKey),
|
||||
Accept: 'audio/mpeg',
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* localAIProvider function
|
||||
* This function prepares the necessary data and headers for making a request to the LocalAI TTS
|
||||
* It uses the provided TTS schema, input text, and voice to create the request
|
||||
*
|
||||
* @param {TCustomConfig['tts']['localai']} ttsSchema - The TTS schema containing the LocalAI configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* @throws {Error} Throws an error if the selected voice is not available
|
||||
*/
|
||||
function localAIProvider(ttsSchema, input, voice) {
|
||||
let url = ttsSchema?.url;
|
||||
|
||||
if (
|
||||
ttsSchema?.voices &&
|
||||
ttsSchema.voices.length > 0 &&
|
||||
!ttsSchema.voices.includes(voice) &&
|
||||
!ttsSchema.voices.includes('ALL')
|
||||
) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
let data = {
|
||||
input,
|
||||
model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
|
||||
backend: ttsSchema?.backend,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey),
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
if (extractEnvVariable(ttsSchema.apiKey) === '') {
|
||||
delete headers.Authorization;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Returns provider and its schema for use with TTS requests
|
||||
* @param {TCustomConfig} customConfig
|
||||
* @param {string} _voice
|
||||
* @returns {Promise<[string, TProviderSchema]>}
|
||||
*/
|
||||
async function getProviderSchema(customConfig) {
|
||||
const provider = getProvider(customConfig.tts);
|
||||
return [provider, customConfig.tts[provider]];
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Returns a tuple of the TTS schema as well as the voice for the TTS request
|
||||
* @param {TProviderSchema} providerSchema
|
||||
* @param {string} requestVoice
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function getVoice(providerSchema, requestVoice) {
|
||||
const voices = providerSchema.voices.filter((voice) => voice && voice.toUpperCase() !== 'ALL');
|
||||
let voice = requestVoice;
|
||||
if (!voice || !voices.includes(voice) || (voice.toUpperCase() === 'ALL' && voices.length > 1)) {
|
||||
voice = getRandomVoiceId(voices);
|
||||
}
|
||||
|
||||
return voice;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} provider
|
||||
* @param {TProviderSchema} ttsSchema
|
||||
* @param {object} params
|
||||
* @param {string} params.voice
|
||||
* @param {string} params.input
|
||||
* @param {boolean} [params.stream]
|
||||
* @returns {Promise<ArrayBuffer>}
|
||||
*/
|
||||
async function ttsRequest(provider, ttsSchema, { input, voice, stream = true } = { stream: true }) {
|
||||
let [url, data, headers] = [];
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
[url, data, headers] = openAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
[url, data, headers] = elevenLabsProvider(ttsSchema, input, voice, stream);
|
||||
break;
|
||||
case 'localai':
|
||||
[url, data, headers] = localAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
if (stream) {
|
||||
return await axios.post(url, data, { headers, responseType: 'stream' });
|
||||
}
|
||||
|
||||
return await axios.post(url, data, { headers, responseType: 'arraybuffer' });
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles a text-to-speech request. Extracts input and voice from the request, retrieves the TTS configuration,
|
||||
* and sends a request to the appropriate provider. The resulting audio data is sent in the response
|
||||
*
|
||||
* @param {Object} req - The request object, which should contain the input text and voice in its body
|
||||
* @param {Object} res - The response object, used to send the audio data or an error message
|
||||
*
|
||||
* @returns {Promise<void>} This function does not return a value. It sends the audio data or an error message in the response
|
||||
*
|
||||
* @throws {Error} Throws an error if the provider is invalid
|
||||
*/
|
||||
async function textToSpeech(req, res) {
|
||||
const { input } = req.body;
|
||||
|
||||
if (!input) {
|
||||
return res.status(400).send('Missing text in request body');
|
||||
}
|
||||
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
res.status(500).send('Custom config not found');
|
||||
}
|
||||
|
||||
try {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
const [provider, ttsSchema] = await getProviderSchema(customConfig);
|
||||
const voice = await getVoice(ttsSchema, req.body.voice);
|
||||
if (input.length < 4096) {
|
||||
const response = await ttsRequest(provider, ttsSchema, { input, voice });
|
||||
response.data.pipe(res);
|
||||
return;
|
||||
}
|
||||
|
||||
const textChunks = splitTextIntoChunks(input, 1000);
|
||||
|
||||
for (const chunk of textChunks) {
|
||||
try {
|
||||
const response = await ttsRequest(provider, ttsSchema, {
|
||||
voice,
|
||||
input: chunk.text,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
logger.debug(`[textToSpeech] user: ${req?.user?.id} | writing audio stream`);
|
||||
await new Promise((resolve) => {
|
||||
response.data.pipe(res, { end: chunk.isFinished });
|
||||
response.data.on('end', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
if (chunk.isFinished) {
|
||||
break;
|
||||
}
|
||||
} catch (innerError) {
|
||||
logger.error('Error processing manual update:', chunk, innerError);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'Error creating the audio stream. Suggestion: check your provider quota. Error:',
|
||||
error,
|
||||
);
|
||||
res.status(500).send('An error occurred');
|
||||
}
|
||||
}
|
||||
|
||||
async function streamAudio(req, res) {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
return res.status(500).send('Custom config not found');
|
||||
}
|
||||
|
||||
const [provider, ttsSchema] = await getProviderSchema(customConfig);
|
||||
const voice = await getVoice(ttsSchema, req.body.voice);
|
||||
|
||||
try {
|
||||
let shouldContinue = true;
|
||||
|
||||
req.on('close', () => {
|
||||
logger.warn('[streamAudio] Audio Stream Request closed by client');
|
||||
shouldContinue = false;
|
||||
});
|
||||
|
||||
const processChunks = createChunkProcessor(req.body.messageId);
|
||||
|
||||
while (shouldContinue) {
|
||||
// example updates
|
||||
// const updates = [
|
||||
// { text: 'This is a test.', isFinished: false },
|
||||
// { text: 'This is only a test.', isFinished: false },
|
||||
// { text: 'Your voice is like a combination of Fergie and Jesus!', isFinished: true },
|
||||
// ];
|
||||
|
||||
const updates = await processChunks();
|
||||
if (typeof updates === 'string') {
|
||||
logger.error(`Error processing audio stream updates: ${JSON.stringify(updates)}`);
|
||||
res.status(500).end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (updates.length === 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1250));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const update of updates) {
|
||||
try {
|
||||
const response = await ttsRequest(provider, ttsSchema, {
|
||||
voice,
|
||||
input: update.text,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
if (!shouldContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
logger.debug(`[streamAudio] user: ${req?.user?.id} | writing audio stream`);
|
||||
await new Promise((resolve) => {
|
||||
response.data.pipe(res, { end: update.isFinished });
|
||||
response.data.on('end', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
if (update.isFinished) {
|
||||
shouldContinue = false;
|
||||
break;
|
||||
}
|
||||
} catch (innerError) {
|
||||
logger.error('Error processing update:', update, innerError);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!shouldContinue) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch audio:', error);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
textToSpeech,
|
||||
getProvider,
|
||||
streamAudio,
|
||||
};
|
||||
31
api/server/services/Files/Audio/webSocket.js
Normal file
31
api/server/services/Files/Audio/webSocket.js
Normal file
@@ -0,0 +1,31 @@
|
||||
let token = '';
|
||||
|
||||
function updateTokenWebsocket(newToken) {
|
||||
console.log('Token:', newToken);
|
||||
token = newToken;
|
||||
}
|
||||
|
||||
function sendTextToWebsocket(ws, onDataReceived) {
|
||||
if (token === '[DONE]') {
|
||||
ws.send(' ');
|
||||
return;
|
||||
}
|
||||
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(token);
|
||||
|
||||
ws.onmessage = function (event) {
|
||||
console.log('Received:', event.data);
|
||||
if (onDataReceived) {
|
||||
onDataReceived(event.data); // Pass the received data to the callback function
|
||||
}
|
||||
};
|
||||
} else {
|
||||
console.error('WebSocket is not open. Ready state is: ' + ws.readyState);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
updateTokenWebsocket,
|
||||
sendTextToWebsocket,
|
||||
};
|
||||
@@ -180,7 +180,15 @@ const deleteFirebaseFile = async (req, file) => {
|
||||
if (!fileName.includes(req.user.id)) {
|
||||
throw new Error('Invalid file path');
|
||||
}
|
||||
await deleteFile('', fileName);
|
||||
try {
|
||||
await deleteFile('', fileName);
|
||||
} catch (error) {
|
||||
logger.error('Error deleting file from Firebase:', error);
|
||||
if (error.code === 'storage/object-not-found') {
|
||||
return;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -14,9 +14,11 @@ const { logger } = require('~/config');
|
||||
* @returns {Promise<OpenAIFile>}
|
||||
*/
|
||||
async function uploadOpenAIFile({ req, file, openai }) {
|
||||
const { height, width } = req.body;
|
||||
const isImage = height && width;
|
||||
const uploadedFile = await openai.files.create({
|
||||
file: fs.createReadStream(file.path),
|
||||
purpose: FilePurpose.Assistants,
|
||||
purpose: isImage ? FilePurpose.Vision : FilePurpose.Assistants,
|
||||
});
|
||||
|
||||
logger.debug(
|
||||
@@ -34,7 +36,7 @@ async function uploadOpenAIFile({ req, file, openai }) {
|
||||
await sleep(sleepTime);
|
||||
}
|
||||
|
||||
return uploadedFile;
|
||||
return isImage ? { ...uploadedFile, height, width } : uploadedFile;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -10,10 +10,13 @@ const {
|
||||
EModelEndpoint,
|
||||
mergeFileConfig,
|
||||
hostImageIdSuffix,
|
||||
checkOpenAIStorage,
|
||||
hostImageNamePrefix,
|
||||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { addResourceFileId, deleteResourceFileId } = require('~/server/controllers/assistants/v2');
|
||||
const { convertImage, resizeAndConvert } = require('~/server/services/Files/images');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
|
||||
const { LB_QueueAsyncCall } = require('~/server/utils/queue');
|
||||
const { getStrategyFunctions } = require('./strategies');
|
||||
@@ -34,14 +37,16 @@ const processFiles = async (files) => {
|
||||
/**
|
||||
* Enqueues the delete operation to the leaky bucket queue if necessary, or adds it directly to promises.
|
||||
*
|
||||
* @param {Express.Request} req - The express request object.
|
||||
* @param {MongoFile} file - The file object to delete.
|
||||
* @param {Function} deleteFile - The delete file function.
|
||||
* @param {Promise[]} promises - The array of promises to await.
|
||||
* @param {OpenAI | undefined} [openai] - If an OpenAI file, the initialized OpenAI client.
|
||||
* @param {object} params - The passed parameters.
|
||||
* @param {Express.Request} params.req - The express request object.
|
||||
* @param {MongoFile} params.file - The file object to delete.
|
||||
* @param {Function} params.deleteFile - The delete file function.
|
||||
* @param {Promise[]} params.promises - The array of promises to await.
|
||||
* @param {string[]} params.resolvedFileIds - The array of promises to await.
|
||||
* @param {OpenAI | undefined} [params.openai] - If an OpenAI file, the initialized OpenAI client.
|
||||
*/
|
||||
function enqueueDeleteOperation(req, file, deleteFile, promises, openai) {
|
||||
if (file.source === FileSources.openai) {
|
||||
function enqueueDeleteOperation({ req, file, deleteFile, promises, resolvedFileIds, openai }) {
|
||||
if (checkOpenAIStorage(file.source)) {
|
||||
// Enqueue to leaky bucket
|
||||
promises.push(
|
||||
new Promise((resolve, reject) => {
|
||||
@@ -53,6 +58,7 @@ function enqueueDeleteOperation(req, file, deleteFile, promises, openai) {
|
||||
logger.error('Error deleting file from OpenAI source', err);
|
||||
reject(err);
|
||||
} else {
|
||||
resolvedFileIds.push(file.file_id);
|
||||
resolve(result);
|
||||
}
|
||||
},
|
||||
@@ -62,10 +68,12 @@ function enqueueDeleteOperation(req, file, deleteFile, promises, openai) {
|
||||
} else {
|
||||
// Add directly to promises
|
||||
promises.push(
|
||||
deleteFile(req, file).catch((err) => {
|
||||
logger.error('Error deleting file', err);
|
||||
return Promise.reject(err);
|
||||
}),
|
||||
deleteFile(req, file)
|
||||
.then(() => resolvedFileIds.push(file.file_id))
|
||||
.catch((err) => {
|
||||
logger.error('Error deleting file', err);
|
||||
return Promise.reject(err);
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -80,35 +88,71 @@ function enqueueDeleteOperation(req, file, deleteFile, promises, openai) {
|
||||
* @param {Express.Request} params.req - The express request object.
|
||||
* @param {DeleteFilesBody} params.req.body - The request body.
|
||||
* @param {string} [params.req.body.assistant_id] - The assistant ID if file uploaded is associated to an assistant.
|
||||
* @param {string} [params.req.body.tool_resource] - The tool resource if assistant file uploaded is associated to a tool resource.
|
||||
*
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const processDeleteRequest = async ({ req, files }) => {
|
||||
const file_ids = files.map((file) => file.file_id);
|
||||
|
||||
const resolvedFileIds = [];
|
||||
const deletionMethods = {};
|
||||
const promises = [];
|
||||
promises.push(deleteFiles(file_ids));
|
||||
|
||||
/** @type {OpenAI | undefined} */
|
||||
let openai;
|
||||
if (req.body.assistant_id) {
|
||||
({ openai } = await initializeClient({ req }));
|
||||
/** @type {Record<string, OpenAI | undefined>} */
|
||||
const client = { [FileSources.openai]: undefined, [FileSources.azure]: undefined };
|
||||
const initializeClients = async () => {
|
||||
const openAIClient = await getOpenAIClient({
|
||||
req,
|
||||
overrideEndpoint: EModelEndpoint.assistants,
|
||||
});
|
||||
client[FileSources.openai] = openAIClient.openai;
|
||||
|
||||
if (!req.app.locals[EModelEndpoint.azureOpenAI]?.assistants) {
|
||||
return;
|
||||
}
|
||||
|
||||
const azureClient = await getOpenAIClient({
|
||||
req,
|
||||
overrideEndpoint: EModelEndpoint.azureAssistants,
|
||||
});
|
||||
client[FileSources.azure] = azureClient.openai;
|
||||
};
|
||||
|
||||
if (req.body.assistant_id !== undefined) {
|
||||
await initializeClients();
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const source = file.source ?? FileSources.local;
|
||||
|
||||
if (source === FileSources.openai && !openai) {
|
||||
({ openai } = await initializeClient({ req }));
|
||||
if (checkOpenAIStorage(source) && !client[source]) {
|
||||
await initializeClients();
|
||||
}
|
||||
|
||||
if (req.body.assistant_id) {
|
||||
const openai = client[source];
|
||||
|
||||
if (req.body.assistant_id && req.body.tool_resource) {
|
||||
promises.push(
|
||||
deleteResourceFileId({
|
||||
req,
|
||||
openai,
|
||||
file_id: file.file_id,
|
||||
assistant_id: req.body.assistant_id,
|
||||
tool_resource: req.body.tool_resource,
|
||||
}),
|
||||
);
|
||||
} else if (req.body.assistant_id) {
|
||||
promises.push(openai.beta.assistants.files.del(req.body.assistant_id, file.file_id));
|
||||
}
|
||||
|
||||
if (deletionMethods[source]) {
|
||||
enqueueDeleteOperation(req, file, deletionMethods[source], promises, openai);
|
||||
enqueueDeleteOperation({
|
||||
req,
|
||||
file,
|
||||
deleteFile: deletionMethods[source],
|
||||
promises,
|
||||
resolvedFileIds,
|
||||
openai,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -118,10 +162,11 @@ const processDeleteRequest = async ({ req, files }) => {
|
||||
}
|
||||
|
||||
deletionMethods[source] = deleteFile;
|
||||
enqueueDeleteOperation(req, file, deleteFile, promises, openai);
|
||||
enqueueDeleteOperation({ req, file, deleteFile, promises, resolvedFileIds, openai });
|
||||
}
|
||||
|
||||
await Promise.allSettled(promises);
|
||||
await deleteFiles(resolvedFileIds);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -180,12 +225,13 @@ const processFileURL = async ({ fileStrategy, userId, URL, fileName, basePath, c
|
||||
*
|
||||
* @param {Object} params - The parameters object.
|
||||
* @param {Express.Request} params.req - The Express request object.
|
||||
* @param {Express.Response} params.res - The Express response object.
|
||||
* @param {Express.Response} [params.res] - The Express response object.
|
||||
* @param {Express.Multer.File} params.file - The uploaded file.
|
||||
* @param {ImageMetadata} params.metadata - Additional metadata for the file.
|
||||
* @param {boolean} params.returnFile - Whether to return the file metadata or return response as normal.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const processImageFile = async ({ req, res, file, metadata }) => {
|
||||
const processImageFile = async ({ req, res, file, metadata, returnFile = false }) => {
|
||||
const source = req.app.locals.fileStrategy;
|
||||
const { handleImageUpload } = getStrategyFunctions(source);
|
||||
const { file_id, temp_file_id, endpoint } = metadata;
|
||||
@@ -213,6 +259,10 @@ const processImageFile = async ({ req, res, file, metadata }) => {
|
||||
},
|
||||
true,
|
||||
);
|
||||
|
||||
if (returnFile) {
|
||||
return result;
|
||||
}
|
||||
res.status(200).json({ message: 'File uploaded and processed successfully', ...result });
|
||||
};
|
||||
|
||||
@@ -274,28 +324,57 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true })
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const processFileUpload = async ({ req, res, file, metadata }) => {
|
||||
const isAssistantUpload = metadata.endpoint === EModelEndpoint.assistants;
|
||||
const source = isAssistantUpload ? FileSources.openai : FileSources.vectordb;
|
||||
const isAssistantUpload = isAssistantsEndpoint(metadata.endpoint);
|
||||
const assistantSource =
|
||||
metadata.endpoint === EModelEndpoint.azureAssistants ? FileSources.azure : FileSources.openai;
|
||||
const source = isAssistantUpload ? assistantSource : FileSources.vectordb;
|
||||
const { handleFileUpload } = getStrategyFunctions(source);
|
||||
const { file_id, temp_file_id } = metadata;
|
||||
|
||||
/** @type {OpenAI | undefined} */
|
||||
let openai;
|
||||
if (source === FileSources.openai) {
|
||||
({ openai } = await initializeClient({ req }));
|
||||
if (checkOpenAIStorage(source)) {
|
||||
({ openai } = await getOpenAIClient({ req }));
|
||||
}
|
||||
|
||||
const { id, bytes, filename, filepath, embedded } = await handleFileUpload({
|
||||
const {
|
||||
id,
|
||||
bytes,
|
||||
filename,
|
||||
filepath: _filepath,
|
||||
embedded,
|
||||
height,
|
||||
width,
|
||||
} = await handleFileUpload({
|
||||
req,
|
||||
file,
|
||||
file_id,
|
||||
openai,
|
||||
});
|
||||
|
||||
if (isAssistantUpload && !metadata.message_file) {
|
||||
if (isAssistantUpload && !metadata.message_file && !metadata.tool_resource) {
|
||||
await openai.beta.assistants.files.create(metadata.assistant_id, {
|
||||
file_id: id,
|
||||
});
|
||||
} else if (isAssistantUpload && !metadata.message_file) {
|
||||
await addResourceFileId({
|
||||
req,
|
||||
openai,
|
||||
file_id: id,
|
||||
assistant_id: metadata.assistant_id,
|
||||
tool_resource: metadata.tool_resource,
|
||||
});
|
||||
}
|
||||
|
||||
let filepath = isAssistantUpload ? `${openai.baseURL}/files/${id}` : _filepath;
|
||||
if (isAssistantUpload && file.mimetype.startsWith('image')) {
|
||||
const result = await processImageFile({
|
||||
req,
|
||||
file,
|
||||
metadata: { file_id: v4() },
|
||||
returnFile: true,
|
||||
});
|
||||
filepath = result.filepath;
|
||||
}
|
||||
|
||||
const result = await createFile(
|
||||
@@ -304,13 +383,15 @@ const processFileUpload = async ({ req, res, file, metadata }) => {
|
||||
file_id: id ?? file_id,
|
||||
temp_file_id,
|
||||
bytes,
|
||||
filepath,
|
||||
filename: filename ?? file.originalname,
|
||||
filepath: isAssistantUpload ? `${openai.baseURL}/files/${id}` : filepath,
|
||||
context: isAssistantUpload ? FileContext.assistants : FileContext.message_attachment,
|
||||
model: isAssistantUpload ? req.body.model : undefined,
|
||||
type: file.mimetype,
|
||||
embedded,
|
||||
source,
|
||||
height,
|
||||
width,
|
||||
},
|
||||
true,
|
||||
);
|
||||
@@ -340,7 +421,10 @@ const processOpenAIFile = async ({
|
||||
originalName ? `/${originalName}` : ''
|
||||
}`;
|
||||
const type = mime.getType(originalName ?? file_id);
|
||||
|
||||
const source =
|
||||
openai.req.body.endpoint === EModelEndpoint.azureAssistants
|
||||
? FileSources.azure
|
||||
: FileSources.openai;
|
||||
const file = {
|
||||
..._file,
|
||||
type,
|
||||
@@ -349,7 +433,7 @@ const processOpenAIFile = async ({
|
||||
usage: 1,
|
||||
user: userId,
|
||||
context: _file.purpose,
|
||||
source: FileSources.openai,
|
||||
source,
|
||||
model: openai.req.body.model,
|
||||
filename: originalName ?? file_id,
|
||||
};
|
||||
@@ -394,12 +478,14 @@ const processOpenAIImageOutput = async ({ req, buffer, file_id, filename, fileEx
|
||||
filename: `${hostImageNamePrefix}${filename}`,
|
||||
};
|
||||
createFile(file, true);
|
||||
const source =
|
||||
req.body.endpoint === EModelEndpoint.azureAssistants ? FileSources.azure : FileSources.openai;
|
||||
createFile(
|
||||
{
|
||||
...file,
|
||||
file_id,
|
||||
filename,
|
||||
source: FileSources.openai,
|
||||
source,
|
||||
type: mime.getType(fileExt),
|
||||
},
|
||||
true,
|
||||
@@ -500,7 +586,12 @@ async function retrieveAndProcessFile({
|
||||
* Filters a file based on its size and the endpoint origin.
|
||||
*
|
||||
* @param {Object} params - The parameters for the function.
|
||||
* @param {Express.Request} params.req - The request object from Express.
|
||||
* @param {object} params.req - The request object from Express.
|
||||
* @param {string} [params.req.endpoint]
|
||||
* @param {string} [params.req.file_id]
|
||||
* @param {number} [params.req.width]
|
||||
* @param {number} [params.req.height]
|
||||
* @param {number} [params.req.version]
|
||||
* @param {Express.Multer.File} params.file - The file uploaded to the server via multer.
|
||||
* @param {boolean} [params.image] - Whether the file expected is an image.
|
||||
* @returns {void}
|
||||
|
||||
@@ -111,6 +111,8 @@ const getStrategyFunctions = (fileSource) => {
|
||||
return localStrategy();
|
||||
} else if (fileSource === FileSources.openai) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.azure) {
|
||||
return openAIStrategy();
|
||||
} else if (fileSource === FileSources.vectordb) {
|
||||
return vectorStrategy();
|
||||
} else {
|
||||
|
||||
@@ -167,6 +167,8 @@ const getOpenAIModels = async (opts) => {
|
||||
|
||||
if (opts.assistants) {
|
||||
models = defaultModels[EModelEndpoint.assistants];
|
||||
} else if (opts.azure) {
|
||||
models = defaultModels[EModelEndpoint.azureAssistants];
|
||||
}
|
||||
|
||||
if (opts.plugins) {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const throttle = require('lodash/throttle');
|
||||
const {
|
||||
StepTypes,
|
||||
ContentTypes,
|
||||
@@ -8,6 +9,7 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||
const { processRequiredActions } = require('~/server/services/ToolService');
|
||||
const { saveMessage, updateMessageText } = require('~/models/Message');
|
||||
const { createOnProgress, sendMessage } = require('~/server/utils');
|
||||
const { processMessages } = require('~/server/services/Threads');
|
||||
const { logger } = require('~/config');
|
||||
@@ -43,6 +45,8 @@ class StreamRunManager {
|
||||
/** @type {string} */
|
||||
this.apiKey = this.openai.apiKey;
|
||||
/** @type {string} */
|
||||
this.parentMessageId = fields.parentMessageId;
|
||||
/** @type {string} */
|
||||
this.thread_id = fields.thread_id;
|
||||
/** @type {RunCreateAndStreamParams} */
|
||||
this.initialRunBody = fields.runBody;
|
||||
@@ -58,10 +62,14 @@ class StreamRunManager {
|
||||
this.messages = [];
|
||||
/** @type {string} */
|
||||
this.text = '';
|
||||
/** @type {string} */
|
||||
this.intermediateText = '';
|
||||
/** @type {Set<string>} */
|
||||
this.attachedFileIds = fields.attachedFileIds;
|
||||
/** @type {undefined | Promise<ChatCompletion>} */
|
||||
this.visionPromise = fields.visionPromise;
|
||||
/** @type {boolean} */
|
||||
this.savedInitialMessage = false;
|
||||
|
||||
/**
|
||||
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
|
||||
@@ -123,6 +131,33 @@ class StreamRunManager {
|
||||
sendMessage(this.res, contentData);
|
||||
}
|
||||
|
||||
/* <------------------ Misc. Helpers ------------------> */
|
||||
/** Returns the latest intermediate text
|
||||
* @returns {string}
|
||||
*/
|
||||
getText() {
|
||||
return this.intermediateText;
|
||||
}
|
||||
|
||||
/** Saves the initial intermediate message
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async saveInitialMessage() {
|
||||
return saveMessage({
|
||||
conversationId: this.finalMessage.conversationId,
|
||||
messageId: this.finalMessage.messageId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
model: this.req.body.assistant_id,
|
||||
endpoint: this.req.body.endpoint,
|
||||
isCreatedByUser: false,
|
||||
user: this.req.user.id,
|
||||
text: this.getText(),
|
||||
sender: 'Assistant',
|
||||
unfinished: true,
|
||||
error: false,
|
||||
});
|
||||
}
|
||||
|
||||
/* <------------------ Main Event Handlers ------------------> */
|
||||
|
||||
/**
|
||||
@@ -407,6 +442,7 @@ class StreamRunManager {
|
||||
const content = message.delta.content?.[0];
|
||||
|
||||
if (content && content.type === MessageContentTypes.TEXT) {
|
||||
this.intermediateText += content.text.value;
|
||||
onProgress(content.text.value);
|
||||
}
|
||||
}
|
||||
@@ -523,10 +559,24 @@ class StreamRunManager {
|
||||
const stepKey = message_creation.message_id;
|
||||
const index = this.getStepIndex(stepKey);
|
||||
this.orderedRunSteps.set(index, message_creation);
|
||||
|
||||
// Create the Factory Function to stream the message
|
||||
const { onProgress: progressCallback } = createOnProgress({
|
||||
// todo: add option to save partialText to db
|
||||
// onProgress: () => {},
|
||||
onProgress: throttle(
|
||||
() => {
|
||||
if (!this.savedInitialMessage) {
|
||||
this.saveInitialMessage();
|
||||
this.savedInitialMessage = true;
|
||||
} else {
|
||||
updateMessageText({
|
||||
messageId: this.finalMessage.messageId,
|
||||
text: this.getText(),
|
||||
});
|
||||
}
|
||||
},
|
||||
2000,
|
||||
{ trailing: false },
|
||||
),
|
||||
});
|
||||
|
||||
// This creates a function that attaches all of the parameters
|
||||
|
||||
@@ -55,7 +55,7 @@ async function createRun({ openai, thread_id, body }) {
|
||||
* @param {string} params.run_id - The ID of the run to wait for.
|
||||
* @param {string} params.thread_id - The ID of the thread associated with the run.
|
||||
* @param {RunManager} params.runManager - The RunManager instance to manage run steps.
|
||||
* @param {number} [params.pollIntervalMs=750] - The interval for polling the run status; default is 750 milliseconds.
|
||||
* @param {number} [params.pollIntervalMs=2000] - The interval for polling the run status; default is 2000 milliseconds.
|
||||
* @param {number} [params.timeout=180000] - The period to wait until timing out polling; default is 3 minutes (in ms).
|
||||
* @return {Promise<Run>} A promise that resolves to the last fetched run object.
|
||||
*/
|
||||
@@ -64,7 +64,7 @@ async function waitForRun({
|
||||
run_id,
|
||||
thread_id,
|
||||
runManager,
|
||||
pollIntervalMs = 750,
|
||||
pollIntervalMs = 2000,
|
||||
timeout = 60000 * 3,
|
||||
}) {
|
||||
let timeElapsed = 0;
|
||||
@@ -233,7 +233,7 @@ async function _handleRun({ openai, run_id, thread_id }) {
|
||||
run_id,
|
||||
thread_id,
|
||||
runManager,
|
||||
pollIntervalMs: 750,
|
||||
pollIntervalMs: 2000,
|
||||
timeout: 60000,
|
||||
});
|
||||
const actions = [];
|
||||
|
||||
@@ -3,7 +3,6 @@ const { v4 } = require('uuid');
|
||||
const {
|
||||
Constants,
|
||||
ContentTypes,
|
||||
EModelEndpoint,
|
||||
AnnotationTypes,
|
||||
defaultOrderQuery,
|
||||
} = require('librechat-data-provider');
|
||||
@@ -50,6 +49,7 @@ async function initThread({ openai, body, thread_id: _thread_id }) {
|
||||
* @param {string} params.assistant_id - The current assistant Id.
|
||||
* @param {string} params.thread_id - The thread Id.
|
||||
* @param {string} params.conversationId - The message's conversationId
|
||||
* @param {string} params.endpoint - The conversation endpoint
|
||||
* @param {string} [params.parentMessageId] - Optional if initial message.
|
||||
* Defaults to Constants.NO_PARENT.
|
||||
* @param {string} [params.instructions] - Optional: from preset for `instructions` field.
|
||||
@@ -82,7 +82,7 @@ async function saveUserMessage(params) {
|
||||
|
||||
const userMessage = {
|
||||
user: params.user,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint: params.endpoint,
|
||||
messageId: params.messageId,
|
||||
conversationId: params.conversationId,
|
||||
parentMessageId: params.parentMessageId ?? Constants.NO_PARENT,
|
||||
@@ -96,7 +96,7 @@ async function saveUserMessage(params) {
|
||||
};
|
||||
|
||||
const convo = {
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint: params.endpoint,
|
||||
conversationId: params.conversationId,
|
||||
promptPrefix: params.promptPrefix,
|
||||
instructions: params.instructions,
|
||||
@@ -121,11 +121,13 @@ async function saveUserMessage(params) {
|
||||
* @param {Object} params - The parameters of the Assistant message
|
||||
* @param {string} params.user - The user's ID.
|
||||
* @param {string} params.messageId - The message Id.
|
||||
* @param {string} params.text - The concatenated text of the message.
|
||||
* @param {string} params.assistant_id - The assistant Id.
|
||||
* @param {string} params.thread_id - The thread Id.
|
||||
* @param {string} params.model - The model used by the assistant.
|
||||
* @param {ContentPart[]} params.content - The message content parts.
|
||||
* @param {string} params.conversationId - The message's conversationId
|
||||
* @param {string} params.endpoint - The conversation endpoint
|
||||
* @param {string} params.parentMessageId - The latest user message that triggered this response.
|
||||
* @param {string} [params.instructions] - Optional: from preset for `instructions` field.
|
||||
* Overrides the instructions of the assistant.
|
||||
@@ -133,19 +135,11 @@ async function saveUserMessage(params) {
|
||||
* @return {Promise<Run>} A promise that resolves to the created run object.
|
||||
*/
|
||||
async function saveAssistantMessage(params) {
|
||||
const text = params.content.reduce((acc, part) => {
|
||||
if (!part.value) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
return acc + ' ' + part.value;
|
||||
}, '');
|
||||
|
||||
// const tokenCount = // TODO: need to count each content part
|
||||
|
||||
const message = await recordMessage({
|
||||
user: params.user,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint: params.endpoint,
|
||||
messageId: params.messageId,
|
||||
conversationId: params.conversationId,
|
||||
parentMessageId: params.parentMessageId,
|
||||
@@ -155,12 +149,13 @@ async function saveAssistantMessage(params) {
|
||||
content: params.content,
|
||||
sender: 'Assistant',
|
||||
isCreatedByUser: false,
|
||||
text: text.trim(),
|
||||
text: params.text,
|
||||
unfinished: false,
|
||||
// tokenCount,
|
||||
});
|
||||
|
||||
await saveConvo(params.user, {
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint: params.endpoint,
|
||||
conversationId: params.conversationId,
|
||||
promptPrefix: params.promptPrefix,
|
||||
instructions: params.instructions,
|
||||
@@ -205,20 +200,22 @@ async function addThreadMetadata({ openai, thread_id, messageId, messages }) {
|
||||
*
|
||||
* @param {Object} params - The parameters for synchronizing messages.
|
||||
* @param {OpenAIClient} params.openai - The OpenAI client instance.
|
||||
* @param {string} params.endpoint - The current endpoint.
|
||||
* @param {string} params.thread_id - The current thread ID.
|
||||
* @param {TMessage[]} params.dbMessages - The LibreChat DB messages.
|
||||
* @param {ThreadMessage[]} params.apiMessages - The thread messages from the API.
|
||||
* @param {string} params.conversationId - The current conversation ID.
|
||||
* @param {string} params.thread_id - The current thread ID.
|
||||
* @param {string} [params.assistant_id] - The current assistant ID.
|
||||
* @param {string} params.conversationId - The current conversation ID.
|
||||
* @return {Promise<TMessage[]>} A promise that resolves to the updated messages
|
||||
*/
|
||||
async function syncMessages({
|
||||
openai,
|
||||
apiMessages,
|
||||
dbMessages,
|
||||
conversationId,
|
||||
endpoint,
|
||||
thread_id,
|
||||
dbMessages,
|
||||
apiMessages,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
}) {
|
||||
let result = [];
|
||||
let dbMessageMap = new Map(dbMessages.map((msg) => [msg.messageId, msg]));
|
||||
@@ -290,7 +287,7 @@ async function syncMessages({
|
||||
thread_id,
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
endpoint,
|
||||
parentMessageId: lastMessage ? lastMessage.messageId : Constants.NO_PARENT,
|
||||
role: apiMessage.role,
|
||||
isCreatedByUser: apiMessage.role === 'user',
|
||||
@@ -299,6 +296,7 @@ async function syncMessages({
|
||||
aggregateMessages: [{ id: apiMessage.id }],
|
||||
model: apiMessage.role === 'user' ? null : apiMessage.assistant_id,
|
||||
user: openai.req.user.id,
|
||||
unfinished: false,
|
||||
};
|
||||
|
||||
if (apiMessage.file_ids?.length) {
|
||||
@@ -382,13 +380,21 @@ function mapMessagesToSteps(steps, messages) {
|
||||
*
|
||||
* @param {Object} params - The parameters for initializing a thread.
|
||||
* @param {OpenAIClient} params.openai - The OpenAI client instance.
|
||||
* @param {string} params.endpoint - The current endpoint.
|
||||
* @param {string} [params.latestMessageId] - Optional: The latest message ID from LibreChat.
|
||||
* @param {string} params.thread_id - Response thread ID.
|
||||
* @param {string} params.run_id - Response Run ID.
|
||||
* @param {string} params.conversationId - LibreChat conversation ID.
|
||||
* @return {Promise<TMessage[]>} A promise that resolves to the updated messages
|
||||
*/
|
||||
async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, conversationId }) {
|
||||
async function checkMessageGaps({
|
||||
openai,
|
||||
endpoint,
|
||||
latestMessageId,
|
||||
thread_id,
|
||||
run_id,
|
||||
conversationId,
|
||||
}) {
|
||||
const promises = [];
|
||||
promises.push(openai.beta.threads.messages.list(thread_id, defaultOrderQuery));
|
||||
promises.push(openai.beta.threads.runs.steps.list(thread_id, run_id));
|
||||
@@ -406,6 +412,7 @@ async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, co
|
||||
role: 'assistant',
|
||||
run_id,
|
||||
thread_id,
|
||||
endpoint,
|
||||
metadata: {
|
||||
messageId: latestMessageId,
|
||||
},
|
||||
@@ -452,11 +459,12 @@ async function checkMessageGaps({ openai, latestMessageId, thread_id, run_id, co
|
||||
|
||||
const syncedMessages = await syncMessages({
|
||||
openai,
|
||||
endpoint,
|
||||
thread_id,
|
||||
dbMessages,
|
||||
apiMessages,
|
||||
thread_id,
|
||||
conversationId,
|
||||
assistant_id,
|
||||
conversationId,
|
||||
});
|
||||
|
||||
return Object.values(
|
||||
@@ -498,41 +506,62 @@ const recordUsage = async ({
|
||||
};
|
||||
|
||||
/**
|
||||
* Safely replaces the annotated text within the specified range denoted by start_index and end_index,
|
||||
* after verifying that the text within that range matches the given annotation text.
|
||||
* Proceeds with the replacement even if a mismatch is found, but logs a warning.
|
||||
* Creates a replaceAnnotation function with internal state for tracking the index offset.
|
||||
*
|
||||
* @param {string} originalText The original text content.
|
||||
* @param {number} start_index The starting index where replacement should begin.
|
||||
* @param {number} end_index The ending index where replacement should end.
|
||||
* @param {string} expectedText The text expected to be found in the specified range.
|
||||
* @param {string} replacementText The text to insert in place of the existing content.
|
||||
* @returns {string} The text with the replacement applied, regardless of text match.
|
||||
* @returns {function} The replaceAnnotation function with closure for index offset.
|
||||
*/
|
||||
function replaceAnnotation(originalText, start_index, end_index, expectedText, replacementText) {
|
||||
if (start_index < 0 || end_index > originalText.length || start_index > end_index) {
|
||||
logger.warn(`Invalid range specified for annotation replacement.
|
||||
Attempting replacement with \`replace\` method instead...
|
||||
length: ${originalText.length}
|
||||
start_index: ${start_index}
|
||||
end_index: ${end_index}`);
|
||||
return originalText.replace(originalText, replacementText);
|
||||
function createReplaceAnnotation() {
|
||||
let indexOffset = 0;
|
||||
|
||||
/**
|
||||
* Safely replaces the annotated text within the specified range denoted by start_index and end_index,
|
||||
* after verifying that the text within that range matches the given annotation text.
|
||||
* Proceeds with the replacement even if a mismatch is found, but logs a warning.
|
||||
*
|
||||
* @param {object} params The original text content.
|
||||
* @param {string} params.currentText The current text content, with/without replacements.
|
||||
* @param {number} params.start_index The starting index where replacement should begin.
|
||||
* @param {number} params.end_index The ending index where replacement should end.
|
||||
* @param {string} params.expectedText The text expected to be found in the specified range.
|
||||
* @param {string} params.replacementText The text to insert in place of the existing content.
|
||||
* @returns {string} The text with the replacement applied, regardless of text match.
|
||||
*/
|
||||
function replaceAnnotation({
|
||||
currentText,
|
||||
start_index,
|
||||
end_index,
|
||||
expectedText,
|
||||
replacementText,
|
||||
}) {
|
||||
const adjustedStartIndex = start_index + indexOffset;
|
||||
const adjustedEndIndex = end_index + indexOffset;
|
||||
|
||||
if (
|
||||
adjustedStartIndex < 0 ||
|
||||
adjustedEndIndex > currentText.length ||
|
||||
adjustedStartIndex > adjustedEndIndex
|
||||
) {
|
||||
logger.warn(`Invalid range specified for annotation replacement.
|
||||
Attempting replacement with \`replace\` method instead...
|
||||
length: ${currentText.length}
|
||||
start_index: ${adjustedStartIndex}
|
||||
end_index: ${adjustedEndIndex}`);
|
||||
return currentText.replace(expectedText, replacementText);
|
||||
}
|
||||
|
||||
if (currentText.substring(adjustedStartIndex, adjustedEndIndex) !== expectedText) {
|
||||
return currentText.replace(expectedText, replacementText);
|
||||
}
|
||||
|
||||
indexOffset += replacementText.length - (adjustedEndIndex - adjustedStartIndex);
|
||||
return (
|
||||
currentText.slice(0, adjustedStartIndex) +
|
||||
replacementText +
|
||||
currentText.slice(adjustedEndIndex)
|
||||
);
|
||||
}
|
||||
|
||||
const actualTextInRange = originalText.substring(start_index, end_index);
|
||||
|
||||
if (actualTextInRange !== expectedText) {
|
||||
logger.warn(`The text within the specified range does not match the expected annotation text.
|
||||
Attempting replacement with \`replace\` method instead...
|
||||
Expected: ${expectedText}
|
||||
Actual: ${actualTextInRange}`);
|
||||
|
||||
return originalText.replace(originalText, replacementText);
|
||||
}
|
||||
|
||||
const beforeText = originalText.substring(0, start_index);
|
||||
const afterText = originalText.substring(end_index);
|
||||
return beforeText + replacementText + afterText;
|
||||
return replaceAnnotation;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -581,6 +610,11 @@ async function processMessages({ openai, client, messages = [] }) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const originalText = currentText;
|
||||
text += originalText;
|
||||
|
||||
const replaceAnnotation = createReplaceAnnotation();
|
||||
|
||||
logger.debug('[processMessages] Processing annotations:', annotations);
|
||||
for (const annotation of annotations) {
|
||||
let file;
|
||||
@@ -589,14 +623,16 @@ async function processMessages({ openai, client, messages = [] }) {
|
||||
const file_id = annotationType?.file_id;
|
||||
const alreadyProcessed = client.processedFileIds.has(file_id);
|
||||
|
||||
const replaceCurrentAnnotation = (replacement = '') => {
|
||||
currentText = replaceAnnotation(
|
||||
const replaceCurrentAnnotation = (replacementText = '') => {
|
||||
const { start_index, end_index, text: expectedText } = annotation;
|
||||
currentText = replaceAnnotation({
|
||||
originalText,
|
||||
currentText,
|
||||
annotation.start_index,
|
||||
annotation.end_index,
|
||||
annotation.text,
|
||||
replacement,
|
||||
);
|
||||
start_index,
|
||||
end_index,
|
||||
expectedText,
|
||||
replacementText,
|
||||
});
|
||||
edited = true;
|
||||
};
|
||||
|
||||
@@ -623,7 +659,7 @@ async function processMessages({ openai, client, messages = [] }) {
|
||||
replaceCurrentAnnotation(`^${sources.length}^`);
|
||||
}
|
||||
|
||||
text += currentText + ' ';
|
||||
text = currentText;
|
||||
|
||||
if (!file) {
|
||||
continue;
|
||||
|
||||
@@ -340,29 +340,26 @@ async function processRequiredActions(client, requiredActions) {
|
||||
currentAction.toolInput = currentAction.toolInput.input;
|
||||
}
|
||||
|
||||
try {
|
||||
const promise = tool
|
||||
._call(currentAction.toolInput)
|
||||
.then(handleToolOutput)
|
||||
.catch((error) => {
|
||||
logger.error(`Error processing tool ${currentAction.tool}`, error);
|
||||
return {
|
||||
tool_call_id: currentAction.toolCallId,
|
||||
output: `Error processing tool ${currentAction.tool}: ${redactMessage(error.message)}`,
|
||||
};
|
||||
});
|
||||
promises.push(promise);
|
||||
} catch (error) {
|
||||
const handleToolError = (error) => {
|
||||
logger.error(
|
||||
`tool_call_id: ${currentAction.toolCallId} | Error processing tool ${currentAction.tool}`,
|
||||
error,
|
||||
);
|
||||
promises.push(
|
||||
Promise.resolve({
|
||||
tool_call_id: currentAction.toolCallId,
|
||||
error: error.message,
|
||||
}),
|
||||
);
|
||||
return {
|
||||
tool_call_id: currentAction.toolCallId,
|
||||
output: `Error processing tool ${currentAction.tool}: ${redactMessage(error.message, 256)}`,
|
||||
};
|
||||
};
|
||||
|
||||
try {
|
||||
const promise = tool
|
||||
._call(currentAction.toolInput)
|
||||
.then(handleToolOutput)
|
||||
.catch(handleToolError);
|
||||
promises.push(promise);
|
||||
} catch (error) {
|
||||
const toolOutputError = handleToolError(error);
|
||||
promises.push(Promise.resolve(toolOutputError));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const {
|
||||
Capabilities,
|
||||
EModelEndpoint,
|
||||
assistantEndpointSchema,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -12,22 +12,32 @@ const { logger } = require('~/config');
|
||||
function azureAssistantsDefaults() {
|
||||
return {
|
||||
capabilities: [Capabilities.tools, Capabilities.actions, Capabilities.code_interpreter],
|
||||
version: defaultAssistantsVersion.azureAssistants,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up the Assistants configuration from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig} config - The loaded custom configuration.
|
||||
* @param {Partial<TAssistantEndpoint>} [prevConfig]
|
||||
* @param {EModelEndpoint.assistants|EModelEndpoint.azureAssistants} assistantsEndpoint - The Assistants endpoint name.
|
||||
* - The previously loaded assistants configuration from Azure OpenAI Assistants option.
|
||||
* @param {Partial<TAssistantEndpoint>} [prevConfig]
|
||||
* @returns {Partial<TAssistantEndpoint>} The Assistants endpoint configuration.
|
||||
*/
|
||||
function assistantsConfigSetup(config, prevConfig = {}) {
|
||||
const assistantsConfig = config.endpoints[EModelEndpoint.assistants];
|
||||
function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) {
|
||||
const assistantsConfig = config.endpoints[assistantsEndpoint];
|
||||
const parsedConfig = assistantEndpointSchema.parse(assistantsConfig);
|
||||
if (assistantsConfig.supportedIds?.length && assistantsConfig.excludedIds?.length) {
|
||||
logger.warn(
|
||||
`Both \`supportedIds\` and \`excludedIds\` are defined for the ${EModelEndpoint.assistants} endpoint; \`excludedIds\` field will be ignored.`,
|
||||
`Configuration conflict: The '${assistantsEndpoint}' endpoint has both 'supportedIds' and 'excludedIds' defined. The 'excludedIds' will be ignored.`,
|
||||
);
|
||||
}
|
||||
if (
|
||||
assistantsConfig.privateAssistants &&
|
||||
(assistantsConfig.supportedIds?.length || assistantsConfig.excludedIds?.length)
|
||||
) {
|
||||
logger.warn(
|
||||
`Configuration conflict: The '${assistantsEndpoint}' endpoint has both 'privateAssistants' and 'supportedIds' or 'excludedIds' defined. The 'supportedIds' and 'excludedIds' will be ignored.`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -39,6 +49,7 @@ function assistantsConfigSetup(config, prevConfig = {}) {
|
||||
supportedIds: parsedConfig.supportedIds,
|
||||
capabilities: parsedConfig.capabilities,
|
||||
excludedIds: parsedConfig.excludedIds,
|
||||
privateAssistants: parsedConfig.privateAssistants,
|
||||
timeoutMs: parsedConfig.timeoutMs,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -41,6 +41,17 @@ function azureConfigSetup(config) {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
azureConfiguration.assistants &&
|
||||
process.env.ENDPOINTS &&
|
||||
!process.env.ENDPOINTS.includes(EModelEndpoint.azureAssistants)
|
||||
) {
|
||||
logger.warn(
|
||||
`Azure Assistants are configured, but the endpoint will not be accessible as it's not included in the ENDPOINTS environment variable.
|
||||
Please add the value "${EModelEndpoint.azureAssistants}" to the ENDPOINTS list if expected.`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
modelNames,
|
||||
modelGroupMap,
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
const Redis = require('ioredis');
|
||||
const passport = require('passport');
|
||||
const session = require('express-session');
|
||||
const RedisStore = require('connect-redis').default;
|
||||
const passport = require('passport');
|
||||
const {
|
||||
setupOpenId,
|
||||
googleLogin,
|
||||
githubLogin,
|
||||
discordLogin,
|
||||
facebookLogin,
|
||||
setupOpenId,
|
||||
} = require('../strategies');
|
||||
const client = require('../cache/redis');
|
||||
} = require('~/strategies');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -40,6 +41,11 @@ const configureSocialLogins = (app) => {
|
||||
saveUninitialized: false,
|
||||
};
|
||||
if (process.env.USE_REDIS) {
|
||||
const client = new Redis(process.env.REDIS_URI);
|
||||
client
|
||||
.on('error', (err) => logger.error('ioredis error:', err))
|
||||
.on('ready', () => logger.info('ioredis successfully initialized.'))
|
||||
.on('reconnecting', () => logger.info('ioredis reconnecting...'));
|
||||
sessionOptions.store = new RedisStore({ client, prefix: 'librechat' });
|
||||
}
|
||||
app.use(session(sessionOptions));
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
const { Capabilities, defaultRetrievalModels } = require('librechat-data-provider');
|
||||
const {
|
||||
Capabilities,
|
||||
EModelEndpoint,
|
||||
isAssistantsEndpoint,
|
||||
defaultRetrievalModels,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
const { getCitations, citeText } = require('./citations');
|
||||
const partialRight = require('lodash/partialRight');
|
||||
const { sendMessage } = require('./streamResponse');
|
||||
@@ -154,9 +160,10 @@ const isUserProvided = (value) => value === 'user_provided';
|
||||
* Generate the configuration for a given key and base URL.
|
||||
* @param {string} key
|
||||
* @param {string} baseURL
|
||||
* @param {string} endpoint
|
||||
* @returns {boolean | { userProvide: boolean, userProvideURL?: boolean }}
|
||||
*/
|
||||
function generateConfig(key, baseURL, assistants = false) {
|
||||
function generateConfig(key, baseURL, endpoint) {
|
||||
if (!key) {
|
||||
return false;
|
||||
}
|
||||
@@ -168,6 +175,8 @@ function generateConfig(key, baseURL, assistants = false) {
|
||||
config.userProvideURL = isUserProvided(baseURL);
|
||||
}
|
||||
|
||||
const assistants = isAssistantsEndpoint(endpoint);
|
||||
|
||||
if (assistants) {
|
||||
config.retrievalModels = defaultRetrievalModels;
|
||||
config.capabilities = [
|
||||
@@ -179,6 +188,12 @@ function generateConfig(key, baseURL, assistants = false) {
|
||||
];
|
||||
}
|
||||
|
||||
if (assistants && endpoint === EModelEndpoint.azureAssistants) {
|
||||
config.version = defaultAssistantsVersion.azureAssistants;
|
||||
} else if (assistants) {
|
||||
config.version = defaultAssistantsVersion.assistants;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
"endpoint": "openAI",
|
||||
"title": "VW Transporter 2014 Fuel Consumption. Web Search"
|
||||
},
|
||||
"messagesTree": [
|
||||
"messages": [
|
||||
{
|
||||
"_id": "6615516574dc2ddcdebe40b6",
|
||||
"messageId": "b123942f-ca1a-4b16-9e1f-ea4af5171168",
|
||||
|
||||
40
api/server/utils/import/__data__/librechat-linear.json
Normal file
40
api/server/utils/import/__data__/librechat-linear.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"conversationId": "4c9e1402-f746-479d-8a18-2dfab1bf750f",
|
||||
"endpoint": "azureOpenAI",
|
||||
"title": "Adventure Across Time and Space",
|
||||
"exportAt": "20:33:31 GMT-0400 (Eastern Daylight Time)",
|
||||
"branches": false,
|
||||
"recursive": false,
|
||||
"messages": [
|
||||
{
|
||||
"messageId": "ed3fdaaa-e174-4d70-89a9-b06e70939723",
|
||||
"conversationId": "4c9e1402-f746-479d-8a18-2dfab1bf750f",
|
||||
"createdAt": "2024-05-29T00:33:03.661Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": true,
|
||||
"isEdited": false,
|
||||
"model": null,
|
||||
"parentMessageId": null,
|
||||
"sender": "user",
|
||||
"text": "tell me a long story",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-29T00:33:03.661Z"
|
||||
},
|
||||
{
|
||||
"messageId": "113b8336-4b19-4f91-8eed-741b3dc2940e",
|
||||
"conversationId": "4c9e1402-f746-479d-8a18-2dfab1bf750f",
|
||||
"createdAt": "2024-05-29T00:33:03.661Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": false,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "ed3fdaaa-e174-4d70-89a9-b06e70939723",
|
||||
"sender": "GPT-4",
|
||||
"text": "Of course! Here’s a detailed and engaging story:\n\n---\n\n### The Legend of Eldoria\n\nNestled between towering mountains and dense, ancient forests was the enigmatic kingdom of Eldoria. This realm, clo aked in perpetual twilight, was the stuff of legends. It was said that the land was blessed by the gods and guarded by mythical creatures. Eldoria was a place where magic and realism intertwined seamlessly, creating a land of beauty, wonder, and peril.\n\nIn the heart of this kingdom lay the grand city of Lumina, known",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-29T00:33:03.661Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
{
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"endpoint": "azureOpenAI",
|
||||
"title": "Original",
|
||||
"exportAt": "08:12:16 GMT-0400 (Eastern Daylight Time)",
|
||||
"branches": true,
|
||||
"recursive": false,
|
||||
"options": {
|
||||
"presetId": null,
|
||||
"model": "gpt-4o",
|
||||
"chatGptLabel": null,
|
||||
"promptPrefix": null,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"resendFiles": true,
|
||||
"imageDetail": "auto",
|
||||
"endpoint": "azureOpenAI",
|
||||
"title": "Original"
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"messageId": "115a6247-8fb0-4937-a536-12956669098d",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:08:55.014Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": true,
|
||||
"isEdited": false,
|
||||
"model": null,
|
||||
"parentMessageId": "00000000-0000-0000-0000-000000000000",
|
||||
"sender": "User",
|
||||
"text": "tell me a long story",
|
||||
"tokenCount": 9,
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:09:27.193Z"
|
||||
},
|
||||
{
|
||||
"messageId": "069b9c22-7649-45a9-b90b-fc050533ea21",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:08:55.390Z",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": false,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "115a6247-8fb0-4937-a536-12956669098d",
|
||||
"sender": "GPT-4",
|
||||
"text": "Of course! Settle in for a tale of adventure across time and space.\n\n---\n\nOnce upon a time in the small, sleepy village of Eldoria, there was a young woman named Elara who longed for adventure. Eldoria was a place of routine and simplicity, nestled between rolling hills and dense forests, but Elara always felt that there was more to the world than the boundaries",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:08:58.669Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"tokenCount": 78,
|
||||
"finish_reason": "incomplete"
|
||||
},
|
||||
{
|
||||
"messageId": "303e4c2c-f03e-4e0a-8551-c96ec73be5fe",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:09:27.444Z",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": false,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "115a6247-8fb0-4937-a536-12956669098d",
|
||||
"sender": "GPT-4",
|
||||
"text": "Sure, I can craft a long story for you. Here it goes:\n\n### The Chronicles of Elenor: The Luminary of Anduril\n\nIn an age long forgotten by men, in a world kissed by the glow of dual suns, the Kingdom of Anduril flourished. Verdant valleys graced its land, majestic mountains shielded",
|
||||
"unfinished": true,
|
||||
"updatedAt": "2024-05-28T18:09:30.448Z"
|
||||
},
|
||||
{
|
||||
"messageId": "599e1908-8c52-4a73-ba6b-f6dffbd79ba0",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:14:07.988Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": true,
|
||||
"isEdited": false,
|
||||
"model": null,
|
||||
"parentMessageId": "00000000-0000-0000-0000-000000000000",
|
||||
"sender": "User",
|
||||
"text": "tell me a long long story",
|
||||
"tokenCount": 9,
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:14:07.988Z"
|
||||
},
|
||||
{
|
||||
"messageId": "de9a4e7c-020d-4856-a5a6-ce6794efef99",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:14:08.403Z",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": true,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "599e1908-8c52-4a73-ba6b-f6dffbd79ba0",
|
||||
"sender": "GPT-4",
|
||||
"text": "Of course! Here’s a detailed and engaging story:\n\n---\n\n### The Legend of Eldoria\n\nNestled between towering mountains and dense, ancient forests was the enigmatic kingdom of Eldoria. This realm, clo aked in perpetual twilight, was the stuff of legends. It was said that the land was blessed by the gods and guarded by mythical creatures. Eldoria was a place where magic and realism intertwined seamlessly, creating a land of beauty, wonder, and peril.\n\nIn the heart of this kingdom lay the grand city of Lumina, known",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:14:20.349Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"finish_reason": "incomplete",
|
||||
"tokenCount": 110
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -19,7 +19,7 @@
|
||||
"endpoint": "openAI",
|
||||
"title": "Troubleshooting Python Virtual Environment Activation Issue"
|
||||
},
|
||||
"messagesTree": [
|
||||
"messages": [
|
||||
{
|
||||
"_id": "66326f3f04bed94b7f5be68d",
|
||||
"messageId": "9501f99d-9bbb-40cb-bbb2-16d79aeceb72",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
@@ -24,7 +25,7 @@ function getImporter(jsonData) {
|
||||
}
|
||||
|
||||
// For LibreChat
|
||||
if (jsonData.conversationId && jsonData.messagesTree) {
|
||||
if (jsonData.conversationId && (jsonData.messagesTree || jsonData.messages)) {
|
||||
logger.info('Importing LibreChat conversation');
|
||||
return importLibreChatConvo;
|
||||
}
|
||||
@@ -85,47 +86,92 @@ async function importLibreChatConvo(
|
||||
try {
|
||||
/** @type {ImportBatchBuilder} */
|
||||
const importBatchBuilder = builderFactory(requestUserId);
|
||||
importBatchBuilder.startConversation(EModelEndpoint.openAI);
|
||||
const options = jsonData.options || {};
|
||||
|
||||
/* Endpoint configuration */
|
||||
let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI;
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
const endpointConfig = endpointsConfig?.[endpoint];
|
||||
if (!endpointConfig && endpointsConfig) {
|
||||
endpoint = Object.keys(endpointsConfig)[0];
|
||||
} else if (!endpointConfig) {
|
||||
endpoint = EModelEndpoint.openAI;
|
||||
}
|
||||
|
||||
importBatchBuilder.startConversation(endpoint);
|
||||
|
||||
let firstMessageDate = null;
|
||||
|
||||
const traverseMessages = (messages, parentMessageId = null) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text) {
|
||||
continue;
|
||||
}
|
||||
const messagesToImport = jsonData.messagesTree || jsonData.messages;
|
||||
|
||||
let savedMessage;
|
||||
if (message.sender?.toLowerCase() === 'user') {
|
||||
savedMessage = importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
} else {
|
||||
savedMessage = importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: jsonData.options.model,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
}
|
||||
if (jsonData.recursive) {
|
||||
/**
|
||||
* Recursively traverse the messages tree and save each message to the database.
|
||||
* @param {TMessage[]} messages
|
||||
* @param {string} parentMessageId
|
||||
*/
|
||||
const traverseMessages = async (messages, parentMessageId = null) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let savedMessage;
|
||||
if (message.sender?.toLowerCase() === 'user' || message.isCreatedByUser) {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
} else {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: options.model,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
if (!firstMessageDate) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
|
||||
if (message.children && message.children.length > 0) {
|
||||
await traverseMessages(message.children, savedMessage.messageId);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
await traverseMessages(messagesToImport);
|
||||
} else if (messagesToImport) {
|
||||
const idMapping = new Map();
|
||||
|
||||
for (const message of messagesToImport) {
|
||||
if (!firstMessageDate) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
const newMessageId = uuidv4();
|
||||
idMapping.set(message.messageId, newMessageId);
|
||||
|
||||
if (message.children) {
|
||||
traverseMessages(message.children, savedMessage.messageId);
|
||||
}
|
||||
const clonedMessage = {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT
|
||||
? idMapping.get(message.parentMessageId) || Constants.NO_PARENT
|
||||
: Constants.NO_PARENT,
|
||||
};
|
||||
|
||||
importBatchBuilder.saveMessage(clonedMessage);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
throw new Error('Invalid LibreChat file format');
|
||||
}
|
||||
|
||||
traverseMessages(jsonData.messagesTree);
|
||||
|
||||
importBatchBuilder.finishConversation(jsonData.title, firstMessageDate);
|
||||
importBatchBuilder.finishConversation(jsonData.title, firstMessageDate ?? new Date(), options);
|
||||
await importBatchBuilder.saveBatch();
|
||||
logger.debug(`user: ${requestUserId} | Conversation "${jsonData.title}" imported`);
|
||||
} catch (error) {
|
||||
|
||||
@@ -1,70 +1,68 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { EModelEndpoint, Constants } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
|
||||
const { bulkSaveConvos: _bulkSaveConvos } = require('~/models/Conversation');
|
||||
const { ImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { bulkSaveMessages } = require('~/models/Message');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getImporter } = require('./importers');
|
||||
|
||||
// Mocking the ImportBatchBuilder class and its methods
|
||||
jest.mock('./importBatchBuilder', () => {
|
||||
return {
|
||||
ImportBatchBuilder: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
startConversation: jest.fn().mockResolvedValue(undefined),
|
||||
addUserMessage: jest.fn().mockResolvedValue(undefined),
|
||||
addGptMessage: jest.fn().mockResolvedValue(undefined),
|
||||
saveMessage: jest.fn().mockResolvedValue(undefined),
|
||||
finishConversation: jest.fn().mockResolvedValue(undefined),
|
||||
saveBatch: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
}),
|
||||
};
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const mockedCacheGet = jest.fn();
|
||||
getLogStores.mockImplementation(() => ({
|
||||
get: mockedCacheGet,
|
||||
}));
|
||||
|
||||
// Mock the database methods
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
bulkSaveConvos: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Message', () => ({
|
||||
bulkSaveMessages: jest.fn(),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('importChatGptConvo', () => {
|
||||
it('should import conversation correctly', async () => {
|
||||
const expectedNumberOfMessages = 19;
|
||||
const expectedNumberOfConversations = 2;
|
||||
// Given
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'chatgpt-export.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
// Spy on instance methods
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Then
|
||||
expect(mockedBuilderFactory).toHaveBeenCalledWith(requestUserId);
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
|
||||
expect(mockImportBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(mockImportBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenCalledTimes(
|
||||
expectedNumberOfConversations,
|
||||
); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(jsonData.length);
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should maintain correct message hierarchy (tree parent/children relationship)', async () => {
|
||||
// Prepare test data with known hierarchy
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'chatgpt-tree.json'), 'utf8'),
|
||||
);
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
|
||||
// Then
|
||||
expect(mockedBuilderFactory).toHaveBeenCalledWith(requestUserId);
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const entries = Object.keys(jsonData[0].mapping);
|
||||
// Filter entries that should be processed (not system and have content)
|
||||
const messageEntries = entries.filter(
|
||||
(id) =>
|
||||
jsonData[0].mapping[id].message &&
|
||||
@@ -72,20 +70,16 @@ describe('importChatGptConvo', () => {
|
||||
jsonData[0].mapping[id].message.content,
|
||||
);
|
||||
|
||||
// Expect the saveMessage to be called for each valid entry
|
||||
expect(mockImportBatchBuilder.saveMessage).toHaveBeenCalledTimes(messageEntries.length);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(messageEntries.length);
|
||||
|
||||
const idToUUIDMap = new Map();
|
||||
// Map original IDs to dynamically generated UUIDs
|
||||
mockImportBatchBuilder.saveMessage.mock.calls.forEach((call, index) => {
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call, index) => {
|
||||
const originalId = messageEntries[index];
|
||||
idToUUIDMap.set(originalId, call[0].messageId);
|
||||
});
|
||||
|
||||
// Validate the UUID map contains all expected entries
|
||||
expect(idToUUIDMap.size).toBe(messageEntries.length);
|
||||
|
||||
// Validate correct parent-child relationships
|
||||
messageEntries.forEach((id) => {
|
||||
const { parent } = jsonData[0].mapping[id];
|
||||
|
||||
@@ -93,72 +87,110 @@ describe('importChatGptConvo', () => {
|
||||
? idToUUIDMap.get(parent) ?? Constants.NO_PARENT
|
||||
: Constants.NO_PARENT;
|
||||
|
||||
const actualParentId = idToUUIDMap.get(id)
|
||||
? mockImportBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === idToUUIDMap.get(id),
|
||||
const actualMessageId = idToUUIDMap.get(id);
|
||||
const actualParentId = actualMessageId
|
||||
? importBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === actualMessageId,
|
||||
)[0].parentMessageId
|
||||
: Constants.NO_PARENT;
|
||||
|
||||
expect(actualParentId).toBe(expectedParentId);
|
||||
});
|
||||
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('importLibreChatConvo', () => {
|
||||
it('should import conversation correctly', async () => {
|
||||
const expectedNumberOfMessages = 6;
|
||||
const expectedNumberOfConversations = 1;
|
||||
const jsonDataNonRecursiveBranches = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-opts-nonr-branches.json'), 'utf8'),
|
||||
);
|
||||
|
||||
// Given
|
||||
it('should import conversation correctly', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
[EModelEndpoint.openAI]: {},
|
||||
});
|
||||
const expectedNumberOfMessages = 6;
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-export.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
// Spy on instance methods
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Then
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
expect(mockImportBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(mockImportBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenCalledTimes(
|
||||
expectedNumberOfConversations,
|
||||
); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(1);
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should import linear, non-recursive thread correctly with correct endpoint', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
[EModelEndpoint.azureOpenAI]: {},
|
||||
});
|
||||
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-linear.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
expect(bulkSaveMessages).toHaveBeenCalledTimes(1);
|
||||
|
||||
const messages = bulkSaveMessages.mock.calls[0][0];
|
||||
let lastMessageId = Constants.NO_PARENT;
|
||||
for (const message of messages) {
|
||||
expect(message.parentMessageId).toBe(lastMessageId);
|
||||
lastMessageId = message.messageId;
|
||||
}
|
||||
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.azureOpenAI);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(jsonData.messages.length);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should maintain correct message hierarchy (tree parent/children relationship)', async () => {
|
||||
// Load test data
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-tree.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
|
||||
// Then
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Create a map to track original message IDs to new UUIDs
|
||||
const idToUUIDMap = new Map();
|
||||
mockImportBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
const message = call[0];
|
||||
idToUUIDMap.set(message.originalMessageId, message.messageId);
|
||||
});
|
||||
|
||||
// Function to recursively check children
|
||||
const checkChildren = (children, parentId) => {
|
||||
children.forEach((child) => {
|
||||
const childUUID = idToUUIDMap.get(child.messageId);
|
||||
const expectedParentId = idToUUIDMap.get(parentId) ?? null;
|
||||
const messageCall = mockImportBatchBuilder.saveMessage.mock.calls.find(
|
||||
const messageCall = importBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === childUUID,
|
||||
);
|
||||
|
||||
@@ -172,75 +204,203 @@ describe('importLibreChatConvo', () => {
|
||||
};
|
||||
|
||||
// Start hierarchy validation from root messages
|
||||
checkChildren(jsonData.messagesTree, null); // Assuming root messages have no parent
|
||||
checkChildren(jsonData.messages, null);
|
||||
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should maintain correct message hierarchy (non-recursive)', async () => {
|
||||
const jsonData = jsonDataNonRecursiveBranches;
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const textToMessageMap = new Map();
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
const message = call[0];
|
||||
textToMessageMap.set(message.text, message);
|
||||
});
|
||||
|
||||
const relationships = {
|
||||
'tell me a long story': [
|
||||
'Of course! Settle in for a tale of adventure across time and space.\n\n---\n\nOnce upon a time in the small, sleepy village of Eldoria, there was a young woman named Elara who longed for adventure. Eldoria was a place of routine and simplicity, nestled between rolling hills and dense forests, but Elara always felt that there was more to the world than the boundaries',
|
||||
'Sure, I can craft a long story for you. Here it goes:\n\n### The Chronicles of Elenor: The Luminary of Anduril\n\nIn an age long forgotten by men, in a world kissed by the glow of dual suns, the Kingdom of Anduril flourished. Verdant valleys graced its land, majestic mountains shielded',
|
||||
],
|
||||
'tell me a long long story': [
|
||||
'Of course! Here’s a detailed and engaging story:\n\n---\n\n### The Legend of Eldoria\n\nNestled between towering mountains and dense, ancient forests was the enigmatic kingdom of Eldoria. This realm, clo aked in perpetual twilight, was the stuff of legends. It was said that the land was blessed by the gods and guarded by mythical creatures. Eldoria was a place where magic and realism intertwined seamlessly, creating a land of beauty, wonder, and peril.\n\nIn the heart of this kingdom lay the grand city of Lumina, known',
|
||||
],
|
||||
};
|
||||
|
||||
Object.keys(relationships).forEach((parentText) => {
|
||||
const parentMessage = textToMessageMap.get(parentText);
|
||||
const childrenTexts = relationships[parentText];
|
||||
|
||||
childrenTexts.forEach((childText) => {
|
||||
const childMessage = textToMessageMap.get(childText);
|
||||
expect(childMessage.parentMessageId).toBe(parentMessage.messageId);
|
||||
});
|
||||
});
|
||||
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should retain properties from the original conversation as well as new settings', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
[EModelEndpoint.azureOpenAI]: {},
|
||||
});
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
|
||||
const importer = getImporter(jsonDataNonRecursiveBranches);
|
||||
await importer(jsonDataNonRecursiveBranches, requestUserId, () => importBatchBuilder);
|
||||
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(1);
|
||||
|
||||
const [_title, createdAt, originalConvo] = importBatchBuilder.finishConversation.mock.calls[0];
|
||||
const convo = importBatchBuilder.conversations[0];
|
||||
|
||||
expect(convo).toEqual({
|
||||
...jsonDataNonRecursiveBranches.options,
|
||||
user: requestUserId,
|
||||
conversationId: importBatchBuilder.conversationId,
|
||||
title: originalConvo.title || 'Imported Chat',
|
||||
createdAt: createdAt,
|
||||
updatedAt: createdAt,
|
||||
overrideTimestamp: true,
|
||||
endpoint: importBatchBuilder.endpoint,
|
||||
model: originalConvo.model || openAISettings.model.default,
|
||||
});
|
||||
|
||||
expect(convo.title).toBe('Original');
|
||||
expect(convo.createdAt).toBeInstanceOf(Date);
|
||||
expect(convo.endpoint).toBe(EModelEndpoint.azureOpenAI);
|
||||
expect(convo.model).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
describe('finishConversation', () => {
|
||||
it('should retain properties from the original conversation as well as update with new settings', () => {
|
||||
const requestUserId = 'user-123';
|
||||
const builder = new ImportBatchBuilder(requestUserId);
|
||||
builder.conversationId = 'conv-id-123';
|
||||
builder.messages = [{ text: 'Hello, world!' }];
|
||||
|
||||
const originalConvo = {
|
||||
_id: 'old-convo-id',
|
||||
model: 'custom-model',
|
||||
};
|
||||
|
||||
builder.endpoint = 'test-endpoint';
|
||||
|
||||
const title = 'New Chat Title';
|
||||
const createdAt = new Date('2023-10-01T00:00:00Z');
|
||||
|
||||
const result = builder.finishConversation(title, createdAt, originalConvo);
|
||||
|
||||
expect(result).toEqual({
|
||||
conversation: {
|
||||
user: requestUserId,
|
||||
conversationId: builder.conversationId,
|
||||
title: 'New Chat Title',
|
||||
createdAt: createdAt,
|
||||
updatedAt: createdAt,
|
||||
overrideTimestamp: true,
|
||||
endpoint: 'test-endpoint',
|
||||
model: 'custom-model',
|
||||
},
|
||||
messages: builder.messages,
|
||||
});
|
||||
|
||||
expect(builder.conversations).toContainEqual({
|
||||
user: requestUserId,
|
||||
conversationId: builder.conversationId,
|
||||
title: 'New Chat Title',
|
||||
createdAt: createdAt,
|
||||
updatedAt: createdAt,
|
||||
overrideTimestamp: true,
|
||||
endpoint: 'test-endpoint',
|
||||
model: 'custom-model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should use default values if not provided in the original conversation or as parameters', () => {
|
||||
const requestUserId = 'user-123';
|
||||
const builder = new ImportBatchBuilder(requestUserId);
|
||||
builder.conversationId = 'conv-id-123';
|
||||
builder.messages = [{ text: 'Hello, world!' }];
|
||||
builder.endpoint = 'test-endpoint';
|
||||
const result = builder.finishConversation();
|
||||
expect(result.conversation.title).toBe('Imported Chat');
|
||||
expect(result.conversation.model).toBe(openAISettings.model.default);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('importChatBotUiConvo', () => {
|
||||
it('should import custom conversation correctly', async () => {
|
||||
// Given
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'chatbotui-export.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'custom-user-456';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
// Spy on instance methods
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'addUserMessage');
|
||||
jest.spyOn(importBatchBuilder, 'addGptMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Then
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
expect(mockImportBatchBuilder.startConversation).toHaveBeenCalledWith('openAI');
|
||||
|
||||
// User messages
|
||||
expect(mockImportBatchBuilder.addUserMessage).toHaveBeenCalledTimes(3);
|
||||
expect(mockImportBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(importBatchBuilder.addUserMessage).toHaveBeenCalledTimes(3);
|
||||
expect(importBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Hello what are you able to do?',
|
||||
);
|
||||
expect(mockImportBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
'Give me the code that inverts binary tree in COBOL',
|
||||
);
|
||||
|
||||
// GPT messages
|
||||
expect(mockImportBatchBuilder.addGptMessage).toHaveBeenCalledTimes(3);
|
||||
expect(mockImportBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.addGptMessage).toHaveBeenCalledTimes(3);
|
||||
expect(importBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.stringMatching(/^Hello! As an AI developed by OpenAI/),
|
||||
'gpt-4-1106-preview',
|
||||
);
|
||||
expect(mockImportBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
expect.stringContaining('```cobol'),
|
||||
'gpt-3.5-turbo',
|
||||
);
|
||||
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenCalledTimes(2);
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(2);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Hello what are you able to do?',
|
||||
expect.any(Date),
|
||||
);
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'Give me the code that inverts ...',
|
||||
expect.any(Date),
|
||||
);
|
||||
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getImporter', () => {
|
||||
it('should throw an error if the import type is not supported', () => {
|
||||
// Given
|
||||
const jsonData = { unsupported: 'data' };
|
||||
|
||||
// When
|
||||
expect(() => getImporter(jsonData)).toThrow('Unsupported import type');
|
||||
});
|
||||
});
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user