Compare commits

..

6 Commits

Author SHA1 Message Date
Ruben Talstra
e34503edce 🔧 feat: Localize redirect message to OpenID provider in Login component 2025-03-10 12:55:46 +01:00
Ruben Talstra
14bd9f03fa 🔧 refactor: Update getLoginError to use TranslationKeys for improved type safety 2025-03-10 12:48:42 +01:00
Ruben Talstra
17b0f35f93 🔧 feat: Implement custom logout redirect handling and enhance OpenID auto-redirect logic 2025-03-10 11:52:36 +01:00
Ruben Talstra
a2f953460b Merge branch 'main' into feat/oidc-auto-redirect 2025-03-10 09:35:49 +01:00
Danilo Pejakovic
bfc7179f16 Added Cooldown logic for OIDC auto redirect for failed login attempts 2025-02-27 10:58:52 +01:00
Danilo Pejakovic
caaadf2fdb added feature for oidc auto redirection 2025-02-26 15:39:55 +01:00
367 changed files with 7733 additions and 21400 deletions

View File

@@ -142,7 +142,7 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro-exp-03-25,gemini-2.0-flash-exp,gemini-2.0-flash-thinking-exp-1219,gemini-exp-1121,gemini-exp-1114,gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
# GOOGLE_MODELS=gemini-2.0-flash-exp,gemini-2.0-flash-thinking-exp-1219,gemini-exp-1121,gemini-exp-1114,gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision
# Vertex AI
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro
@@ -364,7 +364,7 @@ ILLEGAL_MODEL_REQ_SCORE=5
# Balance #
#========================#
# CHECK_BALANCE=false
CHECK_BALANCE=false
# START_BALANCE=20000 # note: the number of tokens that will be credited after registration.
#========================#
@@ -441,10 +441,9 @@ LDAP_URL=
LDAP_BIND_DN=
LDAP_BIND_CREDENTIALS=
LDAP_USER_SEARCH_BASE=
#LDAP_SEARCH_FILTER="mail="
LDAP_SEARCH_FILTER=mail={{username}}
LDAP_CA_CERT_PATH=
# LDAP_TLS_REJECT_UNAUTHORIZED=
# LDAP_STARTTLS=
# LDAP_LOGIN_USES_USERNAME=true
# LDAP_ID=
# LDAP_USERNAME=
@@ -477,24 +476,6 @@ FIREBASE_STORAGE_BUCKET=
FIREBASE_MESSAGING_SENDER_ID=
FIREBASE_APP_ID=
#========================#
# S3 AWS Bucket #
#========================#
AWS_ENDPOINT_URL=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_REGION=
AWS_BUCKET_NAME=
#========================#
# Azure Blob Storage #
#========================#
AZURE_STORAGE_CONNECTION_STRING=
AZURE_STORAGE_PUBLIC_ACCESS=false
AZURE_CONTAINER_NAME=files
#========================#
# Shared Links #
#========================#

4
.gitignore vendored
View File

@@ -37,10 +37,6 @@ client/public/main.js
client/public/main.js.map
client/public/main.js.LICENSE.txt
# Azure Blob Storage Emulator (Azurite)
__azurite**
__blobstorage__/**/*
# Dependency directorys
# Deployed apps should consider commenting these lines out:
# see https://npmjs.org/doc/faq.html#Should-I-check-my-node_modules-folder-into-git

View File

@@ -2,7 +2,6 @@ const Anthropic = require('@anthropic-ai/sdk');
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
Constants,
ErrorTypes,
EModelEndpoint,
anthropicSettings,
getResponseSender,
@@ -148,17 +147,12 @@ class AnthropicClient extends BaseClient {
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
const reservedTokens = this.maxPromptTokens + this.maxResponseTokens;
if (reservedTokens > this.maxContextTokens) {
const info = `Total Possible Tokens + Max Output Tokens must be less than or equal to Max Context Tokens: ${this.maxPromptTokens} (total possible output) + ${this.maxResponseTokens} (max output) = ${reservedTokens}/${this.maxContextTokens} (max context)`;
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
logger.warn(info);
throw new Error(errorMessage);
} else if (this.maxResponseTokens === this.maxContextTokens) {
const info = `Max Output Tokens must be less than Max Context Tokens: ${this.maxResponseTokens} (max output) = ${this.maxContextTokens} (max context)`;
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
logger.warn(info);
throw new Error(errorMessage);
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(
`maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
}
this.sender =
@@ -695,16 +689,6 @@ class AnthropicClient extends BaseClient {
return (msg) => {
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
} else if (msg.content != null) {
/** @type {import('@librechat/agents').MessageContentComplex} */
const newContent = [];
for (let part of msg.content) {
if (part.think != null) {
continue;
}
newContent.push(part);
}
msg.content = newContent;
}
return msg;

View File

@@ -5,15 +5,14 @@ const {
isAgentsEndpoint,
isParamEndpoint,
EModelEndpoint,
ContentTypes,
excludedKeys,
ErrorTypes,
Constants,
} = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models');
const { checkBalance } = require('~/models/balanceMethods');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const { truncateToolCallOutputs } = require('./prompts');
const { addSpaceIfNeeded } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const TextStream = require('./TextStream');
const { logger } = require('~/config');
@@ -366,14 +365,17 @@ class BaseClient {
* context: TMessage[],
* remainingContextTokens: number,
* messagesToRefine: TMessage[],
* }>} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`.
* summaryIndex: number,
* }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
* `context` is an array of messages that fit within the token limit.
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
*/
async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) {
// Every reply is primed with <|start|>assistant<|message|>, so we
// start with 3 tokens for the label after all messages have been counted.
let summaryIndex = -1;
let currentTokenCount = 3;
const instructionsTokenCount = instructions?.tokenCount ?? 0;
let remainingContextTokens =
@@ -406,12 +408,14 @@ class BaseClient {
}
const prunedMemory = messages;
summaryIndex = prunedMemory.length - 1;
remainingContextTokens -= currentTokenCount;
return {
context: context.reverse(),
remainingContextTokens,
messagesToRefine: prunedMemory,
summaryIndex,
};
}
@@ -454,7 +458,7 @@ class BaseClient {
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
let { context, remainingContextTokens, messagesToRefine } =
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
await this.getMessagesWithinTokenLimit({
messages: orderedWithInstructions,
instructions,
@@ -524,7 +528,7 @@ class BaseClient {
}
// Make sure to only continue summarization logic if the summary message was generated
shouldSummarize = summaryMessage != null && shouldSummarize === true;
shouldSummarize = summaryMessage && shouldSummarize;
logger.debug('[BaseClient] Context Count (2/2)', {
remainingContextTokens,
@@ -534,18 +538,17 @@ class BaseClient {
/** @type {Record<string, number> | undefined} */
let tokenCountMap;
if (buildTokenMap) {
const currentPayload = shouldSummarize ? orderedWithInstructions : context;
tokenCountMap = currentPayload.reduce((map, message, index) => {
tokenCountMap = orderedWithInstructions.reduce((map, message, index) => {
const { messageId } = message;
if (!messageId) {
return map;
}
if (shouldSummarize && index === messagesToRefine.length - 1 && !usePrevSummary) {
if (shouldSummarize && index === summaryIndex && !usePrevSummary) {
map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount };
}
map[messageId] = currentPayload[index].tokenCount;
map[messageId] = orderedWithInstructions[index].tokenCount;
return map;
}, {});
}
@@ -634,9 +637,8 @@ class BaseClient {
}
}
const balance = this.options.req?.app?.locals?.balance;
if (
balance?.enabled &&
isEnabled(process.env.CHECK_BALANCE) &&
supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint]
) {
await checkBalance({
@@ -879,14 +881,13 @@ class BaseClient {
: await getConvo(this.options.req?.user?.id, message.conversationId);
const unsetFields = {};
const exceptions = new Set(['spec', 'iconURL']);
if (existingConvo != null) {
this.fetchedConvo = true;
for (const key in existingConvo) {
if (!key) {
continue;
}
if (excludedKeys.has(key) && !exceptions.has(key)) {
if (excludedKeys.has(key)) {
continue;
}
@@ -1020,17 +1021,11 @@ class BaseClient {
const processValue = (value) => {
if (Array.isArray(value)) {
for (let item of value) {
if (
!item ||
!item.type ||
item.type === ContentTypes.THINK ||
item.type === ContentTypes.ERROR ||
item.type === ContentTypes.IMAGE_URL
) {
if (!item || !item.type || item.type === 'image_url') {
continue;
}
if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) {
if (item.type === 'tool_call' && item.tool_call != null) {
const toolName = item.tool_call?.name || '';
if (toolName != null && toolName && typeof toolName === 'string') {
numTokens += this.getTokenCount(toolName);
@@ -1126,13 +1121,9 @@ class BaseClient {
return message;
}
const files = await getFiles(
{
file_id: { $in: fileIds },
},
{},
{},
);
const files = await getFiles({
file_id: { $in: fileIds },
});
await this.addImageURLs(message, files, this.visionMode);

View File

@@ -198,11 +198,7 @@ class GoogleClient extends BaseClient {
*/
checkVisionRequest(attachments) {
/* Validation vision request */
this.defaultVisionModel =
this.options.visionModel ??
(!EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)
? this.modelOptions.model
: 'gemini-pro-vision');
this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision';
const availableModels = this.options.modelsConfig?.[EModelEndpoint.google];
this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels });

View File

@@ -5,7 +5,6 @@ const { SplitStreamHandler, GraphEvents } = require('@librechat/agents');
const {
Constants,
ImageDetail,
ContentTypes,
EModelEndpoint,
resolveHeaders,
KnownEndpoints,
@@ -226,6 +225,10 @@ class OpenAIClient extends BaseClient {
logger.debug('Using Azure endpoint');
}
if (this.useOpenRouter) {
this.completionsUrl = 'https://openrouter.ai/api/v1/chat/completions';
}
return this;
}
@@ -502,24 +505,8 @@ class OpenAIClient extends BaseClient {
if (promptPrefix && this.isOmni === true) {
const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user');
if (lastUserMessageIndex !== -1) {
if (Array.isArray(payload[lastUserMessageIndex].content)) {
const firstTextPartIndex = payload[lastUserMessageIndex].content.findIndex(
(part) => part.type === ContentTypes.TEXT,
);
if (firstTextPartIndex !== -1) {
const firstTextPart = payload[lastUserMessageIndex].content[firstTextPartIndex];
payload[lastUserMessageIndex].content[firstTextPartIndex].text =
`${promptPrefix}\n${firstTextPart.text}`;
} else {
payload[lastUserMessageIndex].content.unshift({
type: ContentTypes.TEXT,
text: promptPrefix,
});
}
} else {
payload[lastUserMessageIndex].content =
`${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
}
payload[lastUserMessageIndex].content =
`${promptPrefix}\n${payload[lastUserMessageIndex].content}`;
}
}
@@ -1120,16 +1107,6 @@ ${convo}
return (msg) => {
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
} else if (msg.content != null) {
/** @type {import('@librechat/agents').MessageContentComplex} */
const newContent = [];
for (let part of msg.content) {
if (part.think != null) {
continue;
}
newContent.push(part);
}
msg.content = newContent;
}
return msg;
@@ -1181,6 +1158,10 @@ ${convo}
opts.httpAgent = new HttpsProxyAgent(this.options.proxy);
}
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}
/** @type {TAzureConfig | undefined} */
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
@@ -1291,29 +1272,6 @@ ${convo}
});
}
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
const searchExcludeParams = [
'frequency_penalty',
'presence_penalty',
'temperature',
'top_p',
'top_k',
'stop',
'logit_bias',
'seed',
'response_format',
'n',
'logprobs',
'user',
];
this.options.dropParams = this.options.dropParams || [];
this.options.dropParams = [
...new Set([...this.options.dropParams, ...searchExcludeParams]),
];
}
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => {
delete modelOptions[param];
@@ -1342,6 +1300,14 @@ ${convo}
let streamResolve;
if (
this.isOmni === true &&
(this.azure || /o1(?!-(?:mini|preview)).*$/.test(modelOptions.model)) &&
!/o3-.*$/.test(this.modelOptions.model) &&
modelOptions.stream
) {
delete modelOptions.stream;
delete modelOptions.stop;
} else if (
(!this.isOmni || /^o1-(mini|preview)/i.test(modelOptions.model)) &&
modelOptions.reasoning_effort != null
) {

View File

@@ -5,8 +5,9 @@ const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_pars
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { processFileURL } = require('~/server/services/Files/process');
const { EModelEndpoint } = require('librechat-data-provider');
const { checkBalance } = require('~/models/balanceMethods');
const { formatLangChainMessages } = require('./prompts');
const checkBalance = require('~/models/checkBalance');
const { isEnabled } = require('~/server/utils');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { logger } = require('~/config');
@@ -335,8 +336,7 @@ class PluginsClient extends OpenAIClient {
}
}
const balance = this.options.req?.app?.locals?.balance;
if (balance?.enabled) {
if (isEnabled(process.env.CHECK_BALANCE)) {
await checkBalance({
req: this.options.req,
res: this.options.res,

View File

@@ -1,8 +1,8 @@
const { promptTokensEstimate } = require('openai-chat-tokens');
const { EModelEndpoint, supportsBalanceCheck } = require('librechat-data-provider');
const { formatFromLangChain } = require('~/app/clients/prompts');
const { getBalanceConfig } = require('~/server/services/Config');
const { checkBalance } = require('~/models/balanceMethods');
const checkBalance = require('~/models/checkBalance');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const createStartHandler = ({
@@ -49,8 +49,8 @@ const createStartHandler = ({
prelimPromptTokens += tokenBuffer;
try {
const balance = await getBalanceConfig();
if (balance?.enabled && supportsBalanceCheck[EModelEndpoint.openAI]) {
// TODO: if plugins extends to non-OpenAI models, this will need to be updated
if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[EModelEndpoint.openAI]) {
const generations =
initialMessageCount && messages.length > initialMessageCount
? messages.slice(initialMessageCount)

View File

@@ -211,7 +211,7 @@ const formatAgentMessages = (payload) => {
} else if (part.type === ContentTypes.THINK) {
hasReasoning = true;
continue;
} else if (part.type === ContentTypes.ERROR || part.type === ContentTypes.AGENT_UPDATE) {
} else if (part.type === ContentTypes.ERROR) {
continue;
} else {
currentContent.push(part);

View File

@@ -164,7 +164,7 @@ describe('BaseClient', () => {
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
expect(result.context).toEqual(expectedContext);
expect(result.messagesToRefine.length - 1).toEqual(expectedIndex);
expect(result.summaryIndex).toEqual(expectedIndex);
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
});
@@ -200,7 +200,7 @@ describe('BaseClient', () => {
const result = await TestClient.getMessagesWithinTokenLimit({ messages });
expect(result.context).toEqual(expectedContext);
expect(result.messagesToRefine.length - 1).toEqual(expectedIndex);
expect(result.summaryIndex).toEqual(expectedIndex);
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
});

View File

@@ -136,11 +136,10 @@ OpenAI.mockImplementation(() => ({
}));
describe('OpenAIClient', () => {
const mockSet = jest.fn();
const mockCache = { set: mockSet };
beforeEach(() => {
const mockCache = {
get: jest.fn().mockResolvedValue({}),
set: jest.fn(),
};
getLogStores.mockReturnValue(mockCache);
});
let client;

View File

@@ -172,7 +172,7 @@ Error Message: ${error.message}`);
{
type: ContentTypes.IMAGE_URL,
image_url: {
url: `data:image/png;base64,${base64}`,
url: `data:image/jpeg;base64,${base64}`,
},
},
];

View File

@@ -21,7 +21,6 @@ const {
} = require('../');
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { createMCPTool } = require('~/server/services/MCP');
const { loadSpecs } = require('./loadSpecs');
const { logger } = require('~/config');
@@ -91,6 +90,45 @@ const validateTools = async (user, tools = []) => {
}
};
const loadAuthValues = async ({ userId, authFields, throwError = true }) => {
let authValues = {};
/**
* Finds the first non-empty value for the given authentication field, supporting alternate fields.
* @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||".
* @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found.
*/
const findAuthValue = async (fields) => {
for (const field of fields) {
let value = process.env[field];
if (value) {
return { authField: field, authValue: value };
}
try {
value = await getUserPluginAuthValue(userId, field, throwError);
} catch (err) {
if (field === fields[fields.length - 1] && !value) {
throw err;
}
}
if (value) {
return { authField: field, authValue: value };
}
}
return null;
};
for (let authField of authFields) {
const fields = authField.split('||');
const result = await findAuthValue(fields);
if (result) {
authValues[result.authField] = result.authValue;
}
}
return authValues;
};
/** @typedef {typeof import('@langchain/core/tools').Tool} ToolConstructor */
/** @typedef {import('@langchain/core/tools').Tool} Tool */
@@ -310,6 +348,7 @@ const loadTools = async ({
module.exports = {
loadToolWithAuth,
loadAuthValues,
validateTools,
loadTools,
};

View File

@@ -1,8 +1,9 @@
const { validateTools, loadTools } = require('./handleTools');
const { validateTools, loadTools, loadAuthValues } = require('./handleTools');
const handleOpenAIErrors = require('./handleOpenAIErrors');
module.exports = {
handleOpenAIErrors,
loadAuthValues,
validateTools,
loadTools,
};

View File

@@ -49,10 +49,6 @@ const genTitle = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
const s3ExpiryInterval = isRedisEnabled
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
: new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES });
const modelQueries = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
@@ -93,7 +89,6 @@ const namespaces = {
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
[CacheKeys.GEN_TITLE]: genTitle,
[CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval,
[CacheKeys.MODEL_QUERIES]: modelQueries,
[CacheKeys.AUDIO_RUNS]: audioRuns,
[CacheKeys.MESSAGES]: messages,

View File

@@ -9,7 +9,7 @@ const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, RED
let keyvRedis;
const redis_prefix = REDIS_KEY_PREFIX || '';
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40;
const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 10;
function mapURI(uri) {
const regex =
@@ -77,10 +77,10 @@ if (REDIS_URI && isEnabled(USE_REDIS)) {
keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err));
keyvRedis.setMaxListeners(redis_max_listeners);
logger.info(
'[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.',
'[Optional] Redis initialized. Note: Redis support is experimental. If you have issues, disable it. Cache needs to be flushed for values to refresh.',
);
} else {
logger.info('[Optional] Redis not initialized.');
logger.info('[Optional] Redis not initialized. Note: Redis support is experimental.');
}
module.exports = keyvRedis;

View File

@@ -1,4 +1,3 @@
const axios = require('axios');
const { EventSource } = require('eventsource');
const { Time, CacheKeys } = require('librechat-data-provider');
const logger = require('./winston');
@@ -48,46 +47,9 @@ const sendEvent = (res, event) => {
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
};
/**
* Creates and configures an Axios instance with optional proxy settings.
*
* @typedef {import('axios').AxiosInstance} AxiosInstance
* @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig
*
* @returns {AxiosInstance} A configured Axios instance
* @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL
*/
function createAxiosInstance() {
const instance = axios.create();
if (process.env.proxy) {
try {
const url = new URL(process.env.proxy);
/** @type {AxiosProxyConfig} */
const proxyConfig = {
host: url.hostname.replace(/^\[|\]$/g, ''),
protocol: url.protocol.replace(':', ''),
};
if (url.port) {
proxyConfig.port = parseInt(url.port, 10);
}
instance.defaults.proxy = proxyConfig;
} catch (error) {
console.error('Error parsing proxy URL:', error);
throw new Error(`Invalid proxy URL: ${process.env.proxy}`);
}
}
return instance;
}
module.exports = {
logger,
sendEvent,
getMCPManager,
createAxiosInstance,
getFlowStateManager,
};

View File

@@ -1,126 +0,0 @@
const axios = require('axios');
const { createAxiosInstance } = require('./index');
// Mock axios
jest.mock('axios', () => ({
interceptors: {
request: { use: jest.fn(), eject: jest.fn() },
response: { use: jest.fn(), eject: jest.fn() },
},
create: jest.fn().mockReturnValue({
defaults: {
proxy: null,
},
get: jest.fn().mockResolvedValue({ data: {} }),
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
}),
get: jest.fn().mockResolvedValue({ data: {} }),
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
reset: jest.fn().mockImplementation(function () {
this.get.mockClear();
this.post.mockClear();
this.put.mockClear();
this.delete.mockClear();
this.create.mockClear();
}),
}));
describe('createAxiosInstance', () => {
const originalEnv = process.env;
beforeEach(() => {
// Reset mocks
jest.clearAllMocks();
// Create a clean copy of process.env
process.env = { ...originalEnv };
// Default: no proxy
delete process.env.proxy;
});
afterAll(() => {
// Restore original process.env
process.env = originalEnv;
});
test('creates an axios instance without proxy when no proxy env is set', () => {
const instance = createAxiosInstance();
expect(axios.create).toHaveBeenCalledTimes(1);
expect(instance.defaults.proxy).toBeNull();
});
test('configures proxy correctly with hostname and protocol', () => {
process.env.proxy = 'http://example.com';
const instance = createAxiosInstance();
expect(axios.create).toHaveBeenCalledTimes(1);
expect(instance.defaults.proxy).toEqual({
host: 'example.com',
protocol: 'http',
});
});
test('configures proxy correctly with hostname, protocol and port', () => {
process.env.proxy = 'https://proxy.example.com:8080';
const instance = createAxiosInstance();
expect(axios.create).toHaveBeenCalledTimes(1);
expect(instance.defaults.proxy).toEqual({
host: 'proxy.example.com',
protocol: 'https',
port: 8080,
});
});
test('handles proxy URLs with authentication', () => {
process.env.proxy = 'http://user:pass@proxy.example.com:3128';
const instance = createAxiosInstance();
expect(axios.create).toHaveBeenCalledTimes(1);
expect(instance.defaults.proxy).toEqual({
host: 'proxy.example.com',
protocol: 'http',
port: 3128,
// Note: The current implementation doesn't handle auth - if needed, add this functionality
});
});
test('throws error when proxy URL is invalid', () => {
process.env.proxy = 'invalid-url';
expect(() => createAxiosInstance()).toThrow('Invalid proxy URL');
expect(axios.create).toHaveBeenCalledTimes(1);
});
// If you want to test the actual URL parsing more thoroughly
test('handles edge case proxy URLs correctly', () => {
// IPv6 address
process.env.proxy = 'http://[::1]:8080';
let instance = createAxiosInstance();
expect(instance.defaults.proxy).toEqual({
host: '::1',
protocol: 'http',
port: 8080,
});
// URL with path (which should be ignored for proxy config)
process.env.proxy = 'http://proxy.example.com:8080/some/path';
instance = createAxiosInstance();
expect(instance.defaults.proxy).toEqual({
host: 'proxy.example.com',
protocol: 'http',
port: 8080,
});
});
});

View File

@@ -4,11 +4,7 @@ require('winston-daily-rotate-file');
const logDir = path.join(__dirname, '..', 'logs');
const { NODE_ENV, DEBUG_LOGGING = false } = process.env;
const useDebugLogging =
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
DEBUG_LOGGING === true;
const { NODE_ENV } = process.env;
const levels = {
error: 0,
@@ -40,10 +36,9 @@ const fileFormat = winston.format.combine(
winston.format.splat(),
);
const logLevel = useDebugLogging ? 'debug' : 'error';
const transports = [
new winston.transports.DailyRotateFile({
level: logLevel,
level: 'debug',
filename: `${logDir}/meiliSync-%DATE%.log`,
datePattern: 'YYYY-MM-DD',
zippedArchive: true,
@@ -53,6 +48,14 @@ const transports = [
}),
];
// if (NODE_ENV !== 'production') {
// transports.push(
// new winston.transports.Console({
// format: winston.format.combine(winston.format.colorize(), winston.format.simple()),
// }),
// );
// }
const consoleFormat = winston.format.combine(
winston.format.colorize({ all: true }),
winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }),

View File

@@ -5,7 +5,7 @@ const { redactFormat, redactMessage, debugTraverse, jsonTruncateFormat } = requi
const logDir = path.join(__dirname, '..', 'logs');
const { NODE_ENV, DEBUG_LOGGING = true, CONSOLE_JSON = false, DEBUG_CONSOLE = false } = process.env;
const { NODE_ENV, DEBUG_LOGGING = true, DEBUG_CONSOLE = false, CONSOLE_JSON = false } = process.env;
const useConsoleJson =
(typeof CONSOLE_JSON === 'string' && CONSOLE_JSON?.toLowerCase() === 'true') ||
@@ -15,10 +15,6 @@ const useDebugConsole =
(typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE?.toLowerCase() === 'true') ||
DEBUG_CONSOLE === true;
const useDebugLogging =
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
DEBUG_LOGGING === true;
const levels = {
error: 0,
warn: 1,
@@ -61,9 +57,28 @@ const transports = [
maxFiles: '14d',
format: fileFormat,
}),
// new winston.transports.DailyRotateFile({
// level: 'info',
// filename: `${logDir}/info-%DATE%.log`,
// datePattern: 'YYYY-MM-DD',
// zippedArchive: true,
// maxSize: '20m',
// maxFiles: '14d',
// }),
];
if (useDebugLogging) {
// if (NODE_ENV !== 'production') {
// transports.push(
// new winston.transports.Console({
// format: winston.format.combine(winston.format.colorize(), winston.format.simple()),
// }),
// );
// }
if (
(typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') ||
DEBUG_LOGGING === true
) {
transports.push(
new winston.transports.DailyRotateFile({
level: 'debug',
@@ -92,16 +107,10 @@ const consoleFormat = winston.format.combine(
}),
);
// Determine console log level
let consoleLogLevel = 'info';
if (useDebugConsole) {
consoleLogLevel = 'debug';
}
if (useDebugConsole) {
transports.push(
new winston.transports.Console({
level: consoleLogLevel,
level: 'debug',
format: useConsoleJson
? winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json())
: winston.format.combine(fileFormat, debugTraverse),
@@ -110,14 +119,14 @@ if (useDebugConsole) {
} else if (useConsoleJson) {
transports.push(
new winston.transports.Console({
level: consoleLogLevel,
level: 'info',
format: winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json()),
}),
);
} else {
transports.push(
new winston.transports.Console({
level: consoleLogLevel,
level: 'info',
format: consoleFormat,
}),
);

View File

@@ -46,10 +46,6 @@ const loadAgent = async ({ req, agent_id }) => {
id: agent_id,
});
if (!agent) {
return null;
}
if (agent.author.toString() === req.user.id) {
return agent;
}

View File

@@ -1,4 +1,44 @@
const mongoose = require('mongoose');
const { balanceSchema } = require('@librechat/data-schemas');
const { getMultiplier } = require('./tx');
const { logger } = require('~/config');
balanceSchema.statics.check = async function ({
user,
model,
endpoint,
valueKey,
tokenType,
amount,
endpointTokenConfig,
}) {
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
const tokenCost = amount * multiplier;
const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {};
logger.debug('[Balance.check]', {
user,
model,
endpoint,
valueKey,
tokenType,
amount,
balance,
multiplier,
endpointTokenConfig: !!endpointTokenConfig,
});
if (!balance) {
return {
canSpend: false,
balance: 0,
tokenCost,
};
}
logger.debug('[Balance.check]', { tokenCost });
return { canSpend: balance >= tokenCost, balance, tokenCost };
};
module.exports = mongoose.model('Balance', balanceSchema);

View File

@@ -28,4 +28,4 @@ const getBanner = async (user) => {
}
};
module.exports = { Banner, getBanner };
module.exports = { getBanner };

View File

@@ -15,6 +15,19 @@ const searchConversation = async (conversationId) => {
throw new Error('Error searching conversation');
}
};
/**
* Searches for a conversation by conversationId and returns associated file ids.
* @param {string} conversationId - The conversation's ID.
* @returns {Promise<string[] | null>}
*/
const getConvoFiles = async (conversationId) => {
try {
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
} catch (error) {
logger.error('[getConvoFiles] Error getting conversation files', error);
throw new Error('Error getting conversation files');
}
};
/**
* Retrieves a single conversation for a given user and conversation ID.
@@ -60,20 +73,6 @@ const deleteNullOrEmptyConversations = async () => {
}
};
/**
* Searches for a conversation by conversationId and returns associated file ids.
* @param {string} conversationId - The conversation's ID.
* @returns {Promise<string[] | null>}
*/
const getConvoFiles = async (conversationId) => {
try {
return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? [];
} catch (error) {
logger.error('[getConvoFiles] Error getting conversation files', error);
throw new Error('Error getting conversation files');
}
};
module.exports = {
Conversation,
getConvoFiles,

View File

@@ -1,6 +1,5 @@
const mongoose = require('mongoose');
const { fileSchema } = require('@librechat/data-schemas');
const { logger } = require('~/config');
const File = mongoose.model('File', fileSchema);
@@ -18,39 +17,11 @@ const findFileById = async (file_id, options = {}) => {
* Retrieves files matching a given filter, sorted by the most recently updated.
* @param {Object} filter - The filter criteria to apply.
* @param {Object} [_sortOptions] - Optional sort parameters.
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
* Default excludes the 'text' field.
* @returns {Promise<Array<IMongoFile>>} A promise that resolves to an array of file documents.
*/
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
const getFiles = async (filter, _sortOptions) => {
const sortOptions = { updatedAt: -1, ..._sortOptions };
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
};
/**
* Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs
* @param {string[]} fileIds - Array of file_id strings to search for
* @returns {Promise<Array<IMongoFile>>} Files that match the criteria
*/
const getToolFilesByIds = async (fileIds) => {
if (!fileIds || !fileIds.length) {
return [];
}
try {
const filter = {
file_id: { $in: fileIds },
$or: [{ embedded: true }, { 'metadata.fileIdentifier': { $exists: true } }],
};
const selectFields = { text: 0 };
const sortOptions = { updatedAt: -1 };
return await getFiles(filter, sortOptions, selectFields);
} catch (error) {
logger.error('[getToolFilesByIds] Error retrieving tool files:', error);
throw new Error('Error retrieving tool files');
}
return await File.find(filter).sort(sortOptions).lean();
};
/**
@@ -134,38 +105,14 @@ const deleteFiles = async (file_ids, user) => {
return await File.deleteMany(deleteQuery);
};
/**
* Batch updates files with new signed URLs in MongoDB
*
* @param {MongoFile[]} updates - Array of updates in the format { file_id, filepath }
* @returns {Promise<void>}
*/
async function batchUpdateFiles(updates) {
if (!updates || updates.length === 0) {
return;
}
const bulkOperations = updates.map((update) => ({
updateOne: {
filter: { file_id: update.file_id },
update: { $set: { filepath: update.filepath } },
},
}));
const result = await File.bulkWrite(bulkOperations);
logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`);
}
module.exports = {
File,
findFileById,
getFiles,
getToolFilesByIds,
createFile,
updateFile,
updateFileUsage,
deleteFile,
deleteFiles,
deleteFileByFilter,
batchUpdateFiles,
};

View File

@@ -71,42 +71,7 @@ async function saveMessage(req, params, metadata) {
} catch (err) {
logger.error('Error saving message:', err);
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
// Check if this is a duplicate key error (MongoDB error code 11000)
if (err.code === 11000 && err.message.includes('duplicate key error')) {
// Log the duplicate key error but don't crash the application
logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`);
try {
// Try to find the existing message with this ID
const existingMessage = await Message.findOne({
messageId: params.messageId,
user: req.user.id,
});
// If we found it, return it
if (existingMessage) {
return existingMessage.toObject();
}
// If we can't find it (unlikely but possible in race conditions)
return {
...params,
messageId: params.messageId,
user: req.user.id,
};
} catch (findError) {
// If the findOne also fails, log it but don't crash
logger.warn(`Could not retrieve existing message with ID ${params.messageId}: ${findError.message}`);
return {
...params,
messageId: params.messageId,
user: req.user.id,
};
}
}
throw err; // Re-throw other errors
throw err;
}
}

View File

@@ -1,144 +1,11 @@
const mongoose = require('mongoose');
const { isEnabled } = require('~/server/utils/handleText');
const { transactionSchema } = require('@librechat/data-schemas');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { logger } = require('~/config');
const Balance = require('./Balance');
const cancelRate = 1.15;
/**
* Updates a user's token balance based on a transaction using optimistic concurrency control
* without schema changes. Compatible with DocumentDB.
* @async
* @function
* @param {Object} params - The function parameters.
* @param {string|mongoose.Types.ObjectId} params.user - The user ID.
* @param {number} params.incrementValue - The value to increment the balance by (can be negative).
* @param {import('mongoose').UpdateQuery<import('@librechat/data-schemas').IBalance>['$set']} [params.setValues] - Optional additional fields to set.
* @returns {Promise<Object>} Returns the updated balance document (lean).
* @throws {Error} Throws an error if the update fails after multiple retries.
*/
const updateBalance = async ({ user, incrementValue, setValues }) => {
let maxRetries = 10; // Number of times to retry on conflict
let delay = 50; // Initial retry delay in ms
let lastError = null;
for (let attempt = 1; attempt <= maxRetries; attempt++) {
let currentBalanceDoc;
try {
// 1. Read the current document state
currentBalanceDoc = await Balance.findOne({ user }).lean();
const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0;
// 2. Calculate the desired new state
const potentialNewCredits = currentCredits + incrementValue;
const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero
// 3. Prepare the update payload
const updatePayload = {
$set: {
tokenCredits: newCredits,
...(setValues || {}), // Merge other values to set
},
};
// 4. Attempt the conditional update or upsert
let updatedBalance = null;
if (currentBalanceDoc) {
// --- Document Exists: Perform Conditional Update ---
// Try to update only if the tokenCredits match the value we read (currentCredits)
updatedBalance = await Balance.findOneAndUpdate(
{
user: user,
tokenCredits: currentCredits, // Optimistic lock: condition based on the read value
},
updatePayload,
{
new: true, // Return the modified document
// lean: true, // .lean() is applied after query execution in Mongoose >= 6
},
).lean(); // Use lean() for plain JS object
if (updatedBalance) {
// Success! The update was applied based on the expected current state.
return updatedBalance;
}
// If updatedBalance is null, it means tokenCredits changed between read and write (conflict).
lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`);
// Proceed to retry logic below.
} else {
// --- Document Does Not Exist: Perform Conditional Upsert ---
// Try to insert the document, but only if it still doesn't exist.
// Using tokenCredits: {$exists: false} helps prevent race conditions where
// another process creates the doc between our findOne and findOneAndUpdate.
try {
updatedBalance = await Balance.findOneAndUpdate(
{
user: user,
// Attempt to match only if the document doesn't exist OR was just created
// without tokenCredits (less likely but possible). A simple { user } filter
// might also work, relying on the retry for conflicts.
// Let's use a simpler filter and rely on retry for races.
// tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits
},
updatePayload,
{
upsert: true, // Create if doesn't exist
new: true, // Return the created/updated document
// setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert
// lean: true,
},
).lean();
if (updatedBalance) {
// Upsert succeeded (likely created the document)
return updatedBalance;
}
// If null, potentially a rare race condition during upsert. Retry should handle it.
lastError = new Error(
`Upsert race condition suspected for user ${user} on attempt ${attempt}.`,
);
} catch (error) {
if (error.code === 11000) {
// E11000 duplicate key error on index
// This means another process created the document *just* before our upsert.
// It's a concurrency conflict during creation. We should retry.
lastError = error; // Store the error
// Proceed to retry logic below.
} else {
// Different error, rethrow
throw error;
}
}
} // End if/else (document exists?)
} catch (error) {
// Catch errors from findOne or unexpected findOneAndUpdate errors
logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error);
lastError = error; // Store the error
// Consider stopping retries for non-transient errors, but for now, we retry.
}
// If we reached here, it means the update failed (conflict or error), wait and retry
if (attempt < maxRetries) {
const jitter = Math.random() * delay * 0.5; // Add jitter to delay
await new Promise((resolve) => setTimeout(resolve, delay + jitter));
delay = Math.min(delay * 2, 2000); // Exponential backoff with cap
}
} // End for loop (retries)
// If loop finishes without success, throw the last encountered error or a generic one
logger.error(
`[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`,
);
throw (
lastError ||
new Error(
`Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`,
)
);
};
/** Method to calculate and set the tokenValue for a transaction */
transactionSchema.methods.calculateTokenValue = function () {
if (!this.valueKey || !this.tokenType) {
@@ -154,39 +21,6 @@ transactionSchema.methods.calculateTokenValue = function () {
}
};
/**
* New static method to create an auto-refill transaction that does NOT trigger a balance update.
* @param {object} txData - Transaction data.
* @param {string} txData.user - The user ID.
* @param {string} txData.tokenType - The type of token.
* @param {string} txData.context - The context of the transaction.
* @param {number} txData.rawAmount - The raw amount of tokens.
* @returns {Promise<object>} - The created transaction.
*/
transactionSchema.statics.createAutoRefillTransaction = async function (txData) {
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
return;
}
const transaction = new this(txData);
transaction.endpointTokenConfig = txData.endpointTokenConfig;
transaction.calculateTokenValue();
await transaction.save();
const balanceResponse = await updateBalance({
user: transaction.user,
incrementValue: txData.rawAmount,
setValues: { lastRefill: new Date() },
});
const result = {
rate: transaction.rate,
user: transaction.user.toString(),
balance: balanceResponse.tokenCredits,
};
logger.debug('[Balance.check] Auto-refill performed', result);
result.transaction = transaction;
return result;
};
/**
* Static method to create a transaction and update the balance
* @param {txData} txData - Transaction data.
@@ -203,22 +37,27 @@ transactionSchema.statics.create = async function (txData) {
await transaction.save();
const balance = await getBalanceConfig();
if (!balance?.enabled) {
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
let balance = await Balance.findOne({ user: transaction.user }).lean();
let incrementValue = transaction.tokenValue;
const balanceResponse = await updateBalance({
user: transaction.user,
incrementValue,
});
if (balance && balance?.tokenCredits + incrementValue < 0) {
incrementValue = -balance.tokenCredits;
}
balance = await Balance.findOneAndUpdate(
{ user: transaction.user },
{ $inc: { tokenCredits: incrementValue } },
{ upsert: true, new: true },
).lean();
return {
rate: transaction.rate,
user: transaction.user.toString(),
balance: balanceResponse.tokenCredits,
balance: balance.tokenCredits,
[transaction.tokenType]: incrementValue,
};
};
@@ -239,22 +78,27 @@ transactionSchema.statics.createStructured = async function (txData) {
await transaction.save();
const balance = await getBalanceConfig();
if (!balance?.enabled) {
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
let balance = await Balance.findOne({ user: transaction.user }).lean();
let incrementValue = transaction.tokenValue;
const balanceResponse = await updateBalance({
user: transaction.user,
incrementValue,
});
if (balance && balance?.tokenCredits + incrementValue < 0) {
incrementValue = -balance.tokenCredits;
}
balance = await Balance.findOneAndUpdate(
{ user: transaction.user },
{ $inc: { tokenCredits: incrementValue } },
{ upsert: true, new: true },
).lean();
return {
rate: transaction.rate,
user: transaction.user.toString(),
balance: balanceResponse.tokenCredits,
balance: balance.tokenCredits,
[transaction.tokenType]: incrementValue,
};
};

View File

@@ -1,13 +1,9 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getBalanceConfig } = require('~/server/services/Config');
const { getMultiplier, getCacheMultiplier } = require('./tx');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
// Mock the custom config module so we can control the balance flag.
jest.mock('~/server/services/Config');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { getMultiplier, getCacheMultiplier } = require('./tx');
let mongoServer;
@@ -24,8 +20,6 @@ afterAll(async () => {
beforeEach(async () => {
await mongoose.connection.dropDatabase();
// Default: enable balance updates in tests.
getBalanceConfig.mockResolvedValue({ enabled: true });
});
describe('Regular Token Spending Tests', () => {
@@ -50,22 +44,34 @@ describe('Regular Token Spending Tests', () => {
};
// Act
process.env.CHECK_BALANCE = 'true';
await spendTokens(txData, tokenUsage);
// Assert
console.log('Initial Balance:', initialBalance);
const updatedBalance = await Balance.findOne({ user: userId });
console.log('Updated Balance:', updatedBalance.tokenCredits);
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
const expectedTotalCost = 100 * promptMultiplier + 50 * completionMultiplier;
const expectedPromptCost = tokenUsage.promptTokens * promptMultiplier;
const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier;
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
const expectedBalance = initialBalance - expectedTotalCost;
expect(updatedBalance.tokenCredits).toBeLessThan(initialBalance);
expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0);
console.log('Expected Total Cost:', expectedTotalCost);
console.log('Actual Balance Decrease:', initialBalance - updatedBalance.tokenCredits);
});
test('spendTokens should handle zero completion tokens', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
const initialBalance = 10000000; // $10.00
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
@@ -83,19 +89,24 @@ describe('Regular Token Spending Tests', () => {
};
// Act
process.env.CHECK_BALANCE = 'true';
await spendTokens(txData, tokenUsage);
// Assert
const updatedBalance = await Balance.findOne({ user: userId });
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
const expectedCost = 100 * promptMultiplier;
const expectedCost = tokenUsage.promptTokens * promptMultiplier;
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
console.log('Initial Balance:', initialBalance);
console.log('Updated Balance:', updatedBalance.tokenCredits);
console.log('Expected Cost:', expectedCost);
});
test('spendTokens should handle undefined token counts', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
const initialBalance = 10000000; // $10.00
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
@@ -109,17 +120,14 @@ describe('Regular Token Spending Tests', () => {
const tokenUsage = {};
// Act
const result = await spendTokens(txData, tokenUsage);
// Assert: No transaction should be created
expect(result).toBeUndefined();
});
test('spendTokens should handle only prompt tokens', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
const initialBalance = 10000000; // $10.00
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
@@ -133,44 +141,14 @@ describe('Regular Token Spending Tests', () => {
const tokenUsage = { promptTokens: 100 };
// Act
await spendTokens(txData, tokenUsage);
// Assert
const updatedBalance = await Balance.findOne({ user: userId });
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
const expectedCost = 100 * promptMultiplier;
expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0);
});
test('spendTokens should not update balance when balance feature is disabled', async () => {
// Arrange: Override the config to disable balance updates.
getBalanceConfig.mockResolvedValue({ balance: { enabled: false } });
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
const model = 'gpt-3.5-turbo';
const txData = {
user: userId,
conversationId: 'test-conversation-id',
model,
context: 'test',
endpointTokenConfig: null,
};
const tokenUsage = {
promptTokens: 100,
completionTokens: 50,
};
// Act
await spendTokens(txData, tokenUsage);
// Assert: Balance should remain unchanged.
const updatedBalance = await Balance.findOne({ user: userId });
expect(updatedBalance.tokenCredits).toBe(initialBalance);
});
});
describe('Structured Token Spending Tests', () => {
@@ -186,7 +164,7 @@ describe('Structured Token Spending Tests', () => {
conversationId: 'c23a18da-706c-470a-ac28-ec87ed065199',
model,
context: 'message',
endpointTokenConfig: null,
endpointTokenConfig: null, // We'll use the default rates
};
const tokenUsage = {
@@ -198,15 +176,28 @@ describe('Structured Token Spending Tests', () => {
completionTokens: 5,
};
// Get the actual multipliers
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
console.log('Multipliers:', {
promptMultiplier,
completionMultiplier,
writeMultiplier,
readMultiplier,
});
// Act
process.env.CHECK_BALANCE = 'true';
const result = await spendStructuredTokens(txData, tokenUsage);
// Calculate expected costs.
// Assert
console.log('Initial Balance:', initialBalance);
console.log('Updated Balance:', result.completion.balance);
console.log('Transaction Result:', result);
const expectedPromptCost =
tokenUsage.promptTokens.input * promptMultiplier +
tokenUsage.promptTokens.write * writeMultiplier +
@@ -215,21 +206,37 @@ describe('Structured Token Spending Tests', () => {
const expectedTotalCost = expectedPromptCost + expectedCompletionCost;
const expectedBalance = initialBalance - expectedTotalCost;
// Assert
console.log('Expected Cost:', expectedTotalCost);
console.log('Expected Balance:', expectedBalance);
expect(result.completion.balance).toBeLessThan(initialBalance);
// Allow for a small difference (e.g., 100 token credits, which is $0.0001)
const allowedDifference = 100;
expect(Math.abs(result.completion.balance - expectedBalance)).toBeLessThan(allowedDifference);
// Check if the decrease is approximately as expected
const balanceDecrease = initialBalance - result.completion.balance;
expect(balanceDecrease).toBeCloseTo(expectedTotalCost, 0);
const expectedPromptTokenValue = -expectedPromptCost;
const expectedCompletionTokenValue = -expectedCompletionCost;
// Check token values
const expectedPromptTokenValue = -(
tokenUsage.promptTokens.input * promptMultiplier +
tokenUsage.promptTokens.write * writeMultiplier +
tokenUsage.promptTokens.read * readMultiplier
);
const expectedCompletionTokenValue = -tokenUsage.completionTokens * completionMultiplier;
expect(result.prompt.prompt).toBeCloseTo(expectedPromptTokenValue, 1);
expect(result.completion.completion).toBe(expectedCompletionTokenValue);
console.log('Expected prompt tokenValue:', expectedPromptTokenValue);
console.log('Actual prompt tokenValue:', result.prompt.prompt);
console.log('Expected completion tokenValue:', expectedCompletionTokenValue);
console.log('Actual completion tokenValue:', result.completion.completion);
});
test('should handle zero completion tokens in structured spending', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 17613154.55;
await Balance.create({ user: userId, tokenCredits: initialBalance });
@@ -251,17 +258,15 @@ describe('Structured Token Spending Tests', () => {
completionTokens: 0,
};
// Act
process.env.CHECK_BALANCE = 'true';
const result = await spendStructuredTokens(txData, tokenUsage);
// Assert
expect(result.prompt).toBeDefined();
expect(result.completion).toBeUndefined();
expect(result.prompt.prompt).toBeLessThan(0);
});
test('should handle only prompt tokens in structured spending', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 17613154.55;
await Balance.create({ user: userId, tokenCredits: initialBalance });
@@ -282,17 +287,15 @@ describe('Structured Token Spending Tests', () => {
},
};
// Act
process.env.CHECK_BALANCE = 'true';
const result = await spendStructuredTokens(txData, tokenUsage);
// Assert
expect(result.prompt).toBeDefined();
expect(result.completion).toBeUndefined();
expect(result.prompt.prompt).toBeLessThan(0);
});
test('should handle undefined token counts in structured spending', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 17613154.55;
await Balance.create({ user: userId, tokenCredits: initialBalance });
@@ -307,10 +310,9 @@ describe('Structured Token Spending Tests', () => {
const tokenUsage = {};
// Act
process.env.CHECK_BALANCE = 'true';
const result = await spendStructuredTokens(txData, tokenUsage);
// Assert
expect(result).toEqual({
prompt: undefined,
completion: undefined,
@@ -318,7 +320,6 @@ describe('Structured Token Spending Tests', () => {
});
test('should handle incomplete context for completion tokens', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 17613154.55;
await Balance.create({ user: userId, tokenCredits: initialBalance });
@@ -340,18 +341,15 @@ describe('Structured Token Spending Tests', () => {
completionTokens: 50,
};
// Act
process.env.CHECK_BALANCE = 'true';
const result = await spendStructuredTokens(txData, tokenUsage);
// Assert:
// (Assuming a multiplier for completion of 15 and a cancel rate of 1.15 as noted in the original test.)
expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0);
expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15
});
});
describe('NaN Handling Tests', () => {
test('should skip transaction creation when rawAmount is NaN', async () => {
// Arrange
const userId = new mongoose.Types.ObjectId();
const initialBalance = 10000000;
await Balance.create({ user: userId, tokenCredits: initialBalance });
@@ -367,11 +365,9 @@ describe('NaN Handling Tests', () => {
tokenType: 'prompt',
};
// Act
const result = await Transaction.create(txData);
// Assert: No transaction should be created and balance remains unchanged.
expect(result).toBeUndefined();
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(initialBalance);
});

View File

@@ -1,156 +0,0 @@
const { ViolationTypes } = require('librechat-data-provider');
const { Transaction } = require('./Transaction');
const { logViolation } = require('~/cache');
const { getMultiplier } = require('./tx');
const { logger } = require('~/config');
const Balance = require('./Balance');
function isInvalidDate(date) {
return isNaN(date);
}
/**
* Simple check method that calculates token cost and returns balance info.
* The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies.
*/
const checkBalanceRecord = async function ({
user,
model,
endpoint,
valueKey,
tokenType,
amount,
endpointTokenConfig,
}) {
const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig });
const tokenCost = amount * multiplier;
// Retrieve the balance record
let record = await Balance.findOne({ user }).lean();
if (!record) {
logger.debug('[Balance.check] No balance record found for user', { user });
return {
canSpend: false,
balance: 0,
tokenCost,
};
}
let balance = record.tokenCredits;
logger.debug('[Balance.check] Initial state', {
user,
model,
endpoint,
valueKey,
tokenType,
amount,
balance,
multiplier,
endpointTokenConfig: !!endpointTokenConfig,
});
// Only perform auto-refill if spending would bring the balance to 0 or below
if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) {
const lastRefillDate = new Date(record.lastRefill);
const now = new Date();
if (
isInvalidDate(lastRefillDate) ||
now >=
addIntervalToDate(lastRefillDate, record.refillIntervalValue, record.refillIntervalUnit)
) {
try {
/** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */
const result = await Transaction.createAutoRefillTransaction({
user: user,
tokenType: 'credits',
context: 'autoRefill',
rawAmount: record.refillAmount,
});
balance = result.balance;
} catch (error) {
logger.error('[Balance.check] Failed to record transaction for auto-refill', error);
}
}
}
logger.debug('[Balance.check] Token cost', { tokenCost });
return { canSpend: balance >= tokenCost, balance, tokenCost };
};
/**
* Adds a time interval to a given date.
* @param {Date} date - The starting date.
* @param {number} value - The numeric value of the interval.
* @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time.
* @returns {Date} A new Date representing the starting date plus the interval.
*/
const addIntervalToDate = (date, value, unit) => {
const result = new Date(date);
switch (unit) {
case 'seconds':
result.setSeconds(result.getSeconds() + value);
break;
case 'minutes':
result.setMinutes(result.getMinutes() + value);
break;
case 'hours':
result.setHours(result.getHours() + value);
break;
case 'days':
result.setDate(result.getDate() + value);
break;
case 'weeks':
result.setDate(result.getDate() + value * 7);
break;
case 'months':
result.setMonth(result.getMonth() + value);
break;
default:
break;
}
return result;
};
/**
* Checks the balance for a user and determines if they can spend a certain amount.
* If the user cannot spend the amount, it logs a violation and denies the request.
*
* @async
* @function
* @param {Object} params - The function parameters.
* @param {Express.Request} params.req - The Express request object.
* @param {Express.Response} params.res - The Express response object.
* @param {Object} params.txData - The transaction data.
* @param {string} params.txData.user - The user ID or identifier.
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
* @param {number} params.txData.amount - The amount of tokens.
* @param {string} params.txData.model - The model name or identifier.
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
* @returns {Promise<boolean>} Throws error if the user cannot spend the amount.
* @throws {Error} Throws an error if there's an issue with the balance check.
*/
const checkBalance = async ({ req, res, txData }) => {
const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData);
if (canSpend) {
return true;
}
const type = ViolationTypes.TOKEN_BALANCE;
const errorMessage = {
type,
balance,
tokenCost,
promptTokens: txData.amount,
};
if (txData.generations && txData.generations.length > 0) {
errorMessage.generations = txData.generations;
}
await logViolation(req, res, type, errorMessage, 0);
throw new Error(JSON.stringify(errorMessage));
};
module.exports = {
checkBalance,
};

View File

@@ -0,0 +1,45 @@
const { ViolationTypes } = require('librechat-data-provider');
const { logViolation } = require('~/cache');
const Balance = require('./Balance');
/**
* Checks the balance for a user and determines if they can spend a certain amount.
* If the user cannot spend the amount, it logs a violation and denies the request.
*
* @async
* @function
* @param {Object} params - The function parameters.
* @param {Express.Request} params.req - The Express request object.
* @param {Express.Response} params.res - The Express response object.
* @param {Object} params.txData - The transaction data.
* @param {string} params.txData.user - The user ID or identifier.
* @param {('prompt' | 'completion')} params.txData.tokenType - The type of token.
* @param {number} params.txData.amount - The amount of tokens.
* @param {string} params.txData.model - The model name or identifier.
* @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint.
* @returns {Promise<boolean>} Returns true if the user can spend the amount, otherwise denies the request.
* @throws {Error} Throws an error if there's an issue with the balance check.
*/
const checkBalance = async ({ req, res, txData }) => {
const { canSpend, balance, tokenCost } = await Balance.check(txData);
if (canSpend) {
return true;
}
const type = ViolationTypes.TOKEN_BALANCE;
const errorMessage = {
type,
balance,
tokenCost,
promptTokens: txData.amount,
};
if (txData.generations && txData.generations.length > 0) {
errorMessage.generations = txData.generations;
}
await logViolation(req, res, type, errorMessage, 0);
throw new Error(JSON.stringify(errorMessage));
};
module.exports = checkBalance;

View File

@@ -1,7 +1,6 @@
const _ = require('lodash');
const mongoose = require('mongoose');
const { MeiliSearch } = require('meilisearch');
const { parseTextParts, ContentTypes } = require('librechat-data-provider');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const logger = require('~/config/meiliLogger');
@@ -239,7 +238,10 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) {
}
if (object.content && Array.isArray(object.content)) {
object.text = parseTextParts(object.content);
object.text = object.content
.filter((item) => item.type === 'text' && item.text && item.text.value)
.map((item) => item.text.value)
.join(' ');
delete object.content;
}

View File

@@ -36,7 +36,7 @@ const spendTokens = async (txData, tokenUsage) => {
prompt = await Transaction.create({
...txData,
tokenType: 'prompt',
rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0),
rawAmount: -Math.max(promptTokens, 0),
});
}
@@ -44,7 +44,7 @@ const spendTokens = async (txData, tokenUsage) => {
completion = await Transaction.create({
...txData,
tokenType: 'completion',
rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0),
rawAmount: -Math.max(completionTokens, 0),
});
}

View File

@@ -1,10 +1,17 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
// Mock the logger to prevent console output during tests
jest.mock('./Transaction', () => ({
Transaction: {
create: jest.fn(),
createStructured: jest.fn(),
},
}));
jest.mock('./Balance', () => ({
findOne: jest.fn(),
findOneAndUpdate: jest.fn(),
}));
jest.mock('~/config', () => ({
logger: {
debug: jest.fn(),
@@ -12,46 +19,19 @@ jest.mock('~/config', () => ({
},
}));
// Mock the Config service
const { getBalanceConfig } = require('~/server/services/Config');
jest.mock('~/server/services/Config');
// Import after mocking
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
const { Transaction } = require('./Transaction');
const Balance = require('./Balance');
describe('spendTokens', () => {
let mongoServer;
let userId;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
// Clear collections before each test
await Transaction.deleteMany({});
await Balance.deleteMany({});
// Create a new user ID for each test
userId = new mongoose.Types.ObjectId();
// Mock the balance config to be enabled by default
getBalanceConfig.mockResolvedValue({ enabled: true });
beforeEach(() => {
jest.clearAllMocks();
process.env.CHECK_BALANCE = 'true';
});
it('should create transactions for both prompt and completion tokens', async () => {
// Create a balance for the user
await Balance.create({
user: userId,
tokenCredits: 10000,
});
const txData = {
user: userId,
user: new mongoose.Types.ObjectId(),
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
@@ -61,35 +41,31 @@ describe('spendTokens', () => {
completionTokens: 50,
};
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 });
Balance.findOne.mockResolvedValue({ tokenCredits: 10000 });
Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 });
await spendTokens(txData, tokenUsage);
// Verify transactions were created
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
expect(transactions).toHaveLength(2);
// Check completion transaction
expect(transactions[0].tokenType).toBe('completion');
expect(transactions[0].rawAmount).toBe(-50);
// Check prompt transaction
expect(transactions[1].tokenType).toBe('prompt');
expect(transactions[1].rawAmount).toBe(-100);
// Verify balance was updated
const balance = await Balance.findOne({ user: userId });
expect(balance).toBeDefined();
expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced
expect(Transaction.create).toHaveBeenCalledTimes(2);
expect(Transaction.create).toHaveBeenCalledWith(
expect.objectContaining({
tokenType: 'prompt',
rawAmount: -100,
}),
);
expect(Transaction.create).toHaveBeenCalledWith(
expect.objectContaining({
tokenType: 'completion',
rawAmount: -50,
}),
);
});
it('should handle zero completion tokens', async () => {
// Create a balance for the user
await Balance.create({
user: userId,
tokenCredits: 10000,
});
const txData = {
user: userId,
user: new mongoose.Types.ObjectId(),
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
@@ -99,26 +75,31 @@ describe('spendTokens', () => {
completionTokens: 0,
};
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -0 });
Balance.findOne.mockResolvedValue({ tokenCredits: 10000 });
Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 });
await spendTokens(txData, tokenUsage);
// Verify transactions were created
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
expect(transactions).toHaveLength(2);
// Check completion transaction
expect(transactions[0].tokenType).toBe('completion');
// In JavaScript -0 and 0 are different but functionally equivalent
// Use Math.abs to handle both 0 and -0
expect(Math.abs(transactions[0].rawAmount)).toBe(0);
// Check prompt transaction
expect(transactions[1].tokenType).toBe('prompt');
expect(transactions[1].rawAmount).toBe(-100);
expect(Transaction.create).toHaveBeenCalledTimes(2);
expect(Transaction.create).toHaveBeenCalledWith(
expect.objectContaining({
tokenType: 'prompt',
rawAmount: -100,
}),
);
expect(Transaction.create).toHaveBeenCalledWith(
expect.objectContaining({
tokenType: 'completion',
rawAmount: -0, // Changed from 0 to -0
}),
);
});
it('should handle undefined token counts', async () => {
const txData = {
user: userId,
user: new mongoose.Types.ObjectId(),
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
@@ -127,22 +108,13 @@ describe('spendTokens', () => {
await spendTokens(txData, tokenUsage);
// Verify no transactions were created
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(0);
expect(Transaction.create).not.toHaveBeenCalled();
});
it('should not update balance when the balance feature is disabled', async () => {
// Override configuration: disable balance updates
getBalanceConfig.mockResolvedValue({ enabled: false });
// Create a balance for the user
await Balance.create({
user: userId,
tokenCredits: 10000,
});
it('should not update balance when CHECK_BALANCE is false', async () => {
process.env.CHECK_BALANCE = 'false';
const txData = {
user: userId,
user: new mongoose.Types.ObjectId(),
conversationId: 'test-convo',
model: 'gpt-3.5-turbo',
context: 'test',
@@ -152,529 +124,19 @@ describe('spendTokens', () => {
completionTokens: 50,
};
await spendTokens(txData, tokenUsage);
// Verify transactions were created
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(2);
// Verify balance was not updated (should still be 10000)
const balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(10000);
});
it('should not allow balance to go below zero when spending tokens', async () => {
// Create a balance with a low amount
await Balance.create({
user: userId,
tokenCredits: 5000,
});
const txData = {
user: userId,
conversationId: 'test-convo',
model: 'gpt-4', // Using a more expensive model
context: 'test',
};
// Spending more tokens than the user has balance for
const tokenUsage = {
promptTokens: 1000,
completionTokens: 500,
};
Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 });
Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 });
await spendTokens(txData, tokenUsage);
// Verify transactions were created
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
expect(transactions).toHaveLength(2);
// Verify balance was reduced to exactly 0, not negative
const balance = await Balance.findOne({ user: userId });
expect(balance).toBeDefined();
expect(balance.tokenCredits).toBe(0);
// Check that the transaction records show the adjusted values
const transactionResults = await Promise.all(
transactions.map((t) =>
Transaction.create({
...txData,
tokenType: t.tokenType,
rawAmount: t.rawAmount,
}),
),
);
// The second transaction should have an adjusted value since balance is already 0
expect(transactionResults[1]).toEqual(
expect.objectContaining({
balance: 0,
}),
);
});
it('should handle multiple transactions in sequence with low balance and not increase balance', async () => {
// This test is specifically checking for the issue reported in production
// where the balance increases after a transaction when it should remain at 0
// Create a balance with a very low amount
await Balance.create({
user: userId,
tokenCredits: 100,
});
// First transaction - should reduce balance to 0
const txData1 = {
user: userId,
conversationId: 'test-convo-1',
model: 'gpt-4',
context: 'test',
};
const tokenUsage1 = {
promptTokens: 100,
completionTokens: 50,
};
await spendTokens(txData1, tokenUsage1);
// Check balance after first transaction
let balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(0);
// Second transaction - should keep balance at 0, not make it negative or increase it
const txData2 = {
user: userId,
conversationId: 'test-convo-2',
model: 'gpt-4',
context: 'test',
};
const tokenUsage2 = {
promptTokens: 200,
completionTokens: 100,
};
await spendTokens(txData2, tokenUsage2);
// Check balance after second transaction - should still be 0
balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(0);
// Verify all transactions were created
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(4); // 2 transactions (prompt+completion) for each call
// Let's examine the actual transaction records to see what's happening
const transactionDetails = await Transaction.find({ user: userId }).sort({ createdAt: 1 });
// Log the transaction details for debugging
console.log('Transaction details:');
transactionDetails.forEach((tx, i) => {
console.log(`Transaction ${i + 1}:`, {
tokenType: tx.tokenType,
rawAmount: tx.rawAmount,
tokenValue: tx.tokenValue,
model: tx.model,
});
});
// Check the return values from Transaction.create directly
// This is to verify that the incrementValue is not becoming positive
const directResult = await Transaction.create({
user: userId,
conversationId: 'test-convo-3',
model: 'gpt-4',
tokenType: 'completion',
rawAmount: -100,
context: 'test',
});
console.log('Direct Transaction.create result:', directResult);
// The completion value should never be positive
expect(directResult.completion).not.toBeGreaterThan(0);
});
it('should ensure tokenValue is always negative for spending tokens', async () => {
// Create a balance for the user
await Balance.create({
user: userId,
tokenCredits: 10000,
});
// Test with various models to check multiplier calculations
const models = ['gpt-3.5-turbo', 'gpt-4', 'claude-3-5-sonnet'];
for (const model of models) {
const txData = {
user: userId,
conversationId: `test-convo-${model}`,
model,
context: 'test',
};
const tokenUsage = {
promptTokens: 100,
completionTokens: 50,
};
await spendTokens(txData, tokenUsage);
// Get the transactions for this model
const transactions = await Transaction.find({
user: userId,
model,
});
// Verify tokenValue is negative for all transactions
transactions.forEach((tx) => {
console.log(`Model ${model}, Type ${tx.tokenType}: tokenValue = ${tx.tokenValue}`);
expect(tx.tokenValue).toBeLessThan(0);
});
}
});
it('should handle structured transactions in sequence with low balance', async () => {
// Create a balance with a very low amount
await Balance.create({
user: userId,
tokenCredits: 100,
});
// First transaction - should reduce balance to 0
const txData1 = {
user: userId,
conversationId: 'test-convo-1',
model: 'claude-3-5-sonnet',
context: 'test',
};
const tokenUsage1 = {
promptTokens: {
input: 10,
write: 100,
read: 5,
},
completionTokens: 50,
};
await spendStructuredTokens(txData1, tokenUsage1);
// Check balance after first transaction
let balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(0);
// Second transaction - should keep balance at 0, not make it negative or increase it
const txData2 = {
user: userId,
conversationId: 'test-convo-2',
model: 'claude-3-5-sonnet',
context: 'test',
};
const tokenUsage2 = {
promptTokens: {
input: 20,
write: 200,
read: 10,
},
completionTokens: 100,
};
await spendStructuredTokens(txData2, tokenUsage2);
// Check balance after second transaction - should still be 0
balance = await Balance.findOne({ user: userId });
expect(balance.tokenCredits).toBe(0);
// Verify all transactions were created
const transactions = await Transaction.find({ user: userId });
expect(transactions).toHaveLength(4); // 2 transactions (prompt+completion) for each call
// Let's examine the actual transaction records to see what's happening
const transactionDetails = await Transaction.find({ user: userId }).sort({ createdAt: 1 });
// Log the transaction details for debugging
console.log('Structured transaction details:');
transactionDetails.forEach((tx, i) => {
console.log(`Transaction ${i + 1}:`, {
tokenType: tx.tokenType,
rawAmount: tx.rawAmount,
tokenValue: tx.tokenValue,
inputTokens: tx.inputTokens,
writeTokens: tx.writeTokens,
readTokens: tx.readTokens,
model: tx.model,
});
});
});
it('should not allow balance to go below zero when spending structured tokens', async () => {
// Create a balance with a low amount
await Balance.create({
user: userId,
tokenCredits: 5000,
});
const txData = {
user: userId,
conversationId: 'test-convo',
model: 'claude-3-5-sonnet', // Using a model that supports structured tokens
context: 'test',
};
// Spending more tokens than the user has balance for
const tokenUsage = {
promptTokens: {
input: 100,
write: 1000,
read: 50,
},
completionTokens: 500,
};
const result = await spendStructuredTokens(txData, tokenUsage);
// Verify transactions were created
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
expect(transactions).toHaveLength(2);
// Verify balance was reduced to exactly 0, not negative
const balance = await Balance.findOne({ user: userId });
expect(balance).toBeDefined();
expect(balance.tokenCredits).toBe(0);
// The result should show the adjusted values
expect(result).toEqual({
prompt: expect.objectContaining({
user: userId.toString(),
balance: expect.any(Number),
}),
completion: expect.objectContaining({
user: userId.toString(),
balance: 0, // Final balance should be 0
}),
});
});
it('should handle multiple concurrent transactions correctly with a high balance', async () => {
// Create a balance with a high amount
const initialBalance = 10000000;
await Balance.create({
user: userId,
tokenCredits: initialBalance,
});
// Simulate the recordCollectedUsage function from the production code
const conversationId = 'test-concurrent-convo';
const context = 'message';
const model = 'gpt-4';
const amount = 50;
// Create `amount` of usage records to simulate multiple transactions
const collectedUsage = Array.from({ length: amount }, (_, i) => ({
model,
input_tokens: 100 + i * 10, // Increasing input tokens
output_tokens: 50 + i * 5, // Increasing output tokens
input_token_details: {
cache_creation: i % 2 === 0 ? 20 : 0, // Some have cache creation
cache_read: i % 3 === 0 ? 10 : 0, // Some have cache read
},
}));
// Process all transactions concurrently to simulate race conditions
const promises = [];
let expectedTotalSpend = 0;
for (let i = 0; i < collectedUsage.length; i++) {
const usage = collectedUsage[i];
if (!usage) {
continue;
}
const cache_creation = Number(usage.input_token_details?.cache_creation) || 0;
const cache_read = Number(usage.input_token_details?.cache_read) || 0;
const txMetadata = {
context,
conversationId,
user: userId,
model: usage.model,
};
// Calculate expected spend for this transaction
const promptTokens = usage.input_tokens;
const completionTokens = usage.output_tokens;
// For regular transactions
if (cache_creation === 0 && cache_read === 0) {
// Add to expected spend using the correct multipliers from tx.js
// For gpt-4, the multipliers are: prompt=30, completion=60
expectedTotalSpend += promptTokens * 30; // gpt-4 prompt rate is 30
expectedTotalSpend += completionTokens * 60; // gpt-4 completion rate is 60
promises.push(
spendTokens(txMetadata, {
promptTokens,
completionTokens,
}),
);
} else {
// For structured transactions with cache operations
// The multipliers for claude models with cache operations are different
// But since we're using gpt-4 in the test, we need to use appropriate values
expectedTotalSpend += promptTokens * 30; // Base prompt rate for gpt-4
// Since gpt-4 doesn't have cache multipliers defined, we'll use the prompt rate
expectedTotalSpend += cache_creation * 30; // Write rate (using prompt rate as fallback)
expectedTotalSpend += cache_read * 30; // Read rate (using prompt rate as fallback)
expectedTotalSpend += completionTokens * 60; // Completion rate for gpt-4
promises.push(
spendStructuredTokens(txMetadata, {
promptTokens: {
input: promptTokens,
write: cache_creation,
read: cache_read,
},
completionTokens,
}),
);
}
}
// Wait for all transactions to complete
await Promise.all(promises);
// Verify final balance
const finalBalance = await Balance.findOne({ user: userId });
expect(finalBalance).toBeDefined();
// The final balance should be the initial balance minus the expected total spend
const expectedFinalBalance = initialBalance - expectedTotalSpend;
console.log('Initial balance:', initialBalance);
console.log('Expected total spend:', expectedTotalSpend);
console.log('Expected final balance:', expectedFinalBalance);
console.log('Actual final balance:', finalBalance.tokenCredits);
// Allow for small rounding differences
expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0);
// Verify all transactions were created
const transactions = await Transaction.find({
user: userId,
conversationId,
});
// We should have 2 transactions (prompt + completion) for each usage record
// Some might be structured, some regular
expect(transactions.length).toBeGreaterThanOrEqual(collectedUsage.length);
// Log transaction details for debugging
console.log('Transaction summary:');
let totalTokenValue = 0;
transactions.forEach((tx) => {
console.log(`${tx.tokenType}: rawAmount=${tx.rawAmount}, tokenValue=${tx.tokenValue}`);
totalTokenValue += tx.tokenValue;
});
console.log('Total token value from transactions:', totalTokenValue);
// The difference between expected and actual is significant
// This is likely due to the multipliers being different in the test environment
// Let's adjust our expectation based on the actual transactions
const actualSpend = initialBalance - finalBalance.tokenCredits;
console.log('Actual spend:', actualSpend);
// Instead of checking the exact balance, let's verify that:
// 1. The balance was reduced (tokens were spent)
expect(finalBalance.tokenCredits).toBeLessThan(initialBalance);
// 2. The total token value from transactions matches the actual spend
expect(Math.abs(totalTokenValue)).toBeCloseTo(actualSpend, -3); // Allow for larger differences
});
// Add this new test case
it('should handle multiple concurrent balance increases correctly', async () => {
// Start with zero balance
const initialBalance = 0;
await Balance.create({
user: userId,
tokenCredits: initialBalance,
});
const numberOfRefills = 25;
const refillAmount = 1000;
const promises = [];
for (let i = 0; i < numberOfRefills; i++) {
promises.push(
Transaction.createAutoRefillTransaction({
user: userId,
tokenType: 'credits',
context: 'concurrent-refill-test',
rawAmount: refillAmount,
}),
);
}
// Wait for all refill transactions to complete
const results = await Promise.all(promises);
// Verify final balance
const finalBalance = await Balance.findOne({ user: userId });
expect(finalBalance).toBeDefined();
// The final balance should be the initial balance plus the sum of all refills
const expectedFinalBalance = initialBalance + numberOfRefills * refillAmount;
console.log('Initial balance (Increase Test):', initialBalance);
console.log(`Performed ${numberOfRefills} refills of ${refillAmount} each.`);
console.log('Expected final balance (Increase Test):', expectedFinalBalance);
console.log('Actual final balance (Increase Test):', finalBalance.tokenCredits);
// Use toBeCloseTo for safety, though toBe should work for integer math
expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0);
// Verify all transactions were created
const transactions = await Transaction.find({
user: userId,
context: 'concurrent-refill-test',
});
// We should have one transaction for each refill attempt
expect(transactions.length).toBe(numberOfRefills);
// Optional: Verify the sum of increments from the results matches the balance change
const totalIncrementReported = results.reduce((sum, result) => {
// Assuming createAutoRefillTransaction returns an object with the increment amount
// Adjust this based on the actual return structure.
// Let's assume it returns { balance: newBalance, transaction: { rawAmount: ... } }
// Or perhaps we check the transaction.rawAmount directly
return sum + (result?.transaction?.rawAmount || 0);
}, 0);
console.log('Total increment reported by results:', totalIncrementReported);
expect(totalIncrementReported).toBe(expectedFinalBalance - initialBalance);
// Optional: Check the sum of tokenValue from saved transactions
let totalTokenValueFromDb = 0;
transactions.forEach((tx) => {
// For refills, rawAmount is positive, and tokenValue might be calculated based on it
// Let's assume tokenValue directly reflects the increment for simplicity here
// If calculation is involved, adjust accordingly
totalTokenValueFromDb += tx.rawAmount; // Or tx.tokenValue if that holds the increment
});
console.log('Total rawAmount from DB transactions:', totalTokenValueFromDb);
expect(totalTokenValueFromDb).toBeCloseTo(expectedFinalBalance - initialBalance, 0);
expect(Transaction.create).toHaveBeenCalledTimes(2);
expect(Balance.findOne).not.toHaveBeenCalled();
expect(Balance.findOneAndUpdate).not.toHaveBeenCalled();
});
it('should create structured transactions for both prompt and completion tokens', async () => {
// Create a balance for the user
await Balance.create({
user: userId,
tokenCredits: 10000,
});
const txData = {
user: userId,
user: new mongoose.Types.ObjectId(),
conversationId: 'test-convo',
model: 'claude-3-5-sonnet',
context: 'test',
@@ -688,37 +150,48 @@ describe('spendTokens', () => {
completionTokens: 50,
};
const result = await spendStructuredTokens(txData, tokenUsage);
// Verify transactions were created
const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 });
expect(transactions).toHaveLength(2);
// Check completion transaction
expect(transactions[0].tokenType).toBe('completion');
expect(transactions[0].rawAmount).toBe(-50);
// Check prompt transaction
expect(transactions[1].tokenType).toBe('prompt');
expect(transactions[1].inputTokens).toBe(-10);
expect(transactions[1].writeTokens).toBe(-100);
expect(transactions[1].readTokens).toBe(-5);
// Verify result contains transaction info
expect(result).toEqual({
prompt: expect.objectContaining({
user: userId.toString(),
prompt: expect.any(Number),
}),
completion: expect.objectContaining({
user: userId.toString(),
completion: expect.any(Number),
}),
Transaction.createStructured.mockResolvedValueOnce({
rate: 3.75,
user: txData.user.toString(),
balance: 9570,
prompt: -430,
});
Transaction.create.mockResolvedValueOnce({
rate: 15,
user: txData.user.toString(),
balance: 8820,
completion: -750,
});
// Verify balance was updated
const balance = await Balance.findOne({ user: userId });
expect(balance).toBeDefined();
expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced
const result = await spendStructuredTokens(txData, tokenUsage);
expect(Transaction.createStructured).toHaveBeenCalledWith(
expect.objectContaining({
tokenType: 'prompt',
inputTokens: -10,
writeTokens: -100,
readTokens: -5,
}),
);
expect(Transaction.create).toHaveBeenCalledWith(
expect.objectContaining({
tokenType: 'completion',
rawAmount: -50,
}),
);
expect(result).toEqual({
prompt: expect.objectContaining({
rate: 3.75,
user: txData.user.toString(),
balance: 9570,
prompt: -430,
}),
completion: expect.objectContaining({
rate: 15,
user: txData.user.toString(),
balance: 8820,
completion: -750,
}),
});
});
});

View File

@@ -61,7 +61,6 @@ const bedrockValues = {
'amazon.nova-micro-v1:0': { prompt: 0.035, completion: 0.14 },
'amazon.nova-lite-v1:0': { prompt: 0.06, completion: 0.24 },
'amazon.nova-pro-v1:0': { prompt: 0.8, completion: 3.2 },
'deepseek.r1': { prompt: 1.35, completion: 5.4 },
};
/**
@@ -109,7 +108,6 @@ const tokenValues = Object.assign(
'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 },
'gemini-2.0-flash': { prompt: 0.1, completion: 0.7 },
'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing
'gemini-2.5': { prompt: 0, completion: 0 }, // Free for a period of time
'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 },
'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 },
'gemini-1.5': { prompt: 2.5, completion: 10 },
@@ -123,12 +121,6 @@ const tokenValues = Object.assign(
'grok-2-latest': { prompt: 2.0, completion: 10.0 },
'grok-2': { prompt: 2.0, completion: 10.0 },
'grok-beta': { prompt: 5.0, completion: 15.0 },
'mistral-large': { prompt: 2.0, completion: 6.0 },
'pixtral-large': { prompt: 2.0, completion: 6.0 },
'mistral-saba': { prompt: 0.2, completion: 0.6 },
codestral: { prompt: 0.3, completion: 0.9 },
'ministral-8b': { prompt: 0.1, completion: 0.1 },
'ministral-3b': { prompt: 0.04, completion: 0.04 },
},
bedrockValues,
);

View File

@@ -288,7 +288,7 @@ describe('AWS Bedrock Model Tests', () => {
});
describe('Deepseek Model Tests', () => {
const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner', 'deepseek.r1'];
const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner'];
it('should return the correct prompt multipliers for all models', () => {
const results = deepseekModels.map((model) => {

View File

@@ -1,6 +1,6 @@
const bcrypt = require('bcryptjs');
const { getBalanceConfig } = require('~/server/services/Config');
const signPayload = require('~/server/services/signPayload');
const { isEnabled } = require('~/server/utils/handleText');
const Balance = require('./Balance');
const User = require('./User');
@@ -13,9 +13,11 @@ const User = require('./User');
*/
const getUserById = async function (userId, fieldsToSelect = null) {
const query = User.findById(userId);
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
@@ -30,6 +32,7 @@ const findUser = async function (searchCriteria, fieldsToSelect = null) {
if (fieldsToSelect) {
query.select(fieldsToSelect);
}
return await query.lean();
};
@@ -55,12 +58,11 @@ const updateUser = async function (userId, updateData) {
* Creates a new user, optionally with a TTL of 1 week.
* @param {MongoUser} data - The user data to be created, must contain user_id.
* @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`.
* @param {boolean} [returnUser=false] - Whether to return the created user object.
* @returns {Promise<ObjectId|MongoUser>} A promise that resolves to the created user document ID or user object.
* @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`.
* @returns {Promise<ObjectId>} A promise that resolves to the created user document ID.
* @throws {Error} If a user with the same user_id already exists.
*/
const createUser = async (data, disableTTL = true, returnUser = false) => {
const balance = await getBalanceConfig();
const userData = {
...data,
expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds
@@ -72,27 +74,13 @@ const createUser = async (data, disableTTL = true, returnUser = false) => {
const user = await User.create(userData);
// If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance
if (balance?.enabled && balance?.startBalance) {
const update = {
$inc: { tokenCredits: balance.startBalance },
};
if (
balance.autoRefillEnabled &&
balance.refillIntervalValue != null &&
balance.refillIntervalUnit != null &&
balance.refillAmount != null
) {
update.$set = {
autoRefillEnabled: true,
refillIntervalValue: balance.refillIntervalValue,
refillIntervalUnit: balance.refillIntervalUnit,
refillAmount: balance.refillAmount,
};
}
await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean();
if (isEnabled(process.env.CHECK_BALANCE) && process.env.START_BALANCE) {
let incrementValue = parseInt(process.env.START_BALANCE);
await Balance.findOneAndUpdate(
{ user: user._id },
{ $inc: { tokenCredits: incrementValue } },
{ upsert: true, new: true },
).lean();
}
if (returnUser) {
@@ -135,7 +123,7 @@ const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15;
/**
* Generates a JWT token for a given user.
*
* @param {MongoUser} user - The user for whom the token is being generated.
* @param {MongoUser} user - ID of the user for whom the token is being generated.
* @returns {Promise<string>} A promise that resolves to a JWT token.
*/
const generateToken = async (user) => {
@@ -158,7 +146,7 @@ const generateToken = async (user) => {
/**
* Compares the provided password with the user's password.
*
* @param {MongoUser} user - The user to compare the password for.
* @param {MongoUser} user - the user to compare password for.
* @param {string} candidatePassword - The password to test against the user's password.
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the password matches.
*/

View File

@@ -35,21 +35,17 @@
"homepage": "https://librechat.ai",
"dependencies": {
"@anthropic-ai/sdk": "^0.37.0",
"@aws-sdk/client-s3": "^3.758.0",
"@aws-sdk/s3-request-presigner": "^3.758.0",
"@azure/identity": "^4.7.0",
"@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.26.0",
"@google/generative-ai": "^0.23.0",
"@googleapis/youtube": "^20.0.0",
"@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.1",
"@langchain/community": "^0.3.34",
"@langchain/core": "^0.3.40",
"@langchain/google-genai": "^0.1.11",
"@langchain/google-vertexai": "^0.2.2",
"@langchain/google-genai": "^0.1.9",
"@langchain/google-vertexai": "^0.2.0",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.3.95",
"@librechat/agents": "^2.2.0",
"@librechat/data-schemas": "*",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2",
@@ -86,7 +82,7 @@
"memorystore": "^1.6.7",
"mime": "^3.0.0",
"module-alias": "^2.2.3",
"mongoose": "^8.12.1",
"mongoose": "^8.9.5",
"multer": "^1.4.5-lts.1",
"nanoid": "^3.3.7",
"nodemailer": "^6.9.15",
@@ -103,16 +99,13 @@
"passport-jwt": "^4.0.1",
"passport-ldapauth": "^3.0.1",
"passport-local": "^1.0.0",
"rate-limit-redis": "^4.2.0",
"sharp": "^0.33.5",
"socket.io": "^4.8.1",
"sharp": "^0.32.6",
"tiktoken": "^1.0.15",
"traverse": "^0.6.7",
"ua-parser-js": "^1.0.36",
"winston": "^3.11.0",
"winston-daily-rotate-file": "^4.7.1",
"youtube-transcript": "^1.2.1",
"wrtc": "^0.4.7",
"zod": "^3.22.4"
},
"devDependencies": {

View File

@@ -1,30 +1,22 @@
const {
generateTOTPSecret,
generateBackupCodes,
verifyTOTP,
verifyBackupCode,
generateTOTPSecret,
generateBackupCodes,
getTOTPSecret,
} = require('~/server/services/twoFactorService');
const { updateUser, getUserById } = require('~/models');
const { logger } = require('~/config');
const { encryptV3 } = require('~/server/utils/crypto');
const { encryptV2 } = require('~/server/utils/crypto');
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
/**
* Enable 2FA for the user by generating a new TOTP secret and backup codes.
* The secret is encrypted and stored, and 2FA is marked as disabled until confirmed.
*/
const enable2FA = async (req, res) => {
const enable2FAController = async (req, res) => {
const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, '');
try {
const userId = req.user.id;
const secret = generateTOTPSecret();
const { plainCodes, codeObjects } = await generateBackupCodes();
// Encrypt the secret with v3 encryption before saving.
const encryptedSecret = encryptV3(secret);
// Update the user record: store the secret & backup codes and set twoFactorEnabled to false.
const encryptedSecret = await encryptV2(secret);
// Set twoFactorEnabled to false until the user confirms 2FA.
const user = await updateUser(userId, {
totpSecret: encryptedSecret,
backupCodes: codeObjects,
@@ -32,50 +24,45 @@ const enable2FA = async (req, res) => {
});
const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`;
return res.status(200).json({ otpauthUrl, backupCodes: plainCodes });
res.status(200).json({
otpauthUrl,
backupCodes: plainCodes,
});
} catch (err) {
logger.error('[enable2FA]', err);
return res.status(500).json({ message: err.message });
logger.error('[enable2FAController]', err);
res.status(500).json({ message: err.message });
}
};
/**
* Verify a 2FA code (either TOTP or backup code) during setup.
*/
const verify2FA = async (req, res) => {
const verify2FAController = async (req, res) => {
try {
const userId = req.user.id;
const { token, backupCode } = req.body;
const user = await getUserById(userId);
// Ensure that 2FA is enabled for this user.
if (!user || !user.totpSecret) {
return res.status(400).json({ message: '2FA not initiated' });
}
// Retrieve the plain TOTP secret using getTOTPSecret.
const secret = await getTOTPSecret(user.totpSecret);
let isVerified = false;
if (token) {
isVerified = await verifyTOTP(secret, token);
} else if (backupCode) {
isVerified = await verifyBackupCode({ user, backupCode });
}
if (isVerified) {
if (token && (await verifyTOTP(secret, token))) {
return res.status(200).json();
} else if (backupCode) {
const verified = await verifyBackupCode({ user, backupCode });
if (verified) {
return res.status(200).json();
}
}
return res.status(400).json({ message: 'Invalid token or backup code.' });
return res.status(400).json({ message: 'Invalid token.' });
} catch (err) {
logger.error('[verify2FA]', err);
return res.status(500).json({ message: err.message });
logger.error('[verify2FAController]', err);
res.status(500).json({ message: err.message });
}
};
/**
* Confirm and enable 2FA after a successful verification.
*/
const confirm2FA = async (req, res) => {
const confirm2FAController = async (req, res) => {
try {
const userId = req.user.id;
const { token } = req.body;
@@ -85,54 +72,52 @@ const confirm2FA = async (req, res) => {
return res.status(400).json({ message: '2FA not initiated' });
}
// Retrieve the plain TOTP secret using getTOTPSecret.
const secret = await getTOTPSecret(user.totpSecret);
if (await verifyTOTP(secret, token)) {
// Upon successful verification, enable 2FA.
await updateUser(userId, { twoFactorEnabled: true });
return res.status(200).json();
}
return res.status(400).json({ message: 'Invalid token.' });
} catch (err) {
logger.error('[confirm2FA]', err);
return res.status(500).json({ message: err.message });
logger.error('[confirm2FAController]', err);
res.status(500).json({ message: err.message });
}
};
/**
* Disable 2FA by clearing the stored secret and backup codes.
*/
const disable2FA = async (req, res) => {
const disable2FAController = async (req, res) => {
try {
const userId = req.user.id;
await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false });
return res.status(200).json();
res.status(200).json();
} catch (err) {
logger.error('[disable2FA]', err);
return res.status(500).json({ message: err.message });
logger.error('[disable2FAController]', err);
res.status(500).json({ message: err.message });
}
};
/**
* Regenerate backup codes for the user.
*/
const regenerateBackupCodes = async (req, res) => {
const regenerateBackupCodesController = async (req, res) => {
try {
const userId = req.user.id;
const { plainCodes, codeObjects } = await generateBackupCodes();
await updateUser(userId, { backupCodes: codeObjects });
return res.status(200).json({
res.status(200).json({
backupCodes: plainCodes,
backupCodesHash: codeObjects,
});
} catch (err) {
logger.error('[regenerateBackupCodes]', err);
return res.status(500).json({ message: err.message });
logger.error('[regenerateBackupCodesController]', err);
res.status(500).json({ message: err.message });
}
};
module.exports = {
enable2FA,
verify2FA,
confirm2FA,
disable2FA,
regenerateBackupCodes,
enable2FAController,
verify2FAController,
confirm2FAController,
disable2FAController,
regenerateBackupCodesController,
};

View File

@@ -1,8 +1,6 @@
const { FileSources } = require('librechat-data-provider');
const {
Balance,
getFiles,
updateUser,
deleteFiles,
deleteConvos,
deletePresets,
@@ -14,7 +12,6 @@ const User = require('~/models/User');
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService');
const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService');
const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud');
const { processDeleteRequest } = require('~/server/services/Files/process');
const { deleteAllSharedLinks } = require('~/models/Share');
const { deleteToolCalls } = require('~/models/ToolCall');
@@ -22,23 +19,8 @@ const { Transaction } = require('~/models/Transaction');
const { logger } = require('~/config');
const getUserController = async (req, res) => {
/** @type {MongoUser} */
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
delete userData.totpSecret;
if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) {
const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600);
if (!avatarNeedsRefresh) {
return res.status(200).send(userData);
}
const originalAvatar = userData.avatar;
try {
userData.avatar = await getNewS3URL(userData.avatar);
await updateUser(userData.id, { avatar: userData.avatar });
} catch (error) {
userData.avatar = originalAvatar;
logger.error('Error getting new S3 URL for avatar:', error);
}
}
res.status(200).send(userData);
};

View File

@@ -10,8 +10,8 @@ const {
ChatModelStreamHandler,
} = require('@librechat/agents');
const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { saveBase64Image } = require('~/server/services/Files/process');
const { loadAuthValues } = require('~/app/clients/tools/util');
const { logger, sendEvent } = require('~/config');
/** @typedef {import('@librechat/agents').Graph} Graph */

View File

@@ -7,16 +7,7 @@
// validateVisionModel,
// mapModelToAzureConfig,
// } = require('librechat-data-provider');
require('events').EventEmitter.defaultMaxListeners = 100;
const {
Callback,
GraphEvents,
formatMessage,
formatAgentMessages,
formatContentStrings,
getTokenCountForMessage,
createMetadataAggregator,
} = require('@librechat/agents');
const { Callback, createMetadataAggregator } = require('@librechat/agents');
const {
Constants,
VisionModes,
@@ -26,19 +17,24 @@ const {
KnownEndpoints,
anthropicSchema,
isAgentsEndpoint,
AgentCapabilities,
bedrockInputSchema,
removeNullishValues,
} = require('librechat-data-provider');
const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const {
formatMessage,
addCacheControl,
formatAgentMessages,
formatContentStrings,
createContextHandlers,
} = require('~/app/clients/prompts');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const Tokenizer = require('~/server/services/Tokenizer');
const BaseClient = require('~/app/clients/BaseClient');
const { logger, sendEvent } = require('~/config');
const { createRun } = require('./run');
const { logger } = require('~/config');
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
@@ -103,8 +99,6 @@ class AgentClient extends BaseClient {
this.outputTokensKey = 'output_tokens';
/** @type {UsageMetadata} */
this.usage;
/** @type {Record<string, number>} */
this.indexTokenCountMap = {};
}
/**
@@ -229,23 +223,14 @@ class AgentClient extends BaseClient {
};
}
/**
*
* @param {TMessage} message
* @param {Array<MongoFile>} attachments
* @returns {Promise<Array<Partial<MongoFile>>>}
*/
async addImageURLs(message, attachments) {
const { files, text, image_urls } = await encodeAndFormat(
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments,
this.options.agent.provider,
VisionModes.agents,
);
message.image_urls = image_urls.length ? image_urls : undefined;
if (text && text.length) {
message.ocr = text;
}
return files;
}
@@ -323,21 +308,7 @@ class AgentClient extends BaseClient {
assistantName: this.options?.modelLabel,
});
if (message.ocr && i !== orderedMessages.length - 1) {
if (typeof formattedMessage.content === 'string') {
formattedMessage.content = message.ocr + '\n' + formattedMessage.content;
} else {
const textPart = formattedMessage.content.find((part) => part.type === 'text');
textPart
? (textPart.text = message.ocr + '\n' + textPart.text)
: formattedMessage.content.unshift({ type: 'text', text: message.ocr });
}
} else if (message.ocr && i === orderedMessages.length - 1) {
systemContent = [systemContent, message.ocr].join('\n');
}
const needsTokenCount =
(this.contextStrategy && !orderedMessages[i].tokenCount) || message.ocr;
const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount;
/* If tokens were never counted, or, is a Vision request and the message has files, count again */
if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) {
@@ -383,10 +354,6 @@ class AgentClient extends BaseClient {
}));
}
for (let i = 0; i < messages.length; i++) {
this.indexTokenCountMap[i] = messages[i].tokenCount;
}
const result = {
tokenCountMap,
prompt: payload,
@@ -471,7 +438,6 @@ class AgentClient extends BaseClient {
err,
);
});
continue;
}
spendTokens(txMetadata, {
promptTokens: usage.input_tokens,
@@ -633,9 +599,6 @@ class AgentClient extends BaseClient {
// });
// }
/** @type {TCustomConfig['endpoints']['agents']} */
const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents];
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
const config = {
configurable: {
@@ -643,30 +606,19 @@ class AgentClient extends BaseClient {
last_agent_index: this.agentConfigs?.size ?? 0,
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
},
recursionLimit: agentsEConfig?.recursionLimit,
recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',
};
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
payload,
this.indexTokenCountMap,
toolSet,
);
const initialMessages = formatAgentMessages(payload);
if (legacyContentEndpoints.has(this.options.agent.endpoint)) {
initialMessages = formatContentStrings(initialMessages);
formatContentStrings(initialMessages);
}
/** @type {ReturnType<createRun>} */
let run;
const countTokens = ((text) => this.getTokenCount(text)).bind(this);
/** @type {(message: BaseMessage) => number} */
const tokenCounter = (message) => {
return getTokenCountForMessage(message, countTokens);
};
/**
*
@@ -674,23 +626,12 @@ class AgentClient extends BaseClient {
* @param {BaseMessage[]} messages
* @param {number} [i]
* @param {TMessageContentParts[]} [contentData]
* @param {Record<string, number>} [currentIndexCountMap]
*/
const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => {
const runAgent = async (agent, _messages, i = 0, contentData = []) => {
config.configurable.model = agent.model_parameters.model;
const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap;
if (i > 0) {
this.model = agent.model_parameters.model;
}
if (agent.recursion_limit && typeof agent.recursion_limit === 'number') {
config.recursionLimit = agent.recursion_limit;
}
if (
agentsEConfig?.maxRecursionLimit &&
config.recursionLimit > agentsEConfig?.maxRecursionLimit
) {
config.recursionLimit = agentsEConfig?.maxRecursionLimit;
}
config.configurable.agent_id = agent.id;
config.configurable.name = agent.name;
config.configurable.agent_index = i;
@@ -753,29 +694,11 @@ class AgentClient extends BaseClient {
}
if (contentData.length) {
const agentUpdate = {
type: ContentTypes.AGENT_UPDATE,
[ContentTypes.AGENT_UPDATE]: {
index: contentData.length,
runId: this.responseMessageId,
agentId: agent.id,
},
};
const streamData = {
event: GraphEvents.ON_AGENT_UPDATE,
data: agentUpdate,
};
this.options.aggregateContent(streamData);
sendEvent(this.options.res, streamData);
contentData.push(agentUpdate);
run.Graph.contentData = contentData;
}
await run.processStream({ messages }, config, {
keepContent: i !== 0,
tokenCounter,
indexTokenCountMap: currentIndexCountMap,
maxContextTokens: agent.maxContextTokens,
callbacks: {
[Callback.TOOL_ERROR]: (graph, error, toolId) => {
logger.error(
@@ -789,13 +712,9 @@ class AgentClient extends BaseClient {
};
await runAgent(this.options.agent, initialMessages);
let finalContentStart = 0;
if (
this.agentConfigs &&
this.agentConfigs.size > 0 &&
(await checkCapability(this.options.req, AgentCapabilities.chain))
) {
const windowSize = 5;
if (this.agentConfigs && this.agentConfigs.size > 0) {
let latestMessage = initialMessages.pop().content;
if (typeof latestMessage !== 'string') {
latestMessage = latestMessage[0].text;
@@ -803,16 +722,7 @@ class AgentClient extends BaseClient {
let i = 1;
let runMessages = [];
const windowIndexCountMap = {};
const windowMessages = initialMessages.slice(-windowSize);
let currentIndex = 4;
for (let i = initialMessages.length - 1; i >= 0; i--) {
windowIndexCountMap[currentIndex] = indexTokenCountMap[i];
currentIndex--;
if (currentIndex < 0) {
break;
}
}
const lastFiveMessages = initialMessages.slice(-5);
for (const [agentId, agent] of this.agentConfigs) {
if (abortController.signal.aborted === true) {
break;
@@ -847,9 +757,7 @@ class AgentClient extends BaseClient {
}
try {
const contextMessages = [];
const runIndexCountMap = {};
for (let i = 0; i < windowMessages.length; i++) {
const message = windowMessages[i];
for (const message of lastFiveMessages) {
const messageType = message._getType();
if (
(!agent.tools || agent.tools.length === 0) &&
@@ -857,13 +765,11 @@ class AgentClient extends BaseClient {
) {
continue;
}
runIndexCountMap[contextMessages.length] = windowIndexCountMap[i];
contextMessages.push(message);
}
const bufferMessage = new HumanMessage(bufferString);
runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage);
const currentMessages = [...contextMessages, bufferMessage];
await runAgent(agent, currentMessages, i, contentData, runIndexCountMap);
const currentMessages = [...contextMessages, new HumanMessage(bufferString)];
await runAgent(agent, currentMessages, i, contentData);
} catch (err) {
logger.error(
`[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`,
@@ -874,7 +780,6 @@ class AgentClient extends BaseClient {
}
}
/** Note: not implemented */
if (config.configurable.hide_sequential_outputs !== true) {
finalContentStart = 0;
}
@@ -932,14 +837,7 @@ class AgentClient extends BaseClient {
};
let endpointConfig = this.options.req.app.locals[this.options.agent.endpoint];
if (!endpointConfig) {
try {
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
err,
);
}
endpointConfig = await getCustomEndpointConfig(this.options.agent.endpoint);
}
if (
endpointConfig &&

View File

@@ -43,12 +43,6 @@ async function createRun({
agent.model_parameters,
);
/** Resolves Mistral type strictness due to new OpenAI usage field */
if (agent.endpoint?.toLowerCase().includes(KnownEndpoints.mistral)) {
llmConfig.streamUsage = false;
llmConfig.usage = true;
}
/** @type {'reasoning_content' | 'reasoning'} */
let reasoningKey;
if (
@@ -57,6 +51,10 @@ async function createRun({
) {
reasoningKey = 'reasoning';
}
if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) {
llmConfig.streaming = false;
llmConfig.disableStreaming = true;
}
/** @type {StandardGraphConfig} */
const graphConfig = {
@@ -70,7 +68,7 @@ async function createRun({
};
// TEMPORARY FOR TESTING
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
if (agent.provider === Providers.ANTHROPIC) {
graphConfig.streamBuffer = 2000;
}

View File

@@ -1,12 +1,10 @@
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const {
Tools,
Constants,
FileContext,
FileSources,
Constants,
Tools,
SystemRoles,
EToolResources,
actionDelimiter,
} = require('librechat-data-provider');
const {
@@ -18,10 +16,9 @@ const {
} = require('~/models/Agent');
const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
const { updateAction, getActions } = require('~/models/Action');
const { updateAgentProjects } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { updateAgentProjects } = require('~/models/Agent');
const { deleteFileByFilter } = require('~/models/File');
const { logger } = require('~/config');
@@ -104,14 +101,6 @@ const getAgentHandler = async (req, res) => {
return res.status(404).json({ error: 'Agent not found' });
}
if (agent.avatar && agent.avatar?.source === FileSources.s3) {
const originalUrl = agent.avatar.filepath;
agent.avatar.filepath = await refreshS3Url(agent.avatar);
if (originalUrl !== agent.avatar.filepath) {
await updateAgent({ id }, { avatar: agent.avatar });
}
}
agent.author = agent.author.toString();
agent.isCollaborative = !!agent.isCollaborative;
@@ -214,25 +203,13 @@ const duplicateAgentHandler = async (req, res) => {
}
const {
id: _id,
_id: __id,
id: _id,
author: _author,
createdAt: _createdAt,
updatedAt: _updatedAt,
tool_resources: _tool_resources = {},
...cloneData
} = agent;
cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', {
dateStyle: 'short',
timeStyle: 'short',
hour12: false,
})})`;
if (_tool_resources?.[EToolResources.ocr]) {
cloneData.tool_resources = {
[EToolResources.ocr]: _tool_resources[EToolResources.ocr],
};
}
const newAgentId = `agent_${nanoid()}`;
const newAgentData = Object.assign(cloneData, {

View File

@@ -19,7 +19,7 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
@@ -27,7 +27,7 @@ const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const checkBalance = require('~/models/checkBalance');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { getModelMaxTokens } = require('~/utils');
@@ -248,8 +248,7 @@ const chatV1 = async (req, res) => {
}
const checkBalanceBeforeRun = async () => {
const balance = req.app?.locals?.balance;
if (!balance?.enabled) {
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
const transactions =

View File

@@ -18,14 +18,14 @@ const {
saveAssistantMessage,
} = require('~/server/services/Threads');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils');
const { createErrorHandler } = require('~/server/controllers/assistants/errors');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { sendMessage, sleep, countTokens } = require('~/server/utils');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const checkBalance = require('~/models/checkBalance');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { getModelMaxTokens } = require('~/utils');
@@ -124,8 +124,7 @@ const chatV2 = async (req, res) => {
}
const checkBalanceBeforeRun = async () => {
const balance = req.app?.locals?.balance;
if (!balance?.enabled) {
if (!isEnabled(process.env.CHECK_BALANCE)) {
return;
}
const transactions =

View File

@@ -8,10 +8,7 @@ const { setAuthTokens } = require('~/server/services/AuthService');
const { getUserById } = require('~/models/userMethods');
const { logger } = require('~/config');
/**
* Verifies the 2FA code during login using a temporary token.
*/
const verify2FAWithTempToken = async (req, res) => {
const verify2FA = async (req, res) => {
try {
const { tempToken, token, backupCode } = req.body;
if (!tempToken) {
@@ -26,23 +23,26 @@ const verify2FAWithTempToken = async (req, res) => {
}
const user = await getUserById(payload.userId);
// Ensure that the user exists and has 2FA enabled
if (!user || !user.twoFactorEnabled) {
return res.status(400).json({ message: '2FA is not enabled for this user' });
}
// Retrieve (and decrypt if necessary) the TOTP secret.
const secret = await getTOTPSecret(user.totpSecret);
let isVerified = false;
if (token) {
isVerified = await verifyTOTP(secret, token);
let verified = false;
if (token && (await verifyTOTP(secret, token))) {
verified = true;
} else if (backupCode) {
isVerified = await verifyBackupCode({ user, backupCode });
verified = await verifyBackupCode({ user, backupCode });
}
if (!isVerified) {
if (!verified) {
return res.status(401).json({ message: 'Invalid 2FA code or backup code' });
}
// Prepare user data to return (omit sensitive fields).
// Prepare user data for response.
const userData = user.toObject ? user.toObject() : { ...user };
delete userData.password;
delete userData.__v;
@@ -52,9 +52,9 @@ const verify2FAWithTempToken = async (req, res) => {
const authToken = await setAuthTokens(user._id, res);
return res.status(200).json({ token: authToken, user: userData });
} catch (err) {
logger.error('[verify2FAWithTempToken]', err);
logger.error('[verify2FA]', err);
return res.status(500).json({ message: 'Something went wrong' });
}
};
module.exports = { verify2FAWithTempToken };
module.exports = { verify2FA };

View File

@@ -10,8 +10,7 @@ const {
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { loadTools } = require('~/app/clients/tools/util');
const { loadAuthValues, loadTools } = require('~/app/clients/tools/util');
const { checkAccess } = require('~/server/middleware');
const { getMessage } = require('~/models/Message');
const { logger } = require('~/config');

View File

@@ -4,7 +4,6 @@ require('module-alias')({ base: path.resolve(__dirname, '..') });
const cors = require('cors');
const axios = require('axios');
const express = require('express');
const { createServer } = require('http');
const compression = require('compression');
const passport = require('passport');
const mongoSanitize = require('express-mongo-sanitize');
@@ -15,8 +14,6 @@ const { connectDb, indexSync } = require('~/lib/db');
const { isEnabled } = require('~/server/utils');
const { ldapLogin } = require('~/strategies');
const { logger } = require('~/config');
const { AudioSocketModule } = require('./services/Files/Audio/AudioSocketModule');
const { SocketIOService } = require('./services/WebSocket/WebSocketServer');
const validateImageRequest = require('./middleware/validateImageRequest');
const errorController = require('./controllers/ErrorController');
const configureSocialLogins = require('./socialLogins');
@@ -31,9 +28,6 @@ const port = Number(PORT) || 3080;
const host = HOST || 'localhost';
const trusted_proxy = Number(TRUST_PROXY) || 1; /* trust first proxy by default */
let socketIOService;
let audioModule;
const startServer = async () => {
if (typeof Bun !== 'undefined') {
axios.defaults.headers.common['Accept-Encoding'] = 'gzip';
@@ -43,21 +37,7 @@ const startServer = async () => {
await indexSync();
const app = express();
const server = createServer(app);
app.disable('x-powered-by');
app.use(
cors({
origin: true,
credentials: true,
}),
);
socketIOService = new SocketIOService(server);
audioModule = new AudioSocketModule(socketIOService);
logger.info('WebSocket server and Audio module initialized');
await AppService(app);
const indexPath = path.join(app.locals.paths.dist, 'index.html');
@@ -130,7 +110,6 @@ const startServer = async () => {
app.use('/api/agents', routes.agents);
app.use('/api/banner', routes.banner);
app.use('/api/bedrock', routes.bedrock);
app.use('/api/websocket', routes.websocket);
app.use('/api/tags', routes.tags);
@@ -148,7 +127,7 @@ const startServer = async () => {
res.send(updatedIndexHtml);
});
server.listen(port, host, () => {
app.listen(port, host, () => {
if (host == '0.0.0.0') {
logger.info(
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
@@ -156,26 +135,11 @@ const startServer = async () => {
} else {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
logger.info(`Socket.IO endpoint: http://${host}:${port}`);
});
};
startServer();
process.on('SIGINT', () => {
logger.info('Shutting down server...');
if (audioModule) {
audioModule.cleanup();
logger.info('Audio module cleaned up');
}
if (socketIOService) {
socketIOService.shutdown();
logger.info('WebSocket server shut down');
}
process.exit(0);
});
let messageCount = 0;
process.on('uncaughtException', (err) => {
if (!err.message.includes('fetch failed')) {

View File

@@ -148,13 +148,6 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
return { abortController, onStart };
};
/**
* @param {ServerResponse} res
* @param {ServerRequest} req
* @param {Error | unknown} error
* @param {Partial<TMessage> & { partialText?: string }} data
* @returns { Promise<void> }
*/
const handleAbortError = async (res, req, error, data) => {
if (error?.message?.includes('base64')) {
logger.error('[handleAbortError] Error in base64 encoding', {
@@ -185,30 +178,17 @@ const handleAbortError = async (res, req, error, data) => {
errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`;
}
/**
* @param {string} partialText
* @returns {Promise<void>}
*/
const respondWithError = async (partialText) => {
const endpointOption = req.body?.endpointOption;
let options = {
sender,
messageId,
conversationId,
parentMessageId,
text: errorText,
user: req.user.id,
shouldSaveMessage: true,
spec: endpointOption?.spec,
iconURL: endpointOption?.iconURL,
modelLabel: endpointOption?.modelLabel,
model: endpointOption?.modelOptions?.model || req.body?.model,
user: req.user.id,
};
if (req.body?.agent_id) {
options.agent_id = req.body.agent_id;
}
if (partialText) {
options = {
...options,

View File

@@ -10,6 +10,7 @@ const openAI = require('~/server/services/Endpoints/openAI');
const agents = require('~/server/services/Endpoints/agents');
const custom = require('~/server/services/Endpoints/custom');
const google = require('~/server/services/Endpoints/google');
const { getConvoFiles } = require('~/models/Conversation');
const { handleError } = require('~/server/utils');
const buildFunction = {
@@ -86,8 +87,16 @@ async function buildEndpointOption(req, res, next) {
// TODO: use `getModelsConfig` only when necessary
const modelsConfig = await getModelsConfig(req);
const { resendFiles = true } = req.body.endpointOption;
req.body.endpointOption.modelsConfig = modelsConfig;
if (req.body.files && !isAgents) {
if (isAgents && resendFiles && req.body.conversationId) {
const fileIds = await getConvoFiles(req.body.conversationId);
const requestFiles = req.body.files ?? [];
if (requestFiles.length || fileIds.length) {
req.body.endpointOption.attachments = processFiles(requestFiles, fileIds);
}
} else if (req.body.files) {
// hold the promise
req.body.endpointOption.attachments = processFiles(req.body.files);
}
next();

View File

@@ -41,7 +41,7 @@ const banResponse = async (req, res) => {
* @function
* @param {Object} req - Express request object.
* @param {Object} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
* @param {Function} next - Next middleware function.
*
* @returns {Promise<function|Object>} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`.
*/

View File

@@ -21,7 +21,7 @@ const {
* @function
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
* @param {function} next - Express next middleware function.
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/
const concurrentLimiter = async (req, res, next) => {

View File

@@ -8,14 +8,12 @@ const concurrentLimiter = require('./concurrentLimiter');
const validateEndpoint = require('./validateEndpoint');
const requireLocalAuth = require('./requireLocalAuth');
const canDeleteAccount = require('./canDeleteAccount');
const setBalanceConfig = require('./setBalanceConfig');
const requireLdapAuth = require('./requireLdapAuth');
const abortMiddleware = require('./abortMiddleware');
const checkInviteUser = require('./checkInviteUser');
const requireJwtAuth = require('./requireJwtAuth');
const validateModel = require('./validateModel');
const moderateText = require('./moderateText');
const logHeaders = require('./logHeaders');
const setHeaders = require('./setHeaders');
const validate = require('./validate');
const limiters = require('./limiters');
@@ -33,7 +31,6 @@ module.exports = {
checkBan,
uaParser,
setHeaders,
logHeaders,
moderateText,
validateModel,
requireJwtAuth,
@@ -42,7 +39,6 @@ module.exports = {
requireLocalAuth,
canDeleteAccount,
validateEndpoint,
setBalanceConfig,
concurrentLimiter,
checkDomainAllowed,
validateMessageReq,

View File

@@ -1,11 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100;
@@ -53,39 +48,21 @@ const createImportLimiters = () => {
const { importIpWindowMs, importIpMax, importUserWindowMs, importUserMax } =
getEnvironmentVariables();
const ipLimiterOptions = {
const importIpLimiter = rateLimit({
windowMs: importIpWindowMs,
max: importIpMax,
handler: createImportHandler(),
};
const userLimiterOptions = {
});
const importUserLimiter = rateLimit({
windowMs: importUserWindowMs,
max: importUserMax,
handler: createImportHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
});
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for import rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'import_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'import_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const importIpLimiter = rateLimit(ipLimiterOptions);
const importUserLimiter = rateLimit(userLimiterOptions);
return { importIpLimiter, importUserLimiter };
};

View File

@@ -1,10 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env;
const windowMs = LOGIN_WINDOW * 60 * 1000;
@@ -24,25 +20,11 @@ const handler = async (req, res) => {
return res.status(429).json({ message });
};
const limiterOptions = {
const loginLimiter = rateLimit({
windowMs,
max,
handler,
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for login rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
prefix: 'login_limiter:',
});
limiterOptions.store = store;
}
const loginLimiter = rateLimit(limiterOptions);
});
module.exports = loginLimiter;

View File

@@ -1,11 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const denyRequest = require('~/server/middleware/denyRequest');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
MESSAGE_IP_MAX = 40,
@@ -46,49 +41,25 @@ const createHandler = (ip = true) => {
};
/**
* Message request rate limiters
* Message request rate limiter by IP
*/
const ipLimiterOptions = {
const messageIpLimiter = rateLimit({
windowMs: ipWindowMs,
max: ipMax,
handler: createHandler(),
};
});
const userLimiterOptions = {
/**
* Message request rate limiter by userId
*/
const messageUserLimiter = rateLimit({
windowMs: userWindowMs,
max: userMax,
handler: createHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for message rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'message_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'message_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
/**
* Message request rate limiter by IP
*/
const messageIpLimiter = rateLimit(ipLimiterOptions);
/**
* Message request rate limiter by userId
*/
const messageUserLimiter = rateLimit(userLimiterOptions);
});
module.exports = {
messageIpLimiter,

View File

@@ -1,10 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env;
const windowMs = REGISTER_WINDOW * 60 * 1000;
@@ -24,25 +20,11 @@ const handler = async (req, res) => {
return res.status(429).json({ message });
};
const limiterOptions = {
const registerLimiter = rateLimit({
windowMs,
max,
handler,
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for register rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
prefix: 'register_limiter:',
});
limiterOptions.store = store;
}
const registerLimiter = rateLimit(limiterOptions);
});
module.exports = registerLimiter;

View File

@@ -1,11 +1,7 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
RESET_PASSWORD_WINDOW = 2,
@@ -29,25 +25,11 @@ const handler = async (req, res) => {
return res.status(429).json({ message });
};
const limiterOptions = {
const resetPasswordLimiter = rateLimit({
windowMs,
max,
handler,
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for reset password rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
prefix: 'reset_password_limiter:',
});
limiterOptions.store = store;
}
const resetPasswordLimiter = rateLimit(limiterOptions);
});
module.exports = resetPasswordLimiter;

View File

@@ -1,11 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
@@ -52,40 +47,20 @@ const createSTTHandler = (ip = true) => {
const createSTTLimiters = () => {
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();
const ipLimiterOptions = {
const sttIpLimiter = rateLimit({
windowMs: sttIpWindowMs,
max: sttIpMax,
handler: createSTTHandler(),
};
});
const userLimiterOptions = {
const sttUserLimiter = rateLimit({
windowMs: sttUserWindowMs,
max: sttUserMax,
handler: createSTTHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for STT rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'stt_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'stt_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const sttIpLimiter = rateLimit(ipLimiterOptions);
const sttUserLimiter = rateLimit(userLimiterOptions);
});
return { sttIpLimiter, sttUserLimiter };
};

View File

@@ -1,46 +1,25 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const handler = async (req, res) => {
const type = ViolationTypes.TOOL_CALL_LIMIT;
const errorMessage = {
type,
max: 1,
limiter: 'user',
windowInMinutes: 1,
};
await logViolation(req, res, type, errorMessage, 0);
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
};
const limiterOptions = {
const toolCallLimiter = rateLimit({
windowMs: 1000,
max: 1,
handler,
handler: async (req, res) => {
const type = ViolationTypes.TOOL_CALL_LIMIT;
const errorMessage = {
type,
max: 1,
limiter: 'user',
windowInMinutes: 1,
};
await logViolation(req, res, type, errorMessage, 0);
res.status(429).json({ message: 'Too many tool call requests. Try again later' });
},
keyGenerator: function (req) {
return req.user?.id;
},
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for tool call rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
prefix: 'tool_call_limiter:',
});
limiterOptions.store = store;
}
const toolCallLimiter = rateLimit(limiterOptions);
});
module.exports = toolCallLimiter;

View File

@@ -1,11 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
@@ -52,40 +47,20 @@ const createTTSHandler = (ip = true) => {
const createTTSLimiters = () => {
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();
const ipLimiterOptions = {
const ttsIpLimiter = rateLimit({
windowMs: ttsIpWindowMs,
max: ttsIpMax,
handler: createTTSHandler(),
};
});
const userLimiterOptions = {
const ttsUserLimiter = rateLimit({
windowMs: ttsUserWindowMs,
max: ttsUserMax,
handler: createTTSHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for TTS rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'tts_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'tts_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const ttsIpLimiter = rateLimit(ipLimiterOptions);
const ttsUserLimiter = rateLimit(userLimiterOptions);
});
return { ttsIpLimiter, ttsUserLimiter };
};

View File

@@ -1,11 +1,6 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');
const { isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { logger } = require('~/config');
const getEnvironmentVariables = () => {
const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100;
@@ -57,40 +52,20 @@ const createFileLimiters = () => {
const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } =
getEnvironmentVariables();
const ipLimiterOptions = {
const fileUploadIpLimiter = rateLimit({
windowMs: fileUploadIpWindowMs,
max: fileUploadIpMax,
handler: createFileUploadHandler(),
};
});
const userLimiterOptions = {
const fileUploadUserLimiter = rateLimit({
windowMs: fileUploadUserWindowMs,
max: fileUploadUserMax,
handler: createFileUploadHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for file upload rate limiters.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const ipStore = new RedisStore({
sendCommand,
prefix: 'file_upload_ip_limiter:',
});
const userStore = new RedisStore({
sendCommand,
prefix: 'file_upload_user_limiter:',
});
ipLimiterOptions.store = ipStore;
userLimiterOptions.store = userStore;
}
const fileUploadIpLimiter = rateLimit(ipLimiterOptions);
const fileUploadUserLimiter = rateLimit(userLimiterOptions);
});
return { fileUploadIpLimiter, fileUploadUserLimiter };
};

View File

@@ -1,11 +1,7 @@
const Keyv = require('keyv');
const rateLimit = require('express-rate-limit');
const { RedisStore } = require('rate-limit-redis');
const { ViolationTypes } = require('librechat-data-provider');
const { removePorts, isEnabled } = require('~/server/utils');
const keyvRedis = require('~/cache/keyvRedis');
const { removePorts } = require('~/server/utils');
const { logViolation } = require('~/cache');
const { logger } = require('~/config');
const {
VERIFY_EMAIL_WINDOW = 2,
@@ -29,25 +25,11 @@ const handler = async (req, res) => {
return res.status(429).json({ message });
};
const limiterOptions = {
const verifyEmailLimiter = rateLimit({
windowMs,
max,
handler,
keyGenerator: removePorts,
};
if (isEnabled(process.env.USE_REDIS)) {
logger.debug('Using Redis for verify email rate limiter.');
const keyv = new Keyv({ store: keyvRedis });
const client = keyv.opts.store.redis;
const sendCommand = (...args) => client.call(...args);
const store = new RedisStore({
sendCommand,
prefix: 'verify_email_limiter:',
});
limiterOptions.store = store;
}
const verifyEmailLimiter = rateLimit(limiterOptions);
});
module.exports = verifyEmailLimiter;

View File

@@ -1,32 +0,0 @@
const { logger } = require('~/config');
/**
* Middleware to log Forwarded Headers
* @function
* @param {ServerRequest} req - Express request object containing user information.
* @param {ServerResponse} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
* @throws {Error} Throws an error if the user exceeds the concurrent request limit.
*/
const logHeaders = (req, res, next) => {
try {
const forwardedHeaders = {};
if (req.headers['x-forwarded-for']) {
forwardedHeaders['x-forwarded-for'] = req.headers['x-forwarded-for'];
}
if (req.headers['x-forwarded-host']) {
forwardedHeaders['x-forwarded-host'] = req.headers['x-forwarded-host'];
}
if (req.headers['x-forwarded-proto']) {
forwardedHeaders['x-forwarded-proto'] = req.headers['x-forwarded-proto'];
}
if (Object.keys(forwardedHeaders).length > 0) {
logger.debug('X-Forwarded headers detected in OAuth request:', forwardedHeaders);
}
} catch (error) {
logger.error('Error logging X-Forwarded headers:', error);
}
next();
};
module.exports = logHeaders;

View File

@@ -1,91 +0,0 @@
const { getBalanceConfig } = require('~/server/services/Config');
const Balance = require('~/models/Balance');
const { logger } = require('~/config');
/**
* Middleware to synchronize user balance settings with current balance configuration.
* @function
* @param {Object} req - Express request object containing user information.
* @param {Object} res - Express response object.
* @param {import('express').NextFunction} next - Next middleware function.
*/
const setBalanceConfig = async (req, res, next) => {
try {
const balanceConfig = await getBalanceConfig();
if (!balanceConfig?.enabled) {
return next();
}
if (balanceConfig.startBalance == null) {
return next();
}
const userId = req.user._id;
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
if (Object.keys(updateFields).length === 0) {
return next();
}
await Balance.findOneAndUpdate(
{ user: userId },
{ $set: updateFields },
{ upsert: true, new: true },
);
next();
} catch (error) {
logger.error('Error setting user balance:', error);
next(error);
}
};
/**
* Build an object containing fields that need updating
* @param {Object} config - The balance configuration
* @param {Object|null} userRecord - The user's current balance record, if any
* @returns {Object} Fields that need updating
*/
function buildUpdateFields(config, userRecord) {
const updateFields = {};
// Ensure user record has the required fields
if (!userRecord) {
updateFields.user = userRecord?.user;
updateFields.tokenCredits = config.startBalance;
}
if (userRecord?.tokenCredits == null && config.startBalance != null) {
updateFields.tokenCredits = config.startBalance;
}
const isAutoRefillConfigValid =
config.autoRefillEnabled &&
config.refillIntervalValue != null &&
config.refillIntervalUnit != null &&
config.refillAmount != null;
if (!isAutoRefillConfigValid) {
return updateFields;
}
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
updateFields.autoRefillEnabled = config.autoRefillEnabled;
}
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
updateFields.refillIntervalValue = config.refillIntervalValue;
}
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
updateFields.refillIntervalUnit = config.refillIntervalUnit;
}
if (userRecord?.refillAmount !== config.refillAmount) {
updateFields.refillAmount = config.refillAmount;
}
return updateFields;
}
module.exports = setBalanceConfig;

View File

@@ -7,23 +7,20 @@ const {
} = require('~/server/controllers/AuthController');
const { loginController } = require('~/server/controllers/auth/LoginController');
const { logoutController } = require('~/server/controllers/auth/LogoutController');
const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController');
const { verify2FA } = require('~/server/controllers/auth/TwoFactorAuthController');
const {
enable2FA,
verify2FA,
disable2FA,
regenerateBackupCodes,
confirm2FA,
enable2FAController,
verify2FAController,
disable2FAController,
regenerateBackupCodesController, confirm2FAController,
} = require('~/server/controllers/TwoFactorController');
const {
checkBan,
logHeaders,
loginLimiter,
requireJwtAuth,
checkInviteUser,
registerLimiter,
requireLdapAuth,
setBalanceConfig,
requireLocalAuth,
resetPasswordLimiter,
validateRegistration,
@@ -37,11 +34,9 @@ const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
router.post('/logout', requireJwtAuth, logoutController);
router.post(
'/login',
logHeaders,
loginLimiter,
checkBan,
ldapAuth ? requireLdapAuth : requireLocalAuth,
setBalanceConfig,
loginController,
);
router.post('/refresh', refreshController);
@@ -62,11 +57,11 @@ router.post(
);
router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController);
router.get('/2fa/enable', requireJwtAuth, enable2FA);
router.post('/2fa/verify', requireJwtAuth, verify2FA);
router.post('/2fa/verify-temp', checkBan, verify2FAWithTempToken);
router.post('/2fa/confirm', requireJwtAuth, confirm2FA);
router.post('/2fa/disable', requireJwtAuth, disable2FA);
router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodes);
router.get('/2fa/enable', requireJwtAuth, enable2FAController);
router.post('/2fa/verify', requireJwtAuth, verify2FAController);
router.post('/2fa/verify-temp', checkBan, verify2FA);
router.post('/2fa/confirm', requireJwtAuth, confirm2FAController);
router.post('/2fa/disable', requireJwtAuth, disable2FAController);
router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodesController);
module.exports = router;

View File

@@ -69,6 +69,7 @@ router.get('/', async function (req, res) {
!!process.env.EMAIL_PASSWORD &&
!!process.env.EMAIL_FROM,
passwordResetEnabled,
checkBalance: isEnabled(process.env.CHECK_BALANCE),
showBirthdayIcon:
isBirthday() ||
isEnabled(process.env.SHOW_BIRTHDAY_ICON) ||
@@ -76,7 +77,6 @@ router.get('/', async function (req, res) {
helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai',
interface: req.app.locals.interfaceConfig,
modelSpecs: req.app.locals.modelSpecs,
balance: req.app.locals.balance,
sharedLinksEnabled,
publicSharedLinksEnabled,
analyticsGtmId: process.env.ANALYTICS_GTM_ID,

View File

@@ -2,9 +2,7 @@ const fs = require('fs').promises;
const express = require('express');
const { EnvVar } = require('@librechat/agents');
const {
Time,
isUUID,
CacheKeys,
FileSources,
EModelEndpoint,
isAgentsEndpoint,
@@ -18,11 +16,9 @@ const {
} = require('~/server/services/Files/process');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
const { getFiles, batchUpdateFiles } = require('~/models/File');
const { loadAuthValues } = require('~/app/clients/tools/util');
const { getAgent } = require('~/models/Agent');
const { getLogStores } = require('~/cache');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
const router = express.Router();
@@ -30,18 +26,6 @@ const router = express.Router();
router.get('/', async (req, res) => {
try {
const files = await getFiles({ user: req.user.id });
if (req.app.locals.fileStrategy === FileSources.s3) {
try {
const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL);
const alreadyChecked = await cache.get(req.user.id);
if (!alreadyChecked) {
await refreshS3FileUrls(files, batchUpdateFiles);
await cache.set(req.user.id, true, Time.THIRTY_MINUTES);
}
} catch (error) {
logger.warn('[/files] Error refreshing S3 file URLs:', error);
}
}
res.status(200).send(files);
} catch (error) {
logger.error('[/files] Error getting files:', error);

View File

@@ -3,7 +3,7 @@ const router = express.Router();
const { getCustomConfigSpeech } = require('~/server/services/Files/Audio');
router.get('/', async (req, res) => {
router.get('/get', async (req, res) => {
await getCustomConfigSpeech(req, res);
});

View File

@@ -4,7 +4,6 @@ const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware');
const stt = require('./stt');
const tts = require('./tts');
const customConfigSpeech = require('./customConfigSpeech');
const realtime = require('./realtime');
const router = express.Router();
@@ -15,6 +14,4 @@ router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);
router.use('/config', customConfigSpeech);
router.use('/realtime', realtime);
module.exports = router;

View File

@@ -1,10 +0,0 @@
const express = require('express');
const router = express.Router();
const { getRealtimeConfig } = require('~/server/services/Files/Audio');
router.get('/', async (req, res) => {
await getRealtimeConfig(req, res);
});
module.exports = router;

View File

@@ -2,7 +2,6 @@ const assistants = require('./assistants');
const categories = require('./categories');
const tokenizer = require('./tokenizer');
const endpoints = require('./endpoints');
const websocket = require('./websocket');
const staticRoute = require('./static');
const messages = require('./messages');
const presets = require('./presets');
@@ -16,7 +15,6 @@ const models = require('./models');
const convos = require('./convos');
const config = require('./config');
const agents = require('./agents');
const banner = require('./banner');
const roles = require('./roles');
const oauth = require('./oauth');
const files = require('./files');
@@ -27,6 +25,7 @@ const edit = require('./edit');
const keys = require('./keys');
const user = require('./user');
const ask = require('./ask');
const banner = require('./banner');
module.exports = {
ask,
@@ -40,7 +39,6 @@ module.exports = {
files,
share,
agents,
banner,
bedrock,
convos,
search,
@@ -52,10 +50,10 @@ module.exports = {
presets,
balance,
messages,
websocket,
endpoints,
tokenizer,
assistants,
categories,
staticRoute,
banner,
};

View File

@@ -1,13 +1,7 @@
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
const express = require('express');
const passport = require('passport');
const {
checkBan,
logHeaders,
loginLimiter,
setBalanceConfig,
checkDomainAllowed,
} = require('~/server/middleware');
const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware');
const { setAuthTokens } = require('~/server/services/AuthService');
const { logger } = require('~/config');
@@ -18,7 +12,6 @@ const domains = {
server: process.env.DOMAIN_SERVER,
};
router.use(logHeaders);
router.use(loginLimiter);
const oauthHandler = async (req, res) => {
@@ -62,7 +55,6 @@ router.get(
session: false,
scope: ['openid', 'profile', 'email'],
}),
setBalanceConfig,
oauthHandler,
);
@@ -87,7 +79,6 @@ router.get(
scope: ['public_profile'],
profileFields: ['id', 'email', 'name'],
}),
setBalanceConfig,
oauthHandler,
);
@@ -108,7 +99,6 @@ router.get(
failureMessage: true,
session: false,
}),
setBalanceConfig,
oauthHandler,
);
@@ -131,7 +121,6 @@ router.get(
session: false,
scope: ['user:email', 'read:user'],
}),
setBalanceConfig,
oauthHandler,
);
@@ -154,7 +143,6 @@ router.get(
session: false,
scope: ['identify', 'email'],
}),
setBalanceConfig,
oauthHandler,
);
@@ -175,7 +163,6 @@ router.post(
failureMessage: true,
session: false,
}),
setBalanceConfig,
oauthHandler,
);

View File

@@ -1,19 +0,0 @@
const express = require('express');
const optionalJwtAuth = require('~/server/middleware/optionalJwtAuth');
const router = express.Router();
router.get('/', optionalJwtAuth, async (req, res) => {
const isProduction = process.env.NODE_ENV === 'production';
const protocol = isProduction && req.secure ? 'https' : 'http';
const serverDomain = process.env.SERVER_DOMAIN
? process.env.SERVER_DOMAIN.replace(/^https?:\/\//, '')
: req.headers.host;
const socketIoUrl = `${protocol}://${serverDomain}`;
res.json({ url: socketIoUrl });
});
module.exports = router;

View File

@@ -13,6 +13,7 @@ const {
actionDomainSeparator,
} = require('librechat-data-provider');
const { refreshAccessToken } = require('~/server/services/TokenService');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { logger, getFlowStateManager, sendEvent } = require('~/config');
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
const { getActions, deleteActions } = require('~/models/Action');
@@ -129,7 +130,6 @@ async function loadActionSets(searchParams) {
* @param {string | undefined} [params.name] - The name of the tool.
* @param {string | undefined} [params.description] - The description for the tool.
* @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition
* @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action.
* @returns { Promise<typeof tool | { _call: (toolInput: Object | string) => unknown}> } An object with `_call` method to execute the tool input.
*/
async function createActionTool({
@@ -140,8 +140,17 @@ async function createActionTool({
zodSchema,
name,
description,
encrypted,
}) {
const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain);
if (!isDomainAllowed) {
return null;
}
const encrypted = {
oauth_client_id: action.metadata.oauth_client_id,
oauth_client_secret: action.metadata.oauth_client_secret,
};
action.metadata = await decryptMetadata(action.metadata);
/** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise<unknown>} */
const _call = async (toolInput, config) => {
try {
@@ -152,9 +161,9 @@ async function createActionTool({
if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) {
try {
const action_id = action.action_id;
const identifier = `${req.user.id}:${action.action_id}`;
if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) {
const action_id = action.action_id;
const identifier = `${req.user.id}:${action.action_id}`;
const requestLogin = async () => {
const { args: _args, stepId, ...toolCall } = config.toolCall ?? {};
if (!stepId) {
@@ -299,8 +308,9 @@ async function createActionTool({
}
return response.data;
} catch (error) {
const message = `API call to ${action.metadata.domain} failed:`;
return logAxiosError({ message, error });
const logMessage = `API call to ${action.metadata.domain} failed`;
logAxiosError({ message: logMessage, error });
throw error;
}
};
@@ -317,27 +327,6 @@ async function createActionTool({
};
}
/**
* Encrypts a sensitive value.
* @param {string} value
* @returns {Promise<string>}
*/
async function encryptSensitiveValue(value) {
// Encode API key to handle special characters like ":"
const encodedValue = encodeURIComponent(value);
return await encryptV2(encodedValue);
}
/**
* Decrypts a sensitive value.
* @param {string} value
* @returns {Promise<string>}
*/
async function decryptSensitiveValue(value) {
const decryptedValue = await decryptV2(value);
return decodeURIComponent(decryptedValue);
}
/**
* Encrypts sensitive metadata values for an action.
*
@@ -350,19 +339,17 @@ async function encryptMetadata(metadata) {
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
encryptedMetadata.api_key = await encryptSensitiveValue(metadata.api_key);
encryptedMetadata.api_key = await encryptV2(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
encryptedMetadata.oauth_client_id = await encryptSensitiveValue(metadata.oauth_client_id);
encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
encryptedMetadata.oauth_client_secret = await encryptSensitiveValue(
metadata.oauth_client_secret,
);
encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret);
}
}
@@ -381,19 +368,17 @@ async function decryptMetadata(metadata) {
// ServiceHttp
if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) {
if (metadata.api_key) {
decryptedMetadata.api_key = await decryptSensitiveValue(metadata.api_key);
decryptedMetadata.api_key = await decryptV2(metadata.api_key);
}
}
// OAuth
else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) {
if (metadata.oauth_client_id) {
decryptedMetadata.oauth_client_id = await decryptSensitiveValue(metadata.oauth_client_id);
decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id);
}
if (metadata.oauth_client_secret) {
decryptedMetadata.oauth_client_secret = await decryptSensitiveValue(
metadata.oauth_client_secret,
);
decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret);
}
}

View File

@@ -1,24 +1,15 @@
const {
FileSources,
EModelEndpoint,
loadOCRConfig,
processMCPEnv,
getConfigDefaults,
} = require('librechat-data-provider');
const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider');
const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks');
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
const { initializeFirebase } = require('./Files/Firebase/initialize');
const loadCustomConfig = require('./Config/loadCustomConfig');
const handleRateLimits = require('./Config/handleRateLimits');
const { loadDefaultInterface } = require('./start/interface');
const { azureConfigSetup } = require('./start/azureOpenAI');
const { processModelSpecs } = require('./start/modelSpecs');
const { initializeS3 } = require('./Files/S3/initialize');
const { loadAndFormatTools } = require('./ToolService');
const { agentsConfigSetup } = require('./start/agents');
const { initializeRoles } = require('~/models/Role');
const { isEnabled } = require('~/server/utils');
const { getMCPManager } = require('~/config');
const paths = require('~/config/paths');
@@ -30,19 +21,13 @@ const paths = require('~/config/paths');
*/
const AppService = async (app) => {
await initializeRoles();
/** @type {TCustomConfig} */
/** @type {TCustomConfig}*/
const config = (await loadCustomConfig()) ?? {};
const configDefaults = getConfigDefaults();
const ocr = loadOCRConfig(config.ocr);
const filteredTools = config.filteredTools;
const includedTools = config.includedTools;
const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy;
const startBalance = process.env.START_BALANCE;
const balance = config.balance ?? {
enabled: isEnabled(process.env.CHECK_BALANCE),
startBalance: startBalance ? parseInt(startBalance, 10) : undefined,
};
const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType;
process.env.CDN_PROVIDER = fileStrategy;
@@ -52,13 +37,9 @@ const AppService = async (app) => {
if (fileStrategy === FileSources.firebase) {
initializeFirebase();
} else if (fileStrategy === FileSources.azure_blob) {
initializeAzureBlobService();
} else if (fileStrategy === FileSources.s3) {
initializeS3();
}
/** @type {Record<string, FunctionTool>} */
/** @type {Record<string, FunctionTool} */
const availableTools = loadAndFormatTools({
adminFilter: filteredTools,
adminIncluded: includedTools,
@@ -67,7 +48,7 @@ const AppService = async (app) => {
if (config.mcpServers != null) {
const mcpManager = await getMCPManager();
await mcpManager.initializeMCP(config.mcpServers, processMCPEnv);
await mcpManager.initializeMCP(config.mcpServers);
await mcpManager.mapAvailableTools(availableTools);
}
@@ -76,7 +57,6 @@ const AppService = async (app) => {
const interfaceConfig = await loadDefaultInterface(config, configDefaults);
const defaultLocals = {
ocr,
paths,
fileStrategy,
socialLogins,
@@ -85,7 +65,6 @@ const AppService = async (app) => {
availableTools,
imageOutputType,
interfaceConfig,
balance,
};
if (!Object.keys(config).length) {
@@ -146,7 +125,7 @@ const AppService = async (app) => {
...defaultLocals,
fileConfig: config?.fileConfig,
secureImageLinks: config?.secureImageLinks,
modelSpecs: processModelSpecs(endpoints, config.modelSpecs, interfaceConfig),
modelSpecs: processModelSpecs(endpoints, config.modelSpecs),
...endpointLocals,
};
};

View File

@@ -15,9 +15,6 @@ jest.mock('./Config/loadCustomConfig', () => {
Promise.resolve({
registration: { socialLogins: ['testLogin'] },
fileStrategy: 'testStrategy',
balance: {
enabled: true,
},
}),
);
});
@@ -123,13 +120,9 @@ describe('AppService', () => {
},
},
paths: expect.anything(),
ocr: expect.anything(),
imageOutputType: expect.any(String),
fileConfig: undefined,
secureImageLinks: undefined,
balance: { enabled: true },
filteredTools: undefined,
includedTools: undefined,
});
});
@@ -347,6 +340,9 @@ describe('AppService', () => {
process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax';
process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow';
// Mock a custom configuration without specific rate limits
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
await AppService(app);
// Verify that process.env falls back to the initial values
@@ -407,6 +403,9 @@ describe('AppService', () => {
process.env.IMPORT_USER_MAX = 'initialUserMax';
process.env.IMPORT_USER_WINDOW = 'initialUserWindow';
// Mock a custom configuration without specific rate limits
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({}));
await AppService(app);
// Verify that process.env falls back to the initial values
@@ -445,27 +444,13 @@ describe('AppService updating app.locals and issuing warnings', () => {
expect(app.locals.availableTools).toBeDefined();
expect(app.locals.fileStrategy).toEqual(FileSources.local);
expect(app.locals.socialLogins).toEqual(defaultSocialLogins);
expect(app.locals.balance).toEqual(
expect.objectContaining({
enabled: false,
startBalance: undefined,
}),
);
});
it('should update app.locals with values from loadCustomConfig', async () => {
// Mock loadCustomConfig to return a specific config object with a complete balance config
// Mock loadCustomConfig to return a specific config object
const customConfig = {
fileStrategy: 'firebase',
registration: { socialLogins: ['testLogin'] },
balance: {
enabled: false,
startBalance: 5000,
autoRefillEnabled: true,
refillIntervalValue: 15,
refillIntervalUnit: 'hours',
refillAmount: 5000,
},
};
require('./Config/loadCustomConfig').mockImplementationOnce(() =>
Promise.resolve(customConfig),
@@ -478,7 +463,6 @@ describe('AppService updating app.locals and issuing warnings', () => {
expect(app.locals.availableTools).toBeDefined();
expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy);
expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins);
expect(app.locals.balance).toEqual(customConfig.balance);
});
it('should apply the assistants endpoint configuration correctly to app.locals', async () => {
@@ -604,33 +588,4 @@ describe('AppService updating app.locals and issuing warnings', () => {
);
});
});
it('should not parse environment variable references in OCR config', async () => {
// Mock custom configuration with env variable references in OCR config
const mockConfig = {
ocr: {
apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}',
baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}',
strategy: 'mistral_ocr',
mistralModel: 'mistral-medium',
},
};
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
// Set actual environment variables with different values
process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key';
process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com';
// Initialize app
const app = { locals: {} };
await AppService(app);
// Verify that the raw string references were preserved and not interpolated
expect(app.locals.ocr).toBeDefined();
expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}');
expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}');
expect(app.locals.ocr.strategy).toEqual('mistral_ocr');
expect(app.locals.ocr.mistralModel).toEqual('mistral-medium');
});
});

View File

@@ -91,7 +91,7 @@ const sendVerificationEmail = async (user) => {
subject: 'Verify your email',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name || user.username || user.email,
name: user.name,
verificationLink: verificationLink,
year: new Date().getFullYear(),
},
@@ -278,7 +278,7 @@ const requestPasswordReset = async (req) => {
subject: 'Password Reset Request',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name || user.username || user.email,
name: user.name,
link: link,
year: new Date().getFullYear(),
},
@@ -331,7 +331,7 @@ const resetPassword = async (userId, token, password) => {
subject: 'Password Reset Successfully',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name || user.username || user.email,
name: user.name,
year: new Date().getFullYear(),
},
template: 'passwordReset.handlebars',
@@ -414,7 +414,7 @@ const resendVerificationEmail = async (req) => {
subject: 'Verify your email',
payload: {
appName: process.env.APP_TITLE || 'LibreChat',
name: user.name || user.username || user.email,
name: user.name,
verificationLink: verificationLink,
year: new Date().getFullYear(),
},

View File

@@ -1,5 +1,5 @@
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
const { normalizeEndpointName, isEnabled } = require('~/server/utils');
const { normalizeEndpointName } = require('~/server/utils');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
@@ -23,26 +23,6 @@ async function getCustomConfig() {
return customConfig;
}
/**
* Retrieves the configuration object
* @function getBalanceConfig
* @returns {Promise<TCustomConfig['balance'] | null>}
* */
async function getBalanceConfig() {
const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE);
const startBalance = process.env.START_BALANCE;
/** @type {TCustomConfig['balance']} */
const config = {
enabled: isLegacyEnabled,
startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined,
};
const customConfig = await getCustomConfig();
if (!customConfig) {
return config;
}
return { ...config, ...(customConfig?.['balance'] ?? {}) };
}
/**
*
* @param {string | EModelEndpoint} endpoint
@@ -60,4 +40,4 @@ const getCustomEndpointConfig = async (endpoint) => {
);
};
module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig };
module.exports = { getCustomConfig, getCustomEndpointConfig };

View File

@@ -33,12 +33,10 @@ async function getEndpointsConfig(req) {
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, allowedProviders, ..._rest } =
req.app.locals[EModelEndpoint.agents];
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];
mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
allowedProviders,
disableBuilder,
capabilities,
};
@@ -74,15 +72,4 @@ async function getEndpointsConfig(req) {
return endpointsConfig;
}
/**
* @param {ServerRequest} req
* @param {import('librechat-data-provider').AgentCapabilities} capability
* @returns {Promise<boolean>}
*/
const checkCapability = async (req, capability) => {
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
return capabilities.includes(capability);
};
module.exports = { getEndpointsConfig, checkCapability };
module.exports = { getEndpointsConfig };

View File

@@ -2,8 +2,15 @@ const { loadAgent } = require('~/models/Agent');
const { logger } = require('~/config');
const buildOptions = (req, endpoint, parsedBody) => {
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
parsedBody;
const {
spec,
iconURL,
agent_id,
instructions,
maxContextTokens,
resendFiles = true,
...model_parameters
} = parsedBody;
const agentPromise = loadAgent({
req,
agent_id,
@@ -17,6 +24,7 @@ const buildOptions = (req, endpoint, parsedBody) => {
iconURL,
endpoint,
agent_id,
resendFiles,
instructions,
maxContextTokens,
model_parameters,

View File

@@ -1,9 +1,7 @@
const { createContentAggregator, Providers } = require('@librechat/agents');
const {
ErrorTypes,
EModelEndpoint,
getResponseSender,
AgentCapabilities,
providerEndpointMap,
} = require('librechat-data-provider');
const {
@@ -17,14 +15,10 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { processFiles } = require('~/server/services/Files/process');
const { loadAgentTools } = require('~/server/services/ToolService');
const AgentClient = require('~/server/controllers/agents/client');
const { getConvoFiles } = require('~/models/Conversation');
const { getToolFilesByIds } = require('~/models/File');
const { getModelMaxTokens } = require('~/utils');
const { getAgent } = require('~/models/Agent');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
const providerConfigMap = {
@@ -40,38 +34,20 @@ const providerConfigMap = {
};
/**
* @param {ServerRequest} req
*
* @param {Promise<Array<MongoFile | null>> | undefined} _attachments
* @param {AgentToolResources | undefined} _tool_resources
* @returns {Promise<{ attachments: Array<MongoFile | undefined> | undefined, tool_resources: AgentToolResources | undefined }>}
*/
const primeResources = async (req, _attachments, _tool_resources) => {
const primeResources = async (_attachments, _tool_resources) => {
try {
/** @type {Array<MongoFile | undefined> | undefined} */
let attachments;
const tool_resources = _tool_resources ?? {};
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
AgentCapabilities.ocr,
);
if (tool_resources.ocr?.file_ids && isOCREnabled) {
const context = await getFiles(
{
file_id: { $in: tool_resources.ocr.file_ids },
},
{},
{},
);
attachments = (attachments ?? []).concat(context);
}
if (!_attachments) {
return { attachments, tool_resources };
return { attachments: undefined, tool_resources: _tool_resources };
}
/** @type {Array<MongoFile | undefined> | undefined} */
const files = await _attachments;
if (!attachments) {
/** @type {Array<MongoFile | undefined>} */
attachments = [];
}
const attachments = [];
const tool_resources = _tool_resources ?? {};
for (const file of files) {
if (!file) {
@@ -100,26 +76,13 @@ const primeResources = async (req, _attachments, _tool_resources) => {
}
};
/**
* @param {...string | number} values
* @returns {string | number | undefined}
*/
function optionalChainWithEmptyCheck(...values) {
for (const value of values) {
if (value !== undefined && value !== null && value !== '') {
return value;
}
}
return values[values.length - 1];
}
/**
* @param {object} params
* @param {ServerRequest} params.req
* @param {ServerResponse} params.res
* @param {Agent} params.agent
* @param {Set<string>} [params.allowedProviders]
* @param {object} [params.endpointOption]
* @param {AgentToolResources} [params.tool_resources]
* @param {boolean} [params.isInitialAgent]
* @returns {Promise<Agent>}
*/
@@ -128,36 +91,9 @@ const initializeAgentOptions = async ({
res,
agent,
endpointOption,
allowedProviders,
tool_resources,
isInitialAgent = false,
}) => {
if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) {
throw new Error(
`{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`,
);
}
let currentFiles;
/** @type {Array<MongoFile>} */
const requestFiles = req.body.files ?? [];
if (
isInitialAgent &&
req.body.conversationId != null &&
(agent.model_parameters?.resendFiles ?? true) === true
) {
const fileIds = (await getConvoFiles(req.body.conversationId)) ?? [];
const toolFiles = await getToolFilesByIds(fileIds);
if (requestFiles.length || toolFiles.length) {
currentFiles = await processFiles(requestFiles.concat(toolFiles));
}
} else if (isInitialAgent && requestFiles.length) {
currentFiles = await processFiles(requestFiles);
}
const { attachments, tool_resources } = await primeResources(
req,
currentFiles,
agent.tool_resources,
);
const { tools, toolContextMap } = await loadAgentTools({
req,
res,
@@ -202,7 +138,6 @@ const initializeAgentOptions = async ({
agent.provider = options.provider;
}
/** @type {import('@librechat/agents').ClientOptions} */
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
if (options.configOptions) {
agent.model_parameters.configuration = options.configOptions;
@@ -221,23 +156,15 @@ const initializeAgentOptions = async ({
const tokensModel =
agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model;
const maxTokens = optionalChainWithEmptyCheck(
agent.model_parameters.maxOutputTokens,
agent.model_parameters.maxTokens,
0,
);
const maxContextTokens = optionalChainWithEmptyCheck(
agent.model_parameters.maxContextTokens,
agent.max_context_tokens,
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
4096,
);
return {
...agent,
tools,
attachments,
toolContextMap,
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
maxContextTokens:
agent.max_context_tokens ??
getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ??
4000,
};
};
@@ -270,9 +197,12 @@ const initializeClient = async ({ req, res, endpointOption }) => {
throw new Error('Agent not found');
}
const { attachments, tool_resources } = await primeResources(
endpointOption.attachments,
primaryAgent.tool_resources,
);
const agentConfigs = new Map();
/** @type {Set<string>} */
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
// Handle primary agent
const primaryConfig = await initializeAgentOptions({
@@ -280,7 +210,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
agent: primaryAgent,
endpointOption,
allowedProviders,
tool_resources,
isInitialAgent: true,
});
@@ -296,7 +226,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
res,
agent,
endpointOption,
allowedProviders,
});
agentConfigs.set(agentId, config);
}
@@ -311,21 +240,18 @@ const initializeClient = async ({ req, res, endpointOption }) => {
const client = new AgentClient({
req,
res,
agent: primaryConfig,
sender,
attachments,
contentParts,
agentConfigs,
eventHandlers,
collectedUsage,
aggregateContent,
artifactPromises,
agent: primaryConfig,
spec: endpointOption.spec,
iconURL: endpointOption.iconURL,
agentConfigs,
endpoint: EModelEndpoint.agents,
attachments: primaryConfig.attachments,
maxContextTokens: primaryConfig.maxContextTokens,
resendFiles: primaryConfig.model_parameters?.resendFiles ?? true,
});
return { client };

View File

@@ -23,9 +23,8 @@ const initializeClient = async ({ req, res, endpointOption }) => {
const agent = {
id: EModelEndpoint.bedrock,
name: endpointOption.name,
provider: EModelEndpoint.bedrock,
endpoint: EModelEndpoint.bedrock,
instructions: endpointOption.promptPrefix,
provider: EModelEndpoint.bedrock,
model: endpointOption.model_parameters.model,
model_parameters: endpointOption.model_parameters,
};
@@ -55,7 +54,6 @@ const initializeClient = async ({ req, res, endpointOption }) => {
const client = new AgentClient({
req,
res,
agent,
sender,
// tools,

View File

@@ -135,9 +135,12 @@ const initializeClient = async ({
}
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
modelOptions.model = modelName;
clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions = Object.assign(
{
modelOptions: endpointOption.model_parameters,
},
clientOptions,
);
clientOptions.modelOptions.user = req.user.id;
const options = getLLMConfig(apiKey, clientOptions);
if (!clientOptions.streamRate) {

View File

@@ -28,7 +28,7 @@ const { isEnabled } = require('~/server/utils');
* @returns {Object} Configuration options for creating an LLM instance.
*/
function getLLMConfig(apiKey, options = {}, endpoint = null) {
let {
const {
modelOptions = {},
reverseProxyUrl,
defaultQuery,
@@ -50,32 +50,10 @@ function getLLMConfig(apiKey, options = {}, endpoint = null) {
if (addParams && typeof addParams === 'object') {
Object.assign(llmConfig, addParams);
}
/** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
const searchExcludeParams = [
'frequency_penalty',
'presence_penalty',
'temperature',
'top_p',
'top_k',
'stop',
'logit_bias',
'seed',
'response_format',
'n',
'logprobs',
'user',
];
dropParams = dropParams || [];
dropParams = [...new Set([...dropParams, ...searchExcludeParams])];
}
if (dropParams && Array.isArray(dropParams)) {
dropParams.forEach((param) => {
if (llmConfig[param]) {
llmConfig[param] = undefined;
}
delete llmConfig[param];
});
}

View File

@@ -1,40 +0,0 @@
const { AudioHandler } = require('./WebRTCHandler');
const { logger } = require('~/config');
class AudioSocketModule {
constructor(socketIOService) {
this.socketIOService = socketIOService;
this.audioHandler = new AudioHandler();
this.moduleId = 'audio-handler';
this.registerHandlers();
}
registerHandlers() {
this.socketIOService.registerModule(this.moduleId, {
connection: (socket) => this.handleConnection(socket),
disconnect: (socket) => this.handleDisconnect(socket),
});
}
handleConnection(socket) {
// Register WebRTC-specific event handlers for this socket
this.audioHandler.registerSocketHandlers(socket, this.config);
logger.debug(`Audio handler registered for client: ${socket.id}`);
}
handleDisconnect(socket) {
// Cleanup audio resources for disconnected client
this.audioHandler.cleanup(socket.id);
logger.debug(`Audio handler cleaned up for client: ${socket.id}`);
}
// Used for app shutdown
cleanup() {
this.audioHandler.cleanupAll();
this.socketIOService.unregisterModule(this.moduleId);
}
}
module.exports = { AudioSocketModule };

View File

@@ -7,78 +7,6 @@ const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
/**
* Maps MIME types to their corresponding file extensions for audio files.
* @type {Object}
*/
const MIME_TO_EXTENSION_MAP = {
// MP4 container formats
'audio/mp4': 'm4a',
'audio/x-m4a': 'm4a',
// Ogg formats
'audio/ogg': 'ogg',
'audio/vorbis': 'ogg',
'application/ogg': 'ogg',
// Wave formats
'audio/wav': 'wav',
'audio/x-wav': 'wav',
'audio/wave': 'wav',
// MP3 formats
'audio/mp3': 'mp3',
'audio/mpeg': 'mp3',
'audio/mpeg3': 'mp3',
// WebM formats
'audio/webm': 'webm',
// Additional formats
'audio/flac': 'flac',
'audio/x-flac': 'flac',
};
/**
* Gets the file extension from the MIME type.
* @param {string} mimeType - The MIME type.
* @returns {string} The file extension.
*/
function getFileExtensionFromMime(mimeType) {
// Default fallback
if (!mimeType) {
return 'webm';
}
// Direct lookup (fastest)
const extension = MIME_TO_EXTENSION_MAP[mimeType];
if (extension) {
return extension;
}
// Try to extract subtype as fallback
const subtype = mimeType.split('/')[1]?.toLowerCase();
// If subtype matches a known extension
if (['mp3', 'mp4', 'ogg', 'wav', 'webm', 'm4a', 'flac'].includes(subtype)) {
return subtype === 'mp4' ? 'm4a' : subtype;
}
// Generic checks for partial matches
if (subtype?.includes('mp4') || subtype?.includes('m4a')) {
return 'm4a';
}
if (subtype?.includes('ogg')) {
return 'ogg';
}
if (subtype?.includes('wav')) {
return 'wav';
}
if (subtype?.includes('mp3') || subtype?.includes('mpeg')) {
return 'mp3';
}
if (subtype?.includes('webm')) {
return 'webm';
}
return 'webm'; // Default fallback
}
/**
* Service class for handling Speech-to-Text (STT) operations.
* @class
@@ -242,10 +170,8 @@ class STTService {
throw new Error('Invalid provider');
}
const fileExtension = getFileExtensionFromMime(audioFile.mimetype);
const audioReadStream = Readable.from(audioBuffer);
audioReadStream.path = `audio.${fileExtension}`;
audioReadStream.path = 'audio.wav';
const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile);

View File

@@ -1,178 +0,0 @@
const { RTCPeerConnection, RTCIceCandidate, MediaStream } = require('wrtc');
const { logger } = require('~/config');
class WebRTCConnection {
constructor(socket, config) {
this.socket = socket;
this.config = config;
this.peerConnection = null;
this.audioTransceiver = null;
this.pendingCandidates = [];
this.state = 'idle';
}
async handleOffer(offer) {
try {
if (!this.peerConnection) {
this.peerConnection = new RTCPeerConnection(this.config.rtcConfig);
this.setupPeerConnectionListeners();
}
await this.peerConnection.setRemoteDescription(offer);
const mediaStream = new MediaStream();
this.audioTransceiver = this.peerConnection.addTransceiver('audio', {
direction: 'sendrecv',
streams: [mediaStream],
});
const answer = await this.peerConnection.createAnswer();
await this.peerConnection.setLocalDescription(answer);
this.socket.emit('webrtc-answer', answer);
} catch (error) {
logger.error(`Error handling offer: ${error}`);
this.socket.emit('webrtc-error', {
message: error.message,
code: 'OFFER_ERROR',
});
}
}
setupPeerConnectionListeners() {
if (!this.peerConnection) {
return;
}
this.peerConnection.ontrack = ({ track }) => {
logger.info(`Received ${track.kind} track from client`);
if (track.kind === 'audio') {
this.handleIncomingAudio(track);
}
track.onended = () => {
logger.info(`${track.kind} track ended`);
};
};
this.peerConnection.onicecandidate = ({ candidate }) => {
if (candidate) {
this.socket.emit('icecandidate', candidate);
}
};
this.peerConnection.onconnectionstatechange = () => {
if (!this.peerConnection) {
return;
}
const state = this.peerConnection.connectionState;
logger.info(`Connection state changed to ${state}`);
if (state === 'failed' || state === 'closed' || state === 'disconnected') {
this.cleanup();
}
};
}
handleIncomingAudio(track) {
if (this.peerConnection) {
const stream = new MediaStream([track]);
this.peerConnection.addTrack(track, stream);
}
}
async addIceCandidate(candidate) {
try {
if (this.peerConnection?.remoteDescription) {
if (candidate && candidate.candidate) {
await this.peerConnection.addIceCandidate(new RTCIceCandidate(candidate));
} else {
logger.warn('Invalid ICE candidate');
}
} else {
this.pendingCandidates.push(candidate);
}
} catch (error) {
logger.error(`Error adding ICE candidate: ${error}`);
}
}
cleanup() {
if (this.peerConnection) {
try {
this.peerConnection.close();
} catch (error) {
logger.error(`Error closing peer connection: ${error}`);
}
this.peerConnection = null;
}
this.audioTransceiver = null;
this.pendingCandidates = [];
this.state = 'idle';
}
}
class AudioHandler {
constructor() {
this.connections = new Map();
this.defaultRTCConfig = {
iceServers: [
{
urls: ['stun:stun.l.google.com:19302', 'stun:stun1.l.google.com:19302'],
},
],
iceCandidatePoolSize: 10,
bundlePolicy: 'max-bundle',
rtcpMuxPolicy: 'require',
};
}
registerSocketHandlers(socket) {
const rtcConfig = {
rtcConfig: this.defaultRTCConfig,
};
const rtcConnection = new WebRTCConnection(socket, rtcConfig);
this.connections.set(socket.id, rtcConnection);
socket.on('webrtc-offer', (offer) => {
logger.debug(`Received WebRTC offer from ${socket.id}`);
rtcConnection.handleOffer(offer);
});
socket.on('icecandidate', (candidate) => {
rtcConnection.addIceCandidate(candidate);
});
socket.on('vad-status', (speaking) => {
logger.debug(`VAD status from ${socket.id}: ${JSON.stringify(speaking)}`);
});
socket.on('disconnect', () => {
rtcConnection.cleanup();
this.connections.delete(socket.id);
});
return rtcConnection;
}
cleanup(socketId) {
const connection = this.connections.get(socketId);
if (connection) {
connection.cleanup();
this.connections.delete(socketId);
}
}
cleanupAll() {
for (const connection of this.connections.values()) {
connection.cleanup();
}
this.connections.clear();
}
}
module.exports = { AudioHandler, WebRTCConnection };

View File

@@ -1,102 +0,0 @@
const { extractEnvVariable, RealtimeVoiceProviders } = require('librechat-data-provider');
const { getCustomConfig } = require('~/server/services/Config');
const { logger } = require('~/config');
class RealtimeService {
constructor(customConfig) {
this.customConfig = customConfig;
this.providerStrategies = {
[RealtimeVoiceProviders.OPENAI]: this.openaiProvider.bind(this),
};
}
static async getInstance() {
const customConfig = await getCustomConfig();
if (!customConfig) {
throw new Error('Custom config not found');
}
return new RealtimeService(customConfig);
}
async getProviderSchema() {
const realtimeSchema = this.customConfig.speech.realtime;
if (!realtimeSchema) {
throw new Error('No Realtime schema is set in config');
}
const providers = Object.entries(realtimeSchema).filter(
([, value]) => Object.keys(value).length > 0,
);
if (providers.length !== 1) {
throw new Error(providers.length > 1 ? 'Multiple providers set' : 'No provider set');
}
return providers[0];
}
async openaiProvider(schema, voice) {
const defaultRealtimeUrl = 'https://api.openai.com/v1/realtime';
const allowedVoices = ['alloy', 'ash', 'ballad', 'coral', 'echo', 'sage', 'shimmer', 'verse'];
if (!voice) {
throw new Error('Voice not specified');
}
if (!allowedVoices.includes(voice)) {
throw new Error(`Invalid voice: ${voice}`);
}
const apiKey = extractEnvVariable(schema.apiKey);
if (!apiKey) {
throw new Error('OpenAI API key not configured');
}
const response = await fetch('https://api.openai.com/v1/realtime/sessions', {
method: 'POST',
headers: {
Authorization: `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: 'gpt-4o-realtime-preview-2024-12-17',
modalities: ['audio', 'text'],
voice: voice,
}),
});
const token = response.json();
return {
provider: RealtimeVoiceProviders.OPENAI,
token: token,
url: schema.url || defaultRealtimeUrl,
};
}
async getRealtimeConfig(req, res) {
try {
const [provider, schema] = await this.getProviderSchema();
const strategy = this.providerStrategies[provider];
if (!strategy) {
throw new Error(`Unsupported provider: ${provider}`);
}
const voice = req.query.voice;
const config = strategy(schema, voice);
res.json(config);
} catch (error) {
logger.error('[RealtimeService] Config generation failed:', error);
res.status(500).json({ error: error.message });
}
}
}
async function getRealtimeConfig(req, res) {
const service = await RealtimeService.getInstance();
await service.getRealtimeConfig(req, res);
}
module.exports = getRealtimeConfig;

View File

@@ -1,5 +1,4 @@
const getCustomConfigSpeech = require('./getCustomConfigSpeech');
const getRealtimeConfig = require('./getRealtimeConfig');
const TTSService = require('./TTSService');
const STTService = require('./STTService');
const getVoices = require('./getVoices');
@@ -7,7 +6,6 @@ const getVoices = require('./getVoices');
module.exports = {
getVoices,
getCustomConfigSpeech,
getRealtimeConfig,
...STTService,
...TTSService,
};

View File

@@ -1,10 +1,4 @@
const {
Time,
CacheKeys,
SEPARATORS,
parseTextParts,
findLastSeparatorIndex,
} = require('librechat-data-provider');
const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider');
const { getMessage } = require('~/models/Message');
const { getLogStores } = require('~/cache');
@@ -90,11 +84,10 @@ function createChunkProcessor(user, messageId) {
notFoundCount++;
return [];
} else {
const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text;
messageCache.set(
messageId,
{
text,
text: message.text,
complete: true,
},
Time.FIVE_MINUTES,
@@ -102,7 +95,7 @@ function createChunkProcessor(user, messageId) {
}
const text = typeof message === 'string' ? message : message.text;
const complete = typeof message === 'string' ? false : (message.complete ?? true);
const complete = typeof message === 'string' ? false : message.complete ?? true;
if (text === processedText) {
noChangeCount++;

View File

@@ -1,253 +0,0 @@
const fs = require('fs');
const path = require('path');
const mime = require('mime');
const axios = require('axios');
const fetch = require('node-fetch');
const { logger } = require('~/config');
const { getAzureContainerClient } = require('./initialize');
const defaultBasePath = 'images';
const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env;
/**
* Uploads a buffer to Azure Blob Storage.
*
* Files will be stored at the path: {basePath}/{userId}/{fileName} within the container.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {Buffer} params.buffer - The buffer to upload.
* @param {string} params.fileName - The name of the file.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function saveBufferToAzure({
userId,
buffer,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const containerClient = getAzureContainerClient(containerName);
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
// Create the container if it doesn't exist. This is done per operation.
await containerClient.createIfNotExists({ access });
const blobPath = `${basePath}/${userId}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
await blockBlobClient.uploadData(buffer);
return blockBlobClient.url;
} catch (error) {
logger.error('[saveBufferToAzure] Error uploading buffer:', error);
throw error;
}
}
/**
* Saves a file from a URL to Azure Blob Storage.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {string} params.URL - The URL of the file.
* @param {string} params.fileName - The name of the file.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function saveURLToAzure({
userId,
URL,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const response = await fetch(URL);
const buffer = await response.buffer();
return await saveBufferToAzure({ userId, buffer, fileName, basePath, containerName });
} catch (error) {
logger.error('[saveURLToAzure] Error uploading file from URL:', error);
throw error;
}
}
/**
* Retrieves a blob URL from Azure Blob Storage.
*
* @param {Object} params
* @param {string} params.fileName - The file name.
* @param {string} [params.basePath='images'] - The base folder used during upload.
* @param {string} [params.userId] - If files are stored in a user-specific directory.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The blob's URL.
*/
async function getAzureURL({ fileName, basePath = defaultBasePath, userId, containerName }) {
try {
const containerClient = getAzureContainerClient(containerName);
const blobPath = userId ? `${basePath}/${userId}/${fileName}` : `${basePath}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
return blockBlobClient.url;
} catch (error) {
logger.error('[getAzureURL] Error retrieving blob URL:', error);
throw error;
}
}
/**
* Deletes a blob from Azure Blob Storage.
*
* @param {Object} params
* @param {ServerRequest} params.req - The Express request object.
* @param {MongoFile} params.file - The file object.
*/
async function deleteFileFromAzure(req, file) {
try {
const containerClient = getAzureContainerClient(AZURE_CONTAINER_NAME);
const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1];
if (!blobPath.includes(req.user.id)) {
throw new Error('User ID not found in blob path');
}
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
await blockBlobClient.delete();
logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage');
} catch (error) {
logger.error('[deleteFileFromAzure] Error deleting blob:', error);
if (error.statusCode === 404) {
return;
}
throw error;
}
}
/**
* Streams a file from disk directly to Azure Blob Storage without loading
* the entire file into memory.
*
* @param {Object} params
* @param {string} params.userId - The user's id.
* @param {string} params.filePath - The local file path to upload.
* @param {string} params.fileName - The name of the file in Azure.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the uploaded blob.
*/
async function streamFileToAzure({
userId,
filePath,
fileName,
basePath = defaultBasePath,
containerName,
}) {
try {
const containerClient = getAzureContainerClient(containerName);
const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined;
// Create the container if it doesn't exist
await containerClient.createIfNotExists({ access });
const blobPath = `${basePath}/${userId}/${fileName}`;
const blockBlobClient = containerClient.getBlockBlobClient(blobPath);
// Get file size for proper content length
const stats = await fs.promises.stat(filePath);
// Create read stream from the file
const fileStream = fs.createReadStream(filePath);
const blobContentType = mime.getType(fileName);
await blockBlobClient.uploadStream(
fileStream,
undefined, // Use default concurrency (5)
undefined, // Use default buffer size (8MB)
{
blobHTTPHeaders: {
blobContentType,
},
onProgress: (progress) => {
logger.debug(
`[streamFileToAzure] Upload progress: ${progress.loadedBytes} bytes of ${stats.size}`,
);
},
},
);
return blockBlobClient.url;
} catch (error) {
logger.error('[streamFileToAzure] Error streaming file:', error);
throw error;
}
}
/**
* Uploads a file from the local file system to Azure Blob Storage.
*
* This function reads the file from disk and then uploads it to Azure Blob Storage
* at the path: {basePath}/{userId}/{fileName}.
*
* @param {Object} params
* @param {object} params.req - The Express request object.
* @param {Express.Multer.File} params.file - The file object.
* @param {string} params.file_id - The file id.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<{ filepath: string, bytes: number }>} An object containing the blob URL and its byte size.
*/
async function uploadFileToAzure({
req,
file,
file_id,
basePath = defaultBasePath,
containerName,
}) {
try {
const inputFilePath = file.path;
const stats = await fs.promises.stat(inputFilePath);
const bytes = stats.size;
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const fileURL = await streamFileToAzure({
userId,
filePath: inputFilePath,
fileName,
basePath,
containerName,
});
return { filepath: fileURL, bytes };
} catch (error) {
logger.error('[uploadFileToAzure] Error uploading file:', error);
throw error;
}
}
/**
* Retrieves a readable stream for a blob from Azure Blob Storage.
*
* @param {object} _req - The Express request object.
* @param {string} fileURL - The URL of the blob.
* @returns {Promise<ReadableStream>} A readable stream of the blob.
*/
async function getAzureFileStream(_req, fileURL) {
try {
const response = await axios({
method: 'get',
url: fileURL,
responseType: 'stream',
});
return response.data;
} catch (error) {
logger.error('[getAzureFileStream] Error getting blob stream:', error);
throw error;
}
}
module.exports = {
saveBufferToAzure,
saveURLToAzure,
getAzureURL,
deleteFileFromAzure,
uploadFileToAzure,
getAzureFileStream,
};

View File

@@ -1,124 +0,0 @@
const fs = require('fs');
const path = require('path');
const sharp = require('sharp');
const { resizeImageBuffer } = require('../images/resize');
const { updateUser } = require('~/models/userMethods');
const { updateFile } = require('~/models/File');
const { logger } = require('~/config');
const { saveBufferToAzure } = require('./crud');
/**
* Uploads an image file to Azure Blob Storage.
* It resizes and converts the image similar to your Firebase implementation.
*
* @param {Object} params
* @param {object} params.req - The Express request object.
* @param {Express.Multer.File} params.file - The file object.
* @param {string} params.file_id - The file id.
* @param {EModelEndpoint} params.endpoint - The endpoint parameters.
* @param {string} [params.resolution='high'] - The image resolution.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>}
*/
async function uploadImageToAzure({
req,
file,
file_id,
endpoint,
resolution = 'high',
basePath = 'images',
containerName,
}) {
try {
const inputFilePath = file.path;
const inputBuffer = await fs.promises.readFile(inputFilePath);
const {
buffer: resizedBuffer,
width,
height,
} = await resizeImageBuffer(inputBuffer, resolution, endpoint);
const extension = path.extname(inputFilePath);
const userId = req.user.id;
let webPBuffer;
let fileName = `${file_id}__${path.basename(inputFilePath)}`;
const targetExtension = `.${req.app.locals.imageOutputType}`;
if (extension.toLowerCase() === targetExtension) {
webPBuffer = resizedBuffer;
} else {
webPBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer();
const extRegExp = new RegExp(path.extname(fileName) + '$');
fileName = fileName.replace(extRegExp, targetExtension);
if (!path.extname(fileName)) {
fileName += targetExtension;
}
}
const downloadURL = await saveBufferToAzure({
userId,
buffer: webPBuffer,
fileName,
basePath,
containerName,
});
await fs.promises.unlink(inputFilePath);
const bytes = Buffer.byteLength(webPBuffer);
return { filepath: downloadURL, bytes, width, height };
} catch (error) {
logger.error('[uploadImageToAzure] Error uploading image:', error);
throw error;
}
}
/**
* Prepares the image URL and updates the file record.
*
* @param {object} req - The Express request object.
* @param {MongoFile} file - The file object.
* @returns {Promise<[MongoFile, string]>}
*/
async function prepareAzureImageURL(req, file) {
const { filepath } = file;
const promises = [];
promises.push(updateFile({ file_id: file.file_id }));
promises.push(filepath);
return await Promise.all(promises);
}
/**
* Uploads and processes a user's avatar to Azure Blob Storage.
*
* @param {Object} params
* @param {Buffer} params.buffer - The avatar image buffer.
* @param {string} params.userId - The user's id.
* @param {string} params.manual - Flag to indicate manual update.
* @param {string} [params.basePath='images'] - The base folder within the container.
* @param {string} [params.containerName] - The Azure Blob container name.
* @returns {Promise<string>} The URL of the avatar.
*/
async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) {
try {
const downloadURL = await saveBufferToAzure({
userId,
buffer,
fileName: 'avatar.png',
basePath,
containerName,
});
const isManual = manual === 'true';
const url = `${downloadURL}?manual=${isManual}`;
if (isManual) {
await updateUser(userId, { avatar: url });
}
return url;
} catch (error) {
logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error);
throw error;
}
}
module.exports = {
uploadImageToAzure,
prepareAzureImageURL,
processAzureAvatar,
};

View File

@@ -1,9 +0,0 @@
const crud = require('./crud');
const images = require('./images');
const initialize = require('./initialize');
module.exports = {
...crud,
...images,
...initialize,
};

View File

@@ -1,55 +0,0 @@
const { BlobServiceClient } = require('@azure/storage-blob');
const { logger } = require('~/config');
let blobServiceClient = null;
let azureWarningLogged = false;
/**
* Initializes the Azure Blob Service client.
* This function establishes a connection by checking if a connection string is provided.
* If available, the connection string is used; otherwise, Managed Identity (via DefaultAzureCredential) is utilized.
* Note: Container creation (and its public access settings) is handled later in the CRUD functions.
* @returns {BlobServiceClient|null} The initialized client, or null if the required configuration is missing.
*/
const initializeAzureBlobService = () => {
if (blobServiceClient) {
return blobServiceClient;
}
const connectionString = process.env.AZURE_STORAGE_CONNECTION_STRING;
if (connectionString) {
blobServiceClient = BlobServiceClient.fromConnectionString(connectionString);
logger.info('Azure Blob Service initialized using connection string');
} else {
const { DefaultAzureCredential } = require('@azure/identity');
const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME;
if (!accountName) {
if (!azureWarningLogged) {
logger.error(
'[initializeAzureBlobService] Azure Blob Service not initialized. Connection string missing and AZURE_STORAGE_ACCOUNT_NAME not provided.',
);
azureWarningLogged = true;
}
return null;
}
const url = `https://${accountName}.blob.core.windows.net`;
const credential = new DefaultAzureCredential();
blobServiceClient = new BlobServiceClient(url, credential);
logger.info('Azure Blob Service initialized using Managed Identity');
}
return blobServiceClient;
};
/**
* Retrieves the Azure ContainerClient for the given container name.
* @param {string} [containerName=process.env.AZURE_CONTAINER_NAME || 'files'] - The container name.
* @returns {ContainerClient|null} The Azure ContainerClient.
*/
const getAzureContainerClient = (containerName = process.env.AZURE_CONTAINER_NAME || 'files') => {
const serviceClient = initializeAzureBlobService();
return serviceClient ? serviceClient.getContainerClient(containerName) : null;
};
module.exports = {
initializeAzureBlobService,
getAzureContainerClient,
};

View File

@@ -1,10 +1,8 @@
const axios = require('axios');
const FormData = require('form-data');
const { getCodeBaseURL } = require('@librechat/agents');
const { createAxiosInstance } = require('~/config');
const { logAxiosError } = require('~/utils');
const axios = createAxiosInstance();
const MAX_FILE_SIZE = 150 * 1024 * 1024;
/**
@@ -29,15 +27,21 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) {
timeout: 15000,
};
if (process.env.PROXY) {
options.proxy = {
host: process.env.PROXY,
protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http',
};
}
const response = await axios(options);
return response;
} catch (error) {
throw new Error(
logAxiosError({
message: `Error downloading code environment file stream: ${error.message}`,
error,
}),
);
logAxiosError({
message: `Error downloading code environment file stream: ${error.message}`,
error,
});
throw new Error(`Error downloading file: ${error.message}`);
}
}
@@ -75,6 +79,13 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
maxBodyLength: MAX_FILE_SIZE,
};
if (process.env.PROXY) {
options.proxy = {
host: process.env.PROXY,
protocol: process.env.PROXY.startsWith('https') ? 'https' : 'http',
};
}
const response = await axios.post(`${baseURL}/upload`, form, options);
/** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */
@@ -90,12 +101,11 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = ''
return `${fileIdentifier}?entity_id=${entity_id}`;
} catch (error) {
throw new Error(
logAxiosError({
message: `Error uploading code environment file: ${error.message}`,
error,
}),
);
logAxiosError({
message: `Error uploading code environment file: ${error.message}`,
error,
});
throw new Error(`Error uploading code environment file: ${error.message}`);
}
}

View File

@@ -1,206 +0,0 @@
// ~/server/services/Files/MistralOCR/crud.js
const fs = require('fs');
const path = require('path');
const FormData = require('form-data');
const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { logger, createAxiosInstance } = require('~/config');
const { logAxiosError } = require('~/utils/axios');
const axios = createAxiosInstance();
/**
* Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory
*
* @param {Object} params Upload parameters
* @param {string} params.filePath The path to the file on disk
* @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath)
* @param {string} params.apiKey Mistral API key
* @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL
* @returns {Promise<Object>} The response from Mistral API
*/
async function uploadDocumentToMistral({
filePath,
fileName = '',
apiKey,
baseURL = 'https://api.mistral.ai/v1',
}) {
const form = new FormData();
form.append('purpose', 'ocr');
const actualFileName = fileName || path.basename(filePath);
const fileStream = fs.createReadStream(filePath);
form.append('file', fileStream, { filename: actualFileName });
return axios
.post(`${baseURL}/files`, form, {
headers: {
Authorization: `Bearer ${apiKey}`,
...form.getHeaders(),
},
maxBodyLength: Infinity,
maxContentLength: Infinity,
})
.then((res) => res.data)
.catch((error) => {
logger.error('Error uploading document to Mistral:', error.message);
throw error;
});
}
async function getSignedUrl({
apiKey,
fileId,
expiry = 24,
baseURL = 'https://api.mistral.ai/v1',
}) {
return axios
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
})
.then((res) => res.data)
.catch((error) => {
logger.error('Error fetching signed URL:', error.message);
throw error;
});
}
/**
* @param {Object} params
* @param {string} params.apiKey
* @param {string} params.documentUrl
* @param {string} [params.baseURL]
* @returns {Promise<OCRResult>}
*/
async function performOCR({
apiKey,
documentUrl,
model = 'mistral-ocr-latest',
baseURL = 'https://api.mistral.ai/v1',
}) {
return axios
.post(
`${baseURL}/ocr`,
{
model,
include_image_base64: false,
document: {
type: 'document_url',
document_url: documentUrl,
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
},
)
.then((res) => res.data)
.catch((error) => {
logger.error('Error performing OCR:', error.message);
throw error;
});
}
function extractVariableName(str) {
const match = str.match(envVarRegex);
return match ? match[1] : null;
}
const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => {
try {
/** @type {TCustomConfig['ocr']} */
const ocrConfig = req.app.locals?.ocr;
const apiKeyConfig = ocrConfig.apiKey || '';
const baseURLConfig = ocrConfig.baseURL || '';
const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig);
const isBaseURLEnvVar = envVarRegex.test(baseURLConfig);
const isApiKeyEmpty = !apiKeyConfig.trim();
const isBaseURLEmpty = !baseURLConfig.trim();
let apiKey, baseURL;
if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) {
const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY';
const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL';
const authValues = await loadAuthValues({
userId: req.user.id,
authFields: [baseURLVarName, apiKeyVarName],
optional: new Set([baseURLVarName]),
});
apiKey = authValues[apiKeyVarName];
baseURL = authValues[baseURLVarName];
} else {
apiKey = apiKeyConfig;
baseURL = baseURLConfig;
}
const mistralFile = await uploadDocumentToMistral({
filePath: file.path,
fileName: file.originalname,
apiKey,
baseURL,
});
const modelConfig = ocrConfig.mistralModel || '';
const model = envVarRegex.test(modelConfig)
? extractEnvVariable(modelConfig)
: modelConfig.trim() || 'mistral-ocr-latest';
const signedUrlResponse = await getSignedUrl({
apiKey,
baseURL,
fileId: mistralFile.id,
});
const ocrResult = await performOCR({
apiKey,
baseURL,
model,
documentUrl: signedUrlResponse.url,
});
let aggregatedText = '';
const images = [];
ocrResult.pages.forEach((page, index) => {
if (ocrResult.pages.length > 1) {
aggregatedText += `# PAGE ${index + 1}\n`;
}
aggregatedText += page.markdown + '\n\n';
if (page.images && page.images.length > 0) {
page.images.forEach((image) => {
if (image.image_base64) {
images.push(image.image_base64);
}
});
}
});
return {
filename: file.originalname,
bytes: aggregatedText.length * 4,
filepath: FileSources.mistral_ocr,
text: aggregatedText,
images,
};
} catch (error) {
const message = 'Error uploading document to Mistral OCR API';
throw new Error(logAxiosError({ error, message }));
}
};
module.exports = {
uploadDocumentToMistral,
uploadMistralOCR,
getSignedUrl,
performOCR,
};

View File

@@ -1,731 +0,0 @@
const fs = require('fs');
const mockAxios = {
interceptors: {
request: { use: jest.fn(), eject: jest.fn() },
response: { use: jest.fn(), eject: jest.fn() },
},
create: jest.fn().mockReturnValue({
defaults: {
proxy: null,
},
get: jest.fn().mockResolvedValue({ data: {} }),
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
}),
get: jest.fn().mockResolvedValue({ data: {} }),
post: jest.fn().mockResolvedValue({ data: {} }),
put: jest.fn().mockResolvedValue({ data: {} }),
delete: jest.fn().mockResolvedValue({ data: {} }),
reset: jest.fn().mockImplementation(function () {
this.get.mockClear();
this.post.mockClear();
this.put.mockClear();
this.delete.mockClear();
this.create.mockClear();
}),
};
jest.mock('axios', () => mockAxios);
jest.mock('fs');
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
},
createAxiosInstance: () => mockAxios,
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud');
describe('MistralOCR Service', () => {
afterEach(() => {
mockAxios.reset();
jest.clearAllMocks();
});
describe('uploadDocumentToMistral', () => {
beforeEach(() => {
// Create a more complete mock for file streams that FormData can work with
const mockReadStream = {
on: jest.fn().mockImplementation(function (event, handler) {
// Simulate immediate 'end' event to make FormData complete processing
if (event === 'end') {
handler();
}
return this;
}),
pipe: jest.fn().mockImplementation(function () {
return this;
}),
pause: jest.fn(),
resume: jest.fn(),
emit: jest.fn(),
once: jest.fn(),
destroy: jest.fn(),
};
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
// Mock FormData's append to avoid actual stream processing
jest.mock('form-data', () => {
const mockFormData = function () {
return {
append: jest.fn(),
getHeaders: jest
.fn()
.mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }),
getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')),
getLength: jest.fn().mockReturnValue(100),
};
};
return mockFormData;
});
});
it('should upload a document to Mistral API using file streaming', async () => {
const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } };
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
// Check that createReadStream was called with the correct file path
expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf');
// Since we're mocking FormData, we'll just check that axios was called correctly
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer test-api-key',
}),
maxBodyLength: Infinity,
maxContentLength: Infinity,
}),
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors during document upload', async () => {
const errorMessage = 'API error';
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
await expect(
uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
}),
).rejects.toThrow();
const { logger } = require('~/config');
expect(logger.error).toHaveBeenCalledWith(
expect.stringContaining('Error uploading document to Mistral:'),
expect.any(String),
);
});
});
describe('getSignedUrl', () => {
it('should fetch signed URL from Mistral API', async () => {
const mockResponse = { data: { url: 'https://document-url.com' } };
mockAxios.get.mockResolvedValueOnce(mockResponse);
const result = await getSignedUrl({
fileId: 'file-123',
apiKey: 'test-api-key',
});
expect(mockAxios.get).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
{
headers: {
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors when fetching signed URL', async () => {
const errorMessage = 'API error';
mockAxios.get.mockRejectedValueOnce(new Error(errorMessage));
await expect(
getSignedUrl({
fileId: 'file-123',
apiKey: 'test-api-key',
}),
).rejects.toThrow();
const { logger } = require('~/config');
expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage);
});
});
describe('performOCR', () => {
it('should perform OCR using Mistral API', async () => {
const mockResponse = {
data: {
pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }],
},
};
mockAxios.post.mockResolvedValueOnce(mockResponse);
const result = await performOCR({
apiKey: 'test-api-key',
documentUrl: 'https://document-url.com',
model: 'mistral-ocr-latest',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
{
model: 'mistral-ocr-latest',
include_image_base64: false,
document: {
type: 'document_url',
document_url: 'https://document-url.com',
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: 'Bearer test-api-key',
},
},
);
expect(result).toEqual(mockResponse.data);
});
it('should handle errors during OCR processing', async () => {
const errorMessage = 'OCR processing error';
mockAxios.post.mockRejectedValueOnce(new Error(errorMessage));
await expect(
performOCR({
apiKey: 'test-api-key',
documentUrl: 'https://document-url.com',
}),
).rejects.toThrow();
const { logger } = require('~/config');
expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage);
});
});
describe('uploadMistralOCR', () => {
beforeEach(() => {
const mockReadStream = {
on: jest.fn().mockImplementation(function (event, handler) {
if (event === 'end') {
handler();
}
return this;
}),
pipe: jest.fn().mockImplementation(function () {
return this;
}),
pause: jest.fn(),
resume: jest.fn(),
emit: jest.fn(),
once: jest.fn(),
destroy: jest.fn(),
};
fs.createReadStream = jest.fn().mockReturnValue(mockReadStream);
});
it('should process OCR for a file with standard configuration', async () => {
// Setup mocks
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1',
});
// Mock file upload response
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-123', purpose: 'ocr' },
});
// Mock signed URL response
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com' },
});
// Mock OCR response with text and images
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [
{
markdown: 'Page 1 content',
images: [{ image_base64: 'base64image1' }],
},
{
markdown: 'Page 2 content',
images: [{ image_base64: 'base64image2' }],
},
],
},
});
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Use environment variable syntax to ensure loadAuthValues is called
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-medium',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Verify OCR result
expect(result).toEqual({
filename: 'document.pdf',
bytes: expect.any(Number),
filepath: 'mistral_ocr',
text: expect.stringContaining('# PAGE 1'),
images: ['base64image1', 'base64image2'],
});
});
it('should process variable references in configuration', async () => {
// Setup mocks with environment variables
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
CUSTOM_API_KEY: 'custom-api-key',
CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1',
});
// Mock API responses
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-123', purpose: 'ocr' },
});
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com' },
});
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [{ markdown: 'Content from custom API' }],
},
});
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
apiKey: '${CUSTOM_API_KEY}',
baseURL: '${CUSTOM_BASEURL}',
mistralModel: '${CUSTOM_MODEL}',
},
},
},
};
// Set environment variable for model
process.env.CUSTOM_MODEL = 'mistral-large';
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify that custom environment variables were extracted and used
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'],
optional: expect.any(Set),
});
// Check that mistral-large was used in the OCR API call
expect(mockAxios.post).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
model: 'mistral-large',
}),
expect.anything(),
);
expect(result.text).toEqual('Content from custom API\n\n');
});
it('should fall back to default values when variables are not properly formatted', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'default-api-key',
OCR_BASEURL: undefined, // Testing optional parameter
});
mockAxios.post.mockResolvedValueOnce({
data: { id: 'file-123', purpose: 'ocr' },
});
mockAxios.get.mockResolvedValueOnce({
data: { url: 'https://signed-url.com' },
});
mockAxios.post.mockResolvedValueOnce({
data: {
pages: [{ markdown: 'Default API result' }],
},
});
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Use environment variable syntax to ensure loadAuthValues is called
apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name
baseURL: '${OCR_BASEURL}', // Using valid env var format
mistralModel: 'mistral-ocr-latest', // Plain string value
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Should use the default values
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['OCR_BASEURL', 'INVALID_FORMAT'],
optional: expect.any(Set),
});
// Should use the default model when not using environment variable format
expect(mockAxios.post).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
model: 'mistral-ocr-latest',
}),
expect.anything(),
);
});
it('should handle API errors during OCR process', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
});
// Mock file upload to fail
mockAxios.post.mockRejectedValueOnce(new Error('Upload failed'));
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
apiKey: 'OCR_API_KEY',
baseURL: 'OCR_BASEURL',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'document.pdf',
};
await expect(
uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
}),
).rejects.toThrow('Error uploading document to Mistral OCR API');
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
});
it('should handle single page documents without page numbering', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'test-api-key',
OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included
});
// Clear all previous mocks
mockAxios.post.mockClear();
mockAxios.get.mockClear();
// 1. First mock: File upload response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
);
// 2. Second mock: Signed URL response
mockAxios.get.mockImplementationOnce(() =>
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
);
// 3. Third mock: OCR response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({
data: {
pages: [{ markdown: 'Single page content' }],
},
}),
);
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
apiKey: 'OCR_API_KEY',
baseURL: 'OCR_BASEURL',
mistralModel: 'mistral-ocr-latest',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'single-page.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify that single page documents don't include page numbering
expect(result.text).not.toContain('# PAGE');
expect(result.text).toEqual('Single page content\n\n');
});
it('should use literal values in configuration when provided directly', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
// We'll still mock this but it should not be used for literal values
loadAuthValues.mockResolvedValue({});
// Clear all previous mocks
mockAxios.post.mockClear();
mockAxios.get.mockClear();
// 1. First mock: File upload response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
);
// 2. Second mock: Signed URL response
mockAxios.get.mockImplementationOnce(() =>
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
);
// 3. Third mock: OCR response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({
data: {
pages: [{ markdown: 'Processed with literal config values' }],
},
}),
);
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Direct values that should be used as-is, without variable substitution
apiKey: 'actual-api-key-value',
baseURL: 'https://direct-api-url.mistral.ai/v1',
mistralModel: 'mistral-direct-model',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'direct-values.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify the correct URL was used with the direct baseURL value
expect(mockAxios.post).toHaveBeenCalledWith(
'https://direct-api-url.mistral.ai/v1/files',
expect.any(Object),
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer actual-api-key-value',
}),
}),
);
// Check the OCR call was made with the direct model value
expect(mockAxios.post).toHaveBeenCalledWith(
'https://direct-api-url.mistral.ai/v1/ocr',
expect.objectContaining({
model: 'mistral-direct-model',
}),
expect.any(Object),
);
// Verify the result
expect(result.text).toEqual('Processed with literal config values\n\n');
// Verify loadAuthValues was never called since we used direct values
expect(loadAuthValues).not.toHaveBeenCalled();
});
it('should handle empty configuration values and use defaults', async () => {
const { loadAuthValues } = require('~/server/services/Tools/credentials');
// Set up the mock values to be returned by loadAuthValues
loadAuthValues.mockResolvedValue({
OCR_API_KEY: 'default-from-env-key',
OCR_BASEURL: 'https://default-from-env.mistral.ai/v1',
});
// Clear all previous mocks
mockAxios.post.mockClear();
mockAxios.get.mockClear();
// 1. First mock: File upload response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }),
);
// 2. Second mock: Signed URL response
mockAxios.get.mockImplementationOnce(() =>
Promise.resolve({ data: { url: 'https://signed-url.com' } }),
);
// 3. Third mock: OCR response
mockAxios.post.mockImplementationOnce(() =>
Promise.resolve({
data: {
pages: [{ markdown: 'Content from default configuration' }],
},
}),
);
const req = {
user: { id: 'user123' },
app: {
locals: {
ocr: {
// Empty string values - should fall back to defaults
apiKey: '',
baseURL: '',
mistralModel: '',
},
},
},
};
const file = {
path: '/tmp/upload/file.pdf',
originalname: 'empty-config.pdf',
};
const result = await uploadMistralOCR({
req,
file,
file_id: 'file123',
entity_id: 'entity123',
});
expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf');
// Verify loadAuthValues was called with the default variable names
expect(loadAuthValues).toHaveBeenCalledWith({
userId: 'user123',
authFields: ['OCR_BASEURL', 'OCR_API_KEY'],
optional: expect.any(Set),
});
// Verify the API calls used the default values from loadAuthValues
expect(mockAxios.post).toHaveBeenCalledWith(
'https://default-from-env.mistral.ai/v1/files',
expect.any(Object),
expect.objectContaining({
headers: expect.objectContaining({
Authorization: 'Bearer default-from-env-key',
}),
}),
);
// Verify the OCR model defaulted to mistral-ocr-latest
expect(mockAxios.post).toHaveBeenCalledWith(
'https://default-from-env.mistral.ai/v1/ocr',
expect.objectContaining({
model: 'mistral-ocr-latest',
}),
expect.any(Object),
);
// Check result
expect(result.text).toEqual('Content from default configuration\n\n');
});
});
});

View File

@@ -1,5 +0,0 @@
const crud = require('./crud');
module.exports = {
...crud,
};

View File

@@ -1,445 +0,0 @@
const fs = require('fs');
const path = require('path');
const fetch = require('node-fetch');
const { FileSources } = require('librechat-data-provider');
const {
PutObjectCommand,
GetObjectCommand,
HeadObjectCommand,
DeleteObjectCommand,
} = require('@aws-sdk/client-s3');
const { getSignedUrl } = require('@aws-sdk/s3-request-presigner');
const { initializeS3 } = require('./initialize');
const { logger } = require('~/config');
const bucketName = process.env.AWS_BUCKET_NAME;
const defaultBasePath = 'images';
let s3UrlExpirySeconds = 7 * 24 * 60 * 60;
if (process.env.S3_URL_EXPIRY_SECONDS !== undefined) {
const parsed = parseInt(process.env.S3_URL_EXPIRY_SECONDS, 10);
if (!isNaN(parsed) && parsed > 0) {
s3UrlExpirySeconds = Math.min(parsed, 7 * 24 * 60 * 60);
} else {
logger.warn(
`[S3] Invalid S3_URL_EXPIRY_SECONDS value: "${process.env.S3_URL_EXPIRY_SECONDS}". Using 7-day expiry.`,
);
}
}
/**
* Constructs the S3 key based on the base path, user ID, and file name.
*/
const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`;
/**
* Uploads a buffer to S3 and returns a signed URL.
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {Buffer} params.buffer - The buffer containing file data.
* @param {string} params.fileName - The file name to use in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} Signed URL of the uploaded file.
*/
async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) {
const key = getS3Key(basePath, userId, fileName);
const params = { Bucket: bucketName, Key: key, Body: buffer };
try {
const s3 = initializeS3();
await s3.send(new PutObjectCommand(params));
return await getS3URL({ userId, fileName, basePath });
} catch (error) {
logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message);
throw error;
}
}
/**
* Retrieves a URL for a file stored in S3.
* Returns a signed URL with expiration time or a proxy URL based on config
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {string} params.fileName - The file name in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} A URL to access the S3 object
*/
async function getS3URL({ userId, fileName, basePath = defaultBasePath }) {
const key = getS3Key(basePath, userId, fileName);
const params = { Bucket: bucketName, Key: key };
try {
const s3 = initializeS3();
return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: s3UrlExpirySeconds });
} catch (error) {
logger.error('[getS3URL] Error getting signed URL from S3:', error.message);
throw error;
}
}
/**
* Saves a file from a given URL to S3.
*
* @param {Object} params
* @param {string} params.userId - The user's unique identifier.
* @param {string} params.URL - The source URL of the file.
* @param {string} params.fileName - The file name to use in S3.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<string>} Signed URL of the uploaded file.
*/
async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) {
try {
const response = await fetch(URL);
const buffer = await response.buffer();
// Optionally you can call getBufferMetadata(buffer) if needed.
return await saveBufferToS3({ userId, buffer, fileName, basePath });
} catch (error) {
logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message);
throw error;
}
}
/**
* Deletes a file from S3.
*
* @param {Object} params
* @param {ServerRequest} params.req
* @param {MongoFile} params.file - The file object to delete.
* @returns {Promise<void>}
*/
async function deleteFileFromS3(req, file) {
const key = extractKeyFromS3Url(file.filepath);
const params = { Bucket: bucketName, Key: key };
if (!key.includes(req.user.id)) {
const message = `[deleteFileFromS3] User ID mismatch: ${req.user.id} vs ${key}`;
logger.error(message);
throw new Error(message);
}
try {
const s3 = initializeS3();
try {
const headCommand = new HeadObjectCommand(params);
await s3.send(headCommand);
logger.debug('[deleteFileFromS3] File exists, proceeding with deletion');
} catch (headErr) {
if (headErr.name === 'NotFound') {
logger.warn(`[deleteFileFromS3] File does not exist: ${key}`);
return;
}
}
const deleteResult = await s3.send(new DeleteObjectCommand(params));
logger.debug('[deleteFileFromS3] Delete command response:', JSON.stringify(deleteResult));
try {
await s3.send(new HeadObjectCommand(params));
logger.error('[deleteFileFromS3] File still exists after deletion!');
} catch (verifyErr) {
if (verifyErr.name === 'NotFound') {
logger.debug(`[deleteFileFromS3] Verified file is deleted: ${key}`);
} else {
logger.error('[deleteFileFromS3] Error verifying deletion:', verifyErr);
}
}
logger.debug('[deleteFileFromS3] S3 File deletion completed');
} catch (error) {
logger.error(`[deleteFileFromS3] Error deleting file from S3: ${error.message}`);
logger.error(error.stack);
// If the file is not found, we can safely return.
if (error.code === 'NoSuchKey') {
return;
}
throw error;
}
}
/**
* Uploads a local file to S3 by streaming it directly without loading into memory.
*
* @param {Object} params
* @param {import('express').Request} params.req - The Express request (must include user).
* @param {Express.Multer.File} params.file - The file object from Multer.
* @param {string} params.file_id - Unique file identifier.
* @param {string} [params.basePath='images'] - The base path in the bucket.
* @returns {Promise<{ filepath: string, bytes: number }>}
*/
async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) {
try {
const inputFilePath = file.path;
const userId = req.user.id;
const fileName = `${file_id}__${path.basename(inputFilePath)}`;
const key = getS3Key(basePath, userId, fileName);
const stats = await fs.promises.stat(inputFilePath);
const bytes = stats.size;
const fileStream = fs.createReadStream(inputFilePath);
const s3 = initializeS3();
const uploadParams = {
Bucket: bucketName,
Key: key,
Body: fileStream,
};
await s3.send(new PutObjectCommand(uploadParams));
const fileURL = await getS3URL({ userId, fileName, basePath });
return { filepath: fileURL, bytes };
} catch (error) {
logger.error('[uploadFileToS3] Error streaming file to S3:', error);
try {
if (file && file.path) {
await fs.promises.unlink(file.path);
}
} catch (unlinkError) {
logger.error(
'[uploadFileToS3] Error deleting temporary file, likely already deleted:',
unlinkError.message,
);
}
throw error;
}
}
/**
* Extracts the S3 key from a URL or returns the key if already properly formatted
*
* @param {string} fileUrlOrKey - The file URL or key
* @returns {string} The S3 key
*/
function extractKeyFromS3Url(fileUrlOrKey) {
if (!fileUrlOrKey) {
throw new Error('Invalid input: URL or key is empty');
}
try {
const url = new URL(fileUrlOrKey);
return url.pathname.substring(1);
} catch (error) {
const parts = fileUrlOrKey.split('/');
if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) {
return fileUrlOrKey;
}
return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey;
}
}
/**
* Retrieves a readable stream for a file stored in S3.
*
* @param {ServerRequest} req - Server request object.
* @param {string} filePath - The S3 key of the file.
* @returns {Promise<NodeJS.ReadableStream>}
*/
async function getS3FileStream(_req, filePath) {
try {
const Key = extractKeyFromS3Url(filePath);
const params = { Bucket: bucketName, Key };
const s3 = initializeS3();
const data = await s3.send(new GetObjectCommand(params));
return data.Body; // Returns a Node.js ReadableStream.
} catch (error) {
logger.error('[getS3FileStream] Error retrieving S3 file stream:', error);
throw error;
}
}
/**
* Determines if a signed S3 URL is close to expiration
*
* @param {string} signedUrl - The signed S3 URL
* @param {number} bufferSeconds - Buffer time in seconds
* @returns {boolean} True if the URL needs refreshing
*/
function needsRefresh(signedUrl, bufferSeconds) {
try {
// Parse the URL
const url = new URL(signedUrl);
// Check if it has the signature parameters that indicate it's a signed URL
// X-Amz-Signature is the most reliable indicator for AWS signed URLs
if (!url.searchParams.has('X-Amz-Signature')) {
// Not a signed URL, so no expiration to check (or it's already a proxy URL)
return false;
}
// Extract the expiration time from the URL
const expiresParam = url.searchParams.get('X-Amz-Expires');
const dateParam = url.searchParams.get('X-Amz-Date');
if (!expiresParam || !dateParam) {
// Missing expiration information, assume it needs refresh to be safe
return true;
}
// Parse the AWS date format (YYYYMMDDTHHMMSSZ)
const year = dateParam.substring(0, 4);
const month = dateParam.substring(4, 6);
const day = dateParam.substring(6, 8);
const hour = dateParam.substring(9, 11);
const minute = dateParam.substring(11, 13);
const second = dateParam.substring(13, 15);
const dateObj = new Date(`${year}-${month}-${day}T${hour}:${minute}:${second}Z`);
const expiresAtDate = new Date(dateObj.getTime() + parseInt(expiresParam) * 1000);
// Check if it's close to expiration
const now = new Date();
const bufferTime = new Date(now.getTime() + bufferSeconds * 1000);
return expiresAtDate <= bufferTime;
} catch (error) {
logger.error('Error checking URL expiration:', error);
// If we can't determine, assume it needs refresh to be safe
return true;
}
}
/**
* Generates a new URL for an expired S3 URL
* @param {string} currentURL - The current file URL
* @returns {Promise<string | undefined>}
*/
async function getNewS3URL(currentURL) {
try {
const s3Key = extractKeyFromS3Url(currentURL);
if (!s3Key) {
return;
}
const keyParts = s3Key.split('/');
if (keyParts.length < 3) {
return;
}
const basePath = keyParts[0];
const userId = keyParts[1];
const fileName = keyParts.slice(2).join('/');
return await getS3URL({
userId,
fileName,
basePath,
});
} catch (error) {
logger.error('Error getting new S3 URL:', error);
}
}
/**
* Refreshes S3 URLs for an array of files if they're expired or close to expiring
*
* @param {IMongoFile[]} files - Array of file documents
* @param {(files: MongoFile[]) => Promise<void>} batchUpdateFiles - Function to update files in the database
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
* @returns {Promise<IMongoFile[]>} The files with refreshed URLs if needed
*/
async function refreshS3FileUrls(files, batchUpdateFiles, bufferSeconds = 3600) {
if (!files || !Array.isArray(files) || files.length === 0) {
return files;
}
const filesToUpdate = [];
for (let i = 0; i < files.length; i++) {
const file = files[i];
if (!file?.file_id) {
continue;
}
if (file.source !== FileSources.s3) {
continue;
}
if (!file.filepath) {
continue;
}
if (!needsRefresh(file.filepath, bufferSeconds)) {
continue;
}
try {
const newURL = await getNewS3URL(file.filepath);
if (!newURL) {
continue;
}
filesToUpdate.push({
file_id: file.file_id,
filepath: newURL,
});
files[i].filepath = newURL;
} catch (error) {
logger.error(`Error refreshing S3 URL for file ${file.file_id}:`, error);
}
}
if (filesToUpdate.length > 0) {
await batchUpdateFiles(filesToUpdate);
}
return files;
}
/**
* Refreshes a single S3 URL if it's expired or close to expiring
*
* @param {{ filepath: string, source: string }} fileObj - Simple file object containing filepath and source
* @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration
* @returns {Promise<string>} The refreshed URL or the original URL if no refresh needed
*/
async function refreshS3Url(fileObj, bufferSeconds = 3600) {
if (!fileObj || fileObj.source !== FileSources.s3 || !fileObj.filepath) {
return fileObj?.filepath || '';
}
if (!needsRefresh(fileObj.filepath, bufferSeconds)) {
return fileObj.filepath;
}
try {
const s3Key = extractKeyFromS3Url(fileObj.filepath);
if (!s3Key) {
logger.warn(`Unable to extract S3 key from URL: ${fileObj.filepath}`);
return fileObj.filepath;
}
const keyParts = s3Key.split('/');
if (keyParts.length < 3) {
logger.warn(`Invalid S3 key format: ${s3Key}`);
return fileObj.filepath;
}
const basePath = keyParts[0];
const userId = keyParts[1];
const fileName = keyParts.slice(2).join('/');
const newUrl = await getS3URL({
userId,
fileName,
basePath,
});
logger.debug(`Refreshed S3 URL for key: ${s3Key}`);
return newUrl;
} catch (error) {
logger.error(`Error refreshing S3 URL: ${error.message}`);
return fileObj.filepath;
}
}
module.exports = {
saveBufferToS3,
saveURLToS3,
getS3URL,
deleteFileFromS3,
uploadFileToS3,
getS3FileStream,
refreshS3FileUrls,
refreshS3Url,
needsRefresh,
getNewS3URL,
};

Some files were not shown because too many files have changed in this diff Show More