Compare commits
31 Commits
update-tit
...
bottleneck
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b76233bd66 | ||
|
|
b8e35002f4 | ||
|
|
8318f26d66 | ||
|
|
08d6bea359 | ||
|
|
a6058c5669 | ||
|
|
e0402b71f0 | ||
|
|
a618266905 | ||
|
|
d5a7806e32 | ||
|
|
e2cb2905e7 | ||
|
|
3f600f0d3f | ||
|
|
c9e7d4ac18 | ||
|
|
40685f6eb4 | ||
|
|
0ee060d730 | ||
|
|
5dc5d875ba | ||
|
|
9f2538fcd9 | ||
|
|
2b7a973a33 | ||
|
|
c704a23749 | ||
|
|
eb5733083e | ||
|
|
b80f38e49e | ||
|
|
4369e75ca7 | ||
|
|
35ba4ba1a4 | ||
|
|
dcd2e3e62d | ||
|
|
514a502b9c | ||
|
|
8e66683577 | ||
|
|
dc1778b11f | ||
|
|
795bb9c568 | ||
|
|
a937650df6 | ||
|
|
6cf1c85363 | ||
|
|
b3e03b75d0 | ||
|
|
9d8fd92dd3 | ||
|
|
f00a8f87f7 |
18
.env.example
18
.env.example
@@ -119,7 +119,7 @@ GOOGLE_KEY=user_provided
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
|
||||
|
||||
# Vertex AI
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0409,gemini-1.0-pro-vision-001,gemini-pro,gemini-pro-vision,chat-bison,chat-bison-32k,codechat-bison,codechat-bison-32k,text-bison,text-bison-32k,text-unicorn,code-gecko,code-bison,code-bison-32k
|
||||
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
|
||||
|
||||
# Google Gemini Safety Settings
|
||||
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
|
||||
@@ -257,6 +257,14 @@ MEILI_NO_ANALYTICS=true
|
||||
MEILI_HOST=http://0.0.0.0:7700
|
||||
MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt
|
||||
|
||||
|
||||
#==================================================#
|
||||
# Speech to Text & Text to Speech #
|
||||
#==================================================#
|
||||
|
||||
STT_API_KEY=
|
||||
TTS_API_KEY=
|
||||
|
||||
#===================================================#
|
||||
# User System #
|
||||
#===================================================#
|
||||
@@ -352,6 +360,14 @@ OPENID_REQUIRED_ROLE_PARAMETER_PATH=
|
||||
OPENID_BUTTON_LABEL=
|
||||
OPENID_IMAGE_URL=
|
||||
|
||||
# LDAP
|
||||
LDAP_URL=
|
||||
LDAP_BIND_DN=
|
||||
LDAP_BIND_CREDENTIALS=
|
||||
LDAP_USER_SEARCH_BASE=
|
||||
LDAP_SEARCH_FILTER=mail={{username}}
|
||||
LDAP_CA_CERT_PATH=
|
||||
|
||||
#========================#
|
||||
# Email Password Reset #
|
||||
#========================#
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const Anthropic = require('@anthropic-ai/sdk');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
getResponseSender,
|
||||
@@ -123,9 +124,14 @@ class AnthropicClient extends BaseClient {
|
||||
getClient() {
|
||||
/** @type {Anthropic.default.RequestOptions} */
|
||||
const options = {
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
};
|
||||
|
||||
if (this.options.proxy) {
|
||||
options.httpAgent = new HttpsProxyAgent(this.options.proxy);
|
||||
}
|
||||
|
||||
if (this.options.reverseProxyUrl) {
|
||||
options.baseURL = this.options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const crypto = require('crypto');
|
||||
const fetch = require('node-fetch');
|
||||
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
|
||||
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
@@ -17,6 +18,7 @@ class BaseClient {
|
||||
month: 'long',
|
||||
day: 'numeric',
|
||||
});
|
||||
this.fetch = this.fetch.bind(this);
|
||||
}
|
||||
|
||||
setOptions() {
|
||||
@@ -54,6 +56,22 @@ class BaseClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes an HTTP request and logs the process.
|
||||
*
|
||||
* @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object.
|
||||
* @param {RequestInit} [init] - Optional init options for the request.
|
||||
* @returns {Promise<Response>} - A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
async fetch(_url, init) {
|
||||
let url = _url;
|
||||
if (this.options.directEndpoint) {
|
||||
url = this.options.reverseProxyUrl;
|
||||
}
|
||||
logger.debug(`Making request to ${url}`);
|
||||
return await fetch(url, init);
|
||||
}
|
||||
|
||||
getBuildMessagesOptions() {
|
||||
throw new Error('Subclasses must implement getBuildMessagesOptions');
|
||||
}
|
||||
@@ -373,6 +391,14 @@ class BaseClient {
|
||||
const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } =
|
||||
await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
const { generation = '' } = opts;
|
||||
|
||||
// It's not necessary to push to currentMessages
|
||||
|
||||
@@ -438,9 +438,17 @@ class ChatGPTClient extends BaseClient {
|
||||
|
||||
if (message.eventType === 'text-generation' && message.text) {
|
||||
onTokenProgress(message.text);
|
||||
} else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply += message.text;
|
||||
}
|
||||
/*
|
||||
Cohere API Chinese Unicode character replacement hotfix.
|
||||
Should be un-commented when the following issue is resolved:
|
||||
https://github.com/cohere-ai/cohere-typescript/issues/151
|
||||
|
||||
else if (message.eventType === 'stream-end' && message.response) {
|
||||
reply = message.response.text;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
return reply;
|
||||
|
||||
@@ -27,6 +27,7 @@ const {
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { updateTokenWebsocket } = require('~/server/services/Files/Audio');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
const spendTokens = require('~/models/spendTokens');
|
||||
@@ -588,12 +589,13 @@ class OpenAIClient extends BaseClient {
|
||||
let streamResult = null;
|
||||
this.modelOptions.user = this.user;
|
||||
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion || typeof Bun !== 'undefined');
|
||||
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
|
||||
if (typeof opts.onProgress === 'function' && useOldMethod) {
|
||||
const completionResult = await this.getCompletion(
|
||||
payload,
|
||||
(progressMessage) => {
|
||||
if (progressMessage === '[DONE]') {
|
||||
updateTokenWebsocket('[DONE]');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -827,7 +829,7 @@ class OpenAIClient extends BaseClient {
|
||||
|
||||
const instructionsPayload = [
|
||||
{
|
||||
role: 'system',
|
||||
role: this.options.titleMessageRole ?? 'system',
|
||||
content: `Please generate ${titleInstruction}
|
||||
|
||||
${convo}
|
||||
@@ -1106,7 +1108,12 @@ ${convo}
|
||||
}
|
||||
|
||||
if (this.azure || this.options.azure) {
|
||||
// Azure does not accept `model` in the body, so we need to remove it.
|
||||
/* Azure Bug, extremely short default `max_tokens` response */
|
||||
if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') {
|
||||
modelOptions.max_tokens = 4000;
|
||||
}
|
||||
|
||||
/* Azure does not accept `model` in the body, so we need to remove it. */
|
||||
delete modelOptions.model;
|
||||
|
||||
opts.baseURL = this.langchainProxy
|
||||
@@ -1127,6 +1134,7 @@ ${convo}
|
||||
let chatCompletion;
|
||||
/** @type {OpenAI} */
|
||||
const openai = new OpenAI({
|
||||
fetch: this.fetch,
|
||||
apiKey: this.apiKey,
|
||||
...opts,
|
||||
});
|
||||
@@ -1216,6 +1224,7 @@ ${convo}
|
||||
});
|
||||
|
||||
const azureDelay = this.modelOptions.model?.includes('gpt-4') ? 30 : 17;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const token = chunk.choices[0]?.delta?.content || '';
|
||||
intermediateReply += token;
|
||||
|
||||
@@ -250,6 +250,7 @@ class PluginsClient extends OpenAIClient {
|
||||
this.setOptions(opts);
|
||||
return super.sendMessage(message, opts);
|
||||
}
|
||||
|
||||
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
|
||||
const {
|
||||
user,
|
||||
@@ -264,6 +265,14 @@ class PluginsClient extends OpenAIClient {
|
||||
onToolEnd,
|
||||
} = await this.handleStartMethods(message, opts);
|
||||
|
||||
if (opts.progressCallback) {
|
||||
opts.onProgress = opts.progressCallback.call(null, {
|
||||
...(opts.progressOptions ?? {}),
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
this.currentMessages.push(userMessage);
|
||||
|
||||
let {
|
||||
|
||||
@@ -28,7 +28,7 @@ ${convo}`,
|
||||
};
|
||||
|
||||
const titleInstruction =
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. For English, use AP Stylebook Title Case. Never directly mention the language name or the word "title"';
|
||||
'a concise, 5-word-or-less title for the conversation, using its same language, with no punctuation. Apply title case conventions appropriate for the language. Never directly mention the language name or the word "title"';
|
||||
const titleFunctionPrompt = `In this environment you have access to a set of tools you can use to generate the conversation title.
|
||||
|
||||
You may call them like this:
|
||||
|
||||
@@ -80,13 +80,18 @@ class StableDiffusionAPI extends StructuredTool {
|
||||
const payload = {
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sampler_index: 'DPM++ 2M Karras',
|
||||
cfg_scale: 4.5,
|
||||
steps: 22,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
};
|
||||
const generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
let generationResponse;
|
||||
try {
|
||||
generationResponse = await axios.post(`${url}/sdapi/v1/txt2img`, payload);
|
||||
} catch (error) {
|
||||
logger.error('[StableDiffusion] Error while generating image:', error);
|
||||
return 'Error making API request.';
|
||||
}
|
||||
const image = generationResponse.data.images[0];
|
||||
|
||||
/** @type {{ height: number, width: number, seed: number, infotexts: string[] }} */
|
||||
|
||||
8
api/cache/getLogStores.js
vendored
8
api/cache/getLogStores.js
vendored
@@ -7,6 +7,7 @@ const keyvMongo = require('./keyvMongo');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||
const THIRTY_MINUTES = 1800000;
|
||||
const TEN_MINUTES = 600000;
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
|
||||
@@ -24,6 +25,10 @@ const config = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: TEN_MINUTES });
|
||||
|
||||
const tokenConfig = isEnabled(USE_REDIS) // ttl: 30 minutes
|
||||
? new Keyv({ store: keyvRedis, ttl: THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: THIRTY_MINUTES });
|
||||
@@ -55,6 +60,8 @@ const namespaces = {
|
||||
message_limit: createViolationInstance('message_limit'),
|
||||
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
|
||||
registrations: createViolationInstance('registrations'),
|
||||
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
|
||||
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
|
||||
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
|
||||
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
|
||||
ViolationTypes.ILLEGAL_MODEL_REQUEST,
|
||||
@@ -64,6 +71,7 @@ const namespaces = {
|
||||
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
|
||||
[CacheKeys.GEN_TITLE]: genTitle,
|
||||
[CacheKeys.MODEL_QUERIES]: modelQueries,
|
||||
[CacheKeys.AUDIO_RUNS]: audioRuns,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -27,26 +27,25 @@ function getMatchingSensitivePatterns(valueStr) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Redacts sensitive information from a console message.
|
||||
*
|
||||
* Redacts sensitive information from a console message and trims it to a specified length if provided.
|
||||
* @param {string} str - The console message to be redacted.
|
||||
* @returns {string} - The redacted console message.
|
||||
* @param {number} [trimLength] - The optional length at which to trim the redacted message.
|
||||
* @returns {string} - The redacted and optionally trimmed console message.
|
||||
*/
|
||||
function redactMessage(str) {
|
||||
function redactMessage(str, trimLength) {
|
||||
if (!str) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const patterns = getMatchingSensitivePatterns(str);
|
||||
|
||||
if (patterns.length === 0) {
|
||||
return str;
|
||||
}
|
||||
|
||||
patterns.forEach((pattern) => {
|
||||
str = str.replace(pattern, '$1[REDACTED]');
|
||||
});
|
||||
|
||||
if (trimLength !== undefined && str.length > trimLength) {
|
||||
return `${str.substring(0, trimLength)}...`;
|
||||
}
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ const Assistant = mongoose.model('assistant', assistantSchema);
|
||||
* @param {mongoose.ClientSession} [session] - The transaction session to use (optional).
|
||||
* @returns {Promise<Object>} The updated or newly created assistant document as a plain object.
|
||||
*/
|
||||
const updateAssistant = async (searchParams, updateData, session = null) => {
|
||||
const updateAssistantDoc = async (searchParams, updateData, session = null) => {
|
||||
const options = { new: true, upsert: true, session };
|
||||
return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean();
|
||||
};
|
||||
@@ -52,7 +52,7 @@ const deleteAssistant = async (searchParams) => {
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateAssistant,
|
||||
updateAssistantDoc,
|
||||
deleteAssistant,
|
||||
getAssistants,
|
||||
getAssistant,
|
||||
|
||||
@@ -21,7 +21,7 @@ module.exports = {
|
||||
Conversation,
|
||||
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
|
||||
try {
|
||||
const messages = await getMessages({ conversationId });
|
||||
const messages = await getMessages({ conversationId }, '_id');
|
||||
const update = { ...convo, messages, user };
|
||||
if (newConversationId) {
|
||||
update.conversationId = newConversationId;
|
||||
|
||||
@@ -129,6 +129,14 @@ module.exports = {
|
||||
throw new Error('Failed to save message.');
|
||||
}
|
||||
},
|
||||
async updateMessageText({ messageId, text }) {
|
||||
try {
|
||||
await Message.updateOne({ messageId }, { text });
|
||||
} catch (err) {
|
||||
logger.error('Error updating message text:', err);
|
||||
throw new Error('Failed to update message text.');
|
||||
}
|
||||
},
|
||||
async updateMessage(message) {
|
||||
try {
|
||||
const { messageId, ...update } = message;
|
||||
@@ -171,8 +179,18 @@ module.exports = {
|
||||
}
|
||||
},
|
||||
|
||||
async getMessages(filter) {
|
||||
/**
|
||||
* Retrieves messages from the database.
|
||||
* @param {Record<string, unknown>} filter
|
||||
* @param {string | undefined} [select]
|
||||
* @returns
|
||||
*/
|
||||
async getMessages(filter, select) {
|
||||
try {
|
||||
if (select) {
|
||||
return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean();
|
||||
}
|
||||
|
||||
return await Message.find(filter).sort({ createdAt: 1 }).lean();
|
||||
} catch (err) {
|
||||
logger.error('Error getting messages:', err);
|
||||
|
||||
@@ -64,6 +64,11 @@ const userSchema = mongoose.Schema(
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
ldapId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
sparse: true,
|
||||
},
|
||||
githubId: {
|
||||
type: String,
|
||||
unique: true,
|
||||
|
||||
@@ -40,10 +40,11 @@
|
||||
"@keyv/redis": "^2.8.1",
|
||||
"@langchain/community": "^0.0.46",
|
||||
"@langchain/google-genai": "^0.0.11",
|
||||
"@langchain/google-vertexai": "^0.0.5",
|
||||
"@langchain/google-vertexai": "^0.0.17",
|
||||
"agenda": "^5.0.0",
|
||||
"axios": "^1.3.4",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"bottleneck": "^2.19.5",
|
||||
"cheerio": "^1.0.0-rc.12",
|
||||
"cohere-ai": "^7.9.1",
|
||||
"connect-redis": "^7.1.0",
|
||||
@@ -86,6 +87,7 @@
|
||||
"passport-github2": "^0.1.12",
|
||||
"passport-google-oauth20": "^2.0.0",
|
||||
"passport-jwt": "^4.0.1",
|
||||
"passport-ldapauth": "^3.0.1",
|
||||
"passport-local": "^1.0.0",
|
||||
"pino": "^8.12.1",
|
||||
"sharp": "^0.32.6",
|
||||
@@ -94,6 +96,7 @@
|
||||
"ua-parser-js": "^1.0.36",
|
||||
"winston": "^3.11.0",
|
||||
"winston-daily-rotate-file": "^4.7.1",
|
||||
"ws": "^8.17.0",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -105,11 +105,12 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
getReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, messageOptions);
|
||||
|
||||
@@ -112,11 +112,12 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
getReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
});
|
||||
|
||||
const conversation = await getConvo(user, conversationId);
|
||||
|
||||
@@ -20,6 +20,7 @@ const {
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
@@ -31,15 +32,14 @@ const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {Express.Request} req - The request object, containing the request data.
|
||||
* @param {object} req - The request object, containing the request data.
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {Express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
@@ -60,30 +60,6 @@ const chatV1 = async (req, res) => {
|
||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
|
||||
if (assistantsConfig) {
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
@@ -311,6 +287,7 @@ const chatV1 = async (req, res) => {
|
||||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
|
||||
@@ -20,6 +20,7 @@ const {
|
||||
} = require('~/server/services/Threads');
|
||||
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
|
||||
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { createRun, StreamRunManager } = require('~/server/services/Runs');
|
||||
const { addTitle } = require('~/server/services/Endpoints/assistants');
|
||||
const { getTransactions } = require('~/models/Transaction');
|
||||
@@ -30,8 +31,6 @@ const { getModelMaxTokens } = require('~/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { handleAbortError } = require('~/server/middleware');
|
||||
|
||||
const ten_minutes = 1000 * 60 * 10;
|
||||
|
||||
/**
|
||||
@@ -60,30 +59,6 @@ const chatV2 = async (req, res) => {
|
||||
parentMessageId: _parentId = Constants.NO_PARENT,
|
||||
} = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
|
||||
if (assistantsConfig) {
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId: convoId,
|
||||
messageId: v4(),
|
||||
parentMessageId: _messageId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/** @type {OpenAIClient} */
|
||||
let openai;
|
||||
/** @type {string|undefined} - the current thread id */
|
||||
@@ -309,6 +284,7 @@ const chatV2 = async (req, res) => {
|
||||
});
|
||||
|
||||
openai = _openai;
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
if (previousMessages.length) {
|
||||
parentMessageId = previousMessages[previousMessages.length - 1].messageId;
|
||||
@@ -520,6 +496,7 @@ const chatV2 = async (req, res) => {
|
||||
handlers,
|
||||
thread_id,
|
||||
attachedFileIds,
|
||||
parentMessageId: userMessageId,
|
||||
responseMessage: openai.responseMessage,
|
||||
// streamOptions: {
|
||||
|
||||
@@ -532,6 +509,7 @@ const chatV2 = async (req, res) => {
|
||||
});
|
||||
|
||||
response = streamRunManager;
|
||||
response.text = streamRunManager.intermediateText;
|
||||
};
|
||||
|
||||
await processRun();
|
||||
@@ -554,6 +532,7 @@ const chatV2 = async (req, res) => {
|
||||
/** @type {ResponseMessage} */
|
||||
const responseMessage = {
|
||||
...(response.responseMessage ?? response.finalMessage),
|
||||
text: response.text,
|
||||
parentMessageId: userMessageId,
|
||||
conversationId,
|
||||
user: req.user.id,
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
const { EModelEndpoint, CacheKeys, defaultAssistantsVersion } = require('librechat-data-provider');
|
||||
const {
|
||||
EModelEndpoint,
|
||||
CacheKeys,
|
||||
defaultAssistantsVersion,
|
||||
defaultOrderQuery,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
initializeClient: initAzureClient,
|
||||
} = require('~/server/services/Endpoints/azureAssistants');
|
||||
@@ -35,6 +40,7 @@ const getCurrentVersion = async (req, endpoint) => {
|
||||
* Initializes the client with the current request and response objects and lists assistants
|
||||
* according to the query parameters. This function abstracts the logic for non-Azure paths.
|
||||
*
|
||||
* @deprecated
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
@@ -43,11 +49,65 @@ const getCurrentVersion = async (req, endpoint) => {
|
||||
* @param {object} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAssistants = async ({ req, res, version, query }) => {
|
||||
const _listAssistants = async ({ req, res, version, query }) => {
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
return openai.beta.assistants.list(query);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches all assistants based on provided query params, until `has_more` is `false`.
|
||||
*
|
||||
* @async
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {object} params.req - The request object, used for initializing the client.
|
||||
* @param {object} params.res - The response object, used for initializing the client.
|
||||
* @param {string} params.version - The API version to use.
|
||||
* @param {Omit<AssistantListParams, 'endpoint'>} params.query - The query parameters to list assistants (e.g., limit, order).
|
||||
* @returns {Promise<object>} A promise that resolves to the response from the `openai.beta.assistants.list` method call.
|
||||
*/
|
||||
const listAllAssistants = async ({ req, res, version, query }) => {
|
||||
/** @type {{ openai: OpenAIClient }} */
|
||||
const { openai } = await getOpenAIClient({ req, res, version });
|
||||
const allAssistants = [];
|
||||
|
||||
let first_id;
|
||||
let last_id;
|
||||
let afterToken = query.after;
|
||||
let hasMore = true;
|
||||
|
||||
while (hasMore) {
|
||||
const response = await openai.beta.assistants.list({
|
||||
...query,
|
||||
after: afterToken,
|
||||
});
|
||||
|
||||
const { body } = response;
|
||||
|
||||
allAssistants.push(...body.data);
|
||||
hasMore = body.has_more;
|
||||
|
||||
if (!first_id) {
|
||||
first_id = body.first_id;
|
||||
}
|
||||
|
||||
if (hasMore) {
|
||||
afterToken = body.last_id;
|
||||
} else {
|
||||
last_id = body.last_id;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
data: allAssistants,
|
||||
body: {
|
||||
data: allAssistants,
|
||||
has_more: false,
|
||||
first_id,
|
||||
last_id,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Asynchronously lists assistants for Azure configured groups.
|
||||
*
|
||||
@@ -82,7 +142,7 @@ const listAssistantsForAzure = async ({ req, res, version, azureConfig = {}, que
|
||||
/* The specified model is only necessary to
|
||||
fetch assistants for the shared instance */
|
||||
req.body.model = currentModelTuples[0][0];
|
||||
promises.push(listAssistants({ req, res, version, query }));
|
||||
promises.push(listAllAssistants({ req, res, version, query }));
|
||||
}
|
||||
|
||||
const resolvedQueries = await Promise.all(promises);
|
||||
@@ -133,8 +193,27 @@ async function getOpenAIClient({ req, res, endpointOption, initAppClient, overri
|
||||
return result;
|
||||
}
|
||||
|
||||
const fetchAssistants = async (req, res) => {
|
||||
const { limit = 100, order = 'desc', after, before, endpoint } = req.query;
|
||||
/**
|
||||
* Returns a list of assistants.
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {AssistantListParams} [params.req.query] - The assistant list parameters for pagination and sorting.
|
||||
* @param {object} params.res - Express Response
|
||||
* @param {string} [params.overrideEndpoint] - The endpoint to override the request endpoint.
|
||||
* @returns {Promise<AssistantListResponse>} 200 - success response - application/json
|
||||
*/
|
||||
const fetchAssistants = async ({ req, res, overrideEndpoint }) => {
|
||||
const {
|
||||
limit = 100,
|
||||
order = 'desc',
|
||||
after,
|
||||
before,
|
||||
endpoint,
|
||||
} = req.query ?? {
|
||||
endpoint: overrideEndpoint,
|
||||
...defaultOrderQuery,
|
||||
};
|
||||
|
||||
const version = await getCurrentVersion(req, endpoint);
|
||||
const query = { limit, order, after, before };
|
||||
|
||||
@@ -142,15 +221,47 @@ const fetchAssistants = async (req, res) => {
|
||||
let body;
|
||||
|
||||
if (endpoint === EModelEndpoint.assistants) {
|
||||
({ body } = await listAssistants({ req, res, version, query }));
|
||||
({ body } = await listAllAssistants({ req, res, version, query }));
|
||||
} else if (endpoint === EModelEndpoint.azureAssistants) {
|
||||
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
body = await listAssistantsForAzure({ req, res, version, azureConfig, query });
|
||||
}
|
||||
|
||||
if (req.user.role === 'ADMIN') {
|
||||
return body;
|
||||
} else if (!req.app.locals[endpoint]) {
|
||||
return body;
|
||||
}
|
||||
|
||||
body.data = filterAssistants({
|
||||
userId: req.user.id,
|
||||
assistants: body.data,
|
||||
assistantsConfig: req.app.locals[endpoint],
|
||||
});
|
||||
return body;
|
||||
};
|
||||
|
||||
/**
|
||||
* Filter assistants based on configuration.
|
||||
*
|
||||
* @param {object} params - The parameters object.
|
||||
* @param {string} params.userId - The user ID to filter private assistants.
|
||||
* @param {Assistant[]} params.assistants - The list of assistants to filter.
|
||||
* @param {Partial<TAssistantEndpoint>} params.assistantsConfig - The assistant configuration.
|
||||
* @returns {Assistant[]} - The filtered list of assistants.
|
||||
*/
|
||||
function filterAssistants({ assistants, userId, assistantsConfig }) {
|
||||
const { supportedIds, excludedIds, privateAssistants } = assistantsConfig;
|
||||
if (privateAssistants) {
|
||||
return assistants.filter((assistant) => userId === assistant.metadata?.author);
|
||||
} else if (supportedIds?.length) {
|
||||
return assistants.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
return assistants.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
return assistants;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getOpenAIClient,
|
||||
fetchAssistants,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
const { FileContext } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { deleteAssistantActions } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc, getAssistants } = require('~/models/Assistant');
|
||||
const { uploadImageBuffer } = require('~/server/services/Files/process');
|
||||
const { updateAssistant, getAssistants } = require('~/models/Assistant');
|
||||
const { getOpenAIClient, fetchAssistants } = require('./helpers');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { logger } = require('~/config');
|
||||
@@ -40,9 +41,11 @@ const createAssistant = async (req, res) => {
|
||||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
@@ -61,7 +64,6 @@ const retrieveAssistant = async (req, res) => {
|
||||
try {
|
||||
/* NOTE: not actually being used right now */
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
res.json(assistant);
|
||||
@@ -83,6 +85,7 @@ const retrieveAssistant = async (req, res) => {
|
||||
const patchAssistant = async (req, res) => {
|
||||
try {
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const { endpoint: _e, ...updateData } = req.body;
|
||||
@@ -119,6 +122,7 @@ const patchAssistant = async (req, res) => {
|
||||
const deleteAssistant = async (req, res) => {
|
||||
try {
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const assistant_id = req.params.id;
|
||||
const deletionStatus = await openai.beta.assistants.del(assistant_id);
|
||||
@@ -141,19 +145,7 @@ const deleteAssistant = async (req, res) => {
|
||||
*/
|
||||
const listAssistants = async (req, res) => {
|
||||
try {
|
||||
const body = await fetchAssistants(req, res);
|
||||
|
||||
if (req.app.locals?.[req.query.endpoint]) {
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals[req.query.endpoint];
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
if (supportedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => supportedIds.includes(assistant.id));
|
||||
} else if (excludedIds?.length) {
|
||||
body.data = body.data.filter((assistant) => !excludedIds.includes(assistant.id));
|
||||
}
|
||||
}
|
||||
|
||||
const body = await fetchAssistants({ req, res });
|
||||
res.json(body);
|
||||
} catch (error) {
|
||||
logger.error('[/assistants] Error listing assistants', error);
|
||||
@@ -195,6 +187,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
||||
|
||||
let { metadata: _metadata = '{}' } = req.body;
|
||||
const { openai } = await getOpenAIClient({ req, res });
|
||||
await validateAuthor({ req, openai });
|
||||
|
||||
const image = await uploadImageBuffer({
|
||||
req,
|
||||
@@ -229,7 +222,7 @@ const uploadAssistantAvatar = async (req, res) => {
|
||||
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
avatar: {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const { ToolCallTypes } = require('librechat-data-provider');
|
||||
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
|
||||
const { validateAndUpdateTool } = require('~/server/services/ActionService');
|
||||
const { updateAssistantDoc } = require('~/models/Assistant');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -37,9 +39,11 @@ const createAssistant = async (req, res) => {
|
||||
};
|
||||
|
||||
const assistant = await openai.beta.assistants.create(assistantData);
|
||||
const promise = updateAssistantDoc({ assistant_id: assistant.id }, { user: req.user.id });
|
||||
if (azureModelIdentifier) {
|
||||
assistant.model = azureModelIdentifier;
|
||||
}
|
||||
await promise;
|
||||
logger.debug('/assistants/', assistant);
|
||||
res.status(201).json(assistant);
|
||||
} catch (error) {
|
||||
@@ -58,6 +62,7 @@ const createAssistant = async (req, res) => {
|
||||
* @returns {Promise<Assistant>} The updated assistant.
|
||||
*/
|
||||
const updateAssistant = async ({ req, openai, assistant_id, updateData }) => {
|
||||
await validateAuthor({ req, openai });
|
||||
const tools = [];
|
||||
|
||||
let hasFileSearch = false;
|
||||
|
||||
@@ -15,7 +15,7 @@ const AppService = require('./services/AppService');
|
||||
const noIndex = require('./middleware/noIndex');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { ldapLogin } = require('~/strategies');
|
||||
const routes = require('./routes');
|
||||
|
||||
const { PORT, HOST, ALLOW_SOCIAL_LOGIN } = process.env ?? {};
|
||||
@@ -60,6 +60,11 @@ const startServer = async () => {
|
||||
passport.use(await jwtLogin());
|
||||
passport.use(passportLogin());
|
||||
|
||||
// LDAP Auth
|
||||
if (process.env.LDAP_URL && process.env.LDAP_BIND_DN && process.env.LDAP_USER_SEARCH_BASE) {
|
||||
passport.use(ldapLogin);
|
||||
}
|
||||
|
||||
if (isEnabled(ALLOW_SOCIAL_LOGIN)) {
|
||||
configureSocialLogins(app);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/assistants');
|
||||
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
|
||||
const { deleteMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { sendMessage } = require('~/server/utils');
|
||||
@@ -66,13 +67,19 @@ async function abortRun(req, res) {
|
||||
logger.error('[abortRun] Error fetching or processing run', error);
|
||||
}
|
||||
|
||||
/* TODO: a reconciling strategy between the existing intermediate message would be more optimal than deleting it */
|
||||
await deleteMessages({
|
||||
user: req.user.id,
|
||||
unfinished: true,
|
||||
conversationId,
|
||||
});
|
||||
runMessages = await checkMessageGaps({
|
||||
openai,
|
||||
run_id,
|
||||
endpoint,
|
||||
thread_id,
|
||||
run_id,
|
||||
latestMessageId,
|
||||
conversationId,
|
||||
latestMessageId,
|
||||
});
|
||||
|
||||
const finalEvent = {
|
||||
|
||||
43
api/server/middleware/assistants/validate.js
Normal file
43
api/server/middleware/assistants/validate.js
Normal file
@@ -0,0 +1,43 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { handleAbortError } = require('~/server/middleware/abortMiddleware');
|
||||
|
||||
/**
|
||||
* Checks if the assistant is supported or excluded
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.body - The request payload.
|
||||
* @param {object} res - Express Response
|
||||
* @param {function} next - Express next middleware function.
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAssistant = async (req, res, next) => {
|
||||
const { endpoint, conversationId, assistant_id, messageId } = req.body;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const { supportedIds, excludedIds } = assistantsConfig;
|
||||
const error = { message: 'Assistant not supported' };
|
||||
if (supportedIds?.length && !supportedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
parentMessageId: messageId,
|
||||
error,
|
||||
});
|
||||
} else if (excludedIds?.length && excludedIds.includes(assistant_id)) {
|
||||
return await handleAbortError(res, req, error, {
|
||||
sender: 'System',
|
||||
conversationId,
|
||||
messageId: v4(),
|
||||
parentMessageId: messageId,
|
||||
});
|
||||
}
|
||||
|
||||
return next();
|
||||
};
|
||||
|
||||
module.exports = validateAssistant;
|
||||
42
api/server/middleware/assistants/validateAuthor.js
Normal file
42
api/server/middleware/assistants/validateAuthor.js
Normal file
@@ -0,0 +1,42 @@
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
|
||||
/**
|
||||
* Checks if the assistant is supported or excluded
|
||||
* @param {object} params
|
||||
* @param {object} params.req - Express Request
|
||||
* @param {object} params.req.body - The request payload.
|
||||
* @param {string} params.overrideEndpoint - The override endpoint
|
||||
* @param {string} params.overrideAssistantId - The override assistant ID
|
||||
* @param {OpenAIClient} params.openai - OpenAI API Client
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const validateAuthor = async ({ req, openai, overrideEndpoint, overrideAssistantId }) => {
|
||||
if (req.user.role === 'ADMIN') {
|
||||
return;
|
||||
}
|
||||
|
||||
const endpoint = overrideEndpoint ?? req.body.endpoint ?? req.query.endpoint;
|
||||
const assistant_id =
|
||||
overrideAssistantId ?? req.params.id ?? req.body.assistant_id ?? req.query.assistant_id;
|
||||
|
||||
/** @type {Partial<TAssistantEndpoint>} */
|
||||
const assistantsConfig = req.app.locals?.[endpoint];
|
||||
if (!assistantsConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!assistantsConfig.privateAssistants) {
|
||||
return;
|
||||
}
|
||||
|
||||
const assistantDoc = await getAssistant({ assistant_id, user: req.user.id });
|
||||
if (assistantDoc) {
|
||||
return;
|
||||
}
|
||||
const assistant = await openai.beta.assistants.retrieve(assistant_id);
|
||||
if (req.user.id !== assistant?.metadata?.author) {
|
||||
throw new Error(`Assistant ${assistant_id} is not authored by the user.`);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = validateAuthor;
|
||||
@@ -2,14 +2,12 @@ const Keyv = require('keyv');
|
||||
const uap = require('ua-parser-js');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, removePorts } = require('../utils');
|
||||
const keyvRedis = require('~/cache/keyvRedis');
|
||||
const keyvMongo = require('~/cache/keyvMongo');
|
||||
const denyRequest = require('./denyRequest');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const User = require('~/models/User');
|
||||
|
||||
const banCache = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 });
|
||||
const message = 'Your account has been temporarily banned due to violations of our service.';
|
||||
|
||||
/**
|
||||
|
||||
@@ -6,6 +6,7 @@ const setHeaders = require('./setHeaders');
|
||||
const loginLimiter = require('./loginLimiter');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const uploadLimiters = require('./uploadLimiters');
|
||||
const registerLimiter = require('./registerLimiter');
|
||||
const messageLimiters = require('./messageLimiters');
|
||||
@@ -29,6 +30,7 @@ module.exports = {
|
||||
setHeaders,
|
||||
loginLimiter,
|
||||
requireJwtAuth,
|
||||
requireLdapAuth,
|
||||
registerLimiter,
|
||||
requireLocalAuth,
|
||||
validateEndpoint,
|
||||
|
||||
22
api/server/middleware/requireLdapAuth.js
Normal file
22
api/server/middleware/requireLdapAuth.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const passport = require('passport');
|
||||
|
||||
const requireLdapAuth = (req, res, next) => {
|
||||
passport.authenticate('ldapauth', (err, user, info) => {
|
||||
if (err) {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error at passport.authenticate',
|
||||
parameters: [{ name: 'error', value: err }],
|
||||
});
|
||||
return next(err);
|
||||
}
|
||||
if (!user) {
|
||||
console.log({
|
||||
title: '(requireLdapAuth) Error: No user',
|
||||
});
|
||||
return res.status(422).send(info);
|
||||
}
|
||||
req.user = user;
|
||||
next();
|
||||
})(req, res, next);
|
||||
};
|
||||
module.exports = requireLdapAuth;
|
||||
7
api/server/middleware/speech/index.js
Normal file
7
api/server/middleware/speech/index.js
Normal file
@@ -0,0 +1,7 @@
|
||||
const createTTSLimiters = require('./ttsLimiters');
|
||||
const createSTTLimiters = require('./sttLimiters');
|
||||
|
||||
module.exports = {
|
||||
createTTSLimiters,
|
||||
createSTTLimiters,
|
||||
};
|
||||
68
api/server/middleware/speech/sttLimiters.js
Normal file
68
api/server/middleware/speech/sttLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
|
||||
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
|
||||
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
|
||||
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;
|
||||
|
||||
const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
|
||||
const sttIpMax = STT_IP_MAX;
|
||||
const sttIpWindowInMinutes = sttIpWindowMs / 60000;
|
||||
|
||||
const sttUserWindowMs = STT_USER_WINDOW * 60 * 1000;
|
||||
const sttUserMax = STT_USER_MAX;
|
||||
const sttUserWindowInMinutes = sttUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
sttIpWindowMs,
|
||||
sttIpMax,
|
||||
sttIpWindowInMinutes,
|
||||
sttUserWindowMs,
|
||||
sttUserMax,
|
||||
sttUserWindowInMinutes,
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTHandler = (ip = true) => {
|
||||
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.STT_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? sttIpMax : sttUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many STT requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createSTTLimiters = () => {
|
||||
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
|
||||
|
||||
const sttIpLimiter = rateLimit({
|
||||
windowMs: sttIpWindowMs,
|
||||
max: sttIpMax,
|
||||
handler: createSTTHandler(),
|
||||
});
|
||||
|
||||
const sttUserLimiter = rateLimit({
|
||||
windowMs: sttUserWindowMs,
|
||||
max: sttUserMax,
|
||||
handler: createSTTHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
|
||||
return { sttIpLimiter, sttUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = createSTTLimiters;
|
||||
68
api/server/middleware/speech/ttsLimiters.js
Normal file
68
api/server/middleware/speech/ttsLimiters.js
Normal file
@@ -0,0 +1,68 @@
|
||||
const rateLimit = require('express-rate-limit');
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const logViolation = require('~/cache/logViolation');
|
||||
|
||||
const getEnvironmentVariables = () => {
|
||||
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
|
||||
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
|
||||
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
|
||||
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;
|
||||
|
||||
const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
|
||||
const ttsIpMax = TTS_IP_MAX;
|
||||
const ttsIpWindowInMinutes = ttsIpWindowMs / 60000;
|
||||
|
||||
const ttsUserWindowMs = TTS_USER_WINDOW * 60 * 1000;
|
||||
const ttsUserMax = TTS_USER_MAX;
|
||||
const ttsUserWindowInMinutes = ttsUserWindowMs / 60000;
|
||||
|
||||
return {
|
||||
ttsIpWindowMs,
|
||||
ttsIpMax,
|
||||
ttsIpWindowInMinutes,
|
||||
ttsUserWindowMs,
|
||||
ttsUserMax,
|
||||
ttsUserWindowInMinutes,
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSHandler = (ip = true) => {
|
||||
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
|
||||
getEnvironmentVariables();
|
||||
|
||||
return async (req, res) => {
|
||||
const type = ViolationTypes.TTS_LIMIT;
|
||||
const errorMessage = {
|
||||
type,
|
||||
max: ip ? ttsIpMax : ttsUserMax,
|
||||
limiter: ip ? 'ip' : 'user',
|
||||
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
|
||||
};
|
||||
|
||||
await logViolation(req, res, type, errorMessage);
|
||||
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
|
||||
};
|
||||
};
|
||||
|
||||
const createTTSLimiters = () => {
|
||||
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
|
||||
|
||||
const ttsIpLimiter = rateLimit({
|
||||
windowMs: ttsIpWindowMs,
|
||||
max: ttsIpMax,
|
||||
handler: createTTSHandler(),
|
||||
});
|
||||
|
||||
const ttsUserLimiter = rateLimit({
|
||||
windowMs: ttsUserWindowMs,
|
||||
max: ttsUserMax,
|
||||
handler: createTTSHandler(false),
|
||||
keyGenerator: function (req) {
|
||||
return req.user?.id; // Use the user ID or NULL if not available
|
||||
},
|
||||
});
|
||||
|
||||
return { ttsIpLimiter, ttsUserLimiter };
|
||||
};
|
||||
|
||||
module.exports = createTTSLimiters;
|
||||
@@ -25,6 +25,11 @@ afterEach(() => {
|
||||
delete process.env.DOMAIN_SERVER;
|
||||
delete process.env.ALLOW_REGISTRATION;
|
||||
delete process.env.ALLOW_SOCIAL_LOGIN;
|
||||
delete process.env.LDAP_URL;
|
||||
delete process.env.LDAP_BIND_DN;
|
||||
delete process.env.LDAP_BIND_CREDENTIALS;
|
||||
delete process.env.LDAP_USER_SEARCH_BASE;
|
||||
delete process.env.LDAP_SEARCH_FILTER;
|
||||
});
|
||||
|
||||
//TODO: This works/passes locally but http request tests fail with 404 in CI. Need to figure out why.
|
||||
@@ -50,6 +55,11 @@ describe.skip('GET /', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://test-server.com';
|
||||
process.env.ALLOW_REGISTRATION = 'true';
|
||||
process.env.ALLOW_SOCIAL_LOGIN = 'true';
|
||||
process.env.LDAP_URL = 'Test LDAP URL';
|
||||
process.env.LDAP_BIND_DN = 'Test LDAP Bind DN';
|
||||
process.env.LDAP_BIND_CREDENTIALS = 'Test LDAP Bind Credentials';
|
||||
process.env.LDAP_USER_SEARCH_BASE = 'Test LDAP User Search Base';
|
||||
process.env.LDAP_SEARCH_FILTER = 'Test LDAP Search Filter';
|
||||
|
||||
const response = await request(app).get('/');
|
||||
|
||||
@@ -64,6 +74,7 @@ describe.skip('GET /', () => {
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
ldapLoginEnabled: true,
|
||||
serverDomain: 'http://test-server.com',
|
||||
emailLoginEnabled: 'true',
|
||||
registrationEnabled: 'true',
|
||||
|
||||
@@ -106,7 +106,11 @@ router.post(
|
||||
const pluginMap = new Map();
|
||||
const onAgentAction = async (action, runId) => {
|
||||
pluginMap.set(runId, action.tool);
|
||||
sendIntermediateMessage(res, { plugins });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const onToolStart = async (tool, input, runId, parentRunId) => {
|
||||
@@ -124,7 +128,11 @@ router.post(
|
||||
}
|
||||
const extraTokens = ':::plugin:::\n';
|
||||
plugins.push(latestPlugin);
|
||||
sendIntermediateMessage(res, { plugins }, extraTokens);
|
||||
sendIntermediateMessage(
|
||||
res,
|
||||
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
|
||||
extraTokens,
|
||||
);
|
||||
};
|
||||
|
||||
const onToolEnd = async (output, runId) => {
|
||||
@@ -142,7 +150,11 @@ router.post(
|
||||
|
||||
const onChainEnd = () => {
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, { plugins });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
@@ -174,12 +186,13 @@ router.post(
|
||||
onStart,
|
||||
getPartialText,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
plugins,
|
||||
}),
|
||||
},
|
||||
abortController,
|
||||
});
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ const { encryptMetadata, domainParser } = require('~/server/services/ActionServi
|
||||
const { actionDelimiter, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||
const { updateAssistant, getAssistant } = require('~/models/Assistant');
|
||||
const { updateAssistantDoc, getAssistant } = require('~/models/Assistant');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -109,7 +109,7 @@ router.post('/:assistant_id', async (req, res) => {
|
||||
let updatedAssistant = await openai.beta.assistants.update(assistant_id, { tools });
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
actions,
|
||||
@@ -186,7 +186,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => {
|
||||
|
||||
const promises = [];
|
||||
promises.push(
|
||||
updateAssistant(
|
||||
updateAssistantDoc(
|
||||
{ assistant_id },
|
||||
{
|
||||
actions: updatedActions,
|
||||
|
||||
@@ -8,6 +8,7 @@ const {
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const validateAssistant = require('~/server/middleware/assistants/validate');
|
||||
const chatController = require('~/server/controllers/assistants/chatV1');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
@@ -20,6 +21,6 @@ router.post('/abort', handleAbort());
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', validateModel, buildEndpointOption, setHeaders, chatController);
|
||||
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -8,6 +8,7 @@ const {
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const validateAssistant = require('~/server/middleware/assistants/validate');
|
||||
const chatController = require('~/server/controllers/assistants/chatV2');
|
||||
|
||||
router.post('/abort', handleAbort());
|
||||
@@ -20,6 +21,6 @@ router.post('/abort', handleAbort());
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post('/', validateModel, buildEndpointOption, setHeaders, chatController);
|
||||
router.post('/', validateModel, buildEndpointOption, validateAssistant, setHeaders, chatController);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -12,15 +12,24 @@ const {
|
||||
loginLimiter,
|
||||
registerLimiter,
|
||||
requireJwtAuth,
|
||||
requireLdapAuth,
|
||||
requireLocalAuth,
|
||||
validateRegistration,
|
||||
} = require('../middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const ldapAuth =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
//Local
|
||||
router.post('/logout', requireJwtAuth, logoutController);
|
||||
router.post('/login', loginLimiter, checkBan, requireLocalAuth, loginController);
|
||||
router.post(
|
||||
'/login',
|
||||
loginLimiter,
|
||||
checkBan,
|
||||
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||
loginController,
|
||||
);
|
||||
router.post('/refresh', refreshController);
|
||||
router.post('/register', registerLimiter, checkBan, validateRegistration, registrationController);
|
||||
router.post('/requestPasswordReset', resetPasswordRequestController);
|
||||
|
||||
@@ -13,6 +13,8 @@ router.get('/', async function (req, res) {
|
||||
return today.getMonth() === 1 && today.getDate() === 11;
|
||||
};
|
||||
|
||||
const ldapLoginEnabled =
|
||||
!!process.env.LDAP_URL && !!process.env.LDAP_BIND_DN && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
try {
|
||||
/** @type {TStartupConfig} */
|
||||
const payload = {
|
||||
@@ -30,9 +32,10 @@ router.get('/', async function (req, res) {
|
||||
!!process.env.OPENID_SESSION_SECRET,
|
||||
openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID',
|
||||
openidImageUrl: process.env.OPENID_IMAGE_URL,
|
||||
ldapLoginEnabled,
|
||||
serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080',
|
||||
emailLoginEnabled,
|
||||
registrationEnabled: isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
registrationEnabled: !ldapLoginEnabled && isEnabled(process.env.ALLOW_REGISTRATION),
|
||||
socialLoginEnabled: isEnabled(process.env.ALLOW_SOCIAL_LOGIN),
|
||||
emailEnabled:
|
||||
(!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) &&
|
||||
|
||||
@@ -110,7 +110,11 @@ router.post(
|
||||
if (!start) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, { plugin });
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
@@ -119,7 +123,11 @@ router.post(
|
||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||
plugin.loading = false;
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, { plugin });
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('CHAIN END', plugin.outputs);
|
||||
};
|
||||
|
||||
@@ -153,12 +161,13 @@ router.post(
|
||||
onChainEnd,
|
||||
onStart,
|
||||
...endpointOption,
|
||||
onProgress: progressCallback.call(null, {
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
text,
|
||||
plugin,
|
||||
parentMessageId: overrideParentMessageId || userMessageId,
|
||||
}),
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
},
|
||||
abortController,
|
||||
});
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
const express = require('express');
|
||||
const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware');
|
||||
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware/speech');
|
||||
const { createMulterInstance } = require('./multer');
|
||||
|
||||
const files = require('./files');
|
||||
const images = require('./images');
|
||||
const avatar = require('./avatar');
|
||||
const stt = require('./stt');
|
||||
const tts = require('./tts');
|
||||
|
||||
const initialize = async () => {
|
||||
const router = express.Router();
|
||||
@@ -12,6 +15,12 @@ const initialize = async () => {
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
/* Important: stt/tts routes must be added before the upload limiters */
|
||||
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
|
||||
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
|
||||
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
|
||||
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
|
||||
|
||||
const upload = await createMulterInstance();
|
||||
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
|
||||
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);
|
||||
|
||||
13
api/server/routes/files/stt.js
Normal file
13
api/server/routes/files/stt.js
Normal file
@@ -0,0 +1,13 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const multer = require('multer');
|
||||
const { requireJwtAuth } = require('~/server/middleware/');
|
||||
const { speechToText } = require('~/server/services/Files/Audio');
|
||||
|
||||
const upload = multer();
|
||||
|
||||
router.post('/', requireJwtAuth, upload.single('audio'), async (req, res) => {
|
||||
await speechToText(req, res);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
61
api/server/routes/files/tts.js
Normal file
61
api/server/routes/files/tts.js
Normal file
@@ -0,0 +1,61 @@
|
||||
const multer = require('multer');
|
||||
const express = require('express');
|
||||
const Bottleneck = require('bottleneck');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { getVoices, streamAudio, textToSpeech } = require('~/server/services/Files/Audio');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
const upload = multer();
|
||||
|
||||
// todo: can add Redis support for limiter
|
||||
const limiter = new Bottleneck({
|
||||
minTime: 240, // Minimum time between requests (240ms per request = 250 requests per minute)
|
||||
maxConcurrent: 100, // Maximum number of concurrent requests
|
||||
reservoir: 250, // Initial number of available requests
|
||||
reservoirRefreshAmount: 250, // Number of requests replenished in each interval
|
||||
reservoirRefreshInterval: 60 * 1000, // Reservoir refresh interval (60 seconds)
|
||||
});
|
||||
|
||||
const limitedStreamAudio = limiter.wrap(streamAudio);
|
||||
const limitedTextToSpeech = limiter.wrap(textToSpeech);
|
||||
|
||||
router.post('/manual', upload.none(), async (req, res) => {
|
||||
try {
|
||||
await limitedTextToSpeech(req, res);
|
||||
} catch (error) {
|
||||
logger.error(`[textToSpeech] user: ${req.user.id} | Failed to process textToSpeech: ${error}`);
|
||||
res.status(500).json({ error: 'Failed to process textToSpeech' });
|
||||
}
|
||||
});
|
||||
|
||||
const logDebugMessage = (req, message) =>
|
||||
logger.debug(`[streamAudio] user: ${req?.user?.id ?? 'UNDEFINED_USER'} | ${message}`);
|
||||
|
||||
// TODO: test caching
|
||||
router.post('/', async (req, res) => {
|
||||
try {
|
||||
const audioRunsCache = getLogStores(CacheKeys.AUDIO_RUNS);
|
||||
const audioRun = await audioRunsCache.get(req.body.runId);
|
||||
logDebugMessage(req, 'start stream audio');
|
||||
if (audioRun) {
|
||||
logDebugMessage(req, 'stream audio already running');
|
||||
return res.status(401).json({ error: 'Audio stream already running' });
|
||||
}
|
||||
audioRunsCache.set(req.body.runId, true);
|
||||
await limitedStreamAudio(req, res);
|
||||
logDebugMessage(req, 'end stream audio');
|
||||
res.status(200).end();
|
||||
} catch (error) {
|
||||
logger.error(`[streamAudio] user: ${req.user.id} | Failed to stream audio: ${error}`);
|
||||
res.status(500).json({ error: 'Failed to stream audio' });
|
||||
}
|
||||
});
|
||||
|
||||
// todo: cache voices
|
||||
router.get('/voices', async (req, res) => {
|
||||
await getVoices(req, res);
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -14,7 +14,7 @@ router.use(requireJwtAuth);
|
||||
|
||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
const { conversationId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId }));
|
||||
res.status(200).send(await getMessages({ conversationId }, '-_id -__v -user'));
|
||||
});
|
||||
|
||||
// CREATE
|
||||
@@ -28,7 +28,7 @@ router.post('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
// READ
|
||||
router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
|
||||
const { conversationId, messageId } = req.params;
|
||||
res.status(200).send(await getMessages({ conversationId, messageId }));
|
||||
res.status(200).send(await getMessages({ conversationId, messageId }, '-_id -__v -user'));
|
||||
});
|
||||
|
||||
// UPDATE
|
||||
|
||||
@@ -78,6 +78,7 @@ const AppService = async (app) => {
|
||||
if (config?.endpoints?.[EModelEndpoint.azureAssistants]) {
|
||||
endpointLocals[EModelEndpoint.azureAssistants] = assistantsConfigSetup(
|
||||
config,
|
||||
EModelEndpoint.azureAssistants,
|
||||
endpointLocals[EModelEndpoint.azureAssistants],
|
||||
);
|
||||
}
|
||||
@@ -85,6 +86,7 @@ const AppService = async (app) => {
|
||||
if (config?.endpoints?.[EModelEndpoint.assistants]) {
|
||||
endpointLocals[EModelEndpoint.assistants] = assistantsConfigSetup(
|
||||
config,
|
||||
EModelEndpoint.assistants,
|
||||
endpointLocals[EModelEndpoint.assistants],
|
||||
);
|
||||
}
|
||||
|
||||
@@ -218,6 +218,7 @@ describe('AppService', () => {
|
||||
pollIntervalMs: 5000,
|
||||
timeoutMs: 30000,
|
||||
supportedIds: ['id1', 'id2'],
|
||||
privateAssistants: false,
|
||||
},
|
||||
},
|
||||
}),
|
||||
@@ -232,6 +233,7 @@ describe('AppService', () => {
|
||||
pollIntervalMs: 5000,
|
||||
timeoutMs: 30000,
|
||||
supportedIds: expect.arrayContaining(['id1', 'id2']),
|
||||
privateAssistants: false,
|
||||
}),
|
||||
);
|
||||
});
|
||||
@@ -505,7 +507,31 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Both `supportedIds` and `excludedIds` are defined'),
|
||||
expect.stringContaining(
|
||||
'The \'assistants\' endpoint has both \'supportedIds\' and \'excludedIds\' defined.',
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
it('should log a warning when privateAssistants and supportedIds or excludedIds are provided', async () => {
|
||||
const mockConfig = {
|
||||
endpoints: {
|
||||
assistants: {
|
||||
privateAssistants: true,
|
||||
supportedIds: ['id1'],
|
||||
},
|
||||
},
|
||||
};
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
|
||||
|
||||
const app = { locals: {} };
|
||||
await require('./AppService')(app);
|
||||
|
||||
const { logger } = require('~/config');
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'The \'assistants\' endpoint has both \'privateAssistants\' and \'supportedIds\' or \'excludedIds\' defined.',
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
const { RateLimitPrefix } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {TCustomConfig['rateLimits'] | undefined} rateLimits
|
||||
@@ -6,24 +8,41 @@ const handleRateLimits = (rateLimits) => {
|
||||
if (!rateLimits) {
|
||||
return;
|
||||
}
|
||||
const { fileUploads, conversationsImport } = rateLimits;
|
||||
if (fileUploads) {
|
||||
process.env.FILE_UPLOAD_IP_MAX = fileUploads.ipMax ?? process.env.FILE_UPLOAD_IP_MAX;
|
||||
process.env.FILE_UPLOAD_IP_WINDOW =
|
||||
fileUploads.ipWindowInMinutes ?? process.env.FILE_UPLOAD_IP_WINDOW;
|
||||
process.env.FILE_UPLOAD_USER_MAX = fileUploads.userMax ?? process.env.FILE_UPLOAD_USER_MAX;
|
||||
process.env.FILE_UPLOAD_USER_WINDOW =
|
||||
fileUploads.userWindowInMinutes ?? process.env.FILE_UPLOAD_USER_WINDOW;
|
||||
}
|
||||
|
||||
if (conversationsImport) {
|
||||
process.env.IMPORT_IP_MAX = conversationsImport.ipMax ?? process.env.IMPORT_IP_MAX;
|
||||
process.env.IMPORT_IP_WINDOW =
|
||||
conversationsImport.ipWindowInMinutes ?? process.env.IMPORT_IP_WINDOW;
|
||||
process.env.IMPORT_USER_MAX = conversationsImport.userMax ?? process.env.IMPORT_USER_MAX;
|
||||
process.env.IMPORT_USER_WINDOW =
|
||||
conversationsImport.userWindowInMinutes ?? process.env.IMPORT_USER_WINDOW;
|
||||
}
|
||||
const rateLimitKeys = {
|
||||
fileUploads: RateLimitPrefix.FILE_UPLOAD,
|
||||
conversationsImport: RateLimitPrefix.IMPORT,
|
||||
tts: RateLimitPrefix.TTS,
|
||||
stt: RateLimitPrefix.STT,
|
||||
};
|
||||
|
||||
Object.entries(rateLimitKeys).forEach(([key, prefix]) => {
|
||||
const rateLimit = rateLimits[key];
|
||||
if (rateLimit) {
|
||||
setRateLimitEnvVars(prefix, rateLimit);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Set environment variables for rate limit configurations
|
||||
*
|
||||
* @param {string} prefix - Prefix for environment variable names
|
||||
* @param {object} rateLimit - Rate limit configuration object
|
||||
*/
|
||||
const setRateLimitEnvVars = (prefix, rateLimit) => {
|
||||
const envVarsMapping = {
|
||||
ipMax: `${prefix}_IP_MAX`,
|
||||
ipWindowInMinutes: `${prefix}_IP_WINDOW`,
|
||||
userMax: `${prefix}_USER_MAX`,
|
||||
userWindowInMinutes: `${prefix}_USER_WINDOW`,
|
||||
};
|
||||
|
||||
Object.entries(envVarsMapping).forEach(([key, envVar]) => {
|
||||
if (rateLimit[key] !== undefined) {
|
||||
process.env[envVar] = rateLimit[key];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = handleRateLimits;
|
||||
|
||||
@@ -112,6 +112,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
modelDisplayLabel: endpointConfig.modelDisplayLabel,
|
||||
titleMethod: endpointConfig.titleMethod ?? 'completion',
|
||||
contextStrategy: endpointConfig.summarize ? 'summarize' : null,
|
||||
directEndpoint: endpointConfig.directEndpoint,
|
||||
titleMessageRole: endpointConfig.titleMessageRole,
|
||||
endpointTokenConfig,
|
||||
};
|
||||
|
||||
|
||||
48
api/server/services/Files/Audio/getVoices.js
Normal file
48
api/server/services/Files/Audio/getVoices.js
Normal file
@@ -0,0 +1,48 @@
|
||||
const { logger } = require('~/config');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getProvider } = require('./textToSpeech');
|
||||
|
||||
/**
|
||||
* This function retrieves the available voices for the current TTS provider
|
||||
* It first fetches the TTS configuration and determines the provider
|
||||
* Then, based on the provider, it sends the corresponding voices as a JSON response
|
||||
*
|
||||
* @param {Object} req - The request object
|
||||
* @param {Object} res - The response object
|
||||
* @returns {Promise<void>}
|
||||
* @throws {Error} - If the provider is not 'openai' or 'elevenlabs', an error is thrown
|
||||
*/
|
||||
async function getVoices(req, res) {
|
||||
try {
|
||||
const customConfig = await getCustomConfig();
|
||||
|
||||
if (!customConfig || !customConfig?.tts) {
|
||||
throw new Error('Configuration or TTS schema is missing');
|
||||
}
|
||||
|
||||
const ttsSchema = customConfig?.tts;
|
||||
const provider = getProvider(ttsSchema);
|
||||
let voices;
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
voices = ttsSchema.openai?.voices;
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
voices = ttsSchema.elevenlabs?.voices;
|
||||
break;
|
||||
case 'localai':
|
||||
voices = ttsSchema.localai?.voices;
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
res.json(voices);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get voices: ${error.message}`);
|
||||
res.status(500).json({ error: 'Failed to get voices' });
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = getVoices;
|
||||
11
api/server/services/Files/Audio/index.js
Normal file
11
api/server/services/Files/Audio/index.js
Normal file
@@ -0,0 +1,11 @@
|
||||
const getVoices = require('./getVoices');
|
||||
const textToSpeech = require('./textToSpeech');
|
||||
const speechToText = require('./speechToText');
|
||||
const { updateTokenWebsocket } = require('./webSocket');
|
||||
|
||||
module.exports = {
|
||||
getVoices,
|
||||
speechToText,
|
||||
...textToSpeech,
|
||||
updateTokenWebsocket,
|
||||
};
|
||||
211
api/server/services/Files/Audio/speechToText.js
Normal file
211
api/server/services/Files/Audio/speechToText.js
Normal file
@@ -0,0 +1,211 @@
|
||||
const axios = require('axios');
|
||||
const { Readable } = require('stream');
|
||||
const { logger } = require('~/config');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Handle the response from the STT API
|
||||
* @param {Object} response - The response from the STT API
|
||||
*
|
||||
* @returns {string} The text from the response data
|
||||
*
|
||||
* @throws Will throw an error if the response status is not 200 or the response data is missing
|
||||
*/
|
||||
async function handleResponse(response) {
|
||||
if (response.status !== 200) {
|
||||
throw new Error('Invalid response from the STT API');
|
||||
}
|
||||
|
||||
if (!response.data || !response.data.text) {
|
||||
throw new Error('Missing data in response from the STT API');
|
||||
}
|
||||
|
||||
return response.data.text.trim();
|
||||
}
|
||||
|
||||
function getProvider(sttSchema) {
|
||||
if (sttSchema.openai) {
|
||||
return 'openai';
|
||||
}
|
||||
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
function removeUndefined(obj) {
|
||||
Object.keys(obj).forEach((key) => {
|
||||
if (obj[key] && typeof obj[key] === 'object') {
|
||||
removeUndefined(obj[key]);
|
||||
if (Object.keys(obj[key]).length === 0) {
|
||||
delete obj[key];
|
||||
}
|
||||
} else if (obj[key] === undefined) {
|
||||
delete obj[key];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the OpenAI API
|
||||
* It uses the provided speech-to-text schema and audio stream to create the request
|
||||
*
|
||||
* @param {Object} sttSchema - The speech-to-text schema containing the OpenAI configuration
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it returns an array with three null values and logs the error with logger
|
||||
*/
|
||||
function openAIProvider(sttSchema, audioReadStream) {
|
||||
try {
|
||||
const url = sttSchema.openai?.url || 'https://api.openai.com/v1/audio/transcriptions';
|
||||
const apiKey = sttSchema.openai.apiKey ? extractEnvVariable(sttSchema.openai.apiKey) : '';
|
||||
|
||||
let data = {
|
||||
file: audioReadStream,
|
||||
model: sttSchema.openai.model,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
};
|
||||
|
||||
[headers].forEach(removeUndefined);
|
||||
|
||||
if (apiKey) {
|
||||
headers.Authorization = 'Bearer ' + apiKey;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while preparing the OpenAI API STT request: ', error);
|
||||
return [null, null, null];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the Azure API
|
||||
* It uses the provided request and audio stream to create the request
|
||||
*
|
||||
* @param {Object} req - The request object, which should contain the endpoint in its body
|
||||
* @param {Stream} audioReadStream - The audio data to be transcribed
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it returns an array with three null values and logs the error with logger
|
||||
*/
|
||||
function azureProvider(req, audioReadStream) {
|
||||
try {
|
||||
const { endpoint } = req.body;
|
||||
const azureConfig = req.app.locals[endpoint];
|
||||
|
||||
if (!azureConfig) {
|
||||
throw new Error(`No configuration found for endpoint: ${endpoint}`);
|
||||
}
|
||||
|
||||
const { apiKey, instanceName, whisperModel, apiVersion } = Object.entries(
|
||||
azureConfig.groupMap,
|
||||
).reduce((acc, [, value]) => {
|
||||
if (acc) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const whisperKey = Object.keys(value.models).find((modelKey) =>
|
||||
modelKey.startsWith('whisper'),
|
||||
);
|
||||
|
||||
if (whisperKey) {
|
||||
return {
|
||||
apiVersion: value.version,
|
||||
apiKey: value.apiKey,
|
||||
instanceName: value.instanceName,
|
||||
whisperModel: value.models[whisperKey]['deploymentName'],
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}, null);
|
||||
|
||||
if (!apiKey || !instanceName || !whisperModel || !apiVersion) {
|
||||
throw new Error('Required Azure configuration values are missing');
|
||||
}
|
||||
|
||||
const baseURL = `https://${instanceName}.openai.azure.com`;
|
||||
|
||||
const url = `${baseURL}/openai/deployments/${whisperModel}/audio/transcriptions?api-version=${apiVersion}`;
|
||||
|
||||
let data = {
|
||||
file: audioReadStream,
|
||||
filename: 'audio.wav',
|
||||
contentType: 'audio/wav',
|
||||
knownLength: audioReadStream.length,
|
||||
};
|
||||
|
||||
const headers = {
|
||||
...data.getHeaders(),
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'api-key': apiKey,
|
||||
};
|
||||
|
||||
return [url, data, headers];
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while preparing the Azure API STT request: ', error);
|
||||
return [null, null, null];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert speech to text
|
||||
* @param {Object} req - The request object
|
||||
* @param {Object} res - The response object
|
||||
*
|
||||
* @returns {Object} The response object with the text from the STT API
|
||||
*
|
||||
* @throws Will throw an error if an error occurs while processing the audio
|
||||
*/
|
||||
|
||||
async function speechToText(req, res) {
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
return res.status(500).send('Custom config not found');
|
||||
}
|
||||
|
||||
if (!req.file || !req.file.buffer) {
|
||||
return res.status(400).json({ message: 'No audio file provided in the FormData' });
|
||||
}
|
||||
|
||||
const audioBuffer = req.file.buffer;
|
||||
const audioReadStream = Readable.from(audioBuffer);
|
||||
audioReadStream.path = 'audio.wav';
|
||||
|
||||
const provider = getProvider(customConfig.stt);
|
||||
|
||||
let [url, data, headers] = [];
|
||||
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
[url, data, headers] = openAIProvider(customConfig.stt, audioReadStream);
|
||||
break;
|
||||
case 'azure':
|
||||
[url, data, headers] = azureProvider(req, audioReadStream);
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
if (!Readable.from) {
|
||||
const audioBlob = new Blob([audioBuffer], { type: req.file.mimetype });
|
||||
delete data['file'];
|
||||
data['file'] = audioBlob;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.post(url, data, { headers: headers });
|
||||
const text = await handleResponse(response);
|
||||
|
||||
res.json({ text });
|
||||
} catch (error) {
|
||||
logger.error('An error occurred while processing the audio:', error);
|
||||
res.sendStatus(500);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = speechToText;
|
||||
371
api/server/services/Files/Audio/streamAudio.js
Normal file
371
api/server/services/Files/Audio/streamAudio.js
Normal file
@@ -0,0 +1,371 @@
|
||||
const WebSocket = require('ws');
|
||||
const { Message } = require('~/models/Message');
|
||||
|
||||
/**
|
||||
* @param {string[]} voiceIds - Array of voice IDs
|
||||
* @returns {string}
|
||||
*/
|
||||
function getRandomVoiceId(voiceIds) {
|
||||
const randomIndex = Math.floor(Math.random() * voiceIds.length);
|
||||
return voiceIds[randomIndex];
|
||||
}
|
||||
|
||||
/**
|
||||
* @typedef {Object} VoiceSettings
|
||||
* @property {number} similarity_boost
|
||||
* @property {number} stability
|
||||
* @property {boolean} use_speaker_boost
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} GenerateAudioBulk
|
||||
* @property {string} model_id
|
||||
* @property {string} text
|
||||
* @property {VoiceSettings} voice_settings
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} TextToSpeechClient
|
||||
* @property {function(Object): Promise<stream.Readable>} generate
|
||||
*/
|
||||
|
||||
/**
|
||||
* @typedef {Object} AudioChunk
|
||||
* @property {string} audio
|
||||
* @property {boolean} isFinal
|
||||
* @property {Object} alignment
|
||||
* @property {number[]} alignment.char_start_times_ms
|
||||
* @property {number[]} alignment.chars_durations_ms
|
||||
* @property {string[]} alignment.chars
|
||||
* @property {Object} normalizedAlignment
|
||||
* @property {number[]} normalizedAlignment.char_start_times_ms
|
||||
* @property {number[]} normalizedAlignment.chars_durations_ms
|
||||
* @property {string[]} normalizedAlignment.chars
|
||||
*/
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {Record<string, unknown | undefined>} parameters
|
||||
* @returns
|
||||
*/
|
||||
function assembleQuery(parameters) {
|
||||
let query = '';
|
||||
let hasQuestionMark = false;
|
||||
|
||||
for (const [key, value] of Object.entries(parameters)) {
|
||||
if (value == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!hasQuestionMark) {
|
||||
query += '?';
|
||||
hasQuestionMark = true;
|
||||
} else {
|
||||
query += '&';
|
||||
}
|
||||
|
||||
query += `${key}=${value}`;
|
||||
}
|
||||
|
||||
return query;
|
||||
}
|
||||
|
||||
const SEPARATORS = ['.', '?', '!', '۔', '。', '‥', ';', '¡', '¿', '\n'];
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} text
|
||||
* @param {string[] | undefined} [separators]
|
||||
* @returns
|
||||
*/
|
||||
function findLastSeparatorIndex(text, separators = SEPARATORS) {
|
||||
let lastIndex = -1;
|
||||
for (const separator of separators) {
|
||||
const index = text.lastIndexOf(separator);
|
||||
if (index > lastIndex) {
|
||||
lastIndex = index;
|
||||
}
|
||||
}
|
||||
return lastIndex;
|
||||
}
|
||||
|
||||
const MAX_NOT_FOUND_COUNT = 6;
|
||||
const MAX_NO_CHANGE_COUNT = 10;
|
||||
|
||||
/**
|
||||
* @param {string} messageId
|
||||
* @returns {() => Promise<{ text: string, isFinished: boolean }[]>}
|
||||
*/
|
||||
function createChunkProcessor(messageId) {
|
||||
let notFoundCount = 0;
|
||||
let noChangeCount = 0;
|
||||
let processedText = '';
|
||||
if (!messageId) {
|
||||
throw new Error('Message ID is required');
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<{ text: string, isFinished: boolean }[] | string>}
|
||||
*/
|
||||
async function processChunks() {
|
||||
if (notFoundCount >= MAX_NOT_FOUND_COUNT) {
|
||||
return `Message not found after ${MAX_NOT_FOUND_COUNT} attempts`;
|
||||
}
|
||||
|
||||
if (noChangeCount >= MAX_NO_CHANGE_COUNT) {
|
||||
return `No change in message after ${MAX_NO_CHANGE_COUNT} attempts`;
|
||||
}
|
||||
|
||||
const message = await Message.findOne({ messageId }, 'text unfinished').lean();
|
||||
|
||||
if (!message || !message.text) {
|
||||
notFoundCount++;
|
||||
return [];
|
||||
}
|
||||
|
||||
const { text, unfinished } = message;
|
||||
if (text === processedText) {
|
||||
noChangeCount++;
|
||||
}
|
||||
|
||||
const remainingText = text.slice(processedText.length);
|
||||
const chunks = [];
|
||||
|
||||
if (unfinished && remainingText.length >= 20) {
|
||||
const separatorIndex = findLastSeparatorIndex(remainingText);
|
||||
if (separatorIndex !== -1) {
|
||||
const chunkText = remainingText.slice(0, separatorIndex + 1);
|
||||
chunks.push({ text: chunkText, isFinished: false });
|
||||
processedText += chunkText;
|
||||
} else {
|
||||
chunks.push({ text: remainingText, isFinished: false });
|
||||
processedText = text;
|
||||
}
|
||||
} else if (!unfinished && remainingText.trim().length > 0) {
|
||||
chunks.push({ text: remainingText.trim(), isFinished: true });
|
||||
processedText = text;
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
return processChunks;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} text
|
||||
* @param {number} [chunkSize=4000]
|
||||
* @returns {{ text: string, isFinished: boolean }[]}
|
||||
*/
|
||||
function splitTextIntoChunks(text, chunkSize = 4000) {
|
||||
if (!text) {
|
||||
throw new Error('Text is required');
|
||||
}
|
||||
|
||||
const chunks = [];
|
||||
let startIndex = 0;
|
||||
const textLength = text.length;
|
||||
|
||||
while (startIndex < textLength) {
|
||||
let endIndex = Math.min(startIndex + chunkSize, textLength);
|
||||
let chunkText = text.slice(startIndex, endIndex);
|
||||
|
||||
if (endIndex < textLength) {
|
||||
let lastSeparatorIndex = -1;
|
||||
for (const separator of SEPARATORS) {
|
||||
const index = chunkText.lastIndexOf(separator);
|
||||
if (index !== -1) {
|
||||
lastSeparatorIndex = Math.max(lastSeparatorIndex, index);
|
||||
}
|
||||
}
|
||||
|
||||
if (lastSeparatorIndex !== -1) {
|
||||
endIndex = startIndex + lastSeparatorIndex + 1;
|
||||
chunkText = text.slice(startIndex, endIndex);
|
||||
} else {
|
||||
const nextSeparatorIndex = text.slice(endIndex).search(/\S/);
|
||||
if (nextSeparatorIndex !== -1) {
|
||||
endIndex += nextSeparatorIndex;
|
||||
chunkText = text.slice(startIndex, endIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chunkText = chunkText.trim();
|
||||
if (chunkText) {
|
||||
chunks.push({
|
||||
text: chunkText,
|
||||
isFinished: endIndex >= textLength,
|
||||
});
|
||||
} else if (chunks.length > 0) {
|
||||
chunks[chunks.length - 1].isFinished = true;
|
||||
}
|
||||
|
||||
startIndex = endIndex;
|
||||
while (startIndex < textLength && text[startIndex].trim() === '') {
|
||||
startIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Input stream text to speech
|
||||
* @param {Express.Response} res
|
||||
* @param {AsyncIterable<string>} textStream
|
||||
* @param {(token: string) => Promise<boolean>} callback - Whether to continue the stream or not
|
||||
* @returns {AsyncGenerator<AudioChunk>}
|
||||
*/
|
||||
function inputStreamTextToSpeech(res, textStream, callback) {
|
||||
const model = 'eleven_monolingual_v1';
|
||||
const wsUrl = `wss://api.elevenlabs.io/v1/text-to-speech/${getRandomVoiceId()}/stream-input${assembleQuery(
|
||||
{
|
||||
model_id: model,
|
||||
// flush: true,
|
||||
// optimize_streaming_latency: this.settings.optimizeStreamingLatency,
|
||||
optimize_streaming_latency: 1,
|
||||
// output_format: this.settings.outputFormat,
|
||||
},
|
||||
)}`;
|
||||
const socket = new WebSocket(wsUrl);
|
||||
|
||||
socket.onopen = function () {
|
||||
const streamStart = {
|
||||
text: ' ',
|
||||
voice_settings: {
|
||||
stability: 0.5,
|
||||
similarity_boost: 0.8,
|
||||
},
|
||||
xi_api_key: process.env.ELEVENLABS_API_KEY,
|
||||
// generation_config: { chunk_length_schedule: [50, 90, 120, 150, 200] },
|
||||
};
|
||||
|
||||
socket.send(JSON.stringify(streamStart));
|
||||
|
||||
// send stream until done
|
||||
const streamComplete = new Promise((resolve, reject) => {
|
||||
(async () => {
|
||||
let textBuffer = '';
|
||||
let shouldContinue = true;
|
||||
for await (const textDelta of textStream) {
|
||||
textBuffer += textDelta;
|
||||
|
||||
// using ". " as separator: sending in full sentences improves the quality
|
||||
// of the audio output significantly.
|
||||
const separatorIndex = findLastSeparatorIndex(textBuffer);
|
||||
|
||||
// Callback for textStream (will return false if signal is aborted)
|
||||
shouldContinue = await callback(textDelta);
|
||||
|
||||
if (separatorIndex === -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!shouldContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
const textToProcess = textBuffer.slice(0, separatorIndex);
|
||||
textBuffer = textBuffer.slice(separatorIndex + 1);
|
||||
|
||||
const request = {
|
||||
text: textToProcess,
|
||||
try_trigger_generation: true,
|
||||
};
|
||||
|
||||
socket.send(JSON.stringify(request));
|
||||
}
|
||||
|
||||
// send remaining text:
|
||||
if (shouldContinue && textBuffer.length > 0) {
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
text: `${textBuffer} `, // append space
|
||||
try_trigger_generation: true,
|
||||
}),
|
||||
);
|
||||
}
|
||||
})()
|
||||
.then(resolve)
|
||||
.catch(reject);
|
||||
});
|
||||
|
||||
streamComplete
|
||||
.then(() => {
|
||||
const endStream = {
|
||||
text: '',
|
||||
};
|
||||
|
||||
socket.send(JSON.stringify(endStream));
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error('Error streaming text to speech:', e);
|
||||
throw e;
|
||||
});
|
||||
};
|
||||
|
||||
return (async function* audioStream() {
|
||||
let isDone = false;
|
||||
let chunks = [];
|
||||
let resolve;
|
||||
let waitForMessage = new Promise((r) => (resolve = r));
|
||||
|
||||
socket.onmessage = function (event) {
|
||||
// console.log(event);
|
||||
const audioChunk = JSON.parse(event.data);
|
||||
if (audioChunk.audio && audioChunk.alignment) {
|
||||
res.write(`event: audio\ndata: ${event.data}\n\n`);
|
||||
chunks.push(audioChunk);
|
||||
resolve(null);
|
||||
waitForMessage = new Promise((r) => (resolve = r));
|
||||
} else if (audioChunk.isFinal) {
|
||||
isDone = true;
|
||||
resolve(null);
|
||||
} else if (audioChunk.message) {
|
||||
console.warn('Received Elevenlabs message:', audioChunk.message);
|
||||
resolve(null);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onerror = function (error) {
|
||||
console.error('WebSocket error:', error);
|
||||
// throw error;
|
||||
};
|
||||
|
||||
socket.onclose = function () {
|
||||
isDone = true;
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
while (!isDone) {
|
||||
await waitForMessage;
|
||||
yield* chunks;
|
||||
chunks = [];
|
||||
}
|
||||
|
||||
res.write('event: end\ndata: \n\n');
|
||||
})();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {AsyncIterable<string>} llmStream
|
||||
*/
|
||||
async function* llmMessageSource(llmStream) {
|
||||
for await (const chunk of llmStream) {
|
||||
const message = chunk.choices[0].delta.content;
|
||||
if (message) {
|
||||
yield message;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
inputStreamTextToSpeech,
|
||||
findLastSeparatorIndex,
|
||||
createChunkProcessor,
|
||||
splitTextIntoChunks,
|
||||
llmMessageSource,
|
||||
getRandomVoiceId,
|
||||
};
|
||||
137
api/server/services/Files/Audio/streamAudio.spec.js
Normal file
137
api/server/services/Files/Audio/streamAudio.spec.js
Normal file
@@ -0,0 +1,137 @@
|
||||
const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { Message } = require('~/models/Message');
|
||||
|
||||
jest.mock('~/models/Message', () => ({
|
||||
Message: {
|
||||
findOne: jest.fn().mockReturnValue({
|
||||
lean: jest.fn(),
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('processChunks', () => {
|
||||
let processChunks;
|
||||
|
||||
beforeEach(() => {
|
||||
processChunks = createChunkProcessor('message-id');
|
||||
Message.findOne.mockClear();
|
||||
Message.findOne().lean.mockClear();
|
||||
});
|
||||
|
||||
it('should return an empty array when the message is not found', async () => {
|
||||
Message.findOne().lean.mockResolvedValueOnce(null);
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return an empty array when the message does not have a text property', async () => {
|
||||
Message.findOne().lean.mockResolvedValueOnce({ unfinished: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return chunks for an unfinished message with separators', async () => {
|
||||
const messageText = 'This is a long message. It should be split into chunks. Lol hi mom';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([
|
||||
{ text: 'This is a long message. It should be split into chunks.', isFinished: false },
|
||||
]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return chunks for an unfinished message without separators', async () => {
|
||||
const messageText = 'This is a long message without separators hello there my friend';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: true });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: false }]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return the remaining text as a chunk for a finished message', async () => {
|
||||
const messageText = 'This is a finished message.';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([{ text: messageText, isFinished: true }]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return an empty array for a finished message with no remaining text', async () => {
|
||||
const messageText = 'This is a finished message.';
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
|
||||
await processChunks();
|
||||
Message.findOne().lean.mockResolvedValueOnce({ text: messageText, unfinished: false });
|
||||
const result = await processChunks();
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(Message.findOne).toHaveBeenCalledWith({ messageId: 'message-id' }, 'text unfinished');
|
||||
expect(Message.findOne().lean).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('splitTextIntoChunks', () => {
|
||||
test('splits text into chunks of specified size with default separators', () => {
|
||||
const text = 'This is a test. This is only a test! Make sure it works properly? Okay.';
|
||||
const chunkSize = 20;
|
||||
const expectedChunks = [
|
||||
{ text: 'This is a test.', isFinished: false },
|
||||
{ text: 'This is only a test!', isFinished: false },
|
||||
{ text: 'Make sure it works p', isFinished: false },
|
||||
{ text: 'roperly? Okay.', isFinished: true },
|
||||
];
|
||||
|
||||
const result = splitTextIntoChunks(text, chunkSize);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('splits text into chunks with default size', () => {
|
||||
const text = 'A'.repeat(8000) + '. The end.';
|
||||
const expectedChunks = [
|
||||
{ text: 'A'.repeat(4000), isFinished: false },
|
||||
{ text: 'A'.repeat(4000), isFinished: false },
|
||||
{ text: '. The end.', isFinished: true },
|
||||
];
|
||||
|
||||
const result = splitTextIntoChunks(text);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('returns a single chunk if text length is less than chunk size', () => {
|
||||
const text = 'Short text.';
|
||||
const expectedChunks = [{ text: 'Short text.', isFinished: true }];
|
||||
|
||||
const result = splitTextIntoChunks(text, 4000);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('handles text with no separators correctly', () => {
|
||||
const text = 'ThisTextHasNoSeparatorsAndIsVeryLong'.repeat(100);
|
||||
const chunkSize = 4000;
|
||||
const expectedChunks = [{ text: text, isFinished: true }];
|
||||
|
||||
const result = splitTextIntoChunks(text, chunkSize);
|
||||
expect(result).toEqual(expectedChunks);
|
||||
});
|
||||
|
||||
test('throws an error when text is empty', () => {
|
||||
expect(() => splitTextIntoChunks('')).toThrow('Text is required');
|
||||
});
|
||||
});
|
||||
416
api/server/services/Files/Audio/textToSpeech.js
Normal file
416
api/server/services/Files/Audio/textToSpeech.js
Normal file
@@ -0,0 +1,416 @@
|
||||
const axios = require('axios');
|
||||
const getCustomConfig = require('~/server/services/Config/getCustomConfig');
|
||||
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
|
||||
const { extractEnvVariable } = require('librechat-data-provider');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* getProvider function
|
||||
* This function takes the ttsSchema object and returns the name of the provider
|
||||
* If more than one provider is set or no provider is set, it throws an error
|
||||
*
|
||||
* @param {Object} ttsSchema - The TTS schema containing the provider configuration
|
||||
* @returns {string} The name of the provider
|
||||
* @throws {Error} Throws an error if multiple providers are set or no provider is set
|
||||
*/
|
||||
function getProvider(ttsSchema) {
|
||||
if (!ttsSchema) {
|
||||
throw new Error(`No TTS schema is set. Did you configure TTS in the custom config (librechat.yaml)?
|
||||
|
||||
https://www.librechat.ai/docs/configuration/stt_tts#tts`);
|
||||
}
|
||||
const providers = Object.entries(ttsSchema).filter(([, value]) => Object.keys(value).length > 0);
|
||||
|
||||
if (providers.length > 1) {
|
||||
throw new Error('Multiple providers are set. Please set only one provider.');
|
||||
} else if (providers.length === 0) {
|
||||
throw new Error('No provider is set. Please set a provider.');
|
||||
} else {
|
||||
return providers[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* removeUndefined function
|
||||
* This function takes an object and removes all keys with undefined values
|
||||
* It also removes keys with empty objects as values
|
||||
*
|
||||
* @param {Object} obj - The object to be cleaned
|
||||
* @returns {void} This function does not return a value. It modifies the input object directly
|
||||
*/
|
||||
function removeUndefined(obj) {
|
||||
Object.keys(obj).forEach((key) => {
|
||||
if (obj[key] && typeof obj[key] === 'object') {
|
||||
removeUndefined(obj[key]);
|
||||
if (Object.keys(obj[key]).length === 0) {
|
||||
delete obj[key];
|
||||
}
|
||||
} else if (obj[key] === undefined) {
|
||||
delete obj[key];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* This function prepares the necessary data and headers for making a request to the OpenAI TTS
|
||||
* It uses the provided TTS schema, input text, and voice to create the request
|
||||
*
|
||||
* @param {TCustomConfig['tts']['openai']} ttsSchema - The TTS schema containing the OpenAI configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* If an error occurs, it throws an error with a message indicating that the selected voice is not available
|
||||
*/
|
||||
function openAIProvider(ttsSchema, input, voice) {
|
||||
const url = ttsSchema?.url || 'https://api.openai.com/v1/audio/speech';
|
||||
|
||||
if (
|
||||
ttsSchema?.voices &&
|
||||
ttsSchema.voices.length > 0 &&
|
||||
!ttsSchema.voices.includes(voice) &&
|
||||
!ttsSchema.voices.includes('ALL')
|
||||
) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
let data = {
|
||||
input,
|
||||
model: ttsSchema?.model,
|
||||
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
|
||||
backend: ttsSchema?.backend,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey),
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* elevenLabsProvider function
|
||||
* This function prepares the necessary data and headers for making a request to the Eleven Labs TTS
|
||||
* It uses the provided TTS schema, input text, and voice to create the request
|
||||
*
|
||||
* @param {TCustomConfig['tts']['elevenLabs']} ttsSchema - The TTS schema containing the Eleven Labs configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
* @param {boolean} stream - Whether to stream the audio or not
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* @throws {Error} Throws an error if the selected voice is not available
|
||||
*/
|
||||
function elevenLabsProvider(ttsSchema, input, voice, stream) {
|
||||
let url =
|
||||
ttsSchema?.url ||
|
||||
`https://api.elevenlabs.io/v1/text-to-speech/{voice_id}${stream ? '/stream' : ''}`;
|
||||
|
||||
if (!ttsSchema?.voices.includes(voice) && !ttsSchema?.voices.includes('ALL')) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
url = url.replace('{voice_id}', voice);
|
||||
|
||||
let data = {
|
||||
model_id: ttsSchema?.model,
|
||||
text: input,
|
||||
// voice_id: voice,
|
||||
voice_settings: {
|
||||
similarity_boost: ttsSchema?.voice_settings?.similarity_boost,
|
||||
stability: ttsSchema?.voice_settings?.stability,
|
||||
style: ttsSchema?.voice_settings?.style,
|
||||
use_speaker_boost: ttsSchema?.voice_settings?.use_speaker_boost || undefined,
|
||||
},
|
||||
pronunciation_dictionary_locators: ttsSchema?.pronunciation_dictionary_locators,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'xi-api-key': extractEnvVariable(ttsSchema?.apiKey),
|
||||
Accept: 'audio/mpeg',
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* localAIProvider function
|
||||
* This function prepares the necessary data and headers for making a request to the LocalAI TTS
|
||||
* It uses the provided TTS schema, input text, and voice to create the request
|
||||
*
|
||||
* @param {TCustomConfig['tts']['localai']} ttsSchema - The TTS schema containing the LocalAI configuration
|
||||
* @param {string} input - The text to be converted to speech
|
||||
* @param {string} voice - The voice to be used for the speech
|
||||
*
|
||||
* @returns {Array} An array containing the URL for the API request, the data to be sent, and the headers for the request
|
||||
* @throws {Error} Throws an error if the selected voice is not available
|
||||
*/
|
||||
function localAIProvider(ttsSchema, input, voice) {
|
||||
let url = ttsSchema?.url;
|
||||
|
||||
if (
|
||||
ttsSchema?.voices &&
|
||||
ttsSchema.voices.length > 0 &&
|
||||
!ttsSchema.voices.includes(voice) &&
|
||||
!ttsSchema.voices.includes('ALL')
|
||||
) {
|
||||
throw new Error(`Voice ${voice} is not available.`);
|
||||
}
|
||||
|
||||
let data = {
|
||||
input,
|
||||
model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
|
||||
backend: ttsSchema?.backend,
|
||||
};
|
||||
|
||||
let headers = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey),
|
||||
};
|
||||
|
||||
[data, headers].forEach(removeUndefined);
|
||||
|
||||
if (extractEnvVariable(ttsSchema.apiKey) === '') {
|
||||
delete headers.Authorization;
|
||||
}
|
||||
|
||||
return [url, data, headers];
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Returns provider and its schema for use with TTS requests
|
||||
* @param {TCustomConfig} customConfig
|
||||
* @param {string} _voice
|
||||
* @returns {Promise<[string, TProviderSchema]>}
|
||||
*/
|
||||
async function getProviderSchema(customConfig) {
|
||||
const provider = getProvider(customConfig.tts);
|
||||
return [provider, customConfig.tts[provider]];
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Returns a tuple of the TTS schema as well as the voice for the TTS request
|
||||
* @param {TProviderSchema} providerSchema
|
||||
* @param {string} requestVoice
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
async function getVoice(providerSchema, requestVoice) {
|
||||
const voices = providerSchema.voices.filter((voice) => voice && voice.toUpperCase() !== 'ALL');
|
||||
let voice = requestVoice;
|
||||
if (!voice || !voices.includes(voice) || (voice.toUpperCase() === 'ALL' && voices.length > 1)) {
|
||||
voice = getRandomVoiceId(voices);
|
||||
}
|
||||
|
||||
return voice;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} provider
|
||||
* @param {TProviderSchema} ttsSchema
|
||||
* @param {object} params
|
||||
* @param {string} params.voice
|
||||
* @param {string} params.input
|
||||
* @param {boolean} [params.stream]
|
||||
* @returns {Promise<ArrayBuffer>}
|
||||
*/
|
||||
async function ttsRequest(provider, ttsSchema, { input, voice, stream = true } = { stream: true }) {
|
||||
let [url, data, headers] = [];
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
[url, data, headers] = openAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
case 'elevenlabs':
|
||||
[url, data, headers] = elevenLabsProvider(ttsSchema, input, voice, stream);
|
||||
break;
|
||||
case 'localai':
|
||||
[url, data, headers] = localAIProvider(ttsSchema, input, voice);
|
||||
break;
|
||||
default:
|
||||
throw new Error('Invalid provider');
|
||||
}
|
||||
|
||||
if (stream) {
|
||||
return await axios.post(url, data, { headers, responseType: 'stream' });
|
||||
}
|
||||
|
||||
return await axios.post(url, data, { headers, responseType: 'arraybuffer' });
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles a text-to-speech request. Extracts input and voice from the request, retrieves the TTS configuration,
|
||||
* and sends a request to the appropriate provider. The resulting audio data is sent in the response
|
||||
*
|
||||
* @param {Object} req - The request object, which should contain the input text and voice in its body
|
||||
* @param {Object} res - The response object, used to send the audio data or an error message
|
||||
*
|
||||
* @returns {Promise<void>} This function does not return a value. It sends the audio data or an error message in the response
|
||||
*
|
||||
* @throws {Error} Throws an error if the provider is invalid
|
||||
*/
|
||||
async function textToSpeech(req, res) {
|
||||
const { input } = req.body;
|
||||
|
||||
if (!input) {
|
||||
return res.status(400).send('Missing text in request body');
|
||||
}
|
||||
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
res.status(500).send('Custom config not found');
|
||||
}
|
||||
|
||||
try {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
const [provider, ttsSchema] = await getProviderSchema(customConfig);
|
||||
const voice = await getVoice(ttsSchema, req.body.voice);
|
||||
if (input.length < 4096) {
|
||||
const response = await ttsRequest(provider, ttsSchema, { input, voice });
|
||||
response.data.pipe(res);
|
||||
return;
|
||||
}
|
||||
|
||||
const textChunks = splitTextIntoChunks(input, 1000);
|
||||
|
||||
for (const chunk of textChunks) {
|
||||
try {
|
||||
const response = await ttsRequest(provider, ttsSchema, {
|
||||
voice,
|
||||
input: chunk.text,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
logger.debug(`[textToSpeech] user: ${req?.user?.id} | writing audio stream`);
|
||||
await new Promise((resolve) => {
|
||||
response.data.pipe(res, { end: chunk.isFinished });
|
||||
response.data.on('end', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
if (chunk.isFinished) {
|
||||
break;
|
||||
}
|
||||
} catch (innerError) {
|
||||
logger.error('Error processing manual update:', chunk, innerError);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'Error creating the audio stream. Suggestion: check your provider quota. Error:',
|
||||
error,
|
||||
);
|
||||
res.status(500).send('An error occurred');
|
||||
}
|
||||
}
|
||||
|
||||
async function streamAudio(req, res) {
|
||||
res.setHeader('Content-Type', 'audio/mpeg');
|
||||
const customConfig = await getCustomConfig();
|
||||
if (!customConfig) {
|
||||
return res.status(500).send('Custom config not found');
|
||||
}
|
||||
|
||||
const [provider, ttsSchema] = await getProviderSchema(customConfig);
|
||||
const voice = await getVoice(ttsSchema, req.body.voice);
|
||||
|
||||
try {
|
||||
let shouldContinue = true;
|
||||
|
||||
req.on('close', () => {
|
||||
logger.warn('[streamAudio] Audio Stream Request closed by client');
|
||||
shouldContinue = false;
|
||||
});
|
||||
|
||||
const processChunks = createChunkProcessor(req.body.messageId);
|
||||
|
||||
while (shouldContinue) {
|
||||
// example updates
|
||||
// const updates = [
|
||||
// { text: 'This is a test.', isFinished: false },
|
||||
// { text: 'This is only a test.', isFinished: false },
|
||||
// { text: 'Your voice is like a combination of Fergie and Jesus!', isFinished: true },
|
||||
// ];
|
||||
|
||||
const updates = await processChunks();
|
||||
if (typeof updates === 'string') {
|
||||
logger.error(`Error processing audio stream updates: ${JSON.stringify(updates)}`);
|
||||
res.status(500).end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (updates.length === 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1250));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const update of updates) {
|
||||
try {
|
||||
const response = await ttsRequest(provider, ttsSchema, {
|
||||
voice,
|
||||
input: update.text,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
if (!shouldContinue) {
|
||||
break;
|
||||
}
|
||||
|
||||
logger.debug(`[streamAudio] user: ${req?.user?.id} | writing audio stream`);
|
||||
await new Promise((resolve) => {
|
||||
response.data.pipe(res, { end: update.isFinished });
|
||||
response.data.on('end', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
if (update.isFinished) {
|
||||
shouldContinue = false;
|
||||
break;
|
||||
}
|
||||
} catch (innerError) {
|
||||
logger.error('Error processing update:', update, innerError);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!shouldContinue) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.end();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch audio:', error);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).end();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
textToSpeech,
|
||||
getProvider,
|
||||
streamAudio,
|
||||
};
|
||||
31
api/server/services/Files/Audio/webSocket.js
Normal file
31
api/server/services/Files/Audio/webSocket.js
Normal file
@@ -0,0 +1,31 @@
|
||||
let token = '';
|
||||
|
||||
function updateTokenWebsocket(newToken) {
|
||||
console.log('Token:', newToken);
|
||||
token = newToken;
|
||||
}
|
||||
|
||||
function sendTextToWebsocket(ws, onDataReceived) {
|
||||
if (token === '[DONE]') {
|
||||
ws.send(' ');
|
||||
return;
|
||||
}
|
||||
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(token);
|
||||
|
||||
ws.onmessage = function (event) {
|
||||
console.log('Received:', event.data);
|
||||
if (onDataReceived) {
|
||||
onDataReceived(event.data); // Pass the received data to the callback function
|
||||
}
|
||||
};
|
||||
} else {
|
||||
console.error('WebSocket is not open. Ready state is: ' + ws.readyState);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
updateTokenWebsocket,
|
||||
sendTextToWebsocket,
|
||||
};
|
||||
@@ -1,3 +1,4 @@
|
||||
const throttle = require('lodash/throttle');
|
||||
const {
|
||||
StepTypes,
|
||||
ContentTypes,
|
||||
@@ -8,6 +9,7 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||
const { processRequiredActions } = require('~/server/services/ToolService');
|
||||
const { saveMessage, updateMessageText } = require('~/models/Message');
|
||||
const { createOnProgress, sendMessage } = require('~/server/utils');
|
||||
const { processMessages } = require('~/server/services/Threads');
|
||||
const { logger } = require('~/config');
|
||||
@@ -43,6 +45,8 @@ class StreamRunManager {
|
||||
/** @type {string} */
|
||||
this.apiKey = this.openai.apiKey;
|
||||
/** @type {string} */
|
||||
this.parentMessageId = fields.parentMessageId;
|
||||
/** @type {string} */
|
||||
this.thread_id = fields.thread_id;
|
||||
/** @type {RunCreateAndStreamParams} */
|
||||
this.initialRunBody = fields.runBody;
|
||||
@@ -58,10 +62,14 @@ class StreamRunManager {
|
||||
this.messages = [];
|
||||
/** @type {string} */
|
||||
this.text = '';
|
||||
/** @type {string} */
|
||||
this.intermediateText = '';
|
||||
/** @type {Set<string>} */
|
||||
this.attachedFileIds = fields.attachedFileIds;
|
||||
/** @type {undefined | Promise<ChatCompletion>} */
|
||||
this.visionPromise = fields.visionPromise;
|
||||
/** @type {boolean} */
|
||||
this.savedInitialMessage = false;
|
||||
|
||||
/**
|
||||
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
|
||||
@@ -123,6 +131,33 @@ class StreamRunManager {
|
||||
sendMessage(this.res, contentData);
|
||||
}
|
||||
|
||||
/* <------------------ Misc. Helpers ------------------> */
|
||||
/** Returns the latest intermediate text
|
||||
* @returns {string}
|
||||
*/
|
||||
getText() {
|
||||
return this.intermediateText;
|
||||
}
|
||||
|
||||
/** Saves the initial intermediate message
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async saveInitialMessage() {
|
||||
return saveMessage({
|
||||
conversationId: this.finalMessage.conversationId,
|
||||
messageId: this.finalMessage.messageId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
model: this.req.body.assistant_id,
|
||||
endpoint: this.req.body.endpoint,
|
||||
isCreatedByUser: false,
|
||||
user: this.req.user.id,
|
||||
text: this.getText(),
|
||||
sender: 'Assistant',
|
||||
unfinished: true,
|
||||
error: false,
|
||||
});
|
||||
}
|
||||
|
||||
/* <------------------ Main Event Handlers ------------------> */
|
||||
|
||||
/**
|
||||
@@ -407,6 +442,7 @@ class StreamRunManager {
|
||||
const content = message.delta.content?.[0];
|
||||
|
||||
if (content && content.type === MessageContentTypes.TEXT) {
|
||||
this.intermediateText += content.text.value;
|
||||
onProgress(content.text.value);
|
||||
}
|
||||
}
|
||||
@@ -523,10 +559,24 @@ class StreamRunManager {
|
||||
const stepKey = message_creation.message_id;
|
||||
const index = this.getStepIndex(stepKey);
|
||||
this.orderedRunSteps.set(index, message_creation);
|
||||
|
||||
// Create the Factory Function to stream the message
|
||||
const { onProgress: progressCallback } = createOnProgress({
|
||||
// todo: add option to save partialText to db
|
||||
// onProgress: () => {},
|
||||
onProgress: throttle(
|
||||
() => {
|
||||
if (!this.savedInitialMessage) {
|
||||
this.saveInitialMessage();
|
||||
this.savedInitialMessage = true;
|
||||
} else {
|
||||
updateMessageText({
|
||||
messageId: this.finalMessage.messageId,
|
||||
text: this.getText(),
|
||||
});
|
||||
}
|
||||
},
|
||||
2000,
|
||||
{ trailing: false },
|
||||
),
|
||||
});
|
||||
|
||||
// This creates a function that attaches all of the parameters
|
||||
|
||||
@@ -121,6 +121,7 @@ async function saveUserMessage(params) {
|
||||
* @param {Object} params - The parameters of the Assistant message
|
||||
* @param {string} params.user - The user's ID.
|
||||
* @param {string} params.messageId - The message Id.
|
||||
* @param {string} params.text - The concatenated text of the message.
|
||||
* @param {string} params.assistant_id - The assistant Id.
|
||||
* @param {string} params.thread_id - The thread Id.
|
||||
* @param {string} params.model - The model used by the assistant.
|
||||
@@ -134,14 +135,6 @@ async function saveUserMessage(params) {
|
||||
* @return {Promise<Run>} A promise that resolves to the created run object.
|
||||
*/
|
||||
async function saveAssistantMessage(params) {
|
||||
const text = params.content.reduce((acc, part) => {
|
||||
if (!part.value) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
return acc + ' ' + part.value;
|
||||
}, '');
|
||||
|
||||
// const tokenCount = // TODO: need to count each content part
|
||||
|
||||
const message = await recordMessage({
|
||||
@@ -156,7 +149,8 @@ async function saveAssistantMessage(params) {
|
||||
content: params.content,
|
||||
sender: 'Assistant',
|
||||
isCreatedByUser: false,
|
||||
text: text.trim(),
|
||||
text: params.text,
|
||||
unfinished: false,
|
||||
// tokenCount,
|
||||
});
|
||||
|
||||
@@ -302,6 +296,7 @@ async function syncMessages({
|
||||
aggregateMessages: [{ id: apiMessage.id }],
|
||||
model: apiMessage.role === 'user' ? null : apiMessage.assistant_id,
|
||||
user: openai.req.user.id,
|
||||
unfinished: false,
|
||||
};
|
||||
|
||||
if (apiMessage.file_ids?.length) {
|
||||
|
||||
@@ -340,29 +340,26 @@ async function processRequiredActions(client, requiredActions) {
|
||||
currentAction.toolInput = currentAction.toolInput.input;
|
||||
}
|
||||
|
||||
try {
|
||||
const promise = tool
|
||||
._call(currentAction.toolInput)
|
||||
.then(handleToolOutput)
|
||||
.catch((error) => {
|
||||
logger.error(`Error processing tool ${currentAction.tool}`, error);
|
||||
return {
|
||||
tool_call_id: currentAction.toolCallId,
|
||||
output: `Error processing tool ${currentAction.tool}: ${redactMessage(error.message)}`,
|
||||
};
|
||||
});
|
||||
promises.push(promise);
|
||||
} catch (error) {
|
||||
const handleToolError = (error) => {
|
||||
logger.error(
|
||||
`tool_call_id: ${currentAction.toolCallId} | Error processing tool ${currentAction.tool}`,
|
||||
error,
|
||||
);
|
||||
promises.push(
|
||||
Promise.resolve({
|
||||
tool_call_id: currentAction.toolCallId,
|
||||
error: error.message,
|
||||
}),
|
||||
);
|
||||
return {
|
||||
tool_call_id: currentAction.toolCallId,
|
||||
output: `Error processing tool ${currentAction.tool}: ${redactMessage(error.message, 256)}`,
|
||||
};
|
||||
};
|
||||
|
||||
try {
|
||||
const promise = tool
|
||||
._call(currentAction.toolInput)
|
||||
.then(handleToolOutput)
|
||||
.catch(handleToolError);
|
||||
promises.push(promise);
|
||||
} catch (error) {
|
||||
const toolOutputError = handleToolError(error);
|
||||
promises.push(Promise.resolve(toolOutputError));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const {
|
||||
Capabilities,
|
||||
EModelEndpoint,
|
||||
assistantEndpointSchema,
|
||||
defaultAssistantsVersion,
|
||||
} = require('librechat-data-provider');
|
||||
@@ -20,16 +19,25 @@ function azureAssistantsDefaults() {
|
||||
/**
|
||||
* Sets up the Assistants configuration from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig} config - The loaded custom configuration.
|
||||
* @param {Partial<TAssistantEndpoint>} [prevConfig]
|
||||
* @param {EModelEndpoint.assistants|EModelEndpoint.azureAssistants} assistantsEndpoint - The Assistants endpoint name.
|
||||
* - The previously loaded assistants configuration from Azure OpenAI Assistants option.
|
||||
* @param {Partial<TAssistantEndpoint>} [prevConfig]
|
||||
* @returns {Partial<TAssistantEndpoint>} The Assistants endpoint configuration.
|
||||
*/
|
||||
function assistantsConfigSetup(config, prevConfig = {}) {
|
||||
const assistantsConfig = config.endpoints[EModelEndpoint.assistants];
|
||||
function assistantsConfigSetup(config, assistantsEndpoint, prevConfig = {}) {
|
||||
const assistantsConfig = config.endpoints[assistantsEndpoint];
|
||||
const parsedConfig = assistantEndpointSchema.parse(assistantsConfig);
|
||||
if (assistantsConfig.supportedIds?.length && assistantsConfig.excludedIds?.length) {
|
||||
logger.warn(
|
||||
`Both \`supportedIds\` and \`excludedIds\` are defined for the ${EModelEndpoint.assistants} endpoint; \`excludedIds\` field will be ignored.`,
|
||||
`Configuration conflict: The '${assistantsEndpoint}' endpoint has both 'supportedIds' and 'excludedIds' defined. The 'excludedIds' will be ignored.`,
|
||||
);
|
||||
}
|
||||
if (
|
||||
assistantsConfig.privateAssistants &&
|
||||
(assistantsConfig.supportedIds?.length || assistantsConfig.excludedIds?.length)
|
||||
) {
|
||||
logger.warn(
|
||||
`Configuration conflict: The '${assistantsEndpoint}' endpoint has both 'privateAssistants' and 'supportedIds' or 'excludedIds' defined. The 'supportedIds' and 'excludedIds' will be ignored.`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -41,6 +49,7 @@ function assistantsConfigSetup(config, prevConfig = {}) {
|
||||
supportedIds: parsedConfig.supportedIds,
|
||||
capabilities: parsedConfig.capabilities,
|
||||
excludedIds: parsedConfig.excludedIds,
|
||||
privateAssistants: parsedConfig.privateAssistants,
|
||||
timeoutMs: parsedConfig.timeoutMs,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
const Redis = require('ioredis');
|
||||
const passport = require('passport');
|
||||
const session = require('express-session');
|
||||
const RedisStore = require('connect-redis').default;
|
||||
const passport = require('passport');
|
||||
const {
|
||||
setupOpenId,
|
||||
googleLogin,
|
||||
githubLogin,
|
||||
discordLogin,
|
||||
facebookLogin,
|
||||
setupOpenId,
|
||||
} = require('../strategies');
|
||||
const client = require('../cache/redis');
|
||||
} = require('~/strategies');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -40,6 +41,11 @@ const configureSocialLogins = (app) => {
|
||||
saveUninitialized: false,
|
||||
};
|
||||
if (process.env.USE_REDIS) {
|
||||
const client = new Redis(process.env.REDIS_URI);
|
||||
client
|
||||
.on('error', (err) => logger.error('ioredis error:', err))
|
||||
.on('ready', () => logger.info('ioredis successfully initialized.'))
|
||||
.on('reconnecting', () => logger.info('ioredis reconnecting...'));
|
||||
sessionOptions.store = new RedisStore({ client, prefix: 'librechat' });
|
||||
}
|
||||
app.use(session(sessionOptions));
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
"endpoint": "openAI",
|
||||
"title": "VW Transporter 2014 Fuel Consumption. Web Search"
|
||||
},
|
||||
"messagesTree": [
|
||||
"messages": [
|
||||
{
|
||||
"_id": "6615516574dc2ddcdebe40b6",
|
||||
"messageId": "b123942f-ca1a-4b16-9e1f-ea4af5171168",
|
||||
|
||||
40
api/server/utils/import/__data__/librechat-linear.json
Normal file
40
api/server/utils/import/__data__/librechat-linear.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"conversationId": "4c9e1402-f746-479d-8a18-2dfab1bf750f",
|
||||
"endpoint": "azureOpenAI",
|
||||
"title": "Adventure Across Time and Space",
|
||||
"exportAt": "20:33:31 GMT-0400 (Eastern Daylight Time)",
|
||||
"branches": false,
|
||||
"recursive": false,
|
||||
"messages": [
|
||||
{
|
||||
"messageId": "ed3fdaaa-e174-4d70-89a9-b06e70939723",
|
||||
"conversationId": "4c9e1402-f746-479d-8a18-2dfab1bf750f",
|
||||
"createdAt": "2024-05-29T00:33:03.661Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": true,
|
||||
"isEdited": false,
|
||||
"model": null,
|
||||
"parentMessageId": null,
|
||||
"sender": "user",
|
||||
"text": "tell me a long story",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-29T00:33:03.661Z"
|
||||
},
|
||||
{
|
||||
"messageId": "113b8336-4b19-4f91-8eed-741b3dc2940e",
|
||||
"conversationId": "4c9e1402-f746-479d-8a18-2dfab1bf750f",
|
||||
"createdAt": "2024-05-29T00:33:03.661Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": false,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "ed3fdaaa-e174-4d70-89a9-b06e70939723",
|
||||
"sender": "GPT-4",
|
||||
"text": "Of course! Here’s a detailed and engaging story:\n\n---\n\n### The Legend of Eldoria\n\nNestled between towering mountains and dense, ancient forests was the enigmatic kingdom of Eldoria. This realm, clo aked in perpetual twilight, was the stuff of legends. It was said that the land was blessed by the gods and guarded by mythical creatures. Eldoria was a place where magic and realism intertwined seamlessly, creating a land of beauty, wonder, and peril.\n\nIn the heart of this kingdom lay the grand city of Lumina, known",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-29T00:33:03.661Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
{
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"endpoint": "azureOpenAI",
|
||||
"title": "Original",
|
||||
"exportAt": "08:12:16 GMT-0400 (Eastern Daylight Time)",
|
||||
"branches": true,
|
||||
"recursive": false,
|
||||
"options": {
|
||||
"presetId": null,
|
||||
"model": "gpt-4o",
|
||||
"chatGptLabel": null,
|
||||
"promptPrefix": null,
|
||||
"temperature": 1,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
"resendFiles": true,
|
||||
"imageDetail": "auto",
|
||||
"endpoint": "azureOpenAI",
|
||||
"title": "Original"
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"messageId": "115a6247-8fb0-4937-a536-12956669098d",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:08:55.014Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": true,
|
||||
"isEdited": false,
|
||||
"model": null,
|
||||
"parentMessageId": "00000000-0000-0000-0000-000000000000",
|
||||
"sender": "User",
|
||||
"text": "tell me a long story",
|
||||
"tokenCount": 9,
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:09:27.193Z"
|
||||
},
|
||||
{
|
||||
"messageId": "069b9c22-7649-45a9-b90b-fc050533ea21",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:08:55.390Z",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": false,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "115a6247-8fb0-4937-a536-12956669098d",
|
||||
"sender": "GPT-4",
|
||||
"text": "Of course! Settle in for a tale of adventure across time and space.\n\n---\n\nOnce upon a time in the small, sleepy village of Eldoria, there was a young woman named Elara who longed for adventure. Eldoria was a place of routine and simplicity, nestled between rolling hills and dense forests, but Elara always felt that there was more to the world than the boundaries",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:08:58.669Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"tokenCount": 78,
|
||||
"finish_reason": "incomplete"
|
||||
},
|
||||
{
|
||||
"messageId": "303e4c2c-f03e-4e0a-8551-c96ec73be5fe",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:09:27.444Z",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": false,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "115a6247-8fb0-4937-a536-12956669098d",
|
||||
"sender": "GPT-4",
|
||||
"text": "Sure, I can craft a long story for you. Here it goes:\n\n### The Chronicles of Elenor: The Luminary of Anduril\n\nIn an age long forgotten by men, in a world kissed by the glow of dual suns, the Kingdom of Anduril flourished. Verdant valleys graced its land, majestic mountains shielded",
|
||||
"unfinished": true,
|
||||
"updatedAt": "2024-05-28T18:09:30.448Z"
|
||||
},
|
||||
{
|
||||
"messageId": "599e1908-8c52-4a73-ba6b-f6dffbd79ba0",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:14:07.988Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"error": false,
|
||||
"isCreatedByUser": true,
|
||||
"isEdited": false,
|
||||
"model": null,
|
||||
"parentMessageId": "00000000-0000-0000-0000-000000000000",
|
||||
"sender": "User",
|
||||
"text": "tell me a long long story",
|
||||
"tokenCount": 9,
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:14:07.988Z"
|
||||
},
|
||||
{
|
||||
"messageId": "de9a4e7c-020d-4856-a5a6-ce6794efef99",
|
||||
"conversationId": "27b593be-9500-479c-94cb-050cab8f5033",
|
||||
"createdAt": "2024-05-28T18:14:08.403Z",
|
||||
"error": false,
|
||||
"isCreatedByUser": false,
|
||||
"isEdited": true,
|
||||
"model": "gpt-4o",
|
||||
"parentMessageId": "599e1908-8c52-4a73-ba6b-f6dffbd79ba0",
|
||||
"sender": "GPT-4",
|
||||
"text": "Of course! Here’s a detailed and engaging story:\n\n---\n\n### The Legend of Eldoria\n\nNestled between towering mountains and dense, ancient forests was the enigmatic kingdom of Eldoria. This realm, clo aked in perpetual twilight, was the stuff of legends. It was said that the land was blessed by the gods and guarded by mythical creatures. Eldoria was a place where magic and realism intertwined seamlessly, creating a land of beauty, wonder, and peril.\n\nIn the heart of this kingdom lay the grand city of Lumina, known",
|
||||
"unfinished": false,
|
||||
"updatedAt": "2024-05-28T18:14:20.349Z",
|
||||
"endpoint": "azureOpenAI",
|
||||
"finish_reason": "incomplete",
|
||||
"tokenCount": 110
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -19,7 +19,7 @@
|
||||
"endpoint": "openAI",
|
||||
"title": "Troubleshooting Python Virtual Environment Activation Issue"
|
||||
},
|
||||
"messagesTree": [
|
||||
"messages": [
|
||||
{
|
||||
"_id": "66326f3f04bed94b7f5be68d",
|
||||
"messageId": "9501f99d-9bbb-40cb-bbb2-16d79aeceb72",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, Constants, openAISettings, CacheKeys } = require('librechat-data-provider');
|
||||
const { createImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const logger = require('~/config/winston');
|
||||
|
||||
/**
|
||||
@@ -24,7 +25,7 @@ function getImporter(jsonData) {
|
||||
}
|
||||
|
||||
// For LibreChat
|
||||
if (jsonData.conversationId && jsonData.messagesTree) {
|
||||
if (jsonData.conversationId && (jsonData.messagesTree || jsonData.messages)) {
|
||||
logger.info('Importing LibreChat conversation');
|
||||
return importLibreChatConvo;
|
||||
}
|
||||
@@ -85,47 +86,92 @@ async function importLibreChatConvo(
|
||||
try {
|
||||
/** @type {ImportBatchBuilder} */
|
||||
const importBatchBuilder = builderFactory(requestUserId);
|
||||
importBatchBuilder.startConversation(EModelEndpoint.openAI);
|
||||
const options = jsonData.options || {};
|
||||
|
||||
/* Endpoint configuration */
|
||||
let endpoint = jsonData.endpoint ?? options.endpoint ?? EModelEndpoint.openAI;
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const endpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
|
||||
const endpointConfig = endpointsConfig?.[endpoint];
|
||||
if (!endpointConfig && endpointsConfig) {
|
||||
endpoint = Object.keys(endpointsConfig)[0];
|
||||
} else if (!endpointConfig) {
|
||||
endpoint = EModelEndpoint.openAI;
|
||||
}
|
||||
|
||||
importBatchBuilder.startConversation(endpoint);
|
||||
|
||||
let firstMessageDate = null;
|
||||
|
||||
const traverseMessages = (messages, parentMessageId = null) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text) {
|
||||
continue;
|
||||
}
|
||||
const messagesToImport = jsonData.messagesTree || jsonData.messages;
|
||||
|
||||
let savedMessage;
|
||||
if (message.sender?.toLowerCase() === 'user') {
|
||||
savedMessage = importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
} else {
|
||||
savedMessage = importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: jsonData.options.model,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
}
|
||||
if (jsonData.recursive) {
|
||||
/**
|
||||
* Recursively traverse the messages tree and save each message to the database.
|
||||
* @param {TMessage[]} messages
|
||||
* @param {string} parentMessageId
|
||||
*/
|
||||
const traverseMessages = async (messages, parentMessageId = null) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let savedMessage;
|
||||
if (message.sender?.toLowerCase() === 'user' || message.isCreatedByUser) {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
} else {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: options.model,
|
||||
parentMessageId: parentMessageId,
|
||||
});
|
||||
}
|
||||
|
||||
if (!firstMessageDate) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
|
||||
if (message.children && message.children.length > 0) {
|
||||
await traverseMessages(message.children, savedMessage.messageId);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
await traverseMessages(messagesToImport);
|
||||
} else if (messagesToImport) {
|
||||
const idMapping = new Map();
|
||||
|
||||
for (const message of messagesToImport) {
|
||||
if (!firstMessageDate) {
|
||||
firstMessageDate = new Date(message.createdAt);
|
||||
}
|
||||
const newMessageId = uuidv4();
|
||||
idMapping.set(message.messageId, newMessageId);
|
||||
|
||||
if (message.children) {
|
||||
traverseMessages(message.children, savedMessage.messageId);
|
||||
}
|
||||
const clonedMessage = {
|
||||
...message,
|
||||
messageId: newMessageId,
|
||||
parentMessageId:
|
||||
message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT
|
||||
? idMapping.get(message.parentMessageId) || Constants.NO_PARENT
|
||||
: Constants.NO_PARENT,
|
||||
};
|
||||
|
||||
importBatchBuilder.saveMessage(clonedMessage);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
throw new Error('Invalid LibreChat file format');
|
||||
}
|
||||
|
||||
traverseMessages(jsonData.messagesTree);
|
||||
|
||||
importBatchBuilder.finishConversation(jsonData.title, firstMessageDate);
|
||||
importBatchBuilder.finishConversation(jsonData.title, firstMessageDate ?? new Date(), options);
|
||||
await importBatchBuilder.saveBatch();
|
||||
logger.debug(`user: ${requestUserId} | Conversation "${jsonData.title}" imported`);
|
||||
} catch (error) {
|
||||
|
||||
@@ -1,70 +1,68 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const { EModelEndpoint, Constants } = require('librechat-data-provider');
|
||||
const { EModelEndpoint, Constants, openAISettings } = require('librechat-data-provider');
|
||||
const { bulkSaveConvos: _bulkSaveConvos } = require('~/models/Conversation');
|
||||
const { ImportBatchBuilder } = require('./importBatchBuilder');
|
||||
const { bulkSaveMessages } = require('~/models/Message');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getImporter } = require('./importers');
|
||||
|
||||
// Mocking the ImportBatchBuilder class and its methods
|
||||
jest.mock('./importBatchBuilder', () => {
|
||||
return {
|
||||
ImportBatchBuilder: jest.fn().mockImplementation(() => {
|
||||
return {
|
||||
startConversation: jest.fn().mockResolvedValue(undefined),
|
||||
addUserMessage: jest.fn().mockResolvedValue(undefined),
|
||||
addGptMessage: jest.fn().mockResolvedValue(undefined),
|
||||
saveMessage: jest.fn().mockResolvedValue(undefined),
|
||||
finishConversation: jest.fn().mockResolvedValue(undefined),
|
||||
saveBatch: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
}),
|
||||
};
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const mockedCacheGet = jest.fn();
|
||||
getLogStores.mockImplementation(() => ({
|
||||
get: mockedCacheGet,
|
||||
}));
|
||||
|
||||
// Mock the database methods
|
||||
jest.mock('~/models/Conversation', () => ({
|
||||
bulkSaveConvos: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Message', () => ({
|
||||
bulkSaveMessages: jest.fn(),
|
||||
}));
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('importChatGptConvo', () => {
|
||||
it('should import conversation correctly', async () => {
|
||||
const expectedNumberOfMessages = 19;
|
||||
const expectedNumberOfConversations = 2;
|
||||
// Given
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'chatgpt-export.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
// Spy on instance methods
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Then
|
||||
expect(mockedBuilderFactory).toHaveBeenCalledWith(requestUserId);
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
|
||||
expect(mockImportBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(mockImportBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenCalledTimes(
|
||||
expectedNumberOfConversations,
|
||||
); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(jsonData.length);
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should maintain correct message hierarchy (tree parent/children relationship)', async () => {
|
||||
// Prepare test data with known hierarchy
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'chatgpt-tree.json'), 'utf8'),
|
||||
);
|
||||
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
|
||||
// Then
|
||||
expect(mockedBuilderFactory).toHaveBeenCalledWith(requestUserId);
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const entries = Object.keys(jsonData[0].mapping);
|
||||
// Filter entries that should be processed (not system and have content)
|
||||
const messageEntries = entries.filter(
|
||||
(id) =>
|
||||
jsonData[0].mapping[id].message &&
|
||||
@@ -72,20 +70,16 @@ describe('importChatGptConvo', () => {
|
||||
jsonData[0].mapping[id].message.content,
|
||||
);
|
||||
|
||||
// Expect the saveMessage to be called for each valid entry
|
||||
expect(mockImportBatchBuilder.saveMessage).toHaveBeenCalledTimes(messageEntries.length);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(messageEntries.length);
|
||||
|
||||
const idToUUIDMap = new Map();
|
||||
// Map original IDs to dynamically generated UUIDs
|
||||
mockImportBatchBuilder.saveMessage.mock.calls.forEach((call, index) => {
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call, index) => {
|
||||
const originalId = messageEntries[index];
|
||||
idToUUIDMap.set(originalId, call[0].messageId);
|
||||
});
|
||||
|
||||
// Validate the UUID map contains all expected entries
|
||||
expect(idToUUIDMap.size).toBe(messageEntries.length);
|
||||
|
||||
// Validate correct parent-child relationships
|
||||
messageEntries.forEach((id) => {
|
||||
const { parent } = jsonData[0].mapping[id];
|
||||
|
||||
@@ -93,72 +87,110 @@ describe('importChatGptConvo', () => {
|
||||
? idToUUIDMap.get(parent) ?? Constants.NO_PARENT
|
||||
: Constants.NO_PARENT;
|
||||
|
||||
const actualParentId = idToUUIDMap.get(id)
|
||||
? mockImportBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === idToUUIDMap.get(id),
|
||||
const actualMessageId = idToUUIDMap.get(id);
|
||||
const actualParentId = actualMessageId
|
||||
? importBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === actualMessageId,
|
||||
)[0].parentMessageId
|
||||
: Constants.NO_PARENT;
|
||||
|
||||
expect(actualParentId).toBe(expectedParentId);
|
||||
});
|
||||
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('importLibreChatConvo', () => {
|
||||
it('should import conversation correctly', async () => {
|
||||
const expectedNumberOfMessages = 6;
|
||||
const expectedNumberOfConversations = 1;
|
||||
const jsonDataNonRecursiveBranches = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-opts-nonr-branches.json'), 'utf8'),
|
||||
);
|
||||
|
||||
// Given
|
||||
it('should import conversation correctly', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
[EModelEndpoint.openAI]: {},
|
||||
});
|
||||
const expectedNumberOfMessages = 6;
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-export.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
// Spy on instance methods
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Then
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
expect(mockImportBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(mockImportBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenCalledTimes(
|
||||
expectedNumberOfConversations,
|
||||
); // Adjust expected number
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(expectedNumberOfMessages);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(1);
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should import linear, non-recursive thread correctly with correct endpoint', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
[EModelEndpoint.azureOpenAI]: {},
|
||||
});
|
||||
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-linear.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
expect(bulkSaveMessages).toHaveBeenCalledTimes(1);
|
||||
|
||||
const messages = bulkSaveMessages.mock.calls[0][0];
|
||||
let lastMessageId = Constants.NO_PARENT;
|
||||
for (const message of messages) {
|
||||
expect(message.parentMessageId).toBe(lastMessageId);
|
||||
lastMessageId = message.messageId;
|
||||
}
|
||||
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.azureOpenAI);
|
||||
expect(importBatchBuilder.saveMessage).toHaveBeenCalledTimes(jsonData.messages.length);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should maintain correct message hierarchy (tree parent/children relationship)', async () => {
|
||||
// Load test data
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'librechat-tree.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'user-123';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
|
||||
// Then
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Create a map to track original message IDs to new UUIDs
|
||||
const idToUUIDMap = new Map();
|
||||
mockImportBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
const message = call[0];
|
||||
idToUUIDMap.set(message.originalMessageId, message.messageId);
|
||||
});
|
||||
|
||||
// Function to recursively check children
|
||||
const checkChildren = (children, parentId) => {
|
||||
children.forEach((child) => {
|
||||
const childUUID = idToUUIDMap.get(child.messageId);
|
||||
const expectedParentId = idToUUIDMap.get(parentId) ?? null;
|
||||
const messageCall = mockImportBatchBuilder.saveMessage.mock.calls.find(
|
||||
const messageCall = importBatchBuilder.saveMessage.mock.calls.find(
|
||||
(call) => call[0].messageId === childUUID,
|
||||
);
|
||||
|
||||
@@ -172,75 +204,203 @@ describe('importLibreChatConvo', () => {
|
||||
};
|
||||
|
||||
// Start hierarchy validation from root messages
|
||||
checkChildren(jsonData.messagesTree, null); // Assuming root messages have no parent
|
||||
checkChildren(jsonData.messages, null);
|
||||
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should maintain correct message hierarchy (non-recursive)', async () => {
|
||||
const jsonData = jsonDataNonRecursiveBranches;
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
const textToMessageMap = new Map();
|
||||
importBatchBuilder.saveMessage.mock.calls.forEach((call) => {
|
||||
const message = call[0];
|
||||
textToMessageMap.set(message.text, message);
|
||||
});
|
||||
|
||||
const relationships = {
|
||||
'tell me a long story': [
|
||||
'Of course! Settle in for a tale of adventure across time and space.\n\n---\n\nOnce upon a time in the small, sleepy village of Eldoria, there was a young woman named Elara who longed for adventure. Eldoria was a place of routine and simplicity, nestled between rolling hills and dense forests, but Elara always felt that there was more to the world than the boundaries',
|
||||
'Sure, I can craft a long story for you. Here it goes:\n\n### The Chronicles of Elenor: The Luminary of Anduril\n\nIn an age long forgotten by men, in a world kissed by the glow of dual suns, the Kingdom of Anduril flourished. Verdant valleys graced its land, majestic mountains shielded',
|
||||
],
|
||||
'tell me a long long story': [
|
||||
'Of course! Here’s a detailed and engaging story:\n\n---\n\n### The Legend of Eldoria\n\nNestled between towering mountains and dense, ancient forests was the enigmatic kingdom of Eldoria. This realm, clo aked in perpetual twilight, was the stuff of legends. It was said that the land was blessed by the gods and guarded by mythical creatures. Eldoria was a place where magic and realism intertwined seamlessly, creating a land of beauty, wonder, and peril.\n\nIn the heart of this kingdom lay the grand city of Lumina, known',
|
||||
],
|
||||
};
|
||||
|
||||
Object.keys(relationships).forEach((parentText) => {
|
||||
const parentMessage = textToMessageMap.get(parentText);
|
||||
const childrenTexts = relationships[parentText];
|
||||
|
||||
childrenTexts.forEach((childText) => {
|
||||
const childMessage = textToMessageMap.get(childText);
|
||||
expect(childMessage.parentMessageId).toBe(parentMessage.messageId);
|
||||
});
|
||||
});
|
||||
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should retain properties from the original conversation as well as new settings', async () => {
|
||||
mockedCacheGet.mockResolvedValue({
|
||||
[EModelEndpoint.azureOpenAI]: {},
|
||||
});
|
||||
const requestUserId = 'user-123';
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
|
||||
const importer = getImporter(jsonDataNonRecursiveBranches);
|
||||
await importer(jsonDataNonRecursiveBranches, requestUserId, () => importBatchBuilder);
|
||||
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(1);
|
||||
|
||||
const [_title, createdAt, originalConvo] = importBatchBuilder.finishConversation.mock.calls[0];
|
||||
const convo = importBatchBuilder.conversations[0];
|
||||
|
||||
expect(convo).toEqual({
|
||||
...jsonDataNonRecursiveBranches.options,
|
||||
user: requestUserId,
|
||||
conversationId: importBatchBuilder.conversationId,
|
||||
title: originalConvo.title || 'Imported Chat',
|
||||
createdAt: createdAt,
|
||||
updatedAt: createdAt,
|
||||
overrideTimestamp: true,
|
||||
endpoint: importBatchBuilder.endpoint,
|
||||
model: originalConvo.model || openAISettings.model.default,
|
||||
});
|
||||
|
||||
expect(convo.title).toBe('Original');
|
||||
expect(convo.createdAt).toBeInstanceOf(Date);
|
||||
expect(convo.endpoint).toBe(EModelEndpoint.azureOpenAI);
|
||||
expect(convo.model).toBe('gpt-4o');
|
||||
});
|
||||
|
||||
describe('finishConversation', () => {
|
||||
it('should retain properties from the original conversation as well as update with new settings', () => {
|
||||
const requestUserId = 'user-123';
|
||||
const builder = new ImportBatchBuilder(requestUserId);
|
||||
builder.conversationId = 'conv-id-123';
|
||||
builder.messages = [{ text: 'Hello, world!' }];
|
||||
|
||||
const originalConvo = {
|
||||
_id: 'old-convo-id',
|
||||
model: 'custom-model',
|
||||
};
|
||||
|
||||
builder.endpoint = 'test-endpoint';
|
||||
|
||||
const title = 'New Chat Title';
|
||||
const createdAt = new Date('2023-10-01T00:00:00Z');
|
||||
|
||||
const result = builder.finishConversation(title, createdAt, originalConvo);
|
||||
|
||||
expect(result).toEqual({
|
||||
conversation: {
|
||||
user: requestUserId,
|
||||
conversationId: builder.conversationId,
|
||||
title: 'New Chat Title',
|
||||
createdAt: createdAt,
|
||||
updatedAt: createdAt,
|
||||
overrideTimestamp: true,
|
||||
endpoint: 'test-endpoint',
|
||||
model: 'custom-model',
|
||||
},
|
||||
messages: builder.messages,
|
||||
});
|
||||
|
||||
expect(builder.conversations).toContainEqual({
|
||||
user: requestUserId,
|
||||
conversationId: builder.conversationId,
|
||||
title: 'New Chat Title',
|
||||
createdAt: createdAt,
|
||||
updatedAt: createdAt,
|
||||
overrideTimestamp: true,
|
||||
endpoint: 'test-endpoint',
|
||||
model: 'custom-model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should use default values if not provided in the original conversation or as parameters', () => {
|
||||
const requestUserId = 'user-123';
|
||||
const builder = new ImportBatchBuilder(requestUserId);
|
||||
builder.conversationId = 'conv-id-123';
|
||||
builder.messages = [{ text: 'Hello, world!' }];
|
||||
builder.endpoint = 'test-endpoint';
|
||||
const result = builder.finishConversation();
|
||||
expect(result.conversation.title).toBe('Imported Chat');
|
||||
expect(result.conversation.model).toBe(openAISettings.model.default);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('importChatBotUiConvo', () => {
|
||||
it('should import custom conversation correctly', async () => {
|
||||
// Given
|
||||
const jsonData = JSON.parse(
|
||||
fs.readFileSync(path.join(__dirname, '__data__', 'chatbotui-export.json'), 'utf8'),
|
||||
);
|
||||
const requestUserId = 'custom-user-456';
|
||||
const mockedBuilderFactory = jest.fn().mockReturnValue(new ImportBatchBuilder(requestUserId));
|
||||
const importBatchBuilder = new ImportBatchBuilder(requestUserId);
|
||||
|
||||
// Spy on instance methods
|
||||
jest.spyOn(importBatchBuilder, 'startConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveMessage');
|
||||
jest.spyOn(importBatchBuilder, 'addUserMessage');
|
||||
jest.spyOn(importBatchBuilder, 'addGptMessage');
|
||||
jest.spyOn(importBatchBuilder, 'finishConversation');
|
||||
jest.spyOn(importBatchBuilder, 'saveBatch');
|
||||
|
||||
// When
|
||||
const importer = getImporter(jsonData);
|
||||
await importer(jsonData, requestUserId, mockedBuilderFactory);
|
||||
await importer(jsonData, requestUserId, () => importBatchBuilder);
|
||||
|
||||
// Then
|
||||
const mockImportBatchBuilder = mockedBuilderFactory.mock.results[0].value;
|
||||
expect(mockImportBatchBuilder.startConversation).toHaveBeenCalledWith('openAI');
|
||||
|
||||
// User messages
|
||||
expect(mockImportBatchBuilder.addUserMessage).toHaveBeenCalledTimes(3);
|
||||
expect(mockImportBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.startConversation).toHaveBeenCalledWith(EModelEndpoint.openAI);
|
||||
expect(importBatchBuilder.addUserMessage).toHaveBeenCalledTimes(3);
|
||||
expect(importBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Hello what are you able to do?',
|
||||
);
|
||||
expect(mockImportBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.addUserMessage).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
'Give me the code that inverts binary tree in COBOL',
|
||||
);
|
||||
|
||||
// GPT messages
|
||||
expect(mockImportBatchBuilder.addGptMessage).toHaveBeenCalledTimes(3);
|
||||
expect(mockImportBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.addGptMessage).toHaveBeenCalledTimes(3);
|
||||
expect(importBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.stringMatching(/^Hello! As an AI developed by OpenAI/),
|
||||
'gpt-4-1106-preview',
|
||||
);
|
||||
expect(mockImportBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.addGptMessage).toHaveBeenNthCalledWith(
|
||||
3,
|
||||
expect.stringContaining('```cobol'),
|
||||
'gpt-3.5-turbo',
|
||||
);
|
||||
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenCalledTimes(2);
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenCalledTimes(2);
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Hello what are you able to do?',
|
||||
expect.any(Date),
|
||||
);
|
||||
expect(mockImportBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
expect(importBatchBuilder.finishConversation).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'Give me the code that inverts ...',
|
||||
expect.any(Date),
|
||||
);
|
||||
|
||||
expect(mockImportBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
expect(importBatchBuilder.saveBatch).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getImporter', () => {
|
||||
it('should throw an error if the import type is not supported', () => {
|
||||
// Given
|
||||
const jsonData = { unsupported: 'data' };
|
||||
|
||||
// When
|
||||
expect(() => getImporter(jsonData)).toThrow('Unsupported import type');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,6 +5,7 @@ const discordLogin = require('./discordStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const setupOpenId = require('./openidStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
|
||||
module.exports = {
|
||||
passportLogin,
|
||||
@@ -14,4 +15,5 @@ module.exports = {
|
||||
jwtLogin,
|
||||
facebookLogin,
|
||||
setupOpenId,
|
||||
ldapLogin,
|
||||
};
|
||||
|
||||
67
api/strategies/ldapStrategy.js
Normal file
67
api/strategies/ldapStrategy.js
Normal file
@@ -0,0 +1,67 @@
|
||||
const LdapStrategy = require('passport-ldapauth');
|
||||
const User = require('~/models/User');
|
||||
const fs = require('fs');
|
||||
|
||||
const ldapOptions = {
|
||||
server: {
|
||||
url: process.env.LDAP_URL,
|
||||
bindDN: process.env.LDAP_BIND_DN,
|
||||
bindCredentials: process.env.LDAP_BIND_CREDENTIALS,
|
||||
searchBase: process.env.LDAP_USER_SEARCH_BASE,
|
||||
searchFilter: process.env.LDAP_SEARCH_FILTER || 'mail={{username}}',
|
||||
searchAttributes: ['displayName', 'mail', 'uid', 'cn', 'name', 'commonname', 'givenName', 'sn'],
|
||||
...(process.env.LDAP_CA_CERT_PATH && {
|
||||
tlsOptions: { ca: [fs.readFileSync(process.env.LDAP_CA_CERT_PATH)] },
|
||||
}),
|
||||
},
|
||||
usernameField: 'email',
|
||||
passwordField: 'password',
|
||||
};
|
||||
|
||||
const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => {
|
||||
if (!userinfo) {
|
||||
return done(null, false, { message: 'Invalid credentials' });
|
||||
}
|
||||
|
||||
try {
|
||||
const firstName = userinfo.givenName;
|
||||
const familyName = userinfo.surname || userinfo.sn;
|
||||
const fullName =
|
||||
firstName && familyName
|
||||
? `${firstName} ${familyName}`
|
||||
: userinfo.cn ||
|
||||
userinfo.name ||
|
||||
userinfo.commonname ||
|
||||
userinfo.displayName ||
|
||||
userinfo.mail;
|
||||
|
||||
const username = userinfo.givenName || userinfo.mail;
|
||||
let user = await User.findOne({ email: userinfo.mail });
|
||||
if (user && user.provider !== 'ldap') {
|
||||
return done(null, false, { message: 'Invalid credentials' });
|
||||
}
|
||||
if (!user) {
|
||||
user = new User({
|
||||
provider: 'ldap',
|
||||
ldapId: userinfo.uid,
|
||||
username,
|
||||
email: userinfo.mail || '',
|
||||
emailVerified: true,
|
||||
name: fullName,
|
||||
});
|
||||
} else {
|
||||
user.provider = 'ldap';
|
||||
user.ldapId = userinfo.uid;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
}
|
||||
|
||||
await user.save();
|
||||
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
done(err);
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = ldapLogin;
|
||||
@@ -85,10 +85,21 @@ async function setupOpenId() {
|
||||
},
|
||||
async (tokenset, userinfo, done) => {
|
||||
try {
|
||||
logger.info(`[openidStrategy] verify login openidId: ${userinfo.sub}`);
|
||||
logger.debug('[openidStrategy] very login tokenset and userinfo', { tokenset, userinfo });
|
||||
|
||||
let user = await User.findOne({ openidId: userinfo.sub });
|
||||
logger.info(
|
||||
`[openidStrategy] user ${user ? 'found' : 'not found'} with openidId: ${userinfo.sub}`,
|
||||
);
|
||||
|
||||
if (!user) {
|
||||
user = await User.findOne({ email: userinfo.email });
|
||||
logger.info(
|
||||
`[openidStrategy] user ${user ? 'found' : 'not found'} with email: ${
|
||||
userinfo.email
|
||||
} for openidId: ${userinfo.sub}`,
|
||||
);
|
||||
}
|
||||
|
||||
let fullName = '';
|
||||
@@ -120,8 +131,8 @@ async function setupOpenId() {
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
console.error(
|
||||
`Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
logger.error(
|
||||
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -183,8 +194,21 @@ async function setupOpenId() {
|
||||
|
||||
await user.save();
|
||||
|
||||
logger.info(
|
||||
`[openidStrategy] login success openidId: ${user.openidId} username: ${user.username} email: ${user.email}`,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
},
|
||||
|
||||
@@ -272,6 +272,12 @@
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TPayload
|
||||
* @typedef {import('librechat-data-provider').TPayload} TPayload
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TAzureModelConfig
|
||||
* @typedef {import('librechat-data-provider').TAzureModelConfig} TAzureModelConfig
|
||||
@@ -349,6 +355,12 @@
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TProviderSchema
|
||||
* @typedef {import('librechat-data-provider').TProviderSchema} TProviderSchema
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports TEndpoint
|
||||
* @typedef {import('librechat-data-provider').TEndpoint} TEndpoint
|
||||
|
||||
@@ -79,9 +79,11 @@
|
||||
"react-markdown": "^8.0.6",
|
||||
"react-resizable-panels": "^1.0.9",
|
||||
"react-router-dom": "^6.11.2",
|
||||
"react-speech-recognition": "^3.10.0",
|
||||
"react-textarea-autosize": "^8.4.0",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"recoil": "^0.7.7",
|
||||
"regenerator-runtime": "^0.14.1",
|
||||
"rehype-highlight": "^6.0.0",
|
||||
"rehype-katex": "^6.0.2",
|
||||
"rehype-raw": "^6.1.1",
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
TLoginUser,
|
||||
AuthTypeEnum,
|
||||
TConversation,
|
||||
TStartupConfig,
|
||||
EModelEndpoint,
|
||||
AssistantsEndpoint,
|
||||
AuthorizationTypeEnum,
|
||||
@@ -21,6 +22,21 @@ import type {
|
||||
import type { UseMutationResult } from '@tanstack/react-query';
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
|
||||
export type AudioChunk = {
|
||||
audio: string;
|
||||
isFinal: boolean;
|
||||
alignment: {
|
||||
char_start_times_ms: number[];
|
||||
chars_durations_ms: number[];
|
||||
chars: string[];
|
||||
};
|
||||
normalizedAlignment: {
|
||||
char_start_times_ms: number[];
|
||||
chars_durations_ms: number[];
|
||||
chars: string[];
|
||||
};
|
||||
};
|
||||
|
||||
export type AssistantListItem = {
|
||||
id: string;
|
||||
name: string;
|
||||
@@ -37,6 +53,7 @@ export type LastSelectedModels = Record<EModelEndpoint, string>;
|
||||
export type LocalizeFunction = (phraseKey: string, ...values: string[]) => string;
|
||||
|
||||
export const mainTextareaId = 'prompt-textarea';
|
||||
export const globalAudioId = 'global-audio';
|
||||
|
||||
export enum IconContext {
|
||||
landing = 'landing',
|
||||
@@ -374,3 +391,13 @@ export interface SwitcherProps {
|
||||
endpointKeyProvided: boolean;
|
||||
isCollapsed: boolean;
|
||||
}
|
||||
|
||||
export type TLoginLayoutContext = {
|
||||
startupConfig: TStartupConfig | null;
|
||||
startupConfigError: unknown;
|
||||
isFetching: boolean;
|
||||
error: string | null;
|
||||
setError: React.Dispatch<React.SetStateAction<string | null>>;
|
||||
headerText: string;
|
||||
setHeaderText: React.Dispatch<React.SetStateAction<string>>;
|
||||
};
|
||||
|
||||
90
client/src/components/Auth/AuthLayout.tsx
Normal file
90
client/src/components/Auth/AuthLayout.tsx
Normal file
@@ -0,0 +1,90 @@
|
||||
import { ThemeSelector } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { BlinkAnimation } from './BlinkAnimation';
|
||||
import { TStartupConfig } from 'librechat-data-provider';
|
||||
import SocialLoginRender from './SocialLoginRender';
|
||||
import Footer from './Footer';
|
||||
|
||||
const ErrorRender = ({ children }: { children: React.ReactNode }) => (
|
||||
<div className="mt-16 flex justify-center">
|
||||
<div
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
role="alert"
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
function AuthLayout({
|
||||
children,
|
||||
header,
|
||||
isFetching,
|
||||
startupConfig,
|
||||
startupConfigError,
|
||||
pathname,
|
||||
error,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
header: React.ReactNode;
|
||||
isFetching: boolean;
|
||||
startupConfig: TStartupConfig | null | undefined;
|
||||
startupConfigError: unknown | null | undefined;
|
||||
pathname: string;
|
||||
error: string | null;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const DisplayError = () => {
|
||||
if (startupConfigError !== null && startupConfigError !== undefined) {
|
||||
return <ErrorRender>{localize('com_auth_error_login_server')}</ErrorRender>;
|
||||
} else if (error === 'com_auth_error_invalid_reset_token') {
|
||||
return (
|
||||
<ErrorRender>
|
||||
{localize('com_auth_error_invalid_reset_token')}{' '}
|
||||
<a className="font-semibold text-green-600 hover:underline" href="/forgot-password">
|
||||
{localize('com_auth_click_here')}
|
||||
</a>{' '}
|
||||
{localize('com_auth_to_try_again')}
|
||||
</ErrorRender>
|
||||
);
|
||||
} else if (error) {
|
||||
return <ErrorRender>{localize(error)}</ErrorRender>;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="relative flex min-h-screen flex-col bg-white dark:bg-gray-900">
|
||||
<BlinkAnimation active={isFetching}>
|
||||
<div className="mt-12 h-24 w-full bg-cover">
|
||||
<img src="/assets/logo.svg" className="h-full w-full object-contain" alt="Logo" />
|
||||
</div>
|
||||
</BlinkAnimation>
|
||||
<DisplayError />
|
||||
<div className="absolute bottom-0 left-0 md:m-4">
|
||||
<ThemeSelector />
|
||||
</div>
|
||||
|
||||
<div className="flex flex-grow items-center justify-center">
|
||||
<div className="w-authPageWidth overflow-hidden bg-white px-6 py-4 dark:bg-gray-900 sm:max-w-md sm:rounded-lg">
|
||||
{!startupConfigError && !isFetching && (
|
||||
<h1
|
||||
className="mb-4 text-center text-3xl font-semibold text-black dark:text-white"
|
||||
style={{ userSelect: 'none' }}
|
||||
>
|
||||
{header}
|
||||
</h1>
|
||||
)}
|
||||
{children}
|
||||
{(pathname.includes('login') || pathname.includes('register')) && (
|
||||
<SocialLoginRender startupConfig={startupConfig} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Footer startupConfig={startupConfig} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default AuthLayout;
|
||||
29
client/src/components/Auth/BlinkAnimation.tsx
Normal file
29
client/src/components/Auth/BlinkAnimation.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
export const BlinkAnimation = ({
|
||||
active,
|
||||
children,
|
||||
}: {
|
||||
active: boolean;
|
||||
children: React.ReactNode;
|
||||
}) => {
|
||||
const style = `
|
||||
@keyframes blink-animation {
|
||||
0%,
|
||||
100% {
|
||||
opacity: 1;
|
||||
}
|
||||
50% {
|
||||
opacity: 0;
|
||||
}
|
||||
}`;
|
||||
|
||||
if (!active) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>{style}</style>
|
||||
<div style={{ animation: 'blink-animation 3s infinite' }}>{children}</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
8
client/src/components/Auth/ErrorMessage.tsx
Normal file
8
client/src/components/Auth/ErrorMessage.tsx
Normal file
@@ -0,0 +1,8 @@
|
||||
export const ErrorMessage = ({ children }: { children: React.ReactNode }) => (
|
||||
<div
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
role="alert"
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
45
client/src/components/Auth/Footer.tsx
Normal file
45
client/src/components/Auth/Footer.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { TStartupConfig } from 'librechat-data-provider';
|
||||
|
||||
function Footer({ startupConfig }: { startupConfig: TStartupConfig | null | undefined }) {
|
||||
const localize = useLocalize();
|
||||
if (!startupConfig) {
|
||||
return null;
|
||||
}
|
||||
const privacyPolicy = startupConfig.interface?.privacyPolicy;
|
||||
const termsOfService = startupConfig.interface?.termsOfService;
|
||||
|
||||
const privacyPolicyRender = privacyPolicy?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={privacyPolicy.externalUrl}
|
||||
target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_privacy_policy')}
|
||||
</a>
|
||||
);
|
||||
|
||||
const termsOfServiceRender = termsOfService?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={termsOfService.externalUrl}
|
||||
target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_terms_of_service')}
|
||||
</a>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="align-end m-4 flex justify-center gap-2">
|
||||
{privacyPolicyRender}
|
||||
{privacyPolicyRender && termsOfServiceRender && (
|
||||
<div className="border-r-[1px] border-gray-300 dark:border-gray-600" />
|
||||
)}
|
||||
{termsOfServiceRender}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Footer;
|
||||
@@ -1,182 +1,30 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useGetStartupConfig } from 'librechat-data-provider/react-query';
|
||||
import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components';
|
||||
import { useOutletContext } from 'react-router-dom';
|
||||
import { useAuthContext } from '~/hooks/AuthContext';
|
||||
import { ThemeSelector } from '~/components/ui';
|
||||
import SocialButton from './SocialButton';
|
||||
import type { TLoginLayoutContext } from '~/common';
|
||||
import { ErrorMessage } from '~/components/Auth/ErrorMessage';
|
||||
import { getLoginError } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import LoginForm from './LoginForm';
|
||||
|
||||
function Login() {
|
||||
const { login, error, isAuthenticated } = useAuthContext();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const localize = useLocalize();
|
||||
const navigate = useNavigate();
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) {
|
||||
navigate('/c/new', { replace: true });
|
||||
}
|
||||
}, [isAuthenticated, navigate]);
|
||||
|
||||
if (!startupConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const socialLogins = startupConfig.socialLogins ?? [];
|
||||
|
||||
const providerComponents = {
|
||||
discord: (
|
||||
<SocialButton
|
||||
key="discord"
|
||||
enabled={startupConfig.discordLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="discord"
|
||||
Icon={DiscordIcon}
|
||||
label={localize('com_auth_discord_login')}
|
||||
id="discord"
|
||||
/>
|
||||
),
|
||||
facebook: (
|
||||
<SocialButton
|
||||
key="facebook"
|
||||
enabled={startupConfig.facebookLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="facebook"
|
||||
Icon={FacebookIcon}
|
||||
label={localize('com_auth_facebook_login')}
|
||||
id="facebook"
|
||||
/>
|
||||
),
|
||||
github: (
|
||||
<SocialButton
|
||||
key="github"
|
||||
enabled={startupConfig.githubLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="github"
|
||||
Icon={GithubIcon}
|
||||
label={localize('com_auth_github_login')}
|
||||
id="github"
|
||||
/>
|
||||
),
|
||||
google: (
|
||||
<SocialButton
|
||||
key="google"
|
||||
enabled={startupConfig.googleLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="google"
|
||||
Icon={GoogleIcon}
|
||||
label={localize('com_auth_google_login')}
|
||||
id="google"
|
||||
/>
|
||||
),
|
||||
openid: (
|
||||
<SocialButton
|
||||
key="openid"
|
||||
enabled={startupConfig.openidLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="openid"
|
||||
Icon={() =>
|
||||
startupConfig.openidImageUrl ? (
|
||||
<img src={startupConfig.openidImageUrl} alt="OpenID Logo" className="h-5 w-5" />
|
||||
) : (
|
||||
<OpenIDIcon />
|
||||
)
|
||||
}
|
||||
label={startupConfig.openidLabel}
|
||||
id="openid"
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
const privacyPolicy = startupConfig.interface?.privacyPolicy;
|
||||
const termsOfService = startupConfig.interface?.termsOfService;
|
||||
|
||||
const privacyPolicyRender = privacyPolicy?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={privacyPolicy.externalUrl}
|
||||
target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_privacy_policy')}
|
||||
</a>
|
||||
);
|
||||
|
||||
const termsOfServiceRender = termsOfService?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={termsOfService.externalUrl}
|
||||
target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_terms_of_service')}
|
||||
</a>
|
||||
);
|
||||
const { error, login } = useAuthContext();
|
||||
const { startupConfig } = useOutletContext<TLoginLayoutContext>();
|
||||
|
||||
return (
|
||||
<div className="relative flex min-h-screen flex-col bg-white dark:bg-gray-900">
|
||||
<div className="mt-12 h-24 w-full bg-cover">
|
||||
<img src="/assets/logo.svg" className="h-full w-full object-contain" alt="Logo" />
|
||||
</div>
|
||||
<div className="absolute bottom-0 left-0 md:m-4">
|
||||
<ThemeSelector />
|
||||
</div>
|
||||
<div className="flex flex-grow items-center justify-center">
|
||||
<div className="w-authPageWidth overflow-hidden bg-white px-6 py-4 dark:bg-gray-900 sm:max-w-md sm:rounded-lg">
|
||||
<h1
|
||||
className="mb-4 text-center text-3xl font-semibold text-black dark:text-white"
|
||||
style={{ userSelect: 'none' }}
|
||||
>
|
||||
{localize('com_auth_welcome_back')}
|
||||
</h1>
|
||||
{error && (
|
||||
<div
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
role="alert"
|
||||
>
|
||||
{localize(getLoginError(error))}
|
||||
</div>
|
||||
)}
|
||||
{startupConfig.emailLoginEnabled && <LoginForm onSubmit={login} />}
|
||||
{startupConfig.registrationEnabled && (
|
||||
<p className="my-4 text-center text-sm font-light text-gray-700 dark:text-white">
|
||||
{' '}
|
||||
{localize('com_auth_no_account')}{' '}
|
||||
<a href="/register" className="p-1 text-green-500">
|
||||
{localize('com_auth_sign_up')}
|
||||
</a>
|
||||
</p>
|
||||
)}
|
||||
{startupConfig.socialLoginEnabled && (
|
||||
<>
|
||||
{startupConfig.emailLoginEnabled && (
|
||||
<>
|
||||
<div className="relative mt-6 flex w-full items-center justify-center border border-t border-gray-300 uppercase dark:border-gray-600">
|
||||
<div className="absolute bg-white px-3 text-xs text-black dark:bg-gray-900 dark:text-white">
|
||||
Or
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-8" />
|
||||
</>
|
||||
)}
|
||||
<div className="mt-2">
|
||||
{socialLogins.map((provider) => providerComponents[provider] || null)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="align-end m-4 flex justify-center gap-2">
|
||||
{privacyPolicyRender}
|
||||
{privacyPolicyRender && termsOfServiceRender && (
|
||||
<div className="border-r-[1px] border-gray-300 dark:border-gray-600" />
|
||||
)}
|
||||
{termsOfServiceRender}
|
||||
</div>
|
||||
</div>
|
||||
<>
|
||||
{error && <ErrorMessage>{localize(getLoginError(error))}</ErrorMessage>}
|
||||
{startupConfig?.emailLoginEnabled && <LoginForm onSubmit={login} />}
|
||||
{startupConfig?.registrationEnabled && (
|
||||
<p className="my-4 text-center text-sm font-light text-gray-700 dark:text-white">
|
||||
{' '}
|
||||
{localize('com_auth_no_account')}{' '}
|
||||
<a href="/register" className="p-1 text-green-500">
|
||||
{localize('com_auth_sign_up')}
|
||||
</a>
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { useRegisterUserMutation, useGetStartupConfig } from 'librechat-data-provider/react-query';
|
||||
import type { TRegisterUser } from 'librechat-data-provider';
|
||||
import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components';
|
||||
import { ThemeSelector } from '~/components/ui';
|
||||
import SocialButton from './SocialButton';
|
||||
import { useNavigate, useOutletContext } from 'react-router-dom';
|
||||
import { useRegisterUserMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TRegisterUser, TError } from 'librechat-data-provider';
|
||||
import type { TLoginLayoutContext } from '~/common';
|
||||
import { ErrorMessage } from './ErrorMessage';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
const Registration: React.FC = () => {
|
||||
const navigate = useNavigate();
|
||||
const { data: startupConfig } = useGetStartupConfig();
|
||||
const localize = useLocalize();
|
||||
const { startupConfig, startupConfigError, isFetching } = useOutletContext<TLoginLayoutContext>();
|
||||
|
||||
const {
|
||||
register,
|
||||
@@ -31,10 +30,8 @@ const Registration: React.FC = () => {
|
||||
navigate('/c/new');
|
||||
} catch (error) {
|
||||
setError(true);
|
||||
//@ts-ignore - error is of type unknown
|
||||
if (error.response?.data?.message) {
|
||||
//@ts-ignore - error is of type unknown
|
||||
setErrorMessage(error.response?.data?.message);
|
||||
if ((error as TError).response?.data?.message) {
|
||||
setErrorMessage((error as TError).response?.data?.message ?? '');
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -45,12 +42,6 @@ const Registration: React.FC = () => {
|
||||
}
|
||||
}, [startupConfig, navigate]);
|
||||
|
||||
if (!startupConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const socialLogins = startupConfig.socialLogins ?? [];
|
||||
|
||||
const renderInput = (id: string, label: string, type: string, validation: object) => (
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
@@ -67,7 +58,7 @@ const Registration: React.FC = () => {
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
data-testid={id}
|
||||
></input>
|
||||
/>
|
||||
<label
|
||||
htmlFor={id}
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
@@ -83,120 +74,16 @@ const Registration: React.FC = () => {
|
||||
</div>
|
||||
);
|
||||
|
||||
const providerComponents = {
|
||||
discord: (
|
||||
<SocialButton
|
||||
key="discord"
|
||||
enabled={startupConfig.discordLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="discord"
|
||||
Icon={DiscordIcon}
|
||||
label={localize('com_auth_discord_login')}
|
||||
id="discord"
|
||||
/>
|
||||
),
|
||||
facebook: (
|
||||
<SocialButton
|
||||
key="facebook"
|
||||
enabled={startupConfig.facebookLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="facebook"
|
||||
Icon={FacebookIcon}
|
||||
label={localize('com_auth_facebook_login')}
|
||||
id="facebook"
|
||||
/>
|
||||
),
|
||||
github: (
|
||||
<SocialButton
|
||||
key="github"
|
||||
enabled={startupConfig.githubLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="github"
|
||||
Icon={GithubIcon}
|
||||
label={localize('com_auth_github_login')}
|
||||
id="github"
|
||||
/>
|
||||
),
|
||||
google: (
|
||||
<SocialButton
|
||||
key="google"
|
||||
enabled={startupConfig.googleLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="google"
|
||||
Icon={GoogleIcon}
|
||||
label={localize('com_auth_google_login')}
|
||||
id="google"
|
||||
/>
|
||||
),
|
||||
openid: (
|
||||
<SocialButton
|
||||
key="openid"
|
||||
enabled={startupConfig.openidLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="openid"
|
||||
Icon={() =>
|
||||
startupConfig.openidImageUrl ? (
|
||||
<img src={startupConfig.openidImageUrl} alt="OpenID Logo" className="h-5 w-5" />
|
||||
) : (
|
||||
<OpenIDIcon />
|
||||
)
|
||||
}
|
||||
label={startupConfig.openidLabel}
|
||||
id="openid"
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
const privacyPolicy = startupConfig.interface?.privacyPolicy;
|
||||
const termsOfService = startupConfig.interface?.termsOfService;
|
||||
|
||||
const privacyPolicyRender = privacyPolicy?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={privacyPolicy.externalUrl}
|
||||
target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_privacy_policy')}
|
||||
</a>
|
||||
);
|
||||
|
||||
const termsOfServiceRender = termsOfService?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={termsOfService.externalUrl}
|
||||
target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_terms_of_service')}
|
||||
</a>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="relative flex min-h-screen flex-col bg-white dark:bg-gray-900">
|
||||
<div className="mt-12 h-24 w-full bg-cover">
|
||||
<img src="/assets/logo.svg" className="h-full w-full object-contain" alt="Logo" />
|
||||
</div>
|
||||
<div className="absolute bottom-0 left-0 md:m-4">
|
||||
<ThemeSelector />
|
||||
</div>
|
||||
<div className="flex flex-grow items-center justify-center">
|
||||
<div className="w-authPageWidth overflow-hidden bg-white px-6 py-4 dark:bg-gray-900 sm:max-w-md sm:rounded-lg">
|
||||
<h1
|
||||
className="mb-4 text-center text-3xl font-semibold text-black dark:text-white"
|
||||
style={{ userSelect: 'none' }}
|
||||
>
|
||||
{localize('com_auth_create_account')}
|
||||
</h1>
|
||||
{error && (
|
||||
<div
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
role="alert"
|
||||
data-testid="registration-error"
|
||||
>
|
||||
{localize('com_auth_error_create')} {errorMessage}
|
||||
</div>
|
||||
)}
|
||||
<>
|
||||
{error && (
|
||||
<ErrorMessage>
|
||||
{localize('com_auth_error_create')} {errorMessage}
|
||||
</ErrorMessage>
|
||||
)}
|
||||
|
||||
{!startupConfigError && !isFetching && (
|
||||
<>
|
||||
<form
|
||||
className="mt-6"
|
||||
aria-label="Registration form"
|
||||
@@ -251,7 +138,8 @@ const Registration: React.FC = () => {
|
||||
},
|
||||
})}
|
||||
{renderInput('confirm_password', 'com_auth_password_confirm', 'password', {
|
||||
validate: (value) => value === password || localize('com_auth_password_not_match'),
|
||||
validate: (value: string) =>
|
||||
value === password || localize('com_auth_password_not_match'),
|
||||
})}
|
||||
<div className="mt-6">
|
||||
<button
|
||||
@@ -264,39 +152,16 @@ const Registration: React.FC = () => {
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<p className="my-4 text-center text-sm font-light text-gray-700 dark:text-white">
|
||||
{localize('com_auth_already_have_account')}{' '}
|
||||
<a href="/login" aria-label="Login" className="p-1 text-green-500">
|
||||
{localize('com_auth_login')}
|
||||
</a>
|
||||
</p>
|
||||
{startupConfig.socialLoginEnabled && (
|
||||
<>
|
||||
{startupConfig.emailLoginEnabled && (
|
||||
<>
|
||||
<div className="relative mt-6 flex w-full items-center justify-center border border-t border-gray-300 uppercase dark:border-gray-600">
|
||||
<div className="absolute bg-white px-3 text-xs text-black dark:bg-gray-900 dark:text-white">
|
||||
Or
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-8" />
|
||||
</>
|
||||
)}
|
||||
<div className="mt-2">
|
||||
{socialLogins.map((provider) => providerComponents[provider] || null)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="align-end m-4 flex justify-center gap-2">
|
||||
{privacyPolicyRender}
|
||||
{privacyPolicyRender && termsOfServiceRender && (
|
||||
<div className="border-r-[1px] border-gray-300 dark:border-gray-600" />
|
||||
)}
|
||||
{termsOfServiceRender}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useState, useEffect } from 'react';
|
||||
import {
|
||||
useGetStartupConfig,
|
||||
useRequestPasswordResetMutation,
|
||||
} from 'librechat-data-provider/react-query';
|
||||
import { useOutletContext } from 'react-router-dom';
|
||||
import { useRequestPasswordResetMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TRequestPasswordReset, TRequestPasswordResetResponse } from 'librechat-data-provider';
|
||||
import { ThemeSelector } from '~/components/ui';
|
||||
import type { TLoginLayoutContext } from '~/common';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function RequestPasswordReset() {
|
||||
@@ -15,187 +13,135 @@ function RequestPasswordReset() {
|
||||
handleSubmit,
|
||||
formState: { errors },
|
||||
} = useForm<TRequestPasswordReset>();
|
||||
const requestPasswordReset = useRequestPasswordResetMutation();
|
||||
const config = useGetStartupConfig();
|
||||
const [requestError, setRequestError] = useState<boolean>(false);
|
||||
const [resetLink, setResetLink] = useState<string | undefined>(undefined);
|
||||
const [headerText, setHeaderText] = useState<string>('');
|
||||
const [bodyText, setBodyText] = useState<React.ReactNode | undefined>(undefined);
|
||||
const { startupConfig, setError, setHeaderText } = useOutletContext<TLoginLayoutContext>();
|
||||
|
||||
const requestPasswordReset = useRequestPasswordResetMutation();
|
||||
|
||||
const onSubmit = (data: TRequestPasswordReset) => {
|
||||
requestPasswordReset.mutate(data, {
|
||||
onSuccess: (data: TRequestPasswordResetResponse) => {
|
||||
console.log('emailEnabled: ', config.data?.emailEnabled);
|
||||
if (!config.data?.emailEnabled) {
|
||||
if (!startupConfig?.emailEnabled) {
|
||||
setResetLink(data.link);
|
||||
}
|
||||
},
|
||||
onError: () => {
|
||||
setRequestError(true);
|
||||
setError('com_auth_error_reset_password');
|
||||
setTimeout(() => {
|
||||
setRequestError(false);
|
||||
setError(null);
|
||||
}, 5000);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (requestPasswordReset.isSuccess) {
|
||||
if (config.data?.emailEnabled) {
|
||||
setHeaderText(localize('com_auth_reset_password_link_sent'));
|
||||
setBodyText(localize('com_auth_reset_password_email_sent'));
|
||||
} else {
|
||||
setHeaderText(localize('com_auth_reset_password'));
|
||||
setBodyText(
|
||||
<span>
|
||||
{localize('com_auth_click')}{' '}
|
||||
<a className="text-green-500 hover:underline" href={resetLink}>
|
||||
{localize('com_auth_here')}
|
||||
</a>{' '}
|
||||
{localize('com_auth_to_reset_your_password')}
|
||||
</span>,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
setHeaderText(localize('com_auth_reset_password'));
|
||||
setBodyText(undefined);
|
||||
}
|
||||
}, [requestPasswordReset.isSuccess, config.data?.emailEnabled, resetLink, localize]);
|
||||
|
||||
const renderFormContent = () => {
|
||||
if (bodyText) {
|
||||
return (
|
||||
<div
|
||||
className="relative mt-4 rounded border border-green-400 bg-green-100 px-4 py-3 text-green-700 dark:bg-green-900 dark:text-white"
|
||||
role="alert"
|
||||
>
|
||||
{bodyText}
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<form
|
||||
className="mt-6"
|
||||
aria-label="Password reset form"
|
||||
method="POST"
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
>
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="email"
|
||||
id="email"
|
||||
autoComplete="off"
|
||||
aria-label={localize('com_auth_email')}
|
||||
{...register('email', {
|
||||
required: localize('com_auth_email_required'),
|
||||
minLength: {
|
||||
value: 3,
|
||||
message: localize('com_auth_email_min_length'),
|
||||
},
|
||||
maxLength: {
|
||||
value: 120,
|
||||
message: localize('com_auth_email_max_length'),
|
||||
},
|
||||
pattern: {
|
||||
value: /\S+@\S+\.\S+/,
|
||||
message: localize('com_auth_email_pattern'),
|
||||
},
|
||||
})}
|
||||
aria-invalid={!!errors.email}
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
></input>
|
||||
<label
|
||||
htmlFor="email"
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
>
|
||||
{localize('com_auth_email_address')}
|
||||
</label>
|
||||
</div>
|
||||
{errors.email && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{/* @ts-ignore not sure why */}
|
||||
{errors.email.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-6">
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!!errors.email}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-550 focus:bg-green-550 focus:outline-none disabled:cursor-not-allowed disabled:hover:bg-green-500"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
<div className="mt-4 flex justify-center">
|
||||
<a href="/login" className="text-sm text-green-500">
|
||||
{localize('com_auth_back_to_login')}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (!requestPasswordReset.isSuccess) {
|
||||
setHeaderText('com_auth_reset_password');
|
||||
setBodyText(undefined);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
const privacyPolicy = config.data?.interface?.privacyPolicy;
|
||||
const termsOfService = config.data?.interface?.termsOfService;
|
||||
if (startupConfig?.emailEnabled) {
|
||||
setHeaderText('com_auth_reset_password_link_sent');
|
||||
setBodyText(localize('com_auth_reset_password_email_sent'));
|
||||
return;
|
||||
}
|
||||
|
||||
const privacyPolicyRender = privacyPolicy?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={privacyPolicy.externalUrl}
|
||||
target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_privacy_policy')}
|
||||
</a>
|
||||
);
|
||||
setHeaderText('com_auth_reset_password');
|
||||
setBodyText(
|
||||
<span>
|
||||
{localize('com_auth_click')}{' '}
|
||||
<a className="text-green-500 hover:underline" href={resetLink}>
|
||||
{localize('com_auth_here')}
|
||||
</a>{' '}
|
||||
{localize('com_auth_to_reset_your_password')}
|
||||
</span>,
|
||||
);
|
||||
}, [
|
||||
requestPasswordReset.isSuccess,
|
||||
startupConfig?.emailEnabled,
|
||||
resetLink,
|
||||
localize,
|
||||
setHeaderText,
|
||||
bodyText,
|
||||
]);
|
||||
|
||||
const termsOfServiceRender = termsOfService?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={termsOfService.externalUrl}
|
||||
target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_terms_of_service')}
|
||||
</a>
|
||||
);
|
||||
if (bodyText) {
|
||||
return (
|
||||
<div
|
||||
className="relative mt-4 rounded border border-green-400 bg-green-100 px-4 py-3 text-green-700 dark:bg-green-900 dark:text-white"
|
||||
role="alert"
|
||||
>
|
||||
{bodyText}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative flex min-h-screen flex-col bg-white dark:bg-gray-900">
|
||||
<div className="mt-12 h-24 w-full bg-cover">
|
||||
<img src="/assets/logo.svg" className="h-full w-full object-contain" alt="Logo" />
|
||||
<form
|
||||
className="mt-6"
|
||||
aria-label="Password reset form"
|
||||
method="POST"
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
>
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="email"
|
||||
id="email"
|
||||
autoComplete="off"
|
||||
aria-label={localize('com_auth_email')}
|
||||
{...register('email', {
|
||||
required: localize('com_auth_email_required'),
|
||||
minLength: {
|
||||
value: 3,
|
||||
message: localize('com_auth_email_min_length'),
|
||||
},
|
||||
maxLength: {
|
||||
value: 120,
|
||||
message: localize('com_auth_email_max_length'),
|
||||
},
|
||||
pattern: {
|
||||
value: /\S+@\S+\.\S+/,
|
||||
message: localize('com_auth_email_pattern'),
|
||||
},
|
||||
})}
|
||||
aria-invalid={!!errors.email}
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
/>
|
||||
<label
|
||||
htmlFor="email"
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
>
|
||||
{localize('com_auth_email_address')}
|
||||
</label>
|
||||
</div>
|
||||
{errors.email && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{errors.email.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="absolute bottom-0 left-0 md:m-4">
|
||||
<ThemeSelector />
|
||||
</div>
|
||||
<div className="flex flex-grow items-center justify-center">
|
||||
<div className="w-authPageWidth overflow-hidden bg-white px-6 py-4 dark:bg-gray-900 sm:max-w-md sm:rounded-lg">
|
||||
<h1 className="mb-4 text-center text-3xl font-semibold text-black dark:text-white">
|
||||
{headerText}
|
||||
</h1>
|
||||
{requestError && (
|
||||
<div
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
role="alert"
|
||||
>
|
||||
{localize('com_auth_error_reset_password')}
|
||||
</div>
|
||||
)}
|
||||
{renderFormContent()}
|
||||
<div className="mt-6">
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!!errors.email}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-550 focus:bg-green-550 focus:outline-none disabled:cursor-not-allowed disabled:hover:bg-green-500"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
<div className="mt-4 flex justify-center">
|
||||
<a href="/login" className="text-sm text-green-500">
|
||||
{localize('com_auth_back_to_login')}
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
<div className="align-end m-4 flex justify-center gap-2">
|
||||
{privacyPolicyRender}
|
||||
{privacyPolicyRender && termsOfServiceRender && (
|
||||
<div className="border-r-[1px] border-gray-300 dark:border-gray-600" />
|
||||
)}
|
||||
{termsOfServiceRender}
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useState } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useOutletContext } from 'react-router-dom';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { useGetStartupConfig, useResetPasswordMutation } from 'librechat-data-provider/react-query';
|
||||
import { useResetPasswordMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TResetPassword } from 'librechat-data-provider';
|
||||
import { ThemeSelector } from '~/components/ui';
|
||||
import type { TLoginLayoutContext } from '~/common';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function ResetPassword() {
|
||||
@@ -14,218 +14,146 @@ function ResetPassword() {
|
||||
watch,
|
||||
formState: { errors },
|
||||
} = useForm<TResetPassword>();
|
||||
const resetPassword = useResetPasswordMutation();
|
||||
const config = useGetStartupConfig();
|
||||
const [resetError, setResetError] = useState<boolean>(false);
|
||||
const [params] = useSearchParams();
|
||||
const navigate = useNavigate();
|
||||
const [params] = useSearchParams();
|
||||
const password = watch('password');
|
||||
const resetPassword = useResetPasswordMutation();
|
||||
const { setError, setHeaderText } = useOutletContext<TLoginLayoutContext>();
|
||||
|
||||
const onSubmit = (data: TResetPassword) => {
|
||||
resetPassword.mutate(data, {
|
||||
onError: () => {
|
||||
setResetError(true);
|
||||
setError('com_auth_error_invalid_reset_token');
|
||||
},
|
||||
onSuccess: () => {
|
||||
setHeaderText('com_auth_reset_password_success');
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const privacyPolicy = config.data?.interface?.privacyPolicy;
|
||||
const termsOfService = config.data?.interface?.termsOfService;
|
||||
|
||||
const privacyPolicyRender = privacyPolicy?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={privacyPolicy.externalUrl}
|
||||
target={privacyPolicy.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_privacy_policy')}
|
||||
</a>
|
||||
);
|
||||
|
||||
const termsOfServiceRender = termsOfService?.externalUrl && (
|
||||
<a
|
||||
className="text-sm text-green-500"
|
||||
href={termsOfService.externalUrl}
|
||||
target={termsOfService.openNewTab ? '_blank' : undefined}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{localize('com_ui_terms_of_service')}
|
||||
</a>
|
||||
);
|
||||
|
||||
if (resetPassword.isSuccess) {
|
||||
return (
|
||||
<div className="flex min-h-screen flex-col items-center justify-center bg-white pt-6 dark:bg-gray-900 sm:pt-0">
|
||||
<div className="absolute bottom-0 left-0 m-4">
|
||||
<ThemeSelector />
|
||||
<>
|
||||
<div
|
||||
className="relative mb-8 mt-4 rounded border border-green-400 bg-green-100 px-4 py-3 text-center text-green-700 dark:bg-gray-900 dark:text-white"
|
||||
role="alert"
|
||||
>
|
||||
{localize('com_auth_login_with_new_password')}
|
||||
</div>
|
||||
<div className="mt-6 w-authPageWidth overflow-hidden bg-white px-6 py-4 dark:bg-gray-900 sm:max-w-md sm:rounded-lg">
|
||||
<h1 className="mb-4 text-center text-3xl font-semibold text-black dark:text-white">
|
||||
{localize('com_auth_reset_password_success')}
|
||||
</h1>
|
||||
<div
|
||||
className="relative mb-8 mt-4 rounded border border-green-400 bg-green-100 px-4 py-3 text-center text-green-700 dark:bg-gray-900 dark:text-white"
|
||||
role="alert"
|
||||
>
|
||||
{localize('com_auth_login_with_new_password')}
|
||||
</div>
|
||||
<button
|
||||
onClick={() => navigate('/login')}
|
||||
aria-label={localize('com_auth_sign_in')}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-600 focus:bg-green-600 focus:outline-none"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div className="relative flex min-h-screen flex-col bg-white dark:bg-gray-900">
|
||||
<div className="mt-12 h-24 w-full bg-cover">
|
||||
<img src="/assets/logo.svg" className="h-full w-full object-contain" alt="Logo" />
|
||||
</div>
|
||||
<div className="absolute bottom-0 left-0 md:m-4">
|
||||
<ThemeSelector />
|
||||
</div>
|
||||
<div className="flex flex-grow items-center justify-center">
|
||||
<div className="w-authPageWidth overflow-hidden bg-white px-6 py-4 dark:bg-gray-900 sm:max-w-md sm:rounded-lg">
|
||||
<h1 className="mb-4 text-center text-3xl font-semibold text-black dark:text-white">
|
||||
{localize('com_auth_reset_password')}
|
||||
</h1>
|
||||
{resetError && (
|
||||
<div
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
role="alert"
|
||||
>
|
||||
{localize('com_auth_error_invalid_reset_token')}{' '}
|
||||
<a className="font-semibold text-green-600 hover:underline" href="/forgot-password">
|
||||
{localize('com_auth_click_here')}
|
||||
</a>{' '}
|
||||
{localize('com_auth_to_try_again')}
|
||||
</div>
|
||||
)}
|
||||
<form
|
||||
className="mt-6"
|
||||
aria-label="Password reset form"
|
||||
method="POST"
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
>
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="hidden"
|
||||
id="token"
|
||||
// @ts-ignore - Type 'string | null' is not assignable to type 'string | number | readonly string[] | undefined'
|
||||
value={params.get('token')}
|
||||
{...register('token', { required: 'Unable to process: No valid reset token' })}
|
||||
/>
|
||||
<input
|
||||
type="hidden"
|
||||
id="userId"
|
||||
// @ts-ignore - Type 'string | null' is not assignable to type 'string | number | readonly string[] | undefined'
|
||||
value={params.get('userId')}
|
||||
{...register('userId', { required: 'Unable to process: No valid user id' })}
|
||||
/>
|
||||
<input
|
||||
type="password"
|
||||
id="password"
|
||||
autoComplete="current-password"
|
||||
aria-label={localize('com_auth_password')}
|
||||
{...register('password', {
|
||||
required: localize('com_auth_password_required'),
|
||||
minLength: {
|
||||
value: 8,
|
||||
message: localize('com_auth_password_min_length'),
|
||||
},
|
||||
maxLength: {
|
||||
value: 128,
|
||||
message: localize('com_auth_password_max_length'),
|
||||
},
|
||||
})}
|
||||
aria-invalid={!!errors.password}
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
></input>
|
||||
<label
|
||||
htmlFor="password"
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
>
|
||||
{localize('com_auth_password')}
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{errors.password && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{/* @ts-ignore not sure why */}
|
||||
{errors.password.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="password"
|
||||
id="confirm_password"
|
||||
aria-label={localize('com_auth_password_confirm')}
|
||||
{...register('confirm_password', {
|
||||
validate: (value) =>
|
||||
value === password || localize('com_auth_password_not_match'),
|
||||
})}
|
||||
aria-invalid={!!errors.confirm_password}
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
></input>
|
||||
<label
|
||||
htmlFor="confirm_password"
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
>
|
||||
{localize('com_auth_password_confirm')}
|
||||
</label>
|
||||
</div>
|
||||
{errors.confirm_password && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{/* @ts-ignore not sure why */}
|
||||
{errors.confirm_password.message}
|
||||
</span>
|
||||
)}
|
||||
{errors.token && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{/* @ts-ignore not sure why */}
|
||||
{errors.token.message}
|
||||
</span>
|
||||
)}
|
||||
{errors.userId && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{/* @ts-ignore not sure why */}
|
||||
{errors.userId.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-6">
|
||||
<button
|
||||
disabled={!!errors.password || !!errors.confirm_password}
|
||||
type="submit"
|
||||
aria-label={localize('com_auth_submit_registration')}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-all duration-300 hover:bg-green-550 focus:bg-green-550 focus:outline-none"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
<div className="align-end m-4 flex justify-center gap-2">
|
||||
{privacyPolicyRender}
|
||||
{privacyPolicyRender && termsOfServiceRender && (
|
||||
<div className="border-r-[1px] border-gray-300 dark:border-gray-600" />
|
||||
)}
|
||||
{termsOfServiceRender}
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => navigate('/login')}
|
||||
aria-label={localize('com_auth_sign_in')}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-600 focus:bg-green-600 focus:outline-none"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<form
|
||||
className="mt-6"
|
||||
aria-label="Password reset form"
|
||||
method="POST"
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
>
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="hidden"
|
||||
id="token"
|
||||
value={params.get('token') ?? ''}
|
||||
{...register('token', { required: 'Unable to process: No valid reset token' })}
|
||||
/>
|
||||
<input
|
||||
type="hidden"
|
||||
id="userId"
|
||||
value={params.get('userId') ?? ''}
|
||||
{...register('userId', { required: 'Unable to process: No valid user id' })}
|
||||
/>
|
||||
<input
|
||||
type="password"
|
||||
id="password"
|
||||
autoComplete="current-password"
|
||||
aria-label={localize('com_auth_password')}
|
||||
{...register('password', {
|
||||
required: localize('com_auth_password_required'),
|
||||
minLength: {
|
||||
value: 8,
|
||||
message: localize('com_auth_password_min_length'),
|
||||
},
|
||||
maxLength: {
|
||||
value: 128,
|
||||
message: localize('com_auth_password_max_length'),
|
||||
},
|
||||
})}
|
||||
aria-invalid={!!errors.password}
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
/>
|
||||
<label
|
||||
htmlFor="password"
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
>
|
||||
{localize('com_auth_password')}
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{errors.password && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{errors.password.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="password"
|
||||
id="confirm_password"
|
||||
aria-label={localize('com_auth_password_confirm')}
|
||||
{...register('confirm_password', {
|
||||
validate: (value) => value === password || localize('com_auth_password_not_match'),
|
||||
})}
|
||||
aria-invalid={!!errors.confirm_password}
|
||||
className="webkit-dark-styles peer block w-full appearance-none rounded-md border border-gray-300 bg-transparent px-3.5 pb-3.5 pt-4 text-sm text-gray-900 focus:border-green-500 focus:outline-none focus:ring-0 dark:border-gray-600 dark:text-white dark:focus:border-green-500"
|
||||
placeholder=" "
|
||||
/>
|
||||
<label
|
||||
htmlFor="confirm_password"
|
||||
className="absolute start-1 top-2 z-10 origin-[0] -translate-y-4 scale-75 transform bg-white px-3 text-sm text-gray-500 duration-100 peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100 peer-focus:top-2 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-3 peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400 dark:peer-focus:text-green-500 rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4"
|
||||
>
|
||||
{localize('com_auth_password_confirm')}
|
||||
</label>
|
||||
</div>
|
||||
{errors.confirm_password && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{errors.confirm_password.message}
|
||||
</span>
|
||||
)}
|
||||
{errors.token && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{errors.token.message}
|
||||
</span>
|
||||
)}
|
||||
{errors.userId && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
{errors.userId.message}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-6">
|
||||
<button
|
||||
disabled={!!errors.password || !!errors.confirm_password}
|
||||
type="submit"
|
||||
aria-label={localize('com_auth_submit_registration')}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-all duration-300 hover:bg-green-550 focus:bg-green-550 focus:outline-none"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
export default ResetPassword;
|
||||
|
||||
105
client/src/components/Auth/SocialLoginRender.tsx
Normal file
105
client/src/components/Auth/SocialLoginRender.tsx
Normal file
@@ -0,0 +1,105 @@
|
||||
import { GoogleIcon, FacebookIcon, OpenIDIcon, GithubIcon, DiscordIcon } from '~/components';
|
||||
|
||||
import SocialButton from './SocialButton';
|
||||
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
import { TStartupConfig } from 'librechat-data-provider';
|
||||
|
||||
function SocialLoginRender({
|
||||
startupConfig,
|
||||
}: {
|
||||
startupConfig: TStartupConfig | null | undefined;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
|
||||
if (!startupConfig) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const providerComponents = {
|
||||
discord: startupConfig?.discordLoginEnabled && (
|
||||
<SocialButton
|
||||
key="discord"
|
||||
enabled={startupConfig.discordLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="discord"
|
||||
Icon={DiscordIcon}
|
||||
label={localize('com_auth_discord_login')}
|
||||
id="discord"
|
||||
/>
|
||||
),
|
||||
facebook: startupConfig?.facebookLoginEnabled && (
|
||||
<SocialButton
|
||||
key="facebook"
|
||||
enabled={startupConfig.facebookLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="facebook"
|
||||
Icon={FacebookIcon}
|
||||
label={localize('com_auth_facebook_login')}
|
||||
id="facebook"
|
||||
/>
|
||||
),
|
||||
github: startupConfig?.githubLoginEnabled && (
|
||||
<SocialButton
|
||||
key="github"
|
||||
enabled={startupConfig.githubLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="github"
|
||||
Icon={GithubIcon}
|
||||
label={localize('com_auth_github_login')}
|
||||
id="github"
|
||||
/>
|
||||
),
|
||||
google: startupConfig?.googleLoginEnabled && (
|
||||
<SocialButton
|
||||
key="google"
|
||||
enabled={startupConfig.googleLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="google"
|
||||
Icon={GoogleIcon}
|
||||
label={localize('com_auth_google_login')}
|
||||
id="google"
|
||||
/>
|
||||
),
|
||||
openid: startupConfig?.openidLoginEnabled && (
|
||||
<SocialButton
|
||||
key="openid"
|
||||
enabled={startupConfig.openidLoginEnabled}
|
||||
serverDomain={startupConfig.serverDomain}
|
||||
oauthPath="openid"
|
||||
Icon={() =>
|
||||
startupConfig.openidImageUrl ? (
|
||||
<img src={startupConfig.openidImageUrl} alt="OpenID Logo" className="h-5 w-5" />
|
||||
) : (
|
||||
<OpenIDIcon />
|
||||
)
|
||||
}
|
||||
label={startupConfig.openidLabel}
|
||||
id="openid"
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
return (
|
||||
startupConfig.socialLoginEnabled && (
|
||||
<>
|
||||
{startupConfig.emailLoginEnabled && (
|
||||
<>
|
||||
<div className="relative mt-6 flex w-full items-center justify-center border border-t border-gray-300 uppercase dark:border-gray-600">
|
||||
<div className="absolute bg-white px-3 text-xs text-black dark:bg-gray-900 dark:text-white">
|
||||
Or
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-8" />
|
||||
</>
|
||||
)}
|
||||
<div className="mt-2">
|
||||
{startupConfig.socialLogins?.map((provider) => providerComponents[provider] || null)}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export default SocialLoginRender;
|
||||
@@ -1,10 +1,34 @@
|
||||
import { render, waitFor } from 'test/layout-test-utils';
|
||||
import reactRouter from 'react-router-dom';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import Login from '../Login';
|
||||
import { render, waitFor } from 'test/layout-test-utils';
|
||||
import * as mockDataProvider from 'librechat-data-provider/react-query';
|
||||
import type { TStartupConfig } from 'librechat-data-provider';
|
||||
import AuthLayout from '~/components/Auth/AuthLayout';
|
||||
import Login from '~/components/Auth/Login';
|
||||
|
||||
jest.mock('librechat-data-provider/react-query');
|
||||
|
||||
const mockStartupConfig = {
|
||||
isFetching: false,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
data: {
|
||||
socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'],
|
||||
discordLoginEnabled: true,
|
||||
facebookLoginEnabled: true,
|
||||
githubLoginEnabled: true,
|
||||
googleLoginEnabled: true,
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
ldapLoginEnabled: false,
|
||||
registrationEnabled: true,
|
||||
emailLoginEnabled: true,
|
||||
socialLoginEnabled: true,
|
||||
serverDomain: 'mock-server',
|
||||
},
|
||||
};
|
||||
|
||||
const setup = ({
|
||||
useGetUserQueryReturnValue = {
|
||||
isLoading: false,
|
||||
@@ -27,24 +51,7 @@ const setup = ({
|
||||
user: {},
|
||||
},
|
||||
},
|
||||
useGetStartupCongfigReturnValue = {
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
data: {
|
||||
socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'],
|
||||
discordLoginEnabled: true,
|
||||
facebookLoginEnabled: true,
|
||||
githubLoginEnabled: true,
|
||||
googleLoginEnabled: true,
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
registrationEnabled: true,
|
||||
emailLoginEnabled: true,
|
||||
socialLoginEnabled: true,
|
||||
serverDomain: 'mock-server',
|
||||
},
|
||||
},
|
||||
useGetStartupCongfigReturnValue = mockStartupConfig,
|
||||
} = {}) => {
|
||||
const mockUseLoginUser = jest
|
||||
.spyOn(mockDataProvider, 'useLoginUserMutation')
|
||||
@@ -62,16 +69,38 @@ const setup = ({
|
||||
.spyOn(mockDataProvider, 'useRefreshTokenMutation')
|
||||
//@ts-ignore - we don't need all parameters of the QueryObserverSuccessResult
|
||||
.mockReturnValue(useRefreshTokenMutationReturnValue);
|
||||
const renderResult = render(<Login />);
|
||||
const mockUseOutletContext = jest.spyOn(reactRouter, 'useOutletContext').mockReturnValue({
|
||||
startupConfig: useGetStartupCongfigReturnValue.data,
|
||||
});
|
||||
const renderResult = render(
|
||||
<AuthLayout
|
||||
startupConfig={useGetStartupCongfigReturnValue.data as TStartupConfig}
|
||||
isFetching={useGetStartupCongfigReturnValue.isFetching}
|
||||
error={null}
|
||||
startupConfigError={null}
|
||||
header={'Welcome back'}
|
||||
pathname="login"
|
||||
>
|
||||
<Login />
|
||||
</AuthLayout>,
|
||||
);
|
||||
return {
|
||||
...renderResult,
|
||||
mockUseLoginUser,
|
||||
mockUseGetUserQuery,
|
||||
mockUseOutletContext,
|
||||
mockUseGetStartupConfig,
|
||||
mockUseRefreshTokenMutation,
|
||||
};
|
||||
};
|
||||
|
||||
jest.mock('react-router-dom', () => ({
|
||||
...jest.requireActual('react-router-dom'),
|
||||
useOutletContext: () => ({
|
||||
startupConfig: mockStartupConfig,
|
||||
}),
|
||||
}));
|
||||
|
||||
test('renders login form', () => {
|
||||
const { getByLabelText, getByRole } = setup();
|
||||
expect(getByLabelText(/email/i)).toBeInTheDocument();
|
||||
@@ -132,6 +161,14 @@ test('Navigates to / on successful login', async () => {
|
||||
isError: false,
|
||||
isSuccess: true,
|
||||
},
|
||||
useGetStartupCongfigReturnValue: {
|
||||
...mockStartupConfig,
|
||||
data: {
|
||||
...mockStartupConfig.data,
|
||||
emailLoginEnabled: true,
|
||||
registrationEnabled: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const emailInput = getByLabelText(/email/i);
|
||||
|
||||
@@ -1,10 +1,32 @@
|
||||
import { render, waitFor, screen } from 'test/layout-test-utils';
|
||||
import reactRouter from 'react-router-dom';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import Registration from '../Registration';
|
||||
import { render, waitFor, screen } from 'test/layout-test-utils';
|
||||
import * as mockDataProvider from 'librechat-data-provider/react-query';
|
||||
import type { TStartupConfig } from 'librechat-data-provider';
|
||||
import Registration from '~/components/Auth/Registration';
|
||||
import AuthLayout from '~/components/Auth/AuthLayout';
|
||||
|
||||
jest.mock('librechat-data-provider/react-query');
|
||||
|
||||
const mockStartupConfig = {
|
||||
isFetching: false,
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
data: {
|
||||
socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'],
|
||||
discordLoginEnabled: true,
|
||||
facebookLoginEnabled: true,
|
||||
githubLoginEnabled: true,
|
||||
googleLoginEnabled: true,
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
registrationEnabled: true,
|
||||
socialLoginEnabled: true,
|
||||
serverDomain: 'mock-server',
|
||||
},
|
||||
};
|
||||
|
||||
const setup = ({
|
||||
useGetUserQueryReturnValue = {
|
||||
isLoading: false,
|
||||
@@ -28,23 +50,7 @@ const setup = ({
|
||||
user: {},
|
||||
},
|
||||
},
|
||||
useGetStartupCongfigReturnValue = {
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
data: {
|
||||
socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'],
|
||||
discordLoginEnabled: true,
|
||||
facebookLoginEnabled: true,
|
||||
githubLoginEnabled: true,
|
||||
googleLoginEnabled: true,
|
||||
openidLoginEnabled: true,
|
||||
openidLabel: 'Test OpenID',
|
||||
openidImageUrl: 'http://test-server.com',
|
||||
registrationEnabled: true,
|
||||
socialLoginEnabled: true,
|
||||
serverDomain: 'mock-server',
|
||||
},
|
||||
},
|
||||
useGetStartupCongfigReturnValue = mockStartupConfig,
|
||||
} = {}) => {
|
||||
const mockUseRegisterUserMutation = jest
|
||||
.spyOn(mockDataProvider, 'useRegisterUserMutation')
|
||||
@@ -62,17 +68,39 @@ const setup = ({
|
||||
.spyOn(mockDataProvider, 'useRefreshTokenMutation')
|
||||
//@ts-ignore - we don't need all parameters of the QueryObserverSuccessResult
|
||||
.mockReturnValue(useRefreshTokenMutationReturnValue);
|
||||
const renderResult = render(<Registration />);
|
||||
const mockUseOutletContext = jest.spyOn(reactRouter, 'useOutletContext').mockReturnValue({
|
||||
startupConfig: useGetStartupCongfigReturnValue.data,
|
||||
});
|
||||
const renderResult = render(
|
||||
<AuthLayout
|
||||
startupConfig={useGetStartupCongfigReturnValue.data as TStartupConfig}
|
||||
isFetching={useGetStartupCongfigReturnValue.isFetching}
|
||||
error={null}
|
||||
startupConfigError={null}
|
||||
header={'Create your account'}
|
||||
pathname="register"
|
||||
>
|
||||
<Registration />
|
||||
</AuthLayout>,
|
||||
);
|
||||
|
||||
return {
|
||||
...renderResult,
|
||||
mockUseRegisterUserMutation,
|
||||
mockUseGetUserQuery,
|
||||
mockUseOutletContext,
|
||||
mockUseGetStartupConfig,
|
||||
mockUseRegisterUserMutation,
|
||||
mockUseRefreshTokenMutation,
|
||||
};
|
||||
};
|
||||
|
||||
jest.mock('react-router-dom', () => ({
|
||||
...jest.requireActual('react-router-dom'),
|
||||
useOutletContext: () => ({
|
||||
startupConfig: mockStartupConfig,
|
||||
}),
|
||||
}));
|
||||
|
||||
test('renders registration form', () => {
|
||||
const { getByText, getByTestId, getByRole } = setup();
|
||||
expect(getByText(/Create your account/i)).toBeInTheDocument();
|
||||
|
||||
@@ -4,9 +4,8 @@ import { getConfigDefaults } from 'librechat-data-provider';
|
||||
import { useGetStartupConfig } from 'librechat-data-provider/react-query';
|
||||
import type { ContextType } from '~/common';
|
||||
import { EndpointsMenu, ModelSpecsMenu, PresetsMenu, HeaderNewChat } from './Menus';
|
||||
import HeaderOptions from './Input/HeaderOptions';
|
||||
import ExportButton from './ExportButton';
|
||||
import ExportAndShareMenu from './ExportAndShareMenu';
|
||||
import HeaderOptions from './Input/HeaderOptions';
|
||||
|
||||
const defaultInterface = getConfigDefaults().interface;
|
||||
|
||||
|
||||
81
client/src/components/Chat/Input/AudioRecorder.tsx
Normal file
81
client/src/components/Chat/Input/AudioRecorder.tsx
Normal file
@@ -0,0 +1,81 @@
|
||||
import { useEffect } from 'react';
|
||||
import type { UseFormReturn } from 'react-hook-form';
|
||||
import { TooltipProvider, Tooltip, TooltipTrigger, TooltipContent } from '~/components/ui/';
|
||||
import { ListeningIcon, Spinner } from '~/components/svg';
|
||||
import { useLocalize, useSpeechToText } from '~/hooks';
|
||||
import { globalAudioId } from '~/common';
|
||||
|
||||
export default function AudioRecorder({
|
||||
textAreaRef,
|
||||
methods,
|
||||
ask,
|
||||
disabled,
|
||||
}: {
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
methods: UseFormReturn<{ text: string }>;
|
||||
ask: (data: { text: string }) => void;
|
||||
disabled: boolean;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const handleTranscriptionComplete = (text: string) => {
|
||||
if (text) {
|
||||
const globalAudio = document.getElementById(globalAudioId) as HTMLAudioElement;
|
||||
if (globalAudio) {
|
||||
console.log('Unmuting global audio');
|
||||
globalAudio.muted = false;
|
||||
}
|
||||
ask({ text });
|
||||
methods.reset({ text: '' });
|
||||
clearText();
|
||||
}
|
||||
};
|
||||
|
||||
const { isListening, isLoading, startRecording, stopRecording, speechText, clearText } =
|
||||
useSpeechToText(handleTranscriptionComplete);
|
||||
|
||||
useEffect(() => {
|
||||
if (textAreaRef.current) {
|
||||
textAreaRef.current.value = speechText;
|
||||
methods.setValue('text', speechText, { shouldValidate: true });
|
||||
}
|
||||
}, [speechText, methods, textAreaRef]);
|
||||
|
||||
const handleStartRecording = async () => {
|
||||
await startRecording();
|
||||
};
|
||||
|
||||
const handleStopRecording = async () => {
|
||||
await stopRecording();
|
||||
};
|
||||
|
||||
const renderIcon = () => {
|
||||
if (isListening) {
|
||||
return <ListeningIcon className="stroke-red-500" />;
|
||||
}
|
||||
if (isLoading) {
|
||||
return <Spinner className="stroke-gray-700 dark:stroke-gray-300" />;
|
||||
}
|
||||
return <ListeningIcon className="stroke-gray-700 dark:stroke-gray-300" />;
|
||||
};
|
||||
|
||||
return (
|
||||
<TooltipProvider delayDuration={250}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
onClick={isListening ? handleStopRecording : handleStartRecording}
|
||||
disabled={disabled}
|
||||
className="absolute bottom-1.5 right-12 flex h-[30px] w-[30px] items-center justify-center rounded-lg p-0.5 transition-colors hover:bg-gray-200 dark:hover:bg-gray-700 md:bottom-3 md:right-12"
|
||||
type="button"
|
||||
>
|
||||
{renderIcon()}
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" sideOffset={10}>
|
||||
{localize('com_ui_use_micrphone')}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { memo, useCallback, useRef, useMemo } from 'react';
|
||||
import {
|
||||
supportsFiles,
|
||||
@@ -11,9 +11,11 @@ import { useChatContext, useAssistantsMapContext } from '~/Providers';
|
||||
import { useRequiresKey, useTextarea } from '~/hooks';
|
||||
import { TextareaAutosize } from '~/components/ui';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import { cn, removeFocusOutlines } from '~/utils';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import AttachFile from './Files/AttachFile';
|
||||
import AudioRecorder from './AudioRecorder';
|
||||
import { mainTextareaId } from '~/common';
|
||||
import StreamAudio from './StreamAudio';
|
||||
import StopButton from './StopButton';
|
||||
import SendButton from './SendButton';
|
||||
import FileRow from './Files/FileRow';
|
||||
@@ -23,6 +25,9 @@ import store from '~/store';
|
||||
const ChatForm = ({ index = 0 }) => {
|
||||
const submitButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const SpeechToText = useRecoilValue(store.SpeechToText);
|
||||
const TextToSpeech = useRecoilValue(store.TextToSpeech);
|
||||
const automaticPlayback = useRecoilValue(store.automaticPlayback);
|
||||
const [showStopButton, setShowStopButton] = useRecoilState(store.showStopButtonByIndex(index));
|
||||
const [showMentionPopover, setShowMentionPopover] = useRecoilState(
|
||||
store.showMentionPopoverFamily(index),
|
||||
@@ -87,7 +92,7 @@ const ChatForm = ({ index = 0 }) => {
|
||||
const { ref, ...registerProps } = methods.register('text', {
|
||||
required: true,
|
||||
onChange: (e) => {
|
||||
methods.setValue('text', e.target.value);
|
||||
methods.setValue('text', e.target.value, { shouldValidate: true });
|
||||
},
|
||||
});
|
||||
|
||||
@@ -135,9 +140,10 @@ const ChatForm = ({ index = 0 }) => {
|
||||
supportsFiles[endpointType ?? endpoint ?? ''] && !endpointFileConfig?.disabled
|
||||
? ' pl-10 md:pl-[55px]'
|
||||
: 'pl-3 md:pl-4',
|
||||
'm-0 w-full resize-none border-0 bg-transparent py-[10px] pr-10 placeholder-black/50 focus:ring-0 focus-visible:ring-0 dark:bg-transparent dark:placeholder-white/50 md:py-3.5 md:pr-12 ',
|
||||
removeFocusOutlines,
|
||||
'm-0 w-full resize-none border-0 bg-transparent py-[10px] placeholder-black/50 focus:ring-0 focus-visible:ring-0 dark:bg-transparent dark:placeholder-white/50 md:py-3.5 ',
|
||||
SpeechToText ? 'pr-20 md:pr-[85px]' : 'pr-10 md:pr-12',
|
||||
'max-h-[65vh] md:max-h-[75vh]',
|
||||
removeFocusRings,
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
@@ -157,6 +163,15 @@ const ChatForm = ({ index = 0 }) => {
|
||||
/>
|
||||
)
|
||||
)}
|
||||
{SpeechToText && (
|
||||
<AudioRecorder
|
||||
disabled={!!disableInputs}
|
||||
textAreaRef={textAreaRef}
|
||||
ask={submitMessage}
|
||||
methods={methods}
|
||||
/>
|
||||
)}
|
||||
{TextToSpeech && automaticPlayback && <StreamAudio index={index} />}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -225,7 +225,7 @@ export default function DataTable<TData, TValue>({ columns, data }: DataTablePro
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
className="dark:border-gray-500 dark:hover:bg-gray-600 select-none"
|
||||
className="select-none dark:border-gray-500 dark:hover:bg-gray-600"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => table.previousPage()}
|
||||
@@ -234,7 +234,7 @@ export default function DataTable<TData, TValue>({ columns, data }: DataTablePro
|
||||
{localize('com_ui_prev')}
|
||||
</Button>
|
||||
<Button
|
||||
className="dark:border-gray-500 dark:hover:bg-gray-600 select-none"
|
||||
className="select-none dark:border-gray-500 dark:hover:bg-gray-600"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => table.nextPage()}
|
||||
|
||||
231
client/src/components/Chat/Input/StreamAudio.tsx
Normal file
231
client/src/components/Chat/Input/StreamAudio.tsx
Normal file
@@ -0,0 +1,231 @@
|
||||
import { useParams } from 'react-router-dom';
|
||||
import { useEffect, useCallback } from 'react';
|
||||
import { QueryKeys } from 'librechat-data-provider';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useCustomAudioRef, MediaSourceAppender, usePauseGlobalAudio } from '~/hooks/Audio';
|
||||
import { useAuthContext } from '~/hooks';
|
||||
import { globalAudioId } from '~/common';
|
||||
import { getLatestText } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
function timeoutPromise(ms: number, message?: string) {
|
||||
return new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error(message ?? 'Promise timed out')), ms),
|
||||
);
|
||||
}
|
||||
|
||||
const promiseTimeoutMessage = 'Reader promise timed out';
|
||||
const maxPromiseTime = 15000;
|
||||
|
||||
export default function StreamAudio({ index = 0 }) {
|
||||
const { token } = useAuthContext();
|
||||
|
||||
const cacheTTS = useRecoilValue(store.cacheTTS);
|
||||
const playbackRate = useRecoilValue(store.playbackRate);
|
||||
|
||||
const voice = useRecoilValue(store.voice);
|
||||
const activeRunId = useRecoilValue(store.activeRunFamily(index));
|
||||
const automaticPlayback = useRecoilValue(store.automaticPlayback);
|
||||
const isSubmitting = useRecoilValue(store.isSubmittingFamily(index));
|
||||
const latestMessage = useRecoilValue(store.latestMessageFamily(index));
|
||||
const setIsPlaying = useSetRecoilState(store.globalAudioPlayingFamily(index));
|
||||
const [audioRunId, setAudioRunId] = useRecoilState(store.audioRunFamily(index));
|
||||
const [isFetching, setIsFetching] = useRecoilState(store.globalAudioFetchingFamily(index));
|
||||
const [globalAudioURL, setGlobalAudioURL] = useRecoilState(store.globalAudioURLFamily(index));
|
||||
|
||||
const { audioRef } = useCustomAudioRef({ setIsPlaying });
|
||||
const { pauseGlobalAudio } = usePauseGlobalAudio();
|
||||
|
||||
const { conversationId: paramId } = useParams();
|
||||
const queryParam = paramId === 'new' ? paramId : latestMessage?.conversationId ?? paramId ?? '';
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const getMessages = useCallback(
|
||||
() => queryClient.getQueryData<TMessage[]>([QueryKeys.messages, queryParam]),
|
||||
[queryParam, queryClient],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const latestText = getLatestText(latestMessage);
|
||||
|
||||
const shouldFetch = !!(
|
||||
token &&
|
||||
automaticPlayback &&
|
||||
isSubmitting &&
|
||||
latestMessage &&
|
||||
!latestMessage.isCreatedByUser &&
|
||||
latestText &&
|
||||
latestMessage.messageId &&
|
||||
!latestMessage.messageId.includes('_') &&
|
||||
!isFetching &&
|
||||
activeRunId &&
|
||||
activeRunId !== audioRunId
|
||||
);
|
||||
|
||||
if (!shouldFetch) {
|
||||
return;
|
||||
}
|
||||
|
||||
async function fetchAudio() {
|
||||
setIsFetching(true);
|
||||
|
||||
try {
|
||||
if (audioRef.current) {
|
||||
audioRef.current.pause();
|
||||
URL.revokeObjectURL(audioRef.current.src);
|
||||
setGlobalAudioURL(null);
|
||||
}
|
||||
|
||||
let cacheKey = latestMessage?.text ?? '';
|
||||
const cache = await caches.open('tts-responses');
|
||||
const cachedResponse = await cache.match(cacheKey);
|
||||
|
||||
setAudioRunId(activeRunId);
|
||||
if (cachedResponse) {
|
||||
console.log('Audio found in cache');
|
||||
const audioBlob = await cachedResponse.blob();
|
||||
const blobUrl = URL.createObjectURL(audioBlob);
|
||||
setGlobalAudioURL(blobUrl);
|
||||
setIsFetching(false);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('Fetching audio...', navigator.userAgent);
|
||||
const response = await fetch('/api/files/tts', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` },
|
||||
body: JSON.stringify({ messageId: latestMessage?.messageId, runId: activeRunId, voice }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch audio');
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error('Null Response body');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
|
||||
const type = 'audio/mpeg';
|
||||
const browserSupportsType = MediaSource.isTypeSupported(type);
|
||||
let mediaSource: MediaSourceAppender | undefined;
|
||||
if (browserSupportsType) {
|
||||
mediaSource = new MediaSourceAppender(type);
|
||||
setGlobalAudioURL(mediaSource.mediaSourceUrl);
|
||||
}
|
||||
|
||||
let done = false;
|
||||
const chunks: Uint8Array[] = [];
|
||||
|
||||
while (!done) {
|
||||
const readPromise = reader.read();
|
||||
const { value, done: readerDone } = (await Promise.race([
|
||||
readPromise,
|
||||
timeoutPromise(maxPromiseTime, promiseTimeoutMessage),
|
||||
])) as ReadableStreamReadResult<Uint8Array>;
|
||||
|
||||
if (cacheTTS && value) {
|
||||
chunks.push(value);
|
||||
}
|
||||
if (value && mediaSource) {
|
||||
mediaSource.addData(value);
|
||||
}
|
||||
done = readerDone;
|
||||
}
|
||||
|
||||
if (chunks.length) {
|
||||
console.log('Adding audio to cache');
|
||||
const latestMessages = getMessages() ?? [];
|
||||
const targetMessage = latestMessages.find(
|
||||
(msg) => msg.messageId === latestMessage?.messageId,
|
||||
);
|
||||
cacheKey = targetMessage?.text ?? '';
|
||||
if (!cacheKey) {
|
||||
throw new Error('Cache key not found');
|
||||
}
|
||||
const audioBlob = new Blob(chunks, { type });
|
||||
const cachedResponse = new Response(audioBlob);
|
||||
await cache.put(cacheKey, cachedResponse);
|
||||
if (!browserSupportsType) {
|
||||
const unconsumedResponse = await cache.match(cacheKey);
|
||||
if (!unconsumedResponse) {
|
||||
throw new Error('Failed to fetch audio from cache');
|
||||
}
|
||||
const audioBlob = await unconsumedResponse.blob();
|
||||
const blobUrl = URL.createObjectURL(audioBlob);
|
||||
setGlobalAudioURL(blobUrl);
|
||||
}
|
||||
setIsFetching(false);
|
||||
}
|
||||
|
||||
console.log('Audio stream reading ended');
|
||||
} catch (error) {
|
||||
if (error?.['message'] !== promiseTimeoutMessage) {
|
||||
console.log(promiseTimeoutMessage);
|
||||
return;
|
||||
}
|
||||
console.error('Error fetching audio:', error);
|
||||
setIsFetching(false);
|
||||
setGlobalAudioURL(null);
|
||||
} finally {
|
||||
setIsFetching(false);
|
||||
}
|
||||
}
|
||||
|
||||
fetchAudio();
|
||||
}, [
|
||||
automaticPlayback,
|
||||
setGlobalAudioURL,
|
||||
setAudioRunId,
|
||||
setIsFetching,
|
||||
latestMessage,
|
||||
isSubmitting,
|
||||
activeRunId,
|
||||
getMessages,
|
||||
isFetching,
|
||||
audioRunId,
|
||||
cacheTTS,
|
||||
audioRef,
|
||||
voice,
|
||||
token,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
playbackRate &&
|
||||
globalAudioURL &&
|
||||
playbackRate > 0 &&
|
||||
audioRef.current &&
|
||||
audioRef.current.playbackRate !== playbackRate
|
||||
) {
|
||||
audioRef.current.playbackRate = playbackRate;
|
||||
}
|
||||
}, [audioRef, globalAudioURL, playbackRate]);
|
||||
|
||||
useEffect(() => {
|
||||
pauseGlobalAudio();
|
||||
// We only want the effect to run when the paramId changes
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [paramId]);
|
||||
|
||||
return (
|
||||
<audio
|
||||
ref={audioRef}
|
||||
controls
|
||||
controlsList="nodownload nofullscreen noremoteplayback"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
overflow: 'hidden',
|
||||
display: 'none',
|
||||
height: '0px',
|
||||
width: '0px',
|
||||
}}
|
||||
src={globalAudioURL || undefined}
|
||||
id={globalAudioId}
|
||||
muted
|
||||
autoPlay
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -87,10 +87,19 @@ const MenuItem: FC<MenuItemProps> = ({
|
||||
<>
|
||||
<div
|
||||
role="menuitem"
|
||||
className="group m-1.5 flex max-h-[40px] cursor-pointer gap-2 rounded px-5 py-2.5 !pr-3 text-sm !opacity-100 hover:bg-black/5 focus:ring-0 radix-disabled:pointer-events-none radix-disabled:opacity-50 dark:hover:bg-gray-600"
|
||||
tabIndex={-1}
|
||||
className={cn(
|
||||
'group m-1.5 flex max-h-[40px] cursor-pointer gap-2 rounded px-5 py-2.5 !pr-3 text-sm !opacity-100 hover:bg-black/5 radix-disabled:pointer-events-none radix-disabled:opacity-50 dark:hover:bg-gray-600',
|
||||
'focus:outline-none focus:ring-2 focus:ring-gray-400 focus:ring-offset-2 dark:focus:ring-gray-400 dark:focus:ring-offset-gray-900',
|
||||
)}
|
||||
tabIndex={1}
|
||||
{...rest}
|
||||
onClick={() => onSelectEndpoint(endpoint)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
onSelectEndpoint(endpoint);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex grow items-center justify-between gap-2">
|
||||
<div>
|
||||
@@ -120,6 +129,7 @@ const MenuItem: FC<MenuItemProps> = ({
|
||||
expiryTime
|
||||
? 'w-full rounded-lg p-2 hover:text-gray-400 dark:hover:text-gray-400'
|
||||
: '',
|
||||
'focus:outline-none focus:ring-2 focus:ring-gray-400 focus:ring-offset-2 dark:focus:ring-gray-400 dark:focus:ring-offset-gray-900',
|
||||
)}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
|
||||
@@ -35,13 +35,14 @@ any) => {
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{!isSubmitting && unfinished && (
|
||||
{/* Temporarily remove this */}
|
||||
{/* {!isSubmitting && unfinished && (
|
||||
<Suspense>
|
||||
<DelayedRender delay={250}>
|
||||
<UnfinishedMessage message={message} key={`unfinished-${messageId}`} />
|
||||
</DelayedRender>
|
||||
</Suspense>
|
||||
)}
|
||||
)} */}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { useState } from 'react';
|
||||
import React, { useState } from 'react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import type { TConversation, TMessage } from 'librechat-data-provider';
|
||||
import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg';
|
||||
import { EditIcon, Clipboard, CheckMark, ContinueIcon, RegenerateIcon } from '~/components/svg';
|
||||
import { useGenerationsByLatest, useLocalize } from '~/hooks';
|
||||
import { Fork } from '~/components/Conversations';
|
||||
import MessageAudio from './MessageAudio';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
type THoverButtons = {
|
||||
isEditing: boolean;
|
||||
@@ -16,9 +19,11 @@ type THoverButtons = {
|
||||
handleContinue: (e: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
latestMessage: TMessage | null;
|
||||
isLast: boolean;
|
||||
index: number;
|
||||
};
|
||||
|
||||
export default function HoverButtons({
|
||||
index,
|
||||
isEditing,
|
||||
enterEdit,
|
||||
copyToClipboard,
|
||||
@@ -34,6 +39,8 @@ export default function HoverButtons({
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? {};
|
||||
const endpoint = endpointType ?? _endpoint;
|
||||
const [isCopied, setIsCopied] = useState(false);
|
||||
const [TextToSpeech] = useRecoilState<boolean>(store.TextToSpeech);
|
||||
|
||||
const {
|
||||
hideEditButton,
|
||||
regenerateEnabled,
|
||||
@@ -62,13 +69,14 @@ export default function HoverButtons({
|
||||
|
||||
return (
|
||||
<div className="visible mt-0 flex justify-center gap-1 self-end text-gray-400 lg:justify-start">
|
||||
{TextToSpeech && <MessageAudio index={index} message={message} isLast={isLast} />}
|
||||
{isEditableEndpoint && (
|
||||
<button
|
||||
className={cn(
|
||||
'hover-button rounded-md p-1 text-gray-400 hover:text-gray-900 dark:text-gray-400/70 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:group-hover:visible md:group-[.final-completion]:visible',
|
||||
isCreatedByUser ? '' : 'active',
|
||||
hideEditButton ? 'opacity-0' : '',
|
||||
isEditing ? 'active bg-gray-200 text-gray-700 dark:bg-gray-700 dark:text-gray-200' : '',
|
||||
isEditing ? 'active text-gray-700 dark:text-gray-200' : '',
|
||||
!isLast ? 'md:opacity-0 md:group-hover:opacity-100' : '',
|
||||
)}
|
||||
onClick={onEdit}
|
||||
@@ -76,7 +84,7 @@ export default function HoverButtons({
|
||||
title={localize('com_ui_edit')}
|
||||
disabled={hideEditButton}
|
||||
>
|
||||
<EditIcon />
|
||||
<EditIcon size="19" />
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
@@ -91,7 +99,7 @@ export default function HoverButtons({
|
||||
isCopied ? localize('com_ui_copied_to_clipboard') : localize('com_ui_copy_to_clipboard')
|
||||
}
|
||||
>
|
||||
{isCopied ? <CheckMark className="h-[18px] w-[18px]" /> : <Clipboard />}
|
||||
{isCopied ? <CheckMark className="h-[18px] w-[18px]" /> : <Clipboard size="19" />}
|
||||
</button>
|
||||
{regenerateEnabled ? (
|
||||
<button
|
||||
@@ -103,7 +111,10 @@ export default function HoverButtons({
|
||||
type="button"
|
||||
title={localize('com_ui_regenerate')}
|
||||
>
|
||||
<RegenerateIcon className="hover:text-gray-700 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400" />
|
||||
<RegenerateIcon
|
||||
className="hover:text-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400"
|
||||
size="19"
|
||||
/>
|
||||
</button>
|
||||
) : null}
|
||||
<Fork
|
||||
@@ -116,14 +127,14 @@ export default function HoverButtons({
|
||||
{continueSupported ? (
|
||||
<button
|
||||
className={cn(
|
||||
'hover-button active rounded-md p-1 hover:bg-gray-200 hover:text-gray-700 dark:text-gray-400/70 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible ',
|
||||
'hover-button active rounded-md p-1 hover:bg-gray-200 hover:text-gray-700 dark:text-gray-400/70 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:invisible md:group-hover:visible ',
|
||||
!isLast ? 'md:opacity-0 md:group-hover:opacity-100' : '',
|
||||
)}
|
||||
onClick={handleContinue}
|
||||
type="button"
|
||||
title={localize('com_ui_continue')}
|
||||
>
|
||||
<ContinueIcon className="h-4 w-4 hover:text-gray-700 dark:hover:bg-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400" />
|
||||
<ContinueIcon className="h-4 w-4 hover:text-gray-700 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400" />
|
||||
</button>
|
||||
) : null}
|
||||
</div>
|
||||
|
||||
@@ -20,6 +20,7 @@ export default function Message(props: TMessageProps) {
|
||||
const {
|
||||
ask,
|
||||
edit,
|
||||
index,
|
||||
isLast,
|
||||
enterEdit,
|
||||
handleScroll,
|
||||
@@ -102,6 +103,7 @@ export default function Message(props: TMessageProps) {
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
/>
|
||||
<HoverButtons
|
||||
index={index}
|
||||
isEditing={edit}
|
||||
message={message}
|
||||
enterEdit={enterEdit}
|
||||
|
||||
94
client/src/components/Chat/Messages/MessageAudio.tsx
Normal file
94
client/src/components/Chat/Messages/MessageAudio.tsx
Normal file
@@ -0,0 +1,94 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { VolumeIcon, VolumeMuteIcon, Spinner } from '~/components/svg';
|
||||
import { useLocalize, useTextToSpeech } from '~/hooks';
|
||||
import store from '~/store';
|
||||
|
||||
type THoverButtons = {
|
||||
message: TMessage;
|
||||
isLast: boolean;
|
||||
index: number;
|
||||
};
|
||||
|
||||
export default function MessageAudio({ index, message, isLast }: THoverButtons) {
|
||||
const localize = useLocalize();
|
||||
const playbackRate = useRecoilValue(store.playbackRate);
|
||||
|
||||
const { toggleSpeech, isSpeaking, isLoading, audioRef } = useTextToSpeech(message, isLast, index);
|
||||
|
||||
const renderIcon = (size: string) => {
|
||||
if (isLoading) {
|
||||
return <Spinner size={size} />;
|
||||
}
|
||||
|
||||
if (isSpeaking) {
|
||||
return <VolumeMuteIcon size={size} />;
|
||||
}
|
||||
|
||||
return <VolumeIcon size={size} />;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const messageAudio = document.getElementById(
|
||||
`audio-${message.messageId}`,
|
||||
) as HTMLAudioElement | null;
|
||||
if (!messageAudio) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
playbackRate &&
|
||||
playbackRate > 0 &&
|
||||
messageAudio &&
|
||||
messageAudio.playbackRate !== playbackRate
|
||||
) {
|
||||
messageAudio.playbackRate = playbackRate;
|
||||
}
|
||||
}, [audioRef, isSpeaking, playbackRate, message.messageId]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<button
|
||||
className="hover-button rounded-md p-1 pl-0 text-gray-400 hover:text-gray-950 dark:text-gray-400/70 dark:hover:text-gray-200 disabled:dark:hover:text-gray-400 md:group-hover:visible md:group-[.final-completion]:visible"
|
||||
// onMouseDownCapture={() => {
|
||||
// if (audioRef.current) {
|
||||
// audioRef.current.muted = false;
|
||||
// }
|
||||
// handleMouseDown();
|
||||
// }}
|
||||
// onMouseUpCapture={() => {
|
||||
// if (audioRef.current) {
|
||||
// audioRef.current.muted = false;
|
||||
// }
|
||||
// handleMouseUp();
|
||||
// }}
|
||||
onClickCapture={() => {
|
||||
if (audioRef.current) {
|
||||
audioRef.current.muted = false;
|
||||
}
|
||||
toggleSpeech();
|
||||
}}
|
||||
type="button"
|
||||
title={isSpeaking ? localize('com_ui_stop') : localize('com_ui_read_aloud')}
|
||||
>
|
||||
{renderIcon('19')}
|
||||
</button>
|
||||
<audio
|
||||
ref={audioRef}
|
||||
controls
|
||||
controlsList="nodownload nofullscreen noremoteplayback"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
overflow: 'hidden',
|
||||
display: 'none',
|
||||
height: '0px',
|
||||
width: '0px',
|
||||
}}
|
||||
src={audioRef.current?.src || undefined}
|
||||
id={`audio-${message.messageId}`}
|
||||
muted
|
||||
autoPlay
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -16,6 +16,7 @@ export default function Message(props: TMessageProps) {
|
||||
const {
|
||||
ask,
|
||||
edit,
|
||||
index,
|
||||
isLast,
|
||||
enterEdit,
|
||||
assistant,
|
||||
@@ -90,6 +91,7 @@ export default function Message(props: TMessageProps) {
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
/>
|
||||
<HoverButtons
|
||||
index={index}
|
||||
isEditing={edit}
|
||||
message={message}
|
||||
enterEdit={enterEdit}
|
||||
|
||||
@@ -2,15 +2,15 @@ import type { TModelSelectProps } from '~/common';
|
||||
import { ESide } from '~/common';
|
||||
import {
|
||||
Switch,
|
||||
SelectDropDown,
|
||||
Label,
|
||||
Slider,
|
||||
InputNumber,
|
||||
HoverCard,
|
||||
InputNumber,
|
||||
SelectDropDown,
|
||||
HoverCardTrigger,
|
||||
} from '~/components';
|
||||
import OptionHover from './OptionHover';
|
||||
import { cn, optionText, defaultTextProps, removeFocusOutlines } from '~/utils/';
|
||||
import { cn, optionText, defaultTextProps, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
export default function Settings({ conversation, setOption, models, readonly }: TModelSelectProps) {
|
||||
@@ -42,7 +42,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
setValue={setModel}
|
||||
availableValues={models}
|
||||
disabled={readonly}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusOutlines)}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusRings)}
|
||||
containerClassName="flex w-full resize-none"
|
||||
/>
|
||||
</div>
|
||||
@@ -88,7 +88,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
</HoverCard>
|
||||
<div className="grid w-full grid-cols-2 items-center gap-10">
|
||||
<HoverCard openDelay={500}>
|
||||
<HoverCardTrigger className="w-[100px] flex flex-col items-center text-center space-y-4">
|
||||
<HoverCardTrigger className="flex w-[100px] flex-col items-center space-y-4 text-center">
|
||||
<label
|
||||
htmlFor="functions-agent"
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70 dark:text-gray-50"
|
||||
@@ -106,7 +106,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
<OptionHover endpoint={conversation.endpoint ?? ''} type="func" side={ESide.Bottom} />
|
||||
</HoverCard>
|
||||
<HoverCard openDelay={500}>
|
||||
<HoverCardTrigger className="ml-[-60px] w-[100px] flex flex-col items-center text-center space-y-4">
|
||||
<HoverCardTrigger className="ml-[-60px] flex w-[100px] flex-col items-center space-y-4 text-center">
|
||||
<label
|
||||
htmlFor="skip-completion"
|
||||
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70 dark:text-gray-50"
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
SelectDropDown,
|
||||
HoverCardTrigger,
|
||||
} from '~/components/ui';
|
||||
import { cn, defaultTextProps, optionText, removeFocusOutlines } from '~/utils';
|
||||
import { cn, defaultTextProps, optionText, removeFocusOutlines, removeFocusRings } from '~/utils';
|
||||
import OptionHoverAlt from '~/components/SidePanel/Parameters/OptionHover';
|
||||
import { useLocalize, useDebouncedInput } from '~/hooks';
|
||||
import OptionHover from './OptionHover';
|
||||
@@ -59,7 +59,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
setValue={setModel}
|
||||
availableValues={models}
|
||||
disabled={readonly}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusOutlines)}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusRings)}
|
||||
containerClassName="flex w-full resize-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -3,7 +3,7 @@ import TextareaAutosize from 'react-textarea-autosize';
|
||||
import type { Assistant, TPreset } from 'librechat-data-provider';
|
||||
import type { TModelSelectProps, Option } from '~/common';
|
||||
import { Label, HoverCard, SelectDropDown, HoverCardTrigger } from '~/components/ui';
|
||||
import { cn, defaultTextProps, removeFocusOutlines, mapAssistants } from '~/utils';
|
||||
import { cn, defaultTextProps, removeFocusRings, mapAssistants } from '~/utils';
|
||||
import { useLocalize, useDebouncedInput, useAssistantListMap } from '~/hooks';
|
||||
import OptionHover from './OptionHover';
|
||||
import { ESide } from '~/common';
|
||||
@@ -116,7 +116,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
setValue={setModel}
|
||||
availableValues={modelOptions}
|
||||
disabled={readonly}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusOutlines)}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusRings)}
|
||||
containerClassName="flex w-full resize-none"
|
||||
/>
|
||||
</div>
|
||||
@@ -131,7 +131,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
setValue={setAssistant}
|
||||
availableValues={assistants as Option[]}
|
||||
disabled={readonly}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusOutlines)}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusRings)}
|
||||
containerClassName="flex w-full resize-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -2,8 +2,8 @@ import { useEffect, useState } from 'react';
|
||||
import TextareaAutosize from 'react-textarea-autosize';
|
||||
import { useUpdateTokenCountMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TUpdateTokenCountResponse } from 'librechat-data-provider';
|
||||
import { cn, defaultTextProps, removeFocusOutlines } from '~/utils/';
|
||||
import { Label, Checkbox, SelectDropDown } from '~/components/ui';
|
||||
import { cn, defaultTextProps, removeFocusRings } from '~/utils';
|
||||
import { useLocalize, useDebounce } from '~/hooks';
|
||||
import type { TSettingsProps } from '~/common';
|
||||
|
||||
@@ -60,7 +60,7 @@ export default function Settings({ conversation, setOption, readonly }: TSetting
|
||||
setValue={setToneStyle}
|
||||
availableValues={['Creative', 'Fast', 'Balanced', 'Precise']}
|
||||
disabled={readonly}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusOutlines)}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusRings)}
|
||||
containerClassName="flex w-full resize-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useEffect } from 'react';
|
||||
import TextareaAutosize from 'react-textarea-autosize';
|
||||
import { EModelEndpoint, endpointSettings } from 'librechat-data-provider';
|
||||
import type { TModelSelectProps, OnInputNumberChange } from '~/common';
|
||||
@@ -11,7 +10,7 @@ import {
|
||||
SelectDropDown,
|
||||
HoverCardTrigger,
|
||||
} from '~/components/ui';
|
||||
import { cn, defaultTextProps, optionText, removeFocusOutlines } from '~/utils';
|
||||
import { cn, defaultTextProps, optionText, removeFocusOutlines, removeFocusRings } from '~/utils';
|
||||
import OptionHoverAlt from '~/components/SidePanel/Parameters/OptionHover';
|
||||
import { useLocalize, useDebouncedInput } from '~/hooks';
|
||||
import OptionHover from './OptionHover';
|
||||
@@ -31,25 +30,6 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
maxOutputTokens,
|
||||
} = conversation ?? {};
|
||||
|
||||
const isGemini = model?.toLowerCase()?.includes('gemini');
|
||||
|
||||
const maxOutputTokensMax = isGemini
|
||||
? google.maxOutputTokens.maxGemini
|
||||
: google.maxOutputTokens.max;
|
||||
const maxOutputTokensDefault = isGemini
|
||||
? google.maxOutputTokens.defaultGemini
|
||||
: google.maxOutputTokens.default;
|
||||
|
||||
useEffect(
|
||||
() => {
|
||||
if (model) {
|
||||
setOption('maxOutputTokens')(Math.min(Number(maxOutputTokens) ?? 0, maxOutputTokensMax));
|
||||
}
|
||||
},
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[model],
|
||||
);
|
||||
|
||||
const [setMaxContextTokens, maxContextTokensValue] = useDebouncedInput<number | null | undefined>(
|
||||
{
|
||||
setOption,
|
||||
@@ -79,7 +59,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
setValue={setModel}
|
||||
availableValues={models}
|
||||
disabled={readonly}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusOutlines)}
|
||||
className={cn(defaultTextProps, 'flex w-full resize-none', removeFocusRings)}
|
||||
containerClassName="flex w-full resize-none"
|
||||
/>
|
||||
</div>
|
||||
@@ -281,15 +261,15 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
<Label htmlFor="max-tokens-int" className="text-left text-sm font-medium">
|
||||
{localize('com_endpoint_max_output_tokens')}{' '}
|
||||
<small className="opacity-40">
|
||||
({localize('com_endpoint_default_with_num', maxOutputTokensDefault + '')})
|
||||
({localize('com_endpoint_default_with_num', google.maxOutputTokens.default + '')})
|
||||
</small>
|
||||
</Label>
|
||||
<InputNumber
|
||||
id="max-tokens-int"
|
||||
disabled={readonly}
|
||||
value={maxOutputTokens}
|
||||
onChange={(value) => setMaxOutputTokens(value ?? maxOutputTokensDefault)}
|
||||
max={maxOutputTokensMax}
|
||||
onChange={(value) => setMaxOutputTokens(Number(value))}
|
||||
max={google.maxOutputTokens.max}
|
||||
min={google.maxOutputTokens.min}
|
||||
step={google.maxOutputTokens.step}
|
||||
controls={false}
|
||||
@@ -304,10 +284,10 @@ export default function Settings({ conversation, setOption, models, readonly }:
|
||||
</div>
|
||||
<Slider
|
||||
disabled={readonly}
|
||||
value={[maxOutputTokens ?? maxOutputTokensDefault]}
|
||||
value={[maxOutputTokens ?? google.maxOutputTokens.default]}
|
||||
onValueChange={(value) => setMaxOutputTokens(value[0])}
|
||||
doubleClickHandler={() => setMaxOutputTokens(maxOutputTokensDefault)}
|
||||
max={maxOutputTokensMax}
|
||||
doubleClickHandler={() => setMaxOutputTokens(google.maxOutputTokens.default)}
|
||||
max={google.maxOutputTokens.max}
|
||||
min={google.maxOutputTokens.min}
|
||||
step={google.maxOutputTokens.step}
|
||||
className="flex h-4 w-full"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user