Compare commits
36 Commits
v0.7.6
...
feat/conve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
672d12bf85 | ||
|
|
1972162970 | ||
|
|
d048a10b2e | ||
|
|
e6670cd411 | ||
|
|
b35a8b78e2 | ||
|
|
e309c6abef | ||
|
|
b55e695541 | ||
|
|
24d30d7428 | ||
|
|
aa80e4594e | ||
|
|
24beda3d69 | ||
|
|
0855677a36 | ||
|
|
ea1a5c8a30 | ||
|
|
0f95604a67 | ||
|
|
687ab32bd3 | ||
|
|
dd927583a7 | ||
|
|
69a9b8b911 | ||
|
|
916faf6447 | ||
|
|
8aa1e731ca | ||
|
|
b01c744eb8 | ||
|
|
7987e04a2c | ||
|
|
766657da83 | ||
|
|
7c61115a88 | ||
|
|
c26b54c74d | ||
|
|
bf0a84e45a | ||
|
|
28966e3ddc | ||
|
|
65b2d647a1 | ||
|
|
6c9a468b8e | ||
|
|
cb1921626e | ||
|
|
d9c59b08e6 | ||
|
|
24cad6bbd4 | ||
|
|
a423eb8c7b | ||
|
|
d6f1ecf75c | ||
|
|
04923dd185 | ||
|
|
dfe5498301 | ||
|
|
bdb222d5f4 | ||
|
|
9bca2ae953 |
@@ -256,6 +256,7 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
|
||||
# DALLE3_AZURE_API_VERSION=
|
||||
# DALLE2_AZURE_API_VERSION=
|
||||
|
||||
|
||||
# Google
|
||||
#-----------------
|
||||
GOOGLE_SEARCH_API_KEY=
|
||||
@@ -514,4 +515,9 @@ HELP_AND_FAQ_URL=https://librechat.ai
|
||||
|
||||
# no-cache: Forces validation with server before using cached version
|
||||
# no-store: Prevents storing the response entirely
|
||||
# must-revalidate: Prevents using stale content when offline
|
||||
# must-revalidate: Prevents using stale content when offline
|
||||
|
||||
#=====================================================#
|
||||
# OpenWeather #
|
||||
#=====================================================#
|
||||
OPENWEATHER_API_KEY=
|
||||
@@ -2,7 +2,7 @@
|
||||
# v0.7.6
|
||||
|
||||
# Base for all builds
|
||||
FROM node:20-alpine AS base
|
||||
FROM node:20-alpine AS base-min
|
||||
WORKDIR /app
|
||||
RUN apk --no-cache add curl
|
||||
RUN npm config set fetch-retry-maxtimeout 600000 && \
|
||||
@@ -13,6 +13,10 @@ COPY packages/data-provider/package*.json ./packages/data-provider/
|
||||
COPY packages/mcp/package*.json ./packages/mcp/
|
||||
COPY client/package*.json ./client/
|
||||
COPY api/package*.json ./api/
|
||||
|
||||
# Install all dependencies for every build
|
||||
FROM base-min AS base
|
||||
WORKDIR /app
|
||||
RUN npm ci
|
||||
|
||||
# Build data-provider
|
||||
@@ -20,7 +24,6 @@ FROM base AS data-provider-build
|
||||
WORKDIR /app/packages/data-provider
|
||||
COPY packages/data-provider ./
|
||||
RUN npm run build
|
||||
RUN npm prune --production
|
||||
|
||||
# Build mcp package
|
||||
FROM base AS mcp-build
|
||||
@@ -28,7 +31,6 @@ WORKDIR /app/packages/mcp
|
||||
COPY packages/mcp ./
|
||||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
RUN npm run build
|
||||
RUN npm prune --production
|
||||
|
||||
# Client build
|
||||
FROM base AS client-build
|
||||
@@ -37,18 +39,18 @@ COPY client ./
|
||||
COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist
|
||||
ENV NODE_OPTIONS="--max-old-space-size=2048"
|
||||
RUN npm run build
|
||||
RUN npm prune --production
|
||||
|
||||
# API setup (including client dist)
|
||||
FROM base AS api-build
|
||||
FROM base-min AS api-build
|
||||
WORKDIR /app
|
||||
# Install only production deps
|
||||
RUN npm ci --omit=dev
|
||||
COPY api ./api
|
||||
COPY config ./config
|
||||
COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist
|
||||
COPY --from=mcp-build /app/packages/mcp/dist ./packages/mcp/dist
|
||||
COPY --from=client-build /app/client/dist ./client/dist
|
||||
WORKDIR /app/api
|
||||
RUN npm prune --production
|
||||
EXPOSE 3080
|
||||
ENV HOST=0.0.0.0
|
||||
CMD ["node", "server/index.js"]
|
||||
CMD ["node", "server/index.js"]
|
||||
|
||||
@@ -114,7 +114,8 @@ LibreChat brings together the future of assistant AIs with the revolutionary tec
|
||||
|
||||
With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform.
|
||||
|
||||
[](https://www.youtube.com/watch?v=IDukQ7a2f3U)
|
||||
[](https://www.youtube.com/watch?v=ilfwGQtJNlI)
|
||||
|
||||
Click on the thumbnail to open the video☝️
|
||||
|
||||
---
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
const Anthropic = require('@anthropic-ai/sdk');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
@@ -19,6 +18,7 @@ const {
|
||||
} = require('./prompts');
|
||||
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
@@ -26,8 +26,6 @@ const { logger } = require('~/config');
|
||||
const HUMAN_PROMPT = '\n\nHuman:';
|
||||
const AI_PROMPT = '\n\nAssistant:';
|
||||
|
||||
const tokenizersCache = {};
|
||||
|
||||
/** Helper function to introduce a delay before retrying */
|
||||
function delayBeforeRetry(attempts, baseDelay = 1000) {
|
||||
return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts));
|
||||
@@ -149,7 +147,6 @@ class AnthropicClient extends BaseClient {
|
||||
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
|
||||
|
||||
return this;
|
||||
}
|
||||
@@ -849,22 +846,18 @@ class AnthropicClient extends BaseClient {
|
||||
logger.debug('AnthropicClient doesn\'t use getBuildMessagesOptions');
|
||||
}
|
||||
|
||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
if (tokenizersCache[encoding]) {
|
||||
return tokenizersCache[encoding];
|
||||
}
|
||||
let tokenizer;
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||
}
|
||||
tokenizersCache[encoding] = tokenizer;
|
||||
return tokenizer;
|
||||
getEncoding() {
|
||||
return 'cl100k_base';
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
|
||||
* @param {string} text - The text to get the token count for.
|
||||
* @returns {number} The token count of the given text.
|
||||
*/
|
||||
getTokenCount(text) {
|
||||
return this.gptEncoder.encode(text, 'all').length;
|
||||
const encoding = this.getEncoding();
|
||||
return Tokenizer.getTokenCount(text, encoding);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -4,6 +4,7 @@ const {
|
||||
supportsBalanceCheck,
|
||||
isAgentsEndpoint,
|
||||
isParamEndpoint,
|
||||
EModelEndpoint,
|
||||
ErrorTypes,
|
||||
Constants,
|
||||
CacheKeys,
|
||||
@@ -11,6 +12,7 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
|
||||
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const checkBalance = require('~/models/checkBalance');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -95,7 +97,7 @@ class BaseClient {
|
||||
* @returns {number}
|
||||
*/
|
||||
getTokenCountForResponse(responseMessage) {
|
||||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', responseMessage);
|
||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', responseMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -106,7 +108,7 @@ class BaseClient {
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async recordTokenUsage({ promptTokens, completionTokens }) {
|
||||
logger.debug('`[BaseClient] recordTokenUsage` not implemented.', {
|
||||
logger.debug('[BaseClient] `recordTokenUsage` not implemented.', {
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
});
|
||||
@@ -287,6 +289,9 @@ class BaseClient {
|
||||
}
|
||||
|
||||
async handleTokenCountMap(tokenCountMap) {
|
||||
if (this.clientName === EModelEndpoint.agents) {
|
||||
return;
|
||||
}
|
||||
if (this.currentMessages.length === 0) {
|
||||
return;
|
||||
}
|
||||
@@ -394,6 +399,21 @@ class BaseClient {
|
||||
_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
|
||||
let payload = this.addInstructions(formattedMessages, _instructions);
|
||||
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
|
||||
if (this.clientName === EModelEndpoint.agents) {
|
||||
const { dbMessages, editedIndices } = truncateToolCallOutputs(
|
||||
orderedWithInstructions,
|
||||
this.maxContextTokens,
|
||||
this.getTokenCountForMessage.bind(this),
|
||||
);
|
||||
|
||||
if (editedIndices.length > 0) {
|
||||
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
|
||||
for (const index of editedIndices) {
|
||||
payload[index].content = dbMessages[index].content;
|
||||
}
|
||||
orderedWithInstructions = dbMessages;
|
||||
}
|
||||
}
|
||||
|
||||
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
|
||||
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
|
||||
@@ -625,7 +645,7 @@ class BaseClient {
|
||||
await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts });
|
||||
} else {
|
||||
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
|
||||
completionTokens = this.getTokenCount(completion);
|
||||
completionTokens = responseMessage.tokenCount;
|
||||
}
|
||||
|
||||
await this.recordTokenUsage({ promptTokens, completionTokens, usage });
|
||||
@@ -649,15 +669,17 @@ class BaseClient {
|
||||
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
this.savedMessageIds.add(responseMessage.messageId);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
if (responseMessage.text) {
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
}
|
||||
delete responseMessage.tokenCount;
|
||||
return responseMessage;
|
||||
}
|
||||
@@ -929,6 +951,24 @@ class BaseClient {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === 'tool_call' && item.tool_call != null) {
|
||||
const toolName = item.tool_call?.name || '';
|
||||
if (toolName != null && toolName && typeof toolName === 'string') {
|
||||
numTokens += this.getTokenCount(toolName);
|
||||
}
|
||||
|
||||
const args = item.tool_call?.args || '';
|
||||
if (args != null && args && typeof args === 'string') {
|
||||
numTokens += this.getTokenCount(args);
|
||||
}
|
||||
|
||||
const output = item.tool_call?.output || '';
|
||||
if (output != null && output && typeof output === 'string') {
|
||||
numTokens += this.getTokenCount(output);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const nestedValue = item[item.type];
|
||||
|
||||
if (!nestedValue) {
|
||||
|
||||
@@ -6,7 +6,6 @@ const { ChatGoogleVertexAI } = require('@langchain/google-vertexai');
|
||||
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
|
||||
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
|
||||
const { AIMessage, HumanMessage, SystemMessage } = require('@langchain/core/messages');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
validateVisionModel,
|
||||
getResponseSender,
|
||||
@@ -17,6 +16,7 @@ const {
|
||||
AuthKeys,
|
||||
} = require('librechat-data-provider');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
@@ -31,7 +31,6 @@ const BaseClient = require('./BaseClient');
|
||||
const loc = process.env.GOOGLE_LOC || 'us-central1';
|
||||
const publisher = 'google';
|
||||
const endpointPrefix = `${loc}-aiplatform.googleapis.com`;
|
||||
const tokenizersCache = {};
|
||||
|
||||
const settings = endpointSettings[EModelEndpoint.google];
|
||||
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
|
||||
@@ -177,25 +176,15 @@ class GoogleClient extends BaseClient {
|
||||
// without tripping the stop sequences, so I'm using "||>" instead.
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
|
||||
} else if (isTextModel) {
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
});
|
||||
} else {
|
||||
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
|
||||
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
|
||||
// as a single token. So we're using this instead.
|
||||
this.startToken = '||>';
|
||||
this.endToken = '';
|
||||
try {
|
||||
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
|
||||
} catch {
|
||||
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
|
||||
}
|
||||
}
|
||||
|
||||
if (!this.modelOptions.stop) {
|
||||
@@ -873,6 +862,7 @@ class GoogleClient extends BaseClient {
|
||||
|
||||
getSaveOptions() {
|
||||
return {
|
||||
endpointType: null,
|
||||
artifacts: this.options.artifacts,
|
||||
promptPrefix: this.options.promptPrefix,
|
||||
modelLabel: this.options.modelLabel,
|
||||
@@ -896,53 +886,58 @@ class GoogleClient extends BaseClient {
|
||||
}
|
||||
|
||||
getSafetySettings() {
|
||||
const isGemini2 = this.modelOptions.model.includes('gemini-2.0');
|
||||
const mapThreshold = (value) => {
|
||||
if (isGemini2 && value === 'BLOCK_NONE') {
|
||||
return 'OFF';
|
||||
}
|
||||
return value;
|
||||
};
|
||||
|
||||
return [
|
||||
{
|
||||
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
||||
threshold:
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_HATE_SPEECH',
|
||||
threshold: process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_HARASSMENT',
|
||||
threshold: process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
||||
threshold:
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
|
||||
/**
|
||||
* Note: this was added since `gemini-2.0-flash-thinking-exp-1219` does not
|
||||
* accept 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' for 'HARM_CATEGORY_CIVIC_INTEGRITY'
|
||||
* */
|
||||
threshold: process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE',
|
||||
threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/* TO-DO: Handle tokens with Google tokenization NOTE: these are required */
|
||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
if (tokenizersCache[encoding]) {
|
||||
return tokenizersCache[encoding];
|
||||
}
|
||||
let tokenizer;
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||
}
|
||||
tokenizersCache[encoding] = tokenizer;
|
||||
return tokenizer;
|
||||
getEncoding() {
|
||||
return 'cl100k_base';
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the token count of a given text. It also checks and resets the tokenizers if necessary.
|
||||
* @param {string} text - The text to get the token count for.
|
||||
* @returns {number} The token count of the given text.
|
||||
*/
|
||||
getTokenCount(text) {
|
||||
return this.gptEncoder.encode(text, 'all').length;
|
||||
const encoding = this.getEncoding();
|
||||
return Tokenizer.getTokenCount(text, encoding);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ const {
|
||||
validateVisionModel,
|
||||
mapModelToAzureConfig,
|
||||
} = require('librechat-data-provider');
|
||||
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
|
||||
const {
|
||||
extractBaseURL,
|
||||
constructAzureURL,
|
||||
@@ -29,6 +28,7 @@ const {
|
||||
createContextHandlers,
|
||||
} = require('./prompts');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const Tokenizer = require('~/server/services/Tokenizer');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { isEnabled, sleep } = require('~/server/utils');
|
||||
const { handleOpenAIErrors } = require('./tools/util');
|
||||
@@ -40,11 +40,6 @@ const { tokenSplit } = require('./document');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
// Cache to store Tiktoken instances
|
||||
const tokenizersCache = {};
|
||||
// Counter for keeping track of the number of tokenizer calls
|
||||
let tokenizerCallsCount = 0;
|
||||
|
||||
class OpenAIClient extends BaseClient {
|
||||
constructor(apiKey, options = {}) {
|
||||
super(apiKey, options);
|
||||
@@ -307,75 +302,8 @@ class OpenAIClient extends BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
// Selects an appropriate tokenizer based on the current configuration of the client instance.
|
||||
// It takes into account factors such as whether it's a chat completion, an unofficial chat GPT model, etc.
|
||||
selectTokenizer() {
|
||||
let tokenizer;
|
||||
this.encoding = 'text-davinci-003';
|
||||
if (this.isChatCompletion) {
|
||||
this.encoding = this.modelOptions.model.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
|
||||
tokenizer = this.constructor.getTokenizer(this.encoding);
|
||||
} else if (this.isUnofficialChatGptModel) {
|
||||
const extendSpecialTokens = {
|
||||
'<|im_start|>': 100264,
|
||||
'<|im_end|>': 100265,
|
||||
};
|
||||
tokenizer = this.constructor.getTokenizer(this.encoding, true, extendSpecialTokens);
|
||||
} else {
|
||||
try {
|
||||
const { model } = this.modelOptions;
|
||||
this.encoding = model.includes('instruct') ? 'text-davinci-003' : model;
|
||||
tokenizer = this.constructor.getTokenizer(this.encoding, true);
|
||||
} catch {
|
||||
tokenizer = this.constructor.getTokenizer('text-davinci-003', true);
|
||||
}
|
||||
}
|
||||
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
// Retrieves a tokenizer either from the cache or creates a new one if one doesn't exist in the cache.
|
||||
// If a tokenizer is being created, it's also added to the cache.
|
||||
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
|
||||
let tokenizer;
|
||||
if (tokenizersCache[encoding]) {
|
||||
tokenizer = tokenizersCache[encoding];
|
||||
} else {
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding, extendSpecialTokens);
|
||||
}
|
||||
tokenizersCache[encoding] = tokenizer;
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
// Frees all encoders in the cache and resets the count.
|
||||
static freeAndResetAllEncoders() {
|
||||
try {
|
||||
Object.keys(tokenizersCache).forEach((key) => {
|
||||
if (tokenizersCache[key]) {
|
||||
tokenizersCache[key].free();
|
||||
delete tokenizersCache[key];
|
||||
}
|
||||
});
|
||||
// Reset count
|
||||
tokenizerCallsCount = 1;
|
||||
} catch (error) {
|
||||
logger.error('[OpenAIClient] Free and reset encoders error', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the cache of tokenizers has reached a certain size. If it has, it frees and resets all tokenizers.
|
||||
resetTokenizersIfNecessary() {
|
||||
if (tokenizerCallsCount >= 25) {
|
||||
if (this.options.debug) {
|
||||
logger.debug('[OpenAIClient] freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||
}
|
||||
this.constructor.freeAndResetAllEncoders();
|
||||
}
|
||||
tokenizerCallsCount++;
|
||||
getEncoding() {
|
||||
return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -384,15 +312,8 @@ class OpenAIClient extends BaseClient {
|
||||
* @returns {number} The token count of the given text.
|
||||
*/
|
||||
getTokenCount(text) {
|
||||
this.resetTokenizersIfNecessary();
|
||||
try {
|
||||
const tokenizer = this.selectTokenizer();
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
} catch (error) {
|
||||
this.constructor.freeAndResetAllEncoders();
|
||||
const tokenizer = this.selectTokenizer();
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
}
|
||||
const encoding = this.getEncoding();
|
||||
return Tokenizer.getTokenCount(text, encoding);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1372,7 +1293,7 @@ ${convo}
|
||||
});
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const token = chunk.choices[0]?.delta?.content || '';
|
||||
const token = chunk?.choices?.[0]?.delta?.content || '';
|
||||
intermediateReply.push(token);
|
||||
onProgress(token);
|
||||
if (abortController.signal.aborted) {
|
||||
|
||||
@@ -256,15 +256,17 @@ class PluginsClient extends OpenAIClient {
|
||||
}
|
||||
|
||||
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessage.messageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
if (responseMessage.text) {
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessage.messageId,
|
||||
{
|
||||
text: responseMessage.text,
|
||||
complete: true,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
}
|
||||
delete responseMessage.tokenCount;
|
||||
return { ...responseMessage, ...result };
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ const summaryPrompts = require('./summaryPrompts');
|
||||
const handleInputs = require('./handleInputs');
|
||||
const instructions = require('./instructions');
|
||||
const titlePrompts = require('./titlePrompts');
|
||||
const truncateText = require('./truncateText');
|
||||
const truncate = require('./truncate');
|
||||
const createVisionPrompt = require('./createVisionPrompt');
|
||||
const createContextHandlers = require('./createContextHandlers');
|
||||
|
||||
@@ -15,7 +15,7 @@ module.exports = {
|
||||
...handleInputs,
|
||||
...instructions,
|
||||
...titlePrompts,
|
||||
...truncateText,
|
||||
...truncate,
|
||||
createVisionPrompt,
|
||||
createContextHandlers,
|
||||
};
|
||||
|
||||
115
api/app/clients/prompts/truncate.js
Normal file
115
api/app/clients/prompts/truncate.js
Normal file
@@ -0,0 +1,115 @@
|
||||
const MAX_CHAR = 255;
|
||||
|
||||
/**
|
||||
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
|
||||
* if the original text exceeds the maximum length.
|
||||
*
|
||||
* @param {string} text - The text to be truncated.
|
||||
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
|
||||
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
|
||||
*/
|
||||
function truncateText(text, maxLength = MAX_CHAR) {
|
||||
if (text.length > maxLength) {
|
||||
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
|
||||
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
|
||||
* of ellipsis and notification if the original text exceeds the maximum length.
|
||||
*
|
||||
* @param {string} text - The text to be truncated.
|
||||
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
|
||||
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
|
||||
*/
|
||||
function smartTruncateText(text, maxLength = MAX_CHAR) {
|
||||
const ellipsis = '...';
|
||||
const notification = ' [text truncated for brevity]';
|
||||
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);
|
||||
|
||||
if (text.length > maxLength) {
|
||||
const startLastHalf = text.length - halfMaxLength;
|
||||
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {TMessage[]} _messages
|
||||
* @param {number} maxContextTokens
|
||||
* @param {function({role: string, content: TMessageContent[]}): number} getTokenCountForMessage
|
||||
*
|
||||
* @returns {{
|
||||
* dbMessages: TMessage[],
|
||||
* editedIndices: number[]
|
||||
* }}
|
||||
*/
|
||||
function truncateToolCallOutputs(_messages, maxContextTokens, getTokenCountForMessage) {
|
||||
const THRESHOLD_PERCENTAGE = 0.5;
|
||||
const targetTokenLimit = maxContextTokens * THRESHOLD_PERCENTAGE;
|
||||
|
||||
let currentTokenCount = 3;
|
||||
const messages = [..._messages];
|
||||
const processedMessages = [];
|
||||
let currentIndex = messages.length;
|
||||
const editedIndices = new Set();
|
||||
while (messages.length > 0) {
|
||||
currentIndex--;
|
||||
const message = messages.pop();
|
||||
currentTokenCount += message.tokenCount;
|
||||
if (currentTokenCount < targetTokenLimit) {
|
||||
processedMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!message.content || !Array.isArray(message.content)) {
|
||||
processedMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolCallIndices = message.content
|
||||
.map((item, index) => (item.type === 'tool_call' ? index : -1))
|
||||
.filter((index) => index !== -1)
|
||||
.reverse();
|
||||
|
||||
if (toolCallIndices.length === 0) {
|
||||
processedMessages.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const newContent = [...message.content];
|
||||
|
||||
// Truncate all tool outputs since we're over threshold
|
||||
for (const index of toolCallIndices) {
|
||||
const toolCall = newContent[index].tool_call;
|
||||
if (!toolCall || !toolCall.output) {
|
||||
continue;
|
||||
}
|
||||
|
||||
editedIndices.add(currentIndex);
|
||||
|
||||
newContent[index] = {
|
||||
...newContent[index],
|
||||
tool_call: {
|
||||
...toolCall,
|
||||
output: '[OUTPUT_OMITTED_FOR_BREVITY]',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const truncatedMessage = {
|
||||
...message,
|
||||
content: newContent,
|
||||
tokenCount: getTokenCountForMessage({ role: 'assistant', content: newContent }),
|
||||
};
|
||||
|
||||
processedMessages.push(truncatedMessage);
|
||||
}
|
||||
|
||||
return { dbMessages: processedMessages.reverse(), editedIndices: Array.from(editedIndices) };
|
||||
}
|
||||
|
||||
module.exports = { truncateText, smartTruncateText, truncateToolCallOutputs };
|
||||
@@ -1,40 +0,0 @@
|
||||
const MAX_CHAR = 255;
|
||||
|
||||
/**
|
||||
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
|
||||
* if the original text exceeds the maximum length.
|
||||
*
|
||||
* @param {string} text - The text to be truncated.
|
||||
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
|
||||
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
|
||||
*/
|
||||
function truncateText(text, maxLength = MAX_CHAR) {
|
||||
if (text.length > maxLength) {
|
||||
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
|
||||
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
|
||||
* of ellipsis and notification if the original text exceeds the maximum length.
|
||||
*
|
||||
* @param {string} text - The text to be truncated.
|
||||
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
|
||||
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
|
||||
*/
|
||||
function smartTruncateText(text, maxLength = MAX_CHAR) {
|
||||
const ellipsis = '...';
|
||||
const notification = ' [text truncated for brevity]';
|
||||
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);
|
||||
|
||||
if (text.length > maxLength) {
|
||||
const startLastHalf = text.length - halfMaxLength;
|
||||
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
module.exports = { truncateText, smartTruncateText };
|
||||
@@ -615,9 +615,9 @@ describe('BaseClient', () => {
|
||||
test('getTokenCount for response is called with the correct arguments', async () => {
|
||||
const tokenCountMap = {}; // Mock tokenCountMap
|
||||
TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap });
|
||||
TestClient.getTokenCount = jest.fn();
|
||||
TestClient.getTokenCountForResponse = jest.fn();
|
||||
const response = await TestClient.sendMessage('Hello, world!', {});
|
||||
expect(TestClient.getTokenCount).toHaveBeenCalledWith(response.text);
|
||||
expect(TestClient.getTokenCountForResponse).toHaveBeenCalledWith(response);
|
||||
});
|
||||
|
||||
test('returns an object with the correct shape', async () => {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
jest.mock('~/cache/getLogStores');
|
||||
require('dotenv').config();
|
||||
const OpenAI = require('openai');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||
const { genAzureChatCompletion } = require('~/utils/azureUtils');
|
||||
const OpenAIClient = require('../OpenAIClient');
|
||||
@@ -134,7 +136,13 @@ OpenAI.mockImplementation(() => ({
|
||||
}));
|
||||
|
||||
describe('OpenAIClient', () => {
|
||||
let client, client2;
|
||||
const mockSet = jest.fn();
|
||||
const mockCache = { set: mockSet };
|
||||
|
||||
beforeEach(() => {
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
});
|
||||
let client;
|
||||
const model = 'gpt-4';
|
||||
const parentMessageId = '1';
|
||||
const messages = [
|
||||
@@ -176,7 +184,6 @@ describe('OpenAIClient', () => {
|
||||
beforeEach(() => {
|
||||
const options = { ...defaultOptions };
|
||||
client = new OpenAIClient('test-api-key', options);
|
||||
client2 = new OpenAIClient('test-api-key', options);
|
||||
client.summarizeMessages = jest.fn().mockResolvedValue({
|
||||
role: 'assistant',
|
||||
content: 'Refined answer',
|
||||
@@ -185,7 +192,6 @@ describe('OpenAIClient', () => {
|
||||
client.buildPrompt = jest
|
||||
.fn()
|
||||
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
|
||||
client.constructor.freeAndResetAllEncoders();
|
||||
client.getMessages = jest.fn().mockResolvedValue([]);
|
||||
});
|
||||
|
||||
@@ -335,77 +341,11 @@ describe('OpenAIClient', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('selectTokenizer', () => {
|
||||
it('should get the correct tokenizer based on the instance state', () => {
|
||||
const tokenizer = client.selectTokenizer();
|
||||
expect(tokenizer).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAllTokenizers', () => {
|
||||
it('should free all tokenizers', () => {
|
||||
// Create a tokenizer
|
||||
const tokenizer = client.selectTokenizer();
|
||||
|
||||
// Mock 'free' method on the tokenizer
|
||||
tokenizer.free = jest.fn();
|
||||
|
||||
client.constructor.freeAndResetAllEncoders();
|
||||
|
||||
// Check if 'free' method has been called on the tokenizer
|
||||
expect(tokenizer.free).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
it('should return the correct token count', () => {
|
||||
const count = client.getTokenCount('Hello, world!');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset the encoder and count when count reaches 25', () => {
|
||||
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
|
||||
|
||||
// Call getTokenCount 25 times
|
||||
for (let i = 0; i < 25; i++) {
|
||||
client.getTokenCount('test text');
|
||||
}
|
||||
|
||||
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not reset the encoder and count when count is less than 25', () => {
|
||||
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
|
||||
freeAndResetEncoderSpy.mockClear();
|
||||
|
||||
// Call getTokenCount 24 times
|
||||
for (let i = 0; i < 24; i++) {
|
||||
client.getTokenCount('test text');
|
||||
}
|
||||
|
||||
expect(freeAndResetEncoderSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle errors and reset the encoder', () => {
|
||||
const freeAndResetEncoderSpy = jest.spyOn(client.constructor, 'freeAndResetAllEncoders');
|
||||
|
||||
// Mock encode function to throw an error
|
||||
client.selectTokenizer().encode = jest.fn().mockImplementation(() => {
|
||||
throw new Error('Test error');
|
||||
});
|
||||
|
||||
client.getTokenCount('test text');
|
||||
|
||||
expect(freeAndResetEncoderSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not throw null pointer error when freeing the same encoder twice', () => {
|
||||
client.constructor.freeAndResetAllEncoders();
|
||||
client2.constructor.freeAndResetAllEncoders();
|
||||
|
||||
const count = client2.getTokenCount('test text');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSaveOptions', () => {
|
||||
@@ -548,7 +488,6 @@ describe('OpenAIClient', () => {
|
||||
testCases.forEach((testCase) => {
|
||||
it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => {
|
||||
client.modelOptions.model = testCase.model;
|
||||
client.selectTokenizer();
|
||||
// 3 tokens for assistant label
|
||||
let totalTokens = 3;
|
||||
for (let message of example_messages) {
|
||||
@@ -582,7 +521,6 @@ describe('OpenAIClient', () => {
|
||||
|
||||
it(`should return ${expectedTokens} tokens for model ${visionModel} (Vision Request)`, () => {
|
||||
client.modelOptions.model = visionModel;
|
||||
client.selectTokenizer();
|
||||
// 3 tokens for assistant label
|
||||
let totalTokens = 3;
|
||||
for (let message of vision_request) {
|
||||
|
||||
@@ -8,6 +8,7 @@ const StructuredSD = require('./structured/StableDiffusion');
|
||||
const GoogleSearchAPI = require('./structured/GoogleSearch');
|
||||
const TraversaalSearch = require('./structured/TraversaalSearch');
|
||||
const TavilySearchResults = require('./structured/TavilySearchResults');
|
||||
const OpenWeather = require('./structured/OpenWeather');
|
||||
|
||||
module.exports = {
|
||||
availableTools,
|
||||
@@ -19,4 +20,5 @@ module.exports = {
|
||||
TraversaalSearch,
|
||||
StructuredWolfram,
|
||||
TavilySearchResults,
|
||||
OpenWeather,
|
||||
};
|
||||
|
||||
@@ -100,7 +100,6 @@
|
||||
"pluginKey": "calculator",
|
||||
"description": "Perform simple and complex mathematical calculations.",
|
||||
"icon": "https://i.imgur.com/RHsSG5h.png",
|
||||
"isAuthRequired": "false",
|
||||
"authConfig": []
|
||||
},
|
||||
{
|
||||
@@ -135,7 +134,20 @@
|
||||
{
|
||||
"authField": "AZURE_AI_SEARCH_API_KEY",
|
||||
"label": "Azure AI Search API Key",
|
||||
"description": "You need to provideq your API Key for Azure AI Search."
|
||||
"description": "You need to provide your API Key for Azure AI Search."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "OpenWeather",
|
||||
"pluginKey": "open_weather",
|
||||
"description": "Get weather forecasts and historical data from the OpenWeather API",
|
||||
"icon": "/assets/openweather.png",
|
||||
"authConfig": [
|
||||
{
|
||||
"authField": "OPENWEATHER_API_KEY",
|
||||
"label": "OpenWeather API Key",
|
||||
"description": "Sign up at <a href=\"https://home.openweathermap.org/users/sign_up\" target=\"_blank\">OpenWeather</a>, then get your key at <a href=\"https://home.openweathermap.org/api_keys\" target=\"_blank\">API keys</a>."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
317
api/app/clients/tools/structured/OpenWeather.js
Normal file
317
api/app/clients/tools/structured/OpenWeather.js
Normal file
@@ -0,0 +1,317 @@
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { z } = require('zod');
|
||||
const { getEnvironmentVariable } = require('@langchain/core/utils/env');
|
||||
const fetch = require('node-fetch');
|
||||
|
||||
/**
|
||||
* Map user-friendly units to OpenWeather units.
|
||||
* Defaults to Celsius if not specified.
|
||||
*/
|
||||
function mapUnitsToOpenWeather(unit) {
|
||||
if (!unit) {
|
||||
return 'metric';
|
||||
} // Default to Celsius
|
||||
switch (unit) {
|
||||
case 'Celsius':
|
||||
return 'metric';
|
||||
case 'Kelvin':
|
||||
return 'standard';
|
||||
case 'Fahrenheit':
|
||||
return 'imperial';
|
||||
default:
|
||||
return 'metric'; // fallback
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively round temperature fields in the API response.
|
||||
*/
|
||||
function roundTemperatures(obj) {
|
||||
const tempKeys = new Set([
|
||||
'temp',
|
||||
'feels_like',
|
||||
'dew_point',
|
||||
'day',
|
||||
'min',
|
||||
'max',
|
||||
'night',
|
||||
'eve',
|
||||
'morn',
|
||||
'afternoon',
|
||||
'morning',
|
||||
'evening',
|
||||
]);
|
||||
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map((item) => roundTemperatures(item));
|
||||
} else if (obj && typeof obj === 'object') {
|
||||
for (const key of Object.keys(obj)) {
|
||||
const value = obj[key];
|
||||
if (value && typeof value === 'object') {
|
||||
obj[key] = roundTemperatures(value);
|
||||
} else if (typeof value === 'number' && tempKeys.has(key)) {
|
||||
obj[key] = Math.round(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
class OpenWeather extends Tool {
|
||||
name = 'open_weather';
|
||||
description =
|
||||
'Provides weather data from OpenWeather One Call API 3.0. ' +
|
||||
'Actions: help, current_forecast, timestamp, daily_aggregation, overview. ' +
|
||||
'If lat/lon not provided, specify "city" for geocoding. ' +
|
||||
'Units: "Celsius", "Kelvin", or "Fahrenheit" (default: Celsius). ' +
|
||||
'For timestamp action, use "date" in YYYY-MM-DD format.';
|
||||
|
||||
schema = z.object({
|
||||
action: z.enum(['help', 'current_forecast', 'timestamp', 'daily_aggregation', 'overview']),
|
||||
city: z.string().optional(),
|
||||
lat: z.number().optional(),
|
||||
lon: z.number().optional(),
|
||||
exclude: z.string().optional(),
|
||||
units: z.enum(['Celsius', 'Kelvin', 'Fahrenheit']).optional(),
|
||||
lang: z.string().optional(),
|
||||
date: z.string().optional(), // For timestamp and daily_aggregation
|
||||
tz: z.string().optional(),
|
||||
});
|
||||
|
||||
constructor(fields = {}) {
|
||||
super();
|
||||
this.envVar = 'OPENWEATHER_API_KEY';
|
||||
this.override = fields.override ?? false;
|
||||
this.apiKey = fields[this.envVar] ?? this.getApiKey();
|
||||
}
|
||||
|
||||
getApiKey() {
|
||||
const key = getEnvironmentVariable(this.envVar);
|
||||
if (!key && !this.override) {
|
||||
throw new Error(`Missing ${this.envVar} environment variable.`);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
async geocodeCity(city) {
|
||||
const geocodeUrl = `https://api.openweathermap.org/geo/1.0/direct?q=${encodeURIComponent(
|
||||
city,
|
||||
)}&limit=1&appid=${this.apiKey}`;
|
||||
const res = await fetch(geocodeUrl);
|
||||
const data = await res.json();
|
||||
if (!res.ok || !Array.isArray(data) || data.length === 0) {
|
||||
throw new Error(`Could not find coordinates for city: ${city}`);
|
||||
}
|
||||
return { lat: data[0].lat, lon: data[0].lon };
|
||||
}
|
||||
|
||||
convertDateToUnix(dateStr) {
|
||||
const parts = dateStr.split('-');
|
||||
if (parts.length !== 3) {
|
||||
throw new Error('Invalid date format. Expected YYYY-MM-DD.');
|
||||
}
|
||||
const year = parseInt(parts[0], 10);
|
||||
const month = parseInt(parts[1], 10);
|
||||
const day = parseInt(parts[2], 10);
|
||||
if (isNaN(year) || isNaN(month) || isNaN(day)) {
|
||||
throw new Error('Invalid date format. Expected YYYY-MM-DD with valid numbers.');
|
||||
}
|
||||
|
||||
const dateObj = new Date(Date.UTC(year, month - 1, day, 0, 0, 0));
|
||||
if (isNaN(dateObj.getTime())) {
|
||||
throw new Error('Invalid date provided. Cannot parse into a valid date.');
|
||||
}
|
||||
|
||||
return Math.floor(dateObj.getTime() / 1000);
|
||||
}
|
||||
|
||||
async _call(args) {
|
||||
try {
|
||||
const { action, city, lat, lon, exclude, units, lang, date, tz } = args;
|
||||
const owmUnits = mapUnitsToOpenWeather(units);
|
||||
|
||||
if (action === 'help') {
|
||||
return JSON.stringify(
|
||||
{
|
||||
title: 'OpenWeather One Call API 3.0 Help',
|
||||
description: 'Guidance on using the OpenWeather One Call API 3.0.',
|
||||
endpoints: {
|
||||
current_and_forecast: {
|
||||
endpoint: 'data/3.0/onecall',
|
||||
data_provided: [
|
||||
'Current weather',
|
||||
'Minute forecast (1h)',
|
||||
'Hourly forecast (48h)',
|
||||
'Daily forecast (8 days)',
|
||||
'Government weather alerts',
|
||||
],
|
||||
required_params: [['lat', 'lon'], ['city']],
|
||||
optional_params: ['exclude', 'units (Celsius/Kelvin/Fahrenheit)', 'lang'],
|
||||
usage_example: {
|
||||
city: 'Knoxville, Tennessee',
|
||||
units: 'Fahrenheit',
|
||||
lang: 'en',
|
||||
},
|
||||
},
|
||||
weather_for_timestamp: {
|
||||
endpoint: 'data/3.0/onecall/timemachine',
|
||||
data_provided: [
|
||||
'Historical weather (since 1979-01-01)',
|
||||
'Future forecast up to 4 days ahead',
|
||||
],
|
||||
required_params: [
|
||||
['lat', 'lon', 'date (YYYY-MM-DD)'],
|
||||
['city', 'date (YYYY-MM-DD)'],
|
||||
],
|
||||
optional_params: ['units (Celsius/Kelvin/Fahrenheit)', 'lang'],
|
||||
usage_example: {
|
||||
city: 'Knoxville, Tennessee',
|
||||
date: '2020-03-04',
|
||||
units: 'Fahrenheit',
|
||||
lang: 'en',
|
||||
},
|
||||
},
|
||||
daily_aggregation: {
|
||||
endpoint: 'data/3.0/onecall/day_summary',
|
||||
data_provided: [
|
||||
'Aggregated weather data for a specific date (1979-01-02 to 1.5 years ahead)',
|
||||
],
|
||||
required_params: [
|
||||
['lat', 'lon', 'date (YYYY-MM-DD)'],
|
||||
['city', 'date (YYYY-MM-DD)'],
|
||||
],
|
||||
optional_params: ['units (Celsius/Kelvin/Fahrenheit)', 'lang', 'tz'],
|
||||
usage_example: {
|
||||
city: 'Knoxville, Tennessee',
|
||||
date: '2020-03-04',
|
||||
units: 'Celsius',
|
||||
lang: 'en',
|
||||
},
|
||||
},
|
||||
weather_overview: {
|
||||
endpoint: 'data/3.0/onecall/overview',
|
||||
data_provided: ['Human-readable weather summary (today/tomorrow)'],
|
||||
required_params: [['lat', 'lon'], ['city']],
|
||||
optional_params: ['date (YYYY-MM-DD)', 'units (Celsius/Kelvin/Fahrenheit)'],
|
||||
usage_example: {
|
||||
city: 'Knoxville, Tennessee',
|
||||
date: '2024-05-13',
|
||||
units: 'Celsius',
|
||||
},
|
||||
},
|
||||
},
|
||||
notes: [
|
||||
'If lat/lon not provided, you can specify a city name and it will be geocoded.',
|
||||
'For the timestamp action, provide a date in YYYY-MM-DD format instead of a Unix timestamp.',
|
||||
'By default, temperatures are returned in Celsius.',
|
||||
'You can specify units as Celsius, Kelvin, or Fahrenheit.',
|
||||
'All temperatures are rounded to the nearest degree.',
|
||||
],
|
||||
errors: [
|
||||
'400: Bad Request (missing/invalid params)',
|
||||
'401: Unauthorized (check API key)',
|
||||
'404: Not Found (no data or city)',
|
||||
'429: Too many requests',
|
||||
'5xx: Internal error',
|
||||
],
|
||||
},
|
||||
null,
|
||||
2,
|
||||
);
|
||||
}
|
||||
|
||||
let finalLat = lat;
|
||||
let finalLon = lon;
|
||||
|
||||
// If lat/lon not provided but city is given, geocode it
|
||||
if ((finalLat == null || finalLon == null) && city) {
|
||||
const coords = await this.geocodeCity(city);
|
||||
finalLat = coords.lat;
|
||||
finalLon = coords.lon;
|
||||
}
|
||||
|
||||
if (['current_forecast', 'timestamp', 'daily_aggregation', 'overview'].includes(action)) {
|
||||
if (typeof finalLat !== 'number' || typeof finalLon !== 'number') {
|
||||
return 'Error: lat and lon are required and must be numbers for this action (or specify \'city\').';
|
||||
}
|
||||
}
|
||||
|
||||
const baseUrl = 'https://api.openweathermap.org/data/3.0';
|
||||
let endpoint = '';
|
||||
const params = new URLSearchParams({ appid: this.apiKey, units: owmUnits });
|
||||
|
||||
let dt;
|
||||
if (action === 'timestamp') {
|
||||
if (!date) {
|
||||
return 'Error: For timestamp action, a \'date\' in YYYY-MM-DD format is required.';
|
||||
}
|
||||
dt = this.convertDateToUnix(date);
|
||||
}
|
||||
|
||||
if (action === 'daily_aggregation' && !date) {
|
||||
return 'Error: date (YYYY-MM-DD) is required for daily_aggregation action.';
|
||||
}
|
||||
|
||||
switch (action) {
|
||||
case 'current_forecast':
|
||||
endpoint = '/onecall';
|
||||
params.append('lat', String(finalLat));
|
||||
params.append('lon', String(finalLon));
|
||||
if (exclude) {
|
||||
params.append('exclude', exclude);
|
||||
}
|
||||
if (lang) {
|
||||
params.append('lang', lang);
|
||||
}
|
||||
break;
|
||||
case 'timestamp':
|
||||
endpoint = '/onecall/timemachine';
|
||||
params.append('lat', String(finalLat));
|
||||
params.append('lon', String(finalLon));
|
||||
params.append('dt', String(dt));
|
||||
if (lang) {
|
||||
params.append('lang', lang);
|
||||
}
|
||||
break;
|
||||
case 'daily_aggregation':
|
||||
endpoint = '/onecall/day_summary';
|
||||
params.append('lat', String(finalLat));
|
||||
params.append('lon', String(finalLon));
|
||||
params.append('date', date);
|
||||
if (lang) {
|
||||
params.append('lang', lang);
|
||||
}
|
||||
if (tz) {
|
||||
params.append('tz', tz);
|
||||
}
|
||||
break;
|
||||
case 'overview':
|
||||
endpoint = '/onecall/overview';
|
||||
params.append('lat', String(finalLat));
|
||||
params.append('lon', String(finalLon));
|
||||
if (date) {
|
||||
params.append('date', date);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return `Error: Unknown action: ${action}`;
|
||||
}
|
||||
|
||||
const url = `${baseUrl}${endpoint}?${params.toString()}`;
|
||||
const response = await fetch(url);
|
||||
const json = await response.json();
|
||||
if (!response.ok) {
|
||||
return `Error: OpenWeather API request failed with status ${response.status}: ${
|
||||
json.message || JSON.stringify(json)
|
||||
}`;
|
||||
}
|
||||
|
||||
const roundedJson = roundTemperatures(json);
|
||||
return JSON.stringify(roundedJson);
|
||||
} catch (err) {
|
||||
return `Error: ${err.message}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = OpenWeather;
|
||||
@@ -0,0 +1,224 @@
|
||||
// __tests__/openWeather.integration.test.js
|
||||
const OpenWeather = require('../OpenWeather');
|
||||
|
||||
describe('OpenWeather Tool (Integration Test)', () => {
|
||||
let tool;
|
||||
|
||||
beforeAll(() => {
|
||||
tool = new OpenWeather({ override: true });
|
||||
console.log('API Key present:', !!process.env.OPENWEATHER_API_KEY);
|
||||
});
|
||||
|
||||
test('current_forecast with a real API key returns current weather', async () => {
|
||||
// Check if API key is available
|
||||
if (!process.env.OPENWEATHER_API_KEY) {
|
||||
console.warn('Skipping real API test, no OPENWEATHER_API_KEY found.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'London',
|
||||
units: 'Celsius',
|
||||
});
|
||||
|
||||
console.log('Raw API response:', result);
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed).toHaveProperty('current');
|
||||
expect(typeof parsed.current.temp).toBe('number');
|
||||
} catch (error) {
|
||||
console.error('Test failed with error:', error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
|
||||
test('timestamp action with real API key returns historical data', async () => {
|
||||
if (!process.env.OPENWEATHER_API_KEY) {
|
||||
console.warn('Skipping real API test, no OPENWEATHER_API_KEY found.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use a date from yesterday to ensure data availability
|
||||
const yesterday = new Date();
|
||||
yesterday.setDate(yesterday.getDate() - 1);
|
||||
const dateStr = yesterday.toISOString().split('T')[0];
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'timestamp',
|
||||
city: 'London',
|
||||
date: dateStr,
|
||||
units: 'Celsius',
|
||||
});
|
||||
|
||||
console.log('Timestamp API response:', result);
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed).toHaveProperty('data');
|
||||
expect(Array.isArray(parsed.data)).toBe(true);
|
||||
expect(parsed.data[0]).toHaveProperty('temp');
|
||||
} catch (error) {
|
||||
console.error('Timestamp test failed with error:', error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
|
||||
test('daily_aggregation action with real API key returns aggregated data', async () => {
|
||||
if (!process.env.OPENWEATHER_API_KEY) {
|
||||
console.warn('Skipping real API test, no OPENWEATHER_API_KEY found.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use yesterday's date for aggregation
|
||||
const yesterday = new Date();
|
||||
yesterday.setDate(yesterday.getDate() - 1);
|
||||
const dateStr = yesterday.toISOString().split('T')[0];
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'daily_aggregation',
|
||||
city: 'London',
|
||||
date: dateStr,
|
||||
units: 'Celsius',
|
||||
});
|
||||
|
||||
console.log('Daily aggregation API response:', result);
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed).toHaveProperty('temperature');
|
||||
expect(parsed.temperature).toHaveProperty('morning');
|
||||
expect(parsed.temperature).toHaveProperty('afternoon');
|
||||
expect(parsed.temperature).toHaveProperty('evening');
|
||||
} catch (error) {
|
||||
console.error('Daily aggregation test failed with error:', error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
|
||||
test('overview action with real API key returns weather summary', async () => {
|
||||
if (!process.env.OPENWEATHER_API_KEY) {
|
||||
console.warn('Skipping real API test, no OPENWEATHER_API_KEY found.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await tool.call({
|
||||
action: 'overview',
|
||||
city: 'London',
|
||||
units: 'Celsius',
|
||||
});
|
||||
|
||||
console.log('Overview API response:', result);
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed).toHaveProperty('weather_overview');
|
||||
expect(typeof parsed.weather_overview).toBe('string');
|
||||
expect(parsed.weather_overview.length).toBeGreaterThan(0);
|
||||
expect(parsed).toHaveProperty('date');
|
||||
expect(parsed).toHaveProperty('units');
|
||||
expect(parsed.units).toBe('metric');
|
||||
} catch (error) {
|
||||
console.error('Overview test failed with error:', error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
|
||||
test('different temperature units return correct values', async () => {
|
||||
if (!process.env.OPENWEATHER_API_KEY) {
|
||||
console.warn('Skipping real API test, no OPENWEATHER_API_KEY found.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Test Celsius
|
||||
let result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'London',
|
||||
units: 'Celsius',
|
||||
});
|
||||
let parsed = JSON.parse(result);
|
||||
const celsiusTemp = parsed.current.temp;
|
||||
|
||||
// Test Kelvin
|
||||
result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'London',
|
||||
units: 'Kelvin',
|
||||
});
|
||||
parsed = JSON.parse(result);
|
||||
const kelvinTemp = parsed.current.temp;
|
||||
|
||||
// Test Fahrenheit
|
||||
result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'London',
|
||||
units: 'Fahrenheit',
|
||||
});
|
||||
parsed = JSON.parse(result);
|
||||
const fahrenheitTemp = parsed.current.temp;
|
||||
|
||||
// Verify temperature conversions are roughly correct
|
||||
// K = C + 273.15
|
||||
// F = (C * 9/5) + 32
|
||||
const celsiusToKelvin = Math.round(celsiusTemp + 273.15);
|
||||
const celsiusToFahrenheit = Math.round((celsiusTemp * 9) / 5 + 32);
|
||||
|
||||
console.log('Temperature comparisons:', {
|
||||
celsius: celsiusTemp,
|
||||
kelvin: kelvinTemp,
|
||||
fahrenheit: fahrenheitTemp,
|
||||
calculatedKelvin: celsiusToKelvin,
|
||||
calculatedFahrenheit: celsiusToFahrenheit,
|
||||
});
|
||||
|
||||
// Allow for some rounding differences
|
||||
expect(Math.abs(kelvinTemp - celsiusToKelvin)).toBeLessThanOrEqual(1);
|
||||
expect(Math.abs(fahrenheitTemp - celsiusToFahrenheit)).toBeLessThanOrEqual(1);
|
||||
} catch (error) {
|
||||
console.error('Temperature units test failed with error:', error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
|
||||
test('language parameter returns localized data', async () => {
|
||||
if (!process.env.OPENWEATHER_API_KEY) {
|
||||
console.warn('Skipping real API test, no OPENWEATHER_API_KEY found.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Test with English
|
||||
let result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Paris',
|
||||
units: 'Celsius',
|
||||
lang: 'en',
|
||||
});
|
||||
let parsed = JSON.parse(result);
|
||||
const englishDescription = parsed.current.weather[0].description;
|
||||
|
||||
// Test with French
|
||||
result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Paris',
|
||||
units: 'Celsius',
|
||||
lang: 'fr',
|
||||
});
|
||||
parsed = JSON.parse(result);
|
||||
const frenchDescription = parsed.current.weather[0].description;
|
||||
|
||||
console.log('Language comparison:', {
|
||||
english: englishDescription,
|
||||
french: frenchDescription,
|
||||
});
|
||||
|
||||
// Verify descriptions are different (indicating translation worked)
|
||||
expect(englishDescription).not.toBe(frenchDescription);
|
||||
} catch (error) {
|
||||
console.error('Language test failed with error:', error);
|
||||
throw error;
|
||||
}
|
||||
});
|
||||
});
|
||||
358
api/app/clients/tools/structured/specs/openweather.test.js
Normal file
358
api/app/clients/tools/structured/specs/openweather.test.js
Normal file
@@ -0,0 +1,358 @@
|
||||
// __tests__/openweather.test.js
|
||||
const OpenWeather = require('../OpenWeather');
|
||||
const fetch = require('node-fetch');
|
||||
|
||||
// Mock environment variable
|
||||
process.env.OPENWEATHER_API_KEY = 'test-api-key';
|
||||
|
||||
// Mock the fetch function globally
|
||||
jest.mock('node-fetch', () => jest.fn());
|
||||
|
||||
describe('OpenWeather Tool', () => {
|
||||
let tool;
|
||||
|
||||
beforeAll(() => {
|
||||
tool = new OpenWeather();
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
fetch.mockReset();
|
||||
});
|
||||
|
||||
test('action=help returns help instructions', async () => {
|
||||
const result = await tool.call({
|
||||
action: 'help',
|
||||
});
|
||||
|
||||
expect(typeof result).toBe('string');
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed.title).toBe('OpenWeather One Call API 3.0 Help');
|
||||
});
|
||||
|
||||
test('current_forecast with a city and successful geocoding + forecast', async () => {
|
||||
// Mock geocoding response
|
||||
fetch.mockImplementationOnce((url) => {
|
||||
if (url.includes('geo/1.0/direct')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
});
|
||||
}
|
||||
return Promise.reject('Unexpected fetch call for geocoding');
|
||||
});
|
||||
|
||||
// Mock forecast response
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
current: { temp: 293.15, feels_like: 295.15 },
|
||||
daily: [{ temp: { day: 293.15, night: 283.15 } }],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Knoxville, Tennessee',
|
||||
units: 'Kelvin',
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed.current.temp).toBe(293);
|
||||
expect(parsed.current.feels_like).toBe(295);
|
||||
expect(parsed.daily[0].temp.day).toBe(293);
|
||||
expect(parsed.daily[0].temp.night).toBe(283);
|
||||
});
|
||||
|
||||
test('timestamp action with valid date returns mocked historical data', async () => {
|
||||
// Mock geocoding response
|
||||
fetch.mockImplementationOnce((url) => {
|
||||
if (url.includes('geo/1.0/direct')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
});
|
||||
}
|
||||
return Promise.reject('Unexpected fetch call for geocoding');
|
||||
});
|
||||
|
||||
// Mock historical weather response
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [
|
||||
{
|
||||
dt: 1583280000,
|
||||
temp: 283.15,
|
||||
feels_like: 280.15,
|
||||
humidity: 75,
|
||||
weather: [{ description: 'clear sky' }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'timestamp',
|
||||
city: 'Knoxville, Tennessee',
|
||||
date: '2020-03-04',
|
||||
units: 'Kelvin',
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed.data[0].temp).toBe(283);
|
||||
expect(parsed.data[0].feels_like).toBe(280);
|
||||
});
|
||||
|
||||
test('daily_aggregation action returns aggregated weather data', async () => {
|
||||
// Mock geocoding response
|
||||
fetch.mockImplementationOnce((url) => {
|
||||
if (url.includes('geo/1.0/direct')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
});
|
||||
}
|
||||
return Promise.reject('Unexpected fetch call for geocoding');
|
||||
});
|
||||
|
||||
// Mock daily aggregation response
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
date: '2020-03-04',
|
||||
temperature: {
|
||||
morning: 283.15,
|
||||
afternoon: 293.15,
|
||||
evening: 288.15,
|
||||
},
|
||||
humidity: {
|
||||
morning: 75,
|
||||
afternoon: 60,
|
||||
evening: 70,
|
||||
},
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'daily_aggregation',
|
||||
city: 'Knoxville, Tennessee',
|
||||
date: '2020-03-04',
|
||||
units: 'Kelvin',
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed.temperature.morning).toBe(283);
|
||||
expect(parsed.temperature.afternoon).toBe(293);
|
||||
expect(parsed.temperature.evening).toBe(288);
|
||||
});
|
||||
|
||||
test('overview action returns weather summary', async () => {
|
||||
// Mock geocoding response
|
||||
fetch.mockImplementationOnce((url) => {
|
||||
if (url.includes('geo/1.0/direct')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
});
|
||||
}
|
||||
return Promise.reject('Unexpected fetch call for geocoding');
|
||||
});
|
||||
|
||||
// Mock overview response
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
date: '2024-01-07',
|
||||
lat: 35.9606,
|
||||
lon: -83.9207,
|
||||
tz: '+00:00',
|
||||
units: 'metric',
|
||||
weather_overview:
|
||||
'Currently, the temperature is 2°C with a real feel of -2°C. The sky is clear with moderate wind.',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'overview',
|
||||
city: 'Knoxville, Tennessee',
|
||||
units: 'Celsius',
|
||||
});
|
||||
|
||||
const parsed = JSON.parse(result);
|
||||
expect(parsed).toHaveProperty('weather_overview');
|
||||
expect(typeof parsed.weather_overview).toBe('string');
|
||||
expect(parsed.weather_overview.length).toBeGreaterThan(0);
|
||||
expect(parsed).toHaveProperty('date');
|
||||
expect(parsed).toHaveProperty('units');
|
||||
expect(parsed.units).toBe('metric');
|
||||
});
|
||||
|
||||
test('temperature units are correctly converted', async () => {
|
||||
// Mock geocoding response for all three calls
|
||||
const geocodingMock = Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
});
|
||||
|
||||
// Mock weather response for Kelvin
|
||||
const kelvinMock = Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
current: { temp: 293.15 },
|
||||
}),
|
||||
});
|
||||
|
||||
// Mock weather response for Celsius
|
||||
const celsiusMock = Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
current: { temp: 20 },
|
||||
}),
|
||||
});
|
||||
|
||||
// Mock weather response for Fahrenheit
|
||||
const fahrenheitMock = Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
current: { temp: 68 },
|
||||
}),
|
||||
});
|
||||
|
||||
// Test Kelvin
|
||||
fetch.mockImplementationOnce(() => geocodingMock).mockImplementationOnce(() => kelvinMock);
|
||||
|
||||
let result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Knoxville, Tennessee',
|
||||
units: 'Kelvin',
|
||||
});
|
||||
let parsed = JSON.parse(result);
|
||||
expect(parsed.current.temp).toBe(293);
|
||||
|
||||
// Test Celsius
|
||||
fetch.mockImplementationOnce(() => geocodingMock).mockImplementationOnce(() => celsiusMock);
|
||||
|
||||
result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Knoxville, Tennessee',
|
||||
units: 'Celsius',
|
||||
});
|
||||
parsed = JSON.parse(result);
|
||||
expect(parsed.current.temp).toBe(20);
|
||||
|
||||
// Test Fahrenheit
|
||||
fetch.mockImplementationOnce(() => geocodingMock).mockImplementationOnce(() => fahrenheitMock);
|
||||
|
||||
result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Knoxville, Tennessee',
|
||||
units: 'Fahrenheit',
|
||||
});
|
||||
parsed = JSON.parse(result);
|
||||
expect(parsed.current.temp).toBe(68);
|
||||
});
|
||||
|
||||
test('timestamp action without a date returns an error message', async () => {
|
||||
const result = await tool.call({
|
||||
action: 'timestamp',
|
||||
lat: 35.9606,
|
||||
lon: -83.9207,
|
||||
});
|
||||
expect(result).toMatch(
|
||||
/Error: For timestamp action, a 'date' in YYYY-MM-DD format is required./,
|
||||
);
|
||||
});
|
||||
|
||||
test('daily_aggregation action without a date returns an error message', async () => {
|
||||
const result = await tool.call({
|
||||
action: 'daily_aggregation',
|
||||
lat: 35.9606,
|
||||
lon: -83.9207,
|
||||
});
|
||||
expect(result).toMatch(/Error: date \(YYYY-MM-DD\) is required for daily_aggregation action./);
|
||||
});
|
||||
|
||||
test('unknown action returns an error due to schema validation', async () => {
|
||||
await expect(
|
||||
tool.call({
|
||||
action: 'unknown_action',
|
||||
}),
|
||||
).rejects.toThrow(/Received tool input did not match expected schema/);
|
||||
});
|
||||
|
||||
test('geocoding failure returns a descriptive error', async () => {
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [],
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'NowhereCity',
|
||||
});
|
||||
expect(result).toMatch(/Error: Could not find coordinates for city: NowhereCity/);
|
||||
});
|
||||
|
||||
test('API request failure returns an error', async () => {
|
||||
// Mock geocoding success
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
}),
|
||||
);
|
||||
|
||||
// Mock weather request failure
|
||||
fetch.mockImplementationOnce(() =>
|
||||
Promise.resolve({
|
||||
ok: false,
|
||||
status: 404,
|
||||
json: async () => ({ message: 'Not found' }),
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'current_forecast',
|
||||
city: 'Knoxville, Tennessee',
|
||||
});
|
||||
expect(result).toMatch(/Error: OpenWeather API request failed with status 404: Not found/);
|
||||
});
|
||||
|
||||
test('invalid date format returns an error', async () => {
|
||||
// Mock geocoding response first
|
||||
fetch.mockImplementationOnce((url) => {
|
||||
if (url.includes('geo/1.0/direct')) {
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
json: async () => [{ lat: 35.9606, lon: -83.9207 }],
|
||||
});
|
||||
}
|
||||
return Promise.reject('Unexpected fetch call for geocoding');
|
||||
});
|
||||
|
||||
// Mock timestamp API response
|
||||
fetch.mockImplementationOnce((url) => {
|
||||
if (url.includes('onecall/timemachine')) {
|
||||
throw new Error('Invalid date format. Expected YYYY-MM-DD.');
|
||||
}
|
||||
return Promise.reject('Unexpected fetch call');
|
||||
});
|
||||
|
||||
const result = await tool.call({
|
||||
action: 'timestamp',
|
||||
city: 'Knoxville, Tennessee',
|
||||
date: '03-04-2020', // Wrong format
|
||||
});
|
||||
expect(result).toMatch(/Error: Invalid date format. Expected YYYY-MM-DD./);
|
||||
});
|
||||
});
|
||||
@@ -23,6 +23,8 @@ async function handleOpenAIErrors(err, errorCallback, context = 'stream') {
|
||||
logger.warn(`[OpenAIClient.chatCompletion][${context}] Unhandled error type`);
|
||||
}
|
||||
|
||||
logger.error(err);
|
||||
|
||||
if (errorCallback) {
|
||||
errorCallback(err);
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ const {
|
||||
TraversaalSearch,
|
||||
StructuredWolfram,
|
||||
TavilySearchResults,
|
||||
OpenWeather,
|
||||
} = require('../');
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||
@@ -178,6 +179,7 @@ const loadTools = async ({
|
||||
'azure-ai-search': StructuredACS,
|
||||
traversaal_search: TraversaalSearch,
|
||||
tavily_search_results_json: TavilySearchResults,
|
||||
open_weather: OpenWeather,
|
||||
};
|
||||
|
||||
const customConstructors = {
|
||||
|
||||
4
api/cache/banViolation.js
vendored
4
api/cache/banViolation.js
vendored
@@ -1,7 +1,7 @@
|
||||
const { ViolationTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, math, removePorts } = require('~/server/utils');
|
||||
const { deleteAllUserSessions } = require('~/models');
|
||||
const getLogStores = require('./getLogStores');
|
||||
const Session = require('~/models/Session');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};
|
||||
@@ -46,7 +46,7 @@ const banViolation = async (req, res, errorMessage) => {
|
||||
return;
|
||||
}
|
||||
|
||||
await Session.deleteAllUserSessions(user_id);
|
||||
await deleteAllUserSessions({ userId: user_id });
|
||||
res.clearCookie('refreshToken');
|
||||
|
||||
const banLogs = getLogStores(ViolationTypes.BAN);
|
||||
|
||||
179
api/cache/getLogStores.js
vendored
179
api/cache/getLogStores.js
vendored
@@ -5,41 +5,43 @@ const { math, isEnabled } = require('~/server/utils');
|
||||
const keyvRedis = require('./keyvRedis');
|
||||
const keyvMongo = require('./keyvMongo');
|
||||
|
||||
const { BAN_DURATION, USE_REDIS } = process.env ?? {};
|
||||
const { BAN_DURATION, USE_REDIS, DEBUG_MEMORY_CACHE, CI } = process.env ?? {};
|
||||
|
||||
const duration = math(BAN_DURATION, 7200000);
|
||||
const isRedisEnabled = isEnabled(USE_REDIS);
|
||||
const debugMemoryCache = isEnabled(DEBUG_MEMORY_CACHE);
|
||||
|
||||
const createViolationInstance = (namespace) => {
|
||||
const config = isEnabled(USE_REDIS) ? { store: keyvRedis } : { store: violationFile, namespace };
|
||||
const config = isRedisEnabled ? { store: keyvRedis } : { store: violationFile, namespace };
|
||||
return new Keyv(config);
|
||||
};
|
||||
|
||||
// Serve cache from memory so no need to clear it on startup/exit
|
||||
const pending_req = isEnabled(USE_REDIS)
|
||||
const pending_req = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: 'pending_req' });
|
||||
|
||||
const config = isEnabled(USE_REDIS)
|
||||
const config = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.CONFIG_STORE });
|
||||
|
||||
const roles = isEnabled(USE_REDIS)
|
||||
const roles = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ROLES });
|
||||
|
||||
const audioRuns = isEnabled(USE_REDIS)
|
||||
const audioRuns = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
const messages = isEnabled(USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.FIVE_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.FIVE_MINUTES });
|
||||
const messages = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.ONE_MINUTE })
|
||||
: new Keyv({ namespace: CacheKeys.MESSAGES, ttl: Time.ONE_MINUTE });
|
||||
|
||||
const tokenConfig = isEnabled(USE_REDIS)
|
||||
const tokenConfig = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.TOKEN_CONFIG, ttl: Time.THIRTY_MINUTES });
|
||||
|
||||
const genTitle = isEnabled(USE_REDIS)
|
||||
const genTitle = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES })
|
||||
: new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES });
|
||||
|
||||
@@ -47,7 +49,7 @@ const modelQueries = isEnabled(process.env.USE_REDIS)
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.MODEL_QUERIES });
|
||||
|
||||
const abortKeys = isEnabled(USE_REDIS)
|
||||
const abortKeys = isRedisEnabled
|
||||
? new Keyv({ store: keyvRedis })
|
||||
: new Keyv({ namespace: CacheKeys.ABORT_KEYS, ttl: Time.TEN_MINUTES });
|
||||
|
||||
@@ -88,6 +90,159 @@ const namespaces = {
|
||||
[CacheKeys.MESSAGES]: messages,
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets all cache stores that have TTL configured
|
||||
* @returns {Keyv[]}
|
||||
*/
|
||||
function getTTLStores() {
|
||||
return Object.values(namespaces).filter(
|
||||
(store) => store instanceof Keyv && typeof store.opts?.ttl === 'number' && store.opts.ttl > 0,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears entries older than the cache's TTL
|
||||
* @param {Keyv} cache
|
||||
*/
|
||||
async function clearExpiredFromCache(cache) {
|
||||
if (!cache?.opts?.store?.entries) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ttl = cache.opts.ttl;
|
||||
if (!ttl) {
|
||||
return;
|
||||
}
|
||||
|
||||
const expiryTime = Date.now() - ttl;
|
||||
let cleared = 0;
|
||||
|
||||
// Get all keys first to avoid modification during iteration
|
||||
const keys = Array.from(cache.opts.store.keys());
|
||||
|
||||
for (const key of keys) {
|
||||
try {
|
||||
const raw = cache.opts.store.get(key);
|
||||
if (!raw) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = cache.opts.deserialize(raw);
|
||||
// Check if the entry is older than TTL
|
||||
if (data?.expires && data.expires <= expiryTime) {
|
||||
const deleted = await cache.opts.store.delete(key);
|
||||
if (!deleted) {
|
||||
debugMemoryCache &&
|
||||
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
|
||||
continue;
|
||||
}
|
||||
cleared++;
|
||||
}
|
||||
} catch (error) {
|
||||
debugMemoryCache &&
|
||||
console.log(`[Cache] Error processing entry from ${cache.opts.namespace}:`, error);
|
||||
const deleted = await cache.opts.store.delete(key);
|
||||
if (!deleted) {
|
||||
debugMemoryCache &&
|
||||
console.warn(`[Cache] Error deleting entry: ${key} from ${cache.opts.namespace}`);
|
||||
continue;
|
||||
}
|
||||
cleared++;
|
||||
}
|
||||
}
|
||||
|
||||
if (cleared > 0) {
|
||||
debugMemoryCache &&
|
||||
console.log(
|
||||
`[Cache] Cleared ${cleared} entries older than ${ttl}ms from ${cache.opts.namespace}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const auditCache = () => {
|
||||
const ttlStores = getTTLStores();
|
||||
console.log('[Cache] Starting audit');
|
||||
|
||||
ttlStores.forEach((store) => {
|
||||
if (!store?.opts?.store?.entries) {
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`[Cache] ${store.opts.namespace} entries:`, {
|
||||
count: store.opts.store.size,
|
||||
ttl: store.opts.ttl,
|
||||
keys: Array.from(store.opts.store.keys()),
|
||||
entriesWithTimestamps: Array.from(store.opts.store.entries()).map(([key, value]) => ({
|
||||
key,
|
||||
value,
|
||||
})),
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Clears expired entries from all TTL-enabled stores
|
||||
*/
|
||||
async function clearAllExpiredFromCache() {
|
||||
const ttlStores = getTTLStores();
|
||||
await Promise.all(ttlStores.map((store) => clearExpiredFromCache(store)));
|
||||
|
||||
// Force garbage collection if available (Node.js with --expose-gc flag)
|
||||
if (global.gc) {
|
||||
global.gc();
|
||||
}
|
||||
}
|
||||
|
||||
if (!isRedisEnabled && !isEnabled(CI)) {
|
||||
/** @type {Set<NodeJS.Timeout>} */
|
||||
const cleanupIntervals = new Set();
|
||||
|
||||
// Clear expired entries every 30 seconds
|
||||
const cleanup = setInterval(() => {
|
||||
clearAllExpiredFromCache();
|
||||
}, Time.THIRTY_SECONDS);
|
||||
|
||||
cleanupIntervals.add(cleanup);
|
||||
|
||||
if (debugMemoryCache) {
|
||||
const monitor = setInterval(() => {
|
||||
const ttlStores = getTTLStores();
|
||||
const memory = process.memoryUsage();
|
||||
const totalSize = ttlStores.reduce((sum, store) => sum + (store.opts?.store?.size ?? 0), 0);
|
||||
|
||||
console.log('[Cache] Memory usage:', {
|
||||
heapUsed: `${(memory.heapUsed / 1024 / 1024).toFixed(2)} MB`,
|
||||
heapTotal: `${(memory.heapTotal / 1024 / 1024).toFixed(2)} MB`,
|
||||
rss: `${(memory.rss / 1024 / 1024).toFixed(2)} MB`,
|
||||
external: `${(memory.external / 1024 / 1024).toFixed(2)} MB`,
|
||||
totalCacheEntries: totalSize,
|
||||
});
|
||||
|
||||
auditCache();
|
||||
}, Time.ONE_MINUTE);
|
||||
|
||||
cleanupIntervals.add(monitor);
|
||||
}
|
||||
|
||||
const dispose = () => {
|
||||
debugMemoryCache && console.log('[Cache] Cleaning up and shutting down...');
|
||||
cleanupIntervals.forEach((interval) => clearInterval(interval));
|
||||
cleanupIntervals.clear();
|
||||
|
||||
// One final cleanup before exit
|
||||
clearAllExpiredFromCache().then(() => {
|
||||
debugMemoryCache && console.log('[Cache] Final cleanup completed');
|
||||
process.exit(0);
|
||||
});
|
||||
};
|
||||
|
||||
// Handle various termination signals
|
||||
process.on('SIGTERM', dispose);
|
||||
process.on('SIGINT', dispose);
|
||||
process.on('SIGQUIT', dispose);
|
||||
process.on('SIGHUP', dispose);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the keyv cache specified by type.
|
||||
* If an invalid type is passed, an error will be thrown.
|
||||
|
||||
@@ -82,7 +82,7 @@ const loadAgent = async ({ req, agent_id }) => {
|
||||
*/
|
||||
const updateAgent = async (searchParameter, updateData) => {
|
||||
const options = { new: true, upsert: false };
|
||||
return await Agent.findOneAndUpdate(searchParameter, updateData, options).lean();
|
||||
return Agent.findOneAndUpdate(searchParameter, updateData, options).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -96,25 +96,18 @@ const updateAgent = async (searchParameter, updateData) => {
|
||||
*/
|
||||
const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
const agent = await getAgent(searchParameter);
|
||||
|
||||
if (!agent) {
|
||||
// build the update to push or create the file ids set
|
||||
const fileIdsPath = `tool_resources.${tool_resource}.file_ids`;
|
||||
const updateData = { $addToSet: { [fileIdsPath]: file_id } };
|
||||
|
||||
// return the updated agent or throw if no agent matches
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for adding resource file');
|
||||
}
|
||||
|
||||
const tool_resources = agent.tool_resources || {};
|
||||
|
||||
if (!tool_resources[tool_resource]) {
|
||||
tool_resources[tool_resource] = { file_ids: [] };
|
||||
}
|
||||
|
||||
if (!tool_resources[tool_resource].file_ids.includes(file_id)) {
|
||||
tool_resources[tool_resource].file_ids.push(file_id);
|
||||
}
|
||||
|
||||
const updateData = { tool_resources };
|
||||
|
||||
return await updateAgent(searchParameter, updateData);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -126,36 +119,52 @@ const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => {
|
||||
*/
|
||||
const removeAgentResourceFiles = async ({ agent_id, files }) => {
|
||||
const searchParameter = { id: agent_id };
|
||||
const agent = await getAgent(searchParameter);
|
||||
|
||||
if (!agent) {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
|
||||
const tool_resources = { ...agent.tool_resources } || {};
|
||||
|
||||
// associate each tool resource with the respective file ids array
|
||||
const filesByResource = files.reduce((acc, { tool_resource, file_id }) => {
|
||||
if (!acc[tool_resource]) {
|
||||
acc[tool_resource] = new Set();
|
||||
acc[tool_resource] = [];
|
||||
}
|
||||
acc[tool_resource].add(file_id);
|
||||
acc[tool_resource].push(file_id);
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// build the update aggregation pipeline wich removes file ids from tool resources array
|
||||
// and eventually deletes empty tool resources
|
||||
const updateData = [];
|
||||
Object.entries(filesByResource).forEach(([resource, fileIds]) => {
|
||||
if (tool_resources[resource] && tool_resources[resource].file_ids) {
|
||||
tool_resources[resource].file_ids = tool_resources[resource].file_ids.filter(
|
||||
(id) => !fileIds.has(id),
|
||||
);
|
||||
const toolResourcePath = `tool_resources.${resource}`;
|
||||
const fileIdsPath = `${toolResourcePath}.file_ids`;
|
||||
|
||||
if (tool_resources[resource].file_ids.length === 0) {
|
||||
delete tool_resources[resource];
|
||||
}
|
||||
}
|
||||
// file ids removal stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[fileIdsPath]: {
|
||||
$filter: {
|
||||
input: `$${fileIdsPath}`,
|
||||
cond: { $not: [{ $in: ['$$this', fileIds] }] },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// empty tool resource deletion stage
|
||||
updateData.push({
|
||||
$set: {
|
||||
[toolResourcePath]: {
|
||||
$cond: [{ $eq: [`$${fileIdsPath}`, []] }, '$$REMOVE', `$${toolResourcePath}`],
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const updateData = { tool_resources };
|
||||
return await updateAgent(searchParameter, updateData);
|
||||
// return the updated agent or throw if no agent matches
|
||||
const updatedAgent = await updateAgent(searchParameter, updateData);
|
||||
if (updatedAgent) {
|
||||
return updatedAgent;
|
||||
} else {
|
||||
throw new Error('Agent not found for removing resource files');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,75 +1,275 @@
|
||||
const mongoose = require('mongoose');
|
||||
const signPayload = require('~/server/services/signPayload');
|
||||
const { hashToken } = require('~/server/utils/crypto');
|
||||
const sessionSchema = require('./schema/session');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const Session = mongoose.model('Session', sessionSchema);
|
||||
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7;
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default
|
||||
|
||||
const sessionSchema = mongoose.Schema({
|
||||
refreshTokenHash: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
expiration: {
|
||||
type: Date,
|
||||
required: true,
|
||||
expires: 0,
|
||||
},
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
});
|
||||
/**
|
||||
* Error class for Session-related errors
|
||||
*/
|
||||
class SessionError extends Error {
|
||||
constructor(message, code = 'SESSION_ERROR') {
|
||||
super(message);
|
||||
this.name = 'SessionError';
|
||||
this.code = code;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new session for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @param {Object} options - Additional options for session creation
|
||||
* @param {Date} options.expiration - Custom expiration date
|
||||
* @returns {Promise<{session: Session, refreshToken: string}>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const createSession = async (userId, options = {}) => {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
sessionSchema.methods.generateRefreshToken = async function () {
|
||||
try {
|
||||
let expiresIn;
|
||||
if (this.expiration) {
|
||||
expiresIn = this.expiration.getTime();
|
||||
} else {
|
||||
expiresIn = Date.now() + expires;
|
||||
this.expiration = new Date(expiresIn);
|
||||
const session = new Session({
|
||||
user: userId,
|
||||
expiration: options.expiration || new Date(Date.now() + expires),
|
||||
});
|
||||
const refreshToken = await generateRefreshToken(session);
|
||||
return { session, refreshToken };
|
||||
} catch (error) {
|
||||
logger.error('[createSession] Error creating session:', error);
|
||||
throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Finds a session by various parameters
|
||||
* @param {Object} params - Search parameters
|
||||
* @param {string} [params.refreshToken] - The refresh token to search by
|
||||
* @param {string} [params.userId] - The user ID to search by
|
||||
* @param {string} [params.sessionId] - The session ID to search by
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents
|
||||
* @returns {Promise<Session|null>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const findSession = async (params, options = { lean: true }) => {
|
||||
try {
|
||||
const query = {};
|
||||
|
||||
if (!params.refreshToken && !params.userId && !params.sessionId) {
|
||||
throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS');
|
||||
}
|
||||
|
||||
if (params.refreshToken) {
|
||||
const tokenHash = await hashToken(params.refreshToken);
|
||||
query.refreshTokenHash = tokenHash;
|
||||
}
|
||||
|
||||
if (params.userId) {
|
||||
query.user = params.userId;
|
||||
}
|
||||
|
||||
if (params.sessionId) {
|
||||
const sessionId = params.sessionId.sessionId || params.sessionId;
|
||||
if (!mongoose.Types.ObjectId.isValid(sessionId)) {
|
||||
throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID');
|
||||
}
|
||||
query._id = sessionId;
|
||||
}
|
||||
|
||||
// Add expiration check to only return valid sessions
|
||||
query.expiration = { $gt: new Date() };
|
||||
|
||||
const sessionQuery = Session.findOne(query);
|
||||
|
||||
if (options.lean) {
|
||||
return await sessionQuery.lean();
|
||||
}
|
||||
|
||||
return await sessionQuery.exec();
|
||||
} catch (error) {
|
||||
logger.error('[findSession] Error finding session:', error);
|
||||
throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Updates session expiration
|
||||
* @param {Session|string} session - The session or session ID to update
|
||||
* @param {Date} [newExpiration] - Optional new expiration date
|
||||
* @returns {Promise<Session>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const updateExpiration = async (session, newExpiration) => {
|
||||
try {
|
||||
const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session;
|
||||
|
||||
if (!sessionDoc) {
|
||||
throw new SessionError('Session not found', 'SESSION_NOT_FOUND');
|
||||
}
|
||||
|
||||
sessionDoc.expiration = newExpiration || new Date(Date.now() + expires);
|
||||
return await sessionDoc.save();
|
||||
} catch (error) {
|
||||
logger.error('[updateExpiration] Error updating session:', error);
|
||||
throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes a session by refresh token or session ID
|
||||
* @param {Object} params - Delete parameters
|
||||
* @param {string} [params.refreshToken] - The refresh token of the session to delete
|
||||
* @param {string} [params.sessionId] - The ID of the session to delete
|
||||
* @returns {Promise<Object>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const deleteSession = async (params) => {
|
||||
try {
|
||||
if (!params.refreshToken && !params.sessionId) {
|
||||
throw new SessionError(
|
||||
'Either refreshToken or sessionId is required',
|
||||
'INVALID_DELETE_PARAMS',
|
||||
);
|
||||
}
|
||||
|
||||
const query = {};
|
||||
|
||||
if (params.refreshToken) {
|
||||
query.refreshTokenHash = await hashToken(params.refreshToken);
|
||||
}
|
||||
|
||||
if (params.sessionId) {
|
||||
query._id = params.sessionId;
|
||||
}
|
||||
|
||||
const result = await Session.deleteOne(query);
|
||||
|
||||
if (result.deletedCount === 0) {
|
||||
logger.warn('[deleteSession] No session found to delete');
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[deleteSession] Error deleting session:', error);
|
||||
throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Deletes all sessions for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session
|
||||
* @param {string} [options.currentSessionId] - The ID of the current session to exclude
|
||||
* @returns {Promise<Object>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const deleteAllUserSessions = async (userId, options = {}) => {
|
||||
try {
|
||||
if (!userId) {
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
// Extract userId if it's passed as an object
|
||||
const userIdString = userId.userId || userId;
|
||||
|
||||
if (!mongoose.Types.ObjectId.isValid(userIdString)) {
|
||||
throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT');
|
||||
}
|
||||
|
||||
const query = { user: userIdString };
|
||||
|
||||
if (options.excludeCurrentSession && options.currentSessionId) {
|
||||
query._id = { $ne: options.currentSessionId };
|
||||
}
|
||||
|
||||
const result = await Session.deleteMany(query);
|
||||
|
||||
if (result.deletedCount > 0) {
|
||||
logger.debug(
|
||||
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllUserSessions] Error deleting user sessions:', error);
|
||||
throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates a refresh token for a session
|
||||
* @param {Session} session - The session to generate a token for
|
||||
* @returns {Promise<string>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const generateRefreshToken = async (session) => {
|
||||
if (!session || !session.user) {
|
||||
throw new SessionError('Invalid session object', 'INVALID_SESSION');
|
||||
}
|
||||
|
||||
try {
|
||||
const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires;
|
||||
|
||||
if (!session.expiration) {
|
||||
session.expiration = new Date(expiresIn);
|
||||
}
|
||||
|
||||
const refreshToken = await signPayload({
|
||||
payload: { id: this.user },
|
||||
payload: {
|
||||
id: session.user,
|
||||
sessionId: session._id,
|
||||
},
|
||||
secret: process.env.JWT_REFRESH_SECRET,
|
||||
expirationTime: Math.floor((expiresIn - Date.now()) / 1000),
|
||||
});
|
||||
|
||||
this.refreshTokenHash = await hashToken(refreshToken);
|
||||
|
||||
await this.save();
|
||||
session.refreshTokenHash = await hashToken(refreshToken);
|
||||
await session.save();
|
||||
|
||||
return refreshToken;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'Error generating refresh token. Is a `JWT_REFRESH_SECRET` set in the .env file?\n\n',
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
logger.error('[generateRefreshToken] Error generating refresh token:', error);
|
||||
throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
sessionSchema.statics.deleteAllUserSessions = async function (userId) {
|
||||
/**
|
||||
* Counts active sessions for a user
|
||||
* @param {string} userId - The ID of the user
|
||||
* @returns {Promise<number>}
|
||||
* @throws {SessionError}
|
||||
*/
|
||||
const countActiveSessions = async (userId) => {
|
||||
try {
|
||||
if (!userId) {
|
||||
return;
|
||||
}
|
||||
const result = await this.deleteMany({ user: userId });
|
||||
if (result && result?.deletedCount > 0) {
|
||||
logger.debug(
|
||||
`[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userId}.`,
|
||||
);
|
||||
throw new SessionError('User ID is required', 'INVALID_USER_ID');
|
||||
}
|
||||
|
||||
return await Session.countDocuments({
|
||||
user: userId,
|
||||
expiration: { $gt: new Date() },
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[deleteAllUserSessions] Error in deleting user sessions:', error);
|
||||
throw error;
|
||||
logger.error('[countActiveSessions] Error counting active sessions:', error);
|
||||
throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED');
|
||||
}
|
||||
};
|
||||
|
||||
const Session = mongoose.model('Session', sessionSchema);
|
||||
|
||||
module.exports = Session;
|
||||
module.exports = {
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
SessionError,
|
||||
};
|
||||
|
||||
@@ -27,6 +27,9 @@ transactionSchema.methods.calculateTokenValue = function () {
|
||||
*/
|
||||
transactionSchema.statics.create = async function (txData) {
|
||||
const Transaction = this;
|
||||
if (txData.rawAmount != null && isNaN(txData.rawAmount)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = new Transaction(txData);
|
||||
transaction.endpointTokenConfig = txData.endpointTokenConfig;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Transaction } = require('./Transaction');
|
||||
const Balance = require('./Balance');
|
||||
const { spendTokens, spendStructuredTokens } = require('./spendTokens');
|
||||
const { getMultiplier, getCacheMultiplier } = require('./tx');
|
||||
@@ -346,3 +347,28 @@ describe('Structured Token Spending Tests', () => {
|
||||
expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15
|
||||
});
|
||||
});
|
||||
|
||||
describe('NaN Handling Tests', () => {
|
||||
test('should skip transaction creation when rawAmount is NaN', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const initialBalance = 10000000;
|
||||
await Balance.create({ user: userId, tokenCredits: initialBalance });
|
||||
|
||||
const model = 'gpt-3.5-turbo';
|
||||
const txData = {
|
||||
user: userId,
|
||||
conversationId: 'test-conversation-id',
|
||||
model,
|
||||
context: 'test',
|
||||
endpointTokenConfig: null,
|
||||
rawAmount: NaN,
|
||||
tokenType: 'prompt',
|
||||
};
|
||||
|
||||
const result = await Transaction.create(txData);
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance.tokenCredits).toBe(initialBalance);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -26,10 +26,18 @@ const {
|
||||
deleteMessagesSince,
|
||||
deleteMessages,
|
||||
} = require('./Message');
|
||||
const {
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
} = require('./Session');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { createToken, findToken, updateToken, deleteTokens } = require('./Token');
|
||||
const Session = require('./Session');
|
||||
const Balance = require('./Balance');
|
||||
const User = require('./User');
|
||||
const Key = require('./Key');
|
||||
@@ -75,8 +83,15 @@ module.exports = {
|
||||
updateToken,
|
||||
deleteTokens,
|
||||
|
||||
createSession,
|
||||
findSession,
|
||||
updateExpiration,
|
||||
deleteSession,
|
||||
deleteAllUserSessions,
|
||||
generateRefreshToken,
|
||||
countActiveSessions,
|
||||
|
||||
User,
|
||||
Key,
|
||||
Session,
|
||||
Balance,
|
||||
};
|
||||
|
||||
@@ -16,7 +16,6 @@ const keySchema = mongoose.Schema({
|
||||
},
|
||||
expiresAt: {
|
||||
type: Date,
|
||||
expires: 0,
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
20
api/models/schema/session.js
Normal file
20
api/models/schema/session.js
Normal file
@@ -0,0 +1,20 @@
|
||||
const mongoose = require('mongoose');
|
||||
|
||||
const sessionSchema = mongoose.Schema({
|
||||
refreshTokenHash: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
expiration: {
|
||||
type: Date,
|
||||
required: true,
|
||||
expires: 0,
|
||||
},
|
||||
user: {
|
||||
type: mongoose.Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
required: true,
|
||||
},
|
||||
});
|
||||
|
||||
module.exports = sessionSchema;
|
||||
@@ -41,10 +41,10 @@
|
||||
"@keyv/redis": "^2.8.1",
|
||||
"@langchain/community": "^0.3.14",
|
||||
"@langchain/core": "^0.3.18",
|
||||
"@langchain/google-genai": "^0.1.4",
|
||||
"@langchain/google-vertexai": "^0.1.4",
|
||||
"@langchain/google-genai": "^0.1.6",
|
||||
"@langchain/google-vertexai": "^0.1.6",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^1.8.8",
|
||||
"@librechat/agents": "^1.9.94",
|
||||
"axios": "^1.7.7",
|
||||
"bcryptjs": "^2.4.3",
|
||||
"cheerio": "^1.0.0-rc.12",
|
||||
@@ -76,9 +76,10 @@
|
||||
"librechat-mcp": "*",
|
||||
"lodash": "^4.17.21",
|
||||
"meilisearch": "^0.38.0",
|
||||
"memorystore": "^1.6.7",
|
||||
"mime": "^3.0.0",
|
||||
"module-alias": "^2.2.3",
|
||||
"mongoose": "^8.8.3",
|
||||
"mongoose": "^8.9.5",
|
||||
"multer": "^1.4.5-lts.1",
|
||||
"nanoid": "^3.3.7",
|
||||
"nodejs-gpt": "^1.37.4",
|
||||
|
||||
@@ -6,8 +6,7 @@ const {
|
||||
setAuthTokens,
|
||||
requestPasswordReset,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { hashToken } = require('~/server/utils/crypto');
|
||||
const { Session, getUserById } = require('~/models');
|
||||
const { findSession, getUserById, deleteAllUserSessions } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
@@ -45,6 +44,7 @@ const resetPasswordController = async (req, res) => {
|
||||
if (resetPasswordService instanceof Error) {
|
||||
return res.status(400).json(resetPasswordService);
|
||||
} else {
|
||||
await deleteAllUserSessions({ userId: req.body.userId });
|
||||
return res.status(200).json(resetPasswordService);
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -73,11 +73,9 @@ const refreshController = async (req, res) => {
|
||||
return res.status(200).send({ token, user });
|
||||
}
|
||||
|
||||
// Hash the refresh token
|
||||
const hashedToken = await hashToken(refreshToken);
|
||||
|
||||
// Find the session with the hashed refresh token
|
||||
const session = await Session.findOne({ user: userId, refreshTokenHash: hashedToken });
|
||||
const session = await findSession({ userId: userId, refreshToken: refreshToken });
|
||||
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
res.status(200).send({ token, user });
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const {
|
||||
Session,
|
||||
Balance,
|
||||
getFiles,
|
||||
deleteFiles,
|
||||
@@ -7,6 +6,7 @@ const {
|
||||
deletePresets,
|
||||
deleteMessages,
|
||||
deleteUserById,
|
||||
deleteAllUserSessions,
|
||||
} = require('~/models');
|
||||
const User = require('~/models/User');
|
||||
const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService');
|
||||
@@ -112,7 +112,7 @@ const deleteUserController = async (req, res) => {
|
||||
|
||||
try {
|
||||
await deleteMessages({ user: user.id }); // delete user messages
|
||||
await Session.deleteMany({ user: user.id }); // delete user sessions
|
||||
await deleteAllUserSessions({ userId: user.id }); // delete user sessions
|
||||
await Transaction.deleteMany({ user: user.id }); // delete user transactions
|
||||
await deleteUserKey({ userId: user.id, all: true }); // delete user keys
|
||||
await Balance.deleteMany({ user: user._id }); // delete user balances
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider');
|
||||
const {
|
||||
EnvVar,
|
||||
Providers,
|
||||
GraphEvents,
|
||||
ToolEndHandler,
|
||||
handleToolCalls,
|
||||
ChatModelStreamHandler,
|
||||
} = require('@librechat/agents');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
@@ -57,13 +59,22 @@ class ModelEndHandler {
|
||||
return;
|
||||
}
|
||||
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (metadata?.model) {
|
||||
usage.model = metadata.model;
|
||||
}
|
||||
try {
|
||||
if (metadata.provider === Providers.GOOGLE) {
|
||||
handleToolCalls(data?.output?.tool_calls, metadata, graph);
|
||||
}
|
||||
|
||||
const usage = data?.output?.usage_metadata;
|
||||
if (!usage) {
|
||||
return;
|
||||
}
|
||||
if (metadata?.model) {
|
||||
usage.model = metadata.model;
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
this.collectedUsage.push(usage);
|
||||
} catch (error) {
|
||||
logger.error('Error handling model end event:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ const { createRun } = require('./run');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */
|
||||
/** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */
|
||||
|
||||
const providerParsers = {
|
||||
[EModelEndpoint.openAI]: openAISchema,
|
||||
@@ -59,6 +60,9 @@ const noSystemModelRegex = [/\bo1\b/gi];
|
||||
class AgentClient extends BaseClient {
|
||||
constructor(options = {}) {
|
||||
super(null, options);
|
||||
/** The current client class
|
||||
* @type {string} */
|
||||
this.clientName = EModelEndpoint.agents;
|
||||
|
||||
/** @type {'discard' | 'summarize'} */
|
||||
this.contextStrategy = 'discard';
|
||||
@@ -90,6 +94,14 @@ class AgentClient extends BaseClient {
|
||||
this.options = Object.assign({ endpoint: options.endpoint }, clientOptions);
|
||||
/** @type {string} */
|
||||
this.model = this.options.agent.model_parameters.model;
|
||||
/** The key for the usage object's input tokens
|
||||
* @type {string} */
|
||||
this.inputTokensKey = 'input_tokens';
|
||||
/** The key for the usage object's output tokens
|
||||
* @type {string} */
|
||||
this.outputTokensKey = 'output_tokens';
|
||||
/** @type {UsageMetadata} */
|
||||
this.usage;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -192,6 +204,7 @@ class AgentClient extends BaseClient {
|
||||
resendFiles: this.options.resendFiles,
|
||||
imageDetail: this.options.imageDetail,
|
||||
spec: this.options.spec,
|
||||
iconURL: this.options.iconURL,
|
||||
},
|
||||
// TODO: PARSE OPTIONS BY PROVIDER, MAY CONTAIN SENSITIVE DATA
|
||||
runOptions,
|
||||
@@ -327,16 +340,18 @@ class AgentClient extends BaseClient {
|
||||
this.options.agent.instructions = systemContent;
|
||||
}
|
||||
|
||||
/** @type {Record<string, number> | undefined} */
|
||||
let tokenCountMap;
|
||||
|
||||
if (this.contextStrategy) {
|
||||
({ payload, promptTokens, messages } = await this.handleContextStrategy({
|
||||
({ payload, promptTokens, tokenCountMap, messages } = await this.handleContextStrategy({
|
||||
orderedMessages,
|
||||
formattedMessages,
|
||||
/* prefer usage_metadata from final message */
|
||||
buildTokenMap: false,
|
||||
}));
|
||||
}
|
||||
|
||||
const result = {
|
||||
tokenCountMap,
|
||||
prompt: payload,
|
||||
promptTokens,
|
||||
messages,
|
||||
@@ -366,8 +381,26 @@ class AgentClient extends BaseClient {
|
||||
* @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage]
|
||||
*/
|
||||
async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) {
|
||||
for (const usage of collectedUsage) {
|
||||
await spendTokens(
|
||||
if (!collectedUsage || !collectedUsage.length) {
|
||||
return;
|
||||
}
|
||||
const input_tokens = collectedUsage[0]?.input_tokens || 0;
|
||||
|
||||
let output_tokens = 0;
|
||||
let previousTokens = input_tokens; // Start with original input
|
||||
for (let i = 0; i < collectedUsage.length; i++) {
|
||||
const usage = collectedUsage[i];
|
||||
if (i > 0) {
|
||||
// Count new tokens generated (input_tokens minus previous accumulated tokens)
|
||||
output_tokens += (Number(usage.input_tokens) || 0) - previousTokens;
|
||||
}
|
||||
|
||||
// Add this message's output tokens
|
||||
output_tokens += Number(usage.output_tokens) || 0;
|
||||
|
||||
// Update previousTokens to include this message's output
|
||||
previousTokens += Number(usage.output_tokens) || 0;
|
||||
spendTokens(
|
||||
{
|
||||
context,
|
||||
conversationId: this.conversationId,
|
||||
@@ -376,8 +409,66 @@ class AgentClient extends BaseClient {
|
||||
model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model,
|
||||
},
|
||||
{ promptTokens: usage.input_tokens, completionTokens: usage.output_tokens },
|
||||
);
|
||||
).catch((err) => {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens',
|
||||
err,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
this.usage = {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get stream usage as returned by this client's API response.
|
||||
* @returns {UsageMetadata} The stream usage object.
|
||||
*/
|
||||
getStreamUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {TMessage} responseMessage
|
||||
* @returns {number}
|
||||
*/
|
||||
getTokenCountForResponse({ content }) {
|
||||
return this.getTokenCountForMessage({
|
||||
role: 'assistant',
|
||||
content,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the correct token count for the current user message based on the token count map and API usage.
|
||||
* Edge case: If the calculation results in a negative value, it returns the original estimate.
|
||||
* If revisiting a conversation with a chat history entirely composed of token estimates,
|
||||
* the cumulative token count going forward should become more accurate as the conversation progresses.
|
||||
* @param {Object} params - The parameters for the calculation.
|
||||
* @param {Record<string, number>} params.tokenCountMap - A map of message IDs to their token counts.
|
||||
* @param {string} params.currentMessageId - The ID of the current message to calculate.
|
||||
* @param {OpenAIUsageMetadata} params.usage - The usage object returned by the API.
|
||||
* @returns {number} The correct token count for the current user message.
|
||||
*/
|
||||
calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) {
|
||||
const originalEstimate = tokenCountMap[currentMessageId] || 0;
|
||||
|
||||
if (!usage || typeof usage[this.inputTokensKey] !== 'number') {
|
||||
return originalEstimate;
|
||||
}
|
||||
|
||||
tokenCountMap[currentMessageId] = 0;
|
||||
const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => {
|
||||
const numCount = Number(count);
|
||||
return sum + (isNaN(numCount) ? 0 : numCount);
|
||||
}, 0);
|
||||
const totalInputTokens = usage[this.inputTokensKey] ?? 0;
|
||||
|
||||
const currentMessageTokens = totalInputTokens - totalTokensFromMap;
|
||||
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
||||
}
|
||||
|
||||
async chatCompletion({ payload, abortController = null }) {
|
||||
@@ -488,12 +579,14 @@ class AgentClient extends BaseClient {
|
||||
// });
|
||||
// }
|
||||
|
||||
/** @type {Partial<RunnableConfig> & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */
|
||||
const config = {
|
||||
configurable: {
|
||||
thread_id: this.conversationId,
|
||||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
},
|
||||
recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit,
|
||||
signal: abortController.signal,
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
@@ -672,12 +765,14 @@ class AgentClient extends BaseClient {
|
||||
);
|
||||
});
|
||||
|
||||
this.recordCollectedUsage({ context: 'message' }).catch((err) => {
|
||||
try {
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage',
|
||||
err,
|
||||
);
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
if (!abortController.signal.aborted) {
|
||||
logger.error(
|
||||
@@ -763,8 +858,11 @@ class AgentClient extends BaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
/** Silent method, as `recordCollectedUsage` is used instead */
|
||||
async recordTokenUsage() {}
|
||||
|
||||
getEncoding() {
|
||||
return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base';
|
||||
return 'o200k_base';
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -398,15 +398,17 @@ const chatV2 = async (req, res) => {
|
||||
response = streamRunManager;
|
||||
response.text = streamRunManager.intermediateText;
|
||||
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
complete: true,
|
||||
text: response.text,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
if (response.text) {
|
||||
const messageCache = getLogStores(CacheKeys.MESSAGES);
|
||||
messageCache.set(
|
||||
responseMessageId,
|
||||
{
|
||||
complete: true,
|
||||
text: response.text,
|
||||
},
|
||||
Time.FIVE_MINUTES,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
await processRun();
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const request = require('supertest');
|
||||
const express = require('express');
|
||||
const routes = require('../');
|
||||
|
||||
@@ -28,6 +28,12 @@ const oauthHandler = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
router.get('/error', (req, res) => {
|
||||
// A single error message is pushed by passport when authentication fails.
|
||||
logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() });
|
||||
res.redirect(`${domains.client}/login`);
|
||||
});
|
||||
|
||||
/**
|
||||
* Google Routes
|
||||
*/
|
||||
@@ -42,7 +48,7 @@ router.get(
|
||||
router.get(
|
||||
'/google/callback',
|
||||
passport.authenticate('google', {
|
||||
failureRedirect: `${domains.client}/login`,
|
||||
failureRedirect: `${domains.client}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
scope: ['openid', 'profile', 'email'],
|
||||
@@ -62,7 +68,7 @@ router.get(
|
||||
router.get(
|
||||
'/facebook/callback',
|
||||
passport.authenticate('facebook', {
|
||||
failureRedirect: `${domains.client}/login`,
|
||||
failureRedirect: `${domains.client}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
scope: ['public_profile'],
|
||||
@@ -81,7 +87,7 @@ router.get(
|
||||
router.get(
|
||||
'/openid/callback',
|
||||
passport.authenticate('openid', {
|
||||
failureRedirect: `${domains.client}/login`,
|
||||
failureRedirect: `${domains.client}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
@@ -99,7 +105,7 @@ router.get(
|
||||
router.get(
|
||||
'/github/callback',
|
||||
passport.authenticate('github', {
|
||||
failureRedirect: `${domains.client}/login`,
|
||||
failureRedirect: `${domains.client}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
scope: ['user:email', 'read:user'],
|
||||
@@ -117,7 +123,7 @@ router.get(
|
||||
router.get(
|
||||
'/discord/callback',
|
||||
passport.authenticate('discord', {
|
||||
failureRedirect: `${domains.client}/login`,
|
||||
failureRedirect: `${domains.client}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
scope: ['identify', 'email'],
|
||||
|
||||
@@ -11,6 +11,7 @@ const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { encryptV2, decryptV2 } = require('~/server/utils/crypto');
|
||||
const { getActions, deleteActions } = require('~/models/Action');
|
||||
const { deleteAssistant } = require('~/models/Assistant');
|
||||
const { logAxiosError } = require('~/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
@@ -146,15 +147,8 @@ async function createActionTool({ action, requestBuilder, zodSchema, name, descr
|
||||
}
|
||||
return res.data;
|
||||
} catch (error) {
|
||||
logger.error(`API call to ${action.metadata.domain} failed`, error);
|
||||
if (error.response) {
|
||||
const { status, data } = error.response;
|
||||
return `API call to ${
|
||||
action.metadata.domain
|
||||
} failed with status ${status}: ${JSON.stringify(data)}`;
|
||||
}
|
||||
|
||||
return `API call to ${action.metadata.domain} failed.`;
|
||||
const logMessage = `API call to ${action.metadata.domain} failed`;
|
||||
logAxiosError({ message: logMessage, error });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ const loadCustomConfig = require('./Config/loadCustomConfig');
|
||||
const handleRateLimits = require('./Config/handleRateLimits');
|
||||
const { loadDefaultInterface } = require('./start/interface');
|
||||
const { azureConfigSetup } = require('./start/azureOpenAI');
|
||||
const { processModelSpecs } = require('./start/modelSpecs');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { agentsConfigSetup } = require('./start/agents');
|
||||
const { initializeRoles } = require('~/models/Role');
|
||||
@@ -122,9 +123,9 @@ const AppService = async (app) => {
|
||||
|
||||
app.locals = {
|
||||
...defaultLocals,
|
||||
modelSpecs: config.modelSpecs,
|
||||
fileConfig: config?.fileConfig,
|
||||
secureImageLinks: config?.secureImageLinks,
|
||||
modelSpecs: processModelSpecs(endpoints, config.modelSpecs),
|
||||
...endpointLocals,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -10,11 +10,18 @@ const {
|
||||
generateToken,
|
||||
deleteUserById,
|
||||
} = require('~/models/userMethods');
|
||||
const { createToken, findToken, deleteTokens, Session } = require('~/models');
|
||||
const {
|
||||
createToken,
|
||||
findToken,
|
||||
deleteTokens,
|
||||
findSession,
|
||||
deleteSession,
|
||||
createSession,
|
||||
generateRefreshToken,
|
||||
} = require('~/models');
|
||||
const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils');
|
||||
const { isEmailDomainAllowed } = require('~/server/services/domains');
|
||||
const { registerSchema } = require('~/strategies/validators');
|
||||
const { hashToken } = require('~/server/utils/crypto');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const domains = {
|
||||
@@ -34,13 +41,11 @@ const genericVerificationMessage = 'Please check your email to verify your email
|
||||
*/
|
||||
const logoutUser = async (userId, refreshToken) => {
|
||||
try {
|
||||
const hash = await hashToken(refreshToken);
|
||||
const session = await findSession({ userId: userId, refreshToken: refreshToken });
|
||||
|
||||
// Find the session with the matching user and refreshTokenHash
|
||||
const session = await Session.findOne({ user: userId, refreshTokenHash: hash });
|
||||
if (session) {
|
||||
try {
|
||||
await Session.deleteOne({ _id: session._id });
|
||||
await deleteSession({ sessionId: session._id });
|
||||
} catch (deleteErr) {
|
||||
logger.error('[logoutUser] Failed to delete session.', deleteErr);
|
||||
return { status: 500, message: 'Failed to delete session.' };
|
||||
@@ -330,18 +335,19 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
const token = await generateToken(user);
|
||||
|
||||
let session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
if (sessionId) {
|
||||
session = await Session.findById(sessionId);
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
} else {
|
||||
session = new Session({ user: userId });
|
||||
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
|
||||
const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7;
|
||||
refreshTokenExpires = Date.now() + expires;
|
||||
}
|
||||
|
||||
const refreshToken = await session.generateRefreshToken();
|
||||
if (sessionId) {
|
||||
session = await findSession({ sessionId: sessionId }, { lean: false });
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
const result = await createSession(userId);
|
||||
session = result.session;
|
||||
refreshToken = result.refreshToken;
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
|
||||
@@ -3,9 +3,10 @@ const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (req, endpoint, parsedBody) => {
|
||||
const {
|
||||
spec,
|
||||
iconURL,
|
||||
agent_id,
|
||||
instructions,
|
||||
spec,
|
||||
maxContextTokens,
|
||||
resendFiles = true,
|
||||
...model_parameters
|
||||
@@ -20,6 +21,7 @@ const buildOptions = (req, endpoint, parsedBody) => {
|
||||
|
||||
const endpointOption = {
|
||||
spec,
|
||||
iconURL,
|
||||
endpoint,
|
||||
agent_id,
|
||||
resendFiles,
|
||||
|
||||
@@ -12,6 +12,7 @@ const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'
|
||||
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
|
||||
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
|
||||
const initCustom = require('~/server/services/Endpoints/custom/initialize');
|
||||
const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
@@ -24,6 +25,7 @@ const providerConfigMap = {
|
||||
[EModelEndpoint.azureOpenAI]: initOpenAI,
|
||||
[EModelEndpoint.anthropic]: initAnthropic,
|
||||
[EModelEndpoint.bedrock]: getBedrockOptions,
|
||||
[EModelEndpoint.google]: initGoogle,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
};
|
||||
|
||||
@@ -116,6 +118,10 @@ const initializeAgentOptions = async ({
|
||||
endpointOption: _endpointOption,
|
||||
});
|
||||
|
||||
if (options.provider != null) {
|
||||
agent.provider = options.provider;
|
||||
}
|
||||
|
||||
agent.model_parameters = Object.assign(model_parameters, options.llmConfig);
|
||||
if (options.configOptions) {
|
||||
agent.model_parameters.configuration = options.configOptions;
|
||||
@@ -219,6 +225,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
collectedUsage,
|
||||
artifactPromises,
|
||||
spec: endpointOption.spec,
|
||||
iconURL: endpointOption.iconURL,
|
||||
agentConfigs,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
maxContextTokens: primaryConfig.maxContextTokens,
|
||||
|
||||
@@ -20,7 +20,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
checkUserKeyExpiry(expiresAt, EModelEndpoint.anthropic);
|
||||
}
|
||||
|
||||
const clientOptions = {};
|
||||
let clientOptions = {};
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const anthropicConfig = req.app.locals[EModelEndpoint.anthropic];
|
||||
@@ -36,7 +36,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
}
|
||||
|
||||
if (optionsOnly) {
|
||||
const requestOptions = Object.assign(
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
@@ -45,9 +45,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
clientOptions,
|
||||
);
|
||||
if (overrideModel) {
|
||||
requestOptions.modelOptions.model = overrideModel;
|
||||
clientOptions.modelOptions.model = overrideModel;
|
||||
}
|
||||
return getLLMConfig(anthropicApiKey, requestOptions);
|
||||
return getLLMConfig(anthropicApiKey, clientOptions);
|
||||
}
|
||||
|
||||
const client = new AnthropicClient(anthropicApiKey, {
|
||||
|
||||
@@ -28,28 +28,32 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
|
||||
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
|
||||
|
||||
/** @type {AnthropicClientOptions} */
|
||||
const requestOptions = {
|
||||
apiKey,
|
||||
model: mergedOptions.model,
|
||||
stream: mergedOptions.stream,
|
||||
temperature: mergedOptions.temperature,
|
||||
top_p: mergedOptions.topP,
|
||||
top_k: mergedOptions.topK,
|
||||
stop_sequences: mergedOptions.stop,
|
||||
max_tokens:
|
||||
topP: mergedOptions.topP,
|
||||
topK: mergedOptions.topK,
|
||||
stopSequences: mergedOptions.stop,
|
||||
maxTokens:
|
||||
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
|
||||
clientOptions: {},
|
||||
};
|
||||
|
||||
const configOptions = {};
|
||||
if (options.proxy) {
|
||||
configOptions.httpAgent = new HttpsProxyAgent(options.proxy);
|
||||
requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy);
|
||||
}
|
||||
|
||||
if (options.reverseProxyUrl) {
|
||||
configOptions.baseURL = options.reverseProxyUrl;
|
||||
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
return { llmConfig: removeNullishValues(requestOptions), configOptions };
|
||||
return {
|
||||
/** @type {AnthropicClientOptions} */
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = { getLLMConfig };
|
||||
|
||||
@@ -61,6 +61,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
eventHandlers,
|
||||
collectedUsage,
|
||||
spec: endpointOption.spec,
|
||||
iconURL: endpointOption.iconURL,
|
||||
endpoint: EModelEndpoint.bedrock,
|
||||
resendFiles: endpointOption.resendFiles,
|
||||
maxContextTokens:
|
||||
|
||||
@@ -60,42 +60,41 @@ const getOptions = async ({ req, endpointOption }) => {
|
||||
streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
/** @type {import('@librechat/agents').BedrockConverseClientOptions} */
|
||||
const requestOptions = Object.assign(
|
||||
{
|
||||
model: endpointOption.model,
|
||||
region: BEDROCK_AWS_DEFAULT_REGION,
|
||||
streaming: true,
|
||||
streamUsage: true,
|
||||
callbacks: [
|
||||
{
|
||||
handleLLMNewToken: async () => {
|
||||
if (!streamRate) {
|
||||
return;
|
||||
}
|
||||
await sleep(streamRate);
|
||||
},
|
||||
/** @type {BedrockClientOptions} */
|
||||
const requestOptions = {
|
||||
model: endpointOption.model,
|
||||
region: BEDROCK_AWS_DEFAULT_REGION,
|
||||
streaming: true,
|
||||
streamUsage: true,
|
||||
callbacks: [
|
||||
{
|
||||
handleLLMNewToken: async () => {
|
||||
if (!streamRate) {
|
||||
return;
|
||||
}
|
||||
await sleep(streamRate);
|
||||
},
|
||||
],
|
||||
},
|
||||
endpointOption.model_parameters,
|
||||
);
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
if (credentials) {
|
||||
requestOptions.credentials = credentials;
|
||||
}
|
||||
|
||||
if (BEDROCK_REVERSE_PROXY) {
|
||||
requestOptions.endpointHost = BEDROCK_REVERSE_PROXY;
|
||||
}
|
||||
|
||||
const configOptions = {};
|
||||
if (PROXY) {
|
||||
/** NOTE: NOT SUPPORTED BY BEDROCK */
|
||||
configOptions.httpAgent = new HttpsProxyAgent(PROXY);
|
||||
}
|
||||
|
||||
if (BEDROCK_REVERSE_PROXY) {
|
||||
configOptions.endpointHost = BEDROCK_REVERSE_PROXY;
|
||||
}
|
||||
|
||||
return {
|
||||
llmConfig: removeNullishValues(requestOptions),
|
||||
/** @type {BedrockClientOptions} */
|
||||
llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
|
||||
configOptions,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -123,7 +123,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
customOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
const clientOptions = {
|
||||
let clientOptions = {
|
||||
reverseProxyUrl: baseURL ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
req,
|
||||
@@ -135,13 +135,13 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
if (optionsOnly) {
|
||||
const modelOptions = endpointOption.model_parameters;
|
||||
if (endpoint !== Providers.OLLAMA) {
|
||||
const requestOptions = Object.assign(
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
modelOptions,
|
||||
},
|
||||
clientOptions,
|
||||
);
|
||||
const options = getLLMConfig(apiKey, requestOptions);
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
if (!customOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
|
||||
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
|
||||
const { GoogleClient } = require('~/app');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/google/llm');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { GoogleClient } = require('~/app');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
|
||||
const {
|
||||
GOOGLE_KEY,
|
||||
GOOGLE_REVERSE_PROXY,
|
||||
@@ -33,7 +34,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
[AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY,
|
||||
};
|
||||
|
||||
const clientOptions = {};
|
||||
let clientOptions = {};
|
||||
|
||||
/** @type {undefined | TBaseEndpoint} */
|
||||
const allConfig = req.app.locals.all;
|
||||
@@ -48,7 +49,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
clientOptions.streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
const client = new GoogleClient(credentials, {
|
||||
clientOptions = {
|
||||
req,
|
||||
res,
|
||||
reverseProxyUrl: GOOGLE_REVERSE_PROXY ?? null,
|
||||
@@ -56,7 +57,22 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
proxy: PROXY ?? null,
|
||||
...clientOptions,
|
||||
...endpointOption,
|
||||
});
|
||||
};
|
||||
|
||||
if (optionsOnly) {
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
modelOptions: endpointOption.model_parameters,
|
||||
},
|
||||
clientOptions,
|
||||
);
|
||||
if (overrideModel) {
|
||||
clientOptions.modelOptions.model = overrideModel;
|
||||
}
|
||||
return getLLMConfig(credentials, clientOptions);
|
||||
}
|
||||
|
||||
const client = new GoogleClient(credentials, clientOptions);
|
||||
|
||||
return {
|
||||
client,
|
||||
|
||||
168
api/server/services/Endpoints/google/llm.js
Normal file
168
api/server/services/Endpoints/google/llm.js
Normal file
@@ -0,0 +1,168 @@
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { AuthKeys } = require('librechat-data-provider');
|
||||
|
||||
// Example internal constant from your code
|
||||
const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {boolean} isGemini2
|
||||
* @returns {Array<{category: string, threshold: string}>}
|
||||
*/
|
||||
function getSafetySettings(isGemini2) {
|
||||
const mapThreshold = (value) => {
|
||||
if (isGemini2 && value === 'BLOCK_NONE') {
|
||||
return 'OFF';
|
||||
}
|
||||
return value;
|
||||
};
|
||||
|
||||
return [
|
||||
{
|
||||
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_SEXUALLY_EXPLICIT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_HATE_SPEECH',
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_HATE_SPEECH || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_HARASSMENT',
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_HARASSMENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
|
||||
threshold: mapThreshold(
|
||||
process.env.GOOGLE_SAFETY_DANGEROUS_CONTENT || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
|
||||
),
|
||||
},
|
||||
{
|
||||
category: 'HARM_CATEGORY_CIVIC_INTEGRITY',
|
||||
threshold: mapThreshold(process.env.GOOGLE_SAFETY_CIVIC_INTEGRITY || 'BLOCK_NONE'),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Replicates core logic from GoogleClient's constructor and setOptions, plus client determination.
|
||||
* Returns an object with the provider label and the final options that would be passed to createLLM.
|
||||
*
|
||||
* @param {string | object} credentials - Either a JSON string or an object containing Google keys
|
||||
* @param {object} [options={}] - The same shape as the "GoogleClient" constructor options
|
||||
*/
|
||||
|
||||
function getLLMConfig(credentials, options = {}) {
|
||||
// 1. Parse credentials
|
||||
let creds = {};
|
||||
if (typeof credentials === 'string') {
|
||||
try {
|
||||
creds = JSON.parse(credentials);
|
||||
} catch (err) {
|
||||
throw new Error(`Error parsing string credentials: ${err.message}`);
|
||||
}
|
||||
} else if (credentials && typeof credentials === 'object') {
|
||||
creds = credentials;
|
||||
}
|
||||
|
||||
// Extract from credentials
|
||||
const serviceKeyRaw = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {};
|
||||
const serviceKey =
|
||||
typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : serviceKeyRaw ?? {};
|
||||
|
||||
const project_id = serviceKey?.project_id ?? null;
|
||||
const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null;
|
||||
|
||||
const reverseProxyUrl = options.reverseProxyUrl;
|
||||
const authHeader = options.authHeader;
|
||||
|
||||
/** @type {GoogleClientOptions | VertexAIClientOptions} */
|
||||
let llmConfig = {
|
||||
...(options.modelOptions || {}),
|
||||
maxRetries: 2,
|
||||
};
|
||||
|
||||
const isGemini2 = llmConfig.model.includes('gemini-2.0');
|
||||
const isGenerativeModel = llmConfig.model.includes('gemini');
|
||||
const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat');
|
||||
const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model);
|
||||
|
||||
llmConfig.safetySettings = getSafetySettings(isGemini2);
|
||||
|
||||
let provider;
|
||||
|
||||
if (project_id && isTextModel) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else if (project_id && isChatModel) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else if (project_id) {
|
||||
provider = Providers.VERTEXAI;
|
||||
} else if (!EXCLUDED_GENAI_MODELS.test(llmConfig.model)) {
|
||||
provider = Providers.GOOGLE;
|
||||
} else {
|
||||
provider = Providers.GOOGLE;
|
||||
}
|
||||
|
||||
// If we have a GCP project => Vertex AI
|
||||
if (project_id && provider === Providers.VERTEXAI) {
|
||||
/** @type {VertexAIClientOptions['authOptions']} */
|
||||
llmConfig.authOptions = {
|
||||
credentials: { ...serviceKey },
|
||||
projectId: project_id,
|
||||
};
|
||||
llmConfig.location = process.env.GOOGLE_LOC || 'us-central1';
|
||||
} else if (apiKey && provider === Providers.GOOGLE) {
|
||||
llmConfig.apiKey = apiKey;
|
||||
}
|
||||
|
||||
/*
|
||||
let legacyOptions = {};
|
||||
// Filter out any "examples" that are empty
|
||||
legacyOptions.examples = (legacyOptions.examples ?? [])
|
||||
.filter(Boolean)
|
||||
.filter((obj) => obj?.input?.content !== '' && obj?.output?.content !== '');
|
||||
|
||||
// If user has "examples" from legacyOptions, push them onto llmConfig
|
||||
if (legacyOptions.examples?.length) {
|
||||
llmConfig.examples = legacyOptions.examples.map((ex) => {
|
||||
const { input, output } = ex;
|
||||
if (!input?.content || !output?.content) {return undefined;}
|
||||
return {
|
||||
input: new HumanMessage(input.content),
|
||||
output: new AIMessage(output.content),
|
||||
};
|
||||
}).filter(Boolean);
|
||||
}
|
||||
*/
|
||||
|
||||
if (reverseProxyUrl) {
|
||||
llmConfig.baseUrl = reverseProxyUrl;
|
||||
}
|
||||
|
||||
if (authHeader) {
|
||||
/**
|
||||
* NOTE: NOT SUPPORTED BY LANGCHAIN GENAI CLIENT,
|
||||
* REQUIRES PR IN https://github.com/langchain-ai/langchainjs
|
||||
*/
|
||||
llmConfig.customHeaders = {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
};
|
||||
}
|
||||
|
||||
// Return the final shape
|
||||
return {
|
||||
/** @type {Providers.GOOGLE | Providers.VERTEXAI} */
|
||||
provider,
|
||||
/** @type {GoogleClientOptions | VertexAIClientOptions} */
|
||||
llmConfig,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getLLMConfig,
|
||||
};
|
||||
@@ -1,4 +1,5 @@
|
||||
// gptPlugins/initializeClient.spec.js
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider');
|
||||
const { getUserKey, getUserKeyValues } = require('~/server/services/UserService');
|
||||
const initializeClient = require('./initialize');
|
||||
|
||||
@@ -54,7 +54,7 @@ const initializeClient = async ({
|
||||
let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint];
|
||||
let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint];
|
||||
|
||||
const clientOptions = {
|
||||
let clientOptions = {
|
||||
contextStrategy,
|
||||
proxy: PROXY ?? null,
|
||||
debug: isEnabled(DEBUG_OPENAI),
|
||||
@@ -134,13 +134,13 @@ const initializeClient = async ({
|
||||
}
|
||||
|
||||
if (optionsOnly) {
|
||||
const requestOptions = Object.assign(
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
modelOptions: endpointOption.model_parameters,
|
||||
},
|
||||
clientOptions,
|
||||
);
|
||||
const options = getLLMConfig(apiKey, requestOptions);
|
||||
const options = getLLMConfig(apiKey, clientOptions);
|
||||
if (!clientOptions.streamRate) {
|
||||
return options;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
jest.mock('~/cache/getLogStores');
|
||||
const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider');
|
||||
const { getUserKey, getUserKeyValues } = require('~/server/services/UserService');
|
||||
const initializeClient = require('./initialize');
|
||||
|
||||
@@ -38,6 +38,7 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
dropParams,
|
||||
} = options;
|
||||
|
||||
/** @type {OpenAIClientOptions} */
|
||||
let llmConfig = {
|
||||
streaming,
|
||||
};
|
||||
@@ -54,29 +55,28 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
});
|
||||
}
|
||||
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
const configOptions = {};
|
||||
|
||||
// Handle OpenRouter or custom reverse proxy
|
||||
if (useOpenRouter || reverseProxyUrl === 'https://openrouter.ai/api/v1') {
|
||||
configOptions.basePath = 'https://openrouter.ai/api/v1';
|
||||
configOptions.baseOptions = {
|
||||
headers: Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
headers,
|
||||
),
|
||||
};
|
||||
configOptions.baseURL = 'https://openrouter.ai/api/v1';
|
||||
configOptions.defaultHeaders = Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
headers,
|
||||
);
|
||||
} else if (reverseProxyUrl) {
|
||||
configOptions.basePath = reverseProxyUrl;
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
if (headers) {
|
||||
configOptions.baseOptions = { headers };
|
||||
configOptions.defaultHeaders = headers;
|
||||
}
|
||||
}
|
||||
|
||||
if (defaultQuery) {
|
||||
configOptions.baseOptions.defaultQuery = defaultQuery;
|
||||
configOptions.defaultQuery = defaultQuery;
|
||||
}
|
||||
|
||||
if (proxy) {
|
||||
@@ -97,9 +97,9 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
}
|
||||
|
||||
if (configOptions.basePath) {
|
||||
if (configOptions.baseURL) {
|
||||
const azureURL = constructAzureURL({
|
||||
baseURL: configOptions.basePath,
|
||||
baseURL: configOptions.baseURL,
|
||||
azureOptions: azure,
|
||||
});
|
||||
azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0];
|
||||
@@ -118,7 +118,12 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
llmConfig.organization = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
return { llmConfig, configOptions };
|
||||
return {
|
||||
/** @type {OpenAIClientOptions} */
|
||||
llmConfig,
|
||||
/** @type {OpenAIClientOptions['configuration']} */
|
||||
configOptions,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = { getLLMConfig };
|
||||
|
||||
@@ -3,7 +3,7 @@ const axios = require('axios');
|
||||
const FormData = require('form-data');
|
||||
const { getCodeBaseURL } = require('@librechat/agents');
|
||||
|
||||
const MAX_FILE_SIZE = 25 * 1024 * 1024;
|
||||
const MAX_FILE_SIZE = 150 * 1024 * 1024;
|
||||
|
||||
/**
|
||||
* Retrieves a download stream for a specified file.
|
||||
|
||||
@@ -59,6 +59,6 @@ class Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
const tokenizerService = new Tokenizer();
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
module.exports = tokenizerService;
|
||||
module.exports = TokenizerSingleton;
|
||||
|
||||
136
api/server/services/Tokenizer.spec.js
Normal file
136
api/server/services/Tokenizer.spec.js
Normal file
@@ -0,0 +1,136 @@
|
||||
/**
|
||||
* @file Tokenizer.spec.cjs
|
||||
*
|
||||
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
|
||||
* Make sure to install `tiktoken` and have it configured properly.
|
||||
*/
|
||||
|
||||
const Tokenizer = require('./Tokenizer'); // <-- Adjust path to your singleton file
|
||||
const { logger } = require('~/config');
|
||||
|
||||
describe('Tokenizer', () => {
|
||||
it('should be a singleton (same instance)', () => {
|
||||
const AnotherTokenizer = require('./Tokenizer'); // same path
|
||||
expect(Tokenizer).toBe(AnotherTokenizer);
|
||||
});
|
||||
|
||||
describe('getTokenizer', () => {
|
||||
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
|
||||
// The real `encoding_for_model` will be called internally
|
||||
// as soon as we pass isModelName = true.
|
||||
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
|
||||
|
||||
// Basic sanity checks
|
||||
expect(tokenizer).toBeDefined();
|
||||
// You can optionally check certain properties from `tiktoken` if they exist
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
|
||||
// The real `get_encoding` will be called internally
|
||||
// as soon as we pass isModelName = false.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
|
||||
expect(tokenizer).toBeDefined();
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should return cached tokenizer if previously fetched', () => {
|
||||
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
// Should be the exact same instance from the cache
|
||||
expect(tokenizer1).toBe(tokenizer2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAndResetAllEncoders', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
|
||||
// By creating two different encodings, we populate the cache
|
||||
Tokenizer.getTokenizer('cl100k_base', false);
|
||||
Tokenizer.getTokenizer('r50k_base', false);
|
||||
|
||||
// Now free them
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// The internal cache is cleared
|
||||
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
|
||||
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
|
||||
|
||||
// tokenizerCallsCount is reset to 1
|
||||
expect(Tokenizer.tokenizerCallsCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should catch and log errors if freeing fails', () => {
|
||||
// Mock logger.error before the test
|
||||
const mockLoggerError = jest.spyOn(logger, 'error');
|
||||
|
||||
// Set up a problematic tokenizer in the cache
|
||||
Tokenizer.tokenizersCache['cl100k_base'] = {
|
||||
free() {
|
||||
throw new Error('Intentional free error');
|
||||
},
|
||||
};
|
||||
|
||||
// Should not throw uncaught errors
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// Verify logger.error was called with correct arguments
|
||||
expect(mockLoggerError).toHaveBeenCalledWith(
|
||||
'[Tokenizer] Free and reset encoders error',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
// Clean up
|
||||
mockLoggerError.mockRestore();
|
||||
Tokenizer.tokenizersCache = {};
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('should return the number of tokens in the given text', () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset encoders if an error is thrown', () => {
|
||||
// We can simulate an error by temporarily overriding the selected tokenizer’s `encode` method.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const originalEncode = tokenizer.encode;
|
||||
tokenizer.encode = () => {
|
||||
throw new Error('Forced error');
|
||||
};
|
||||
|
||||
// Despite the forced error, the code should catch and reset, then re-encode
|
||||
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
|
||||
// Restore the original encode
|
||||
tokenizer.encode = originalEncode;
|
||||
});
|
||||
|
||||
it('should reset tokenizers after 25 calls', () => {
|
||||
// Spy on freeAndResetAllEncoders
|
||||
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
|
||||
|
||||
// Make 24 calls; should NOT reset yet
|
||||
for (let i = 0; i < 24; i++) {
|
||||
Tokenizer.getTokenCount('test text', 'cl100k_base');
|
||||
}
|
||||
expect(resetSpy).not.toHaveBeenCalled();
|
||||
|
||||
// 25th call triggers the reset
|
||||
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
|
||||
expect(resetSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
61
api/server/services/start/modelSpecs.js
Normal file
61
api/server/services/start/modelSpecs.js
Normal file
@@ -0,0 +1,61 @@
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const { normalizeEndpointName } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Sets up Model Specs from the config (`librechat.yaml`) file.
|
||||
* @param {TCustomConfig['endpoints']} [endpoints] - The loaded custom configuration for endpoints.
|
||||
* @param {TCustomConfig['modelSpecs'] | undefined} [modelSpecs] - The loaded custom configuration for model specs.
|
||||
* @returns {TCustomConfig['modelSpecs'] | undefined} The processed model specs, if any.
|
||||
*/
|
||||
function processModelSpecs(endpoints, _modelSpecs) {
|
||||
if (!_modelSpecs) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** @type {TCustomConfig['modelSpecs']['list']} */
|
||||
const modelSpecs = [];
|
||||
/** @type {TCustomConfig['modelSpecs']['list']} */
|
||||
const list = _modelSpecs.list;
|
||||
|
||||
const customEndpoints = endpoints?.[EModelEndpoint.custom] ?? [];
|
||||
|
||||
for (const spec of list) {
|
||||
if (EModelEndpoint[spec.preset.endpoint] && spec.preset.endpoint !== EModelEndpoint.custom) {
|
||||
modelSpecs.push(spec);
|
||||
continue;
|
||||
} else if (spec.preset.endpoint === EModelEndpoint.custom) {
|
||||
logger.warn(
|
||||
`Model Spec with endpoint "${spec.preset.endpoint}" is not supported. You must specify the name of the custom endpoint (case-sensitive, as defined in your config). Skipping model spec...`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const normalizedName = normalizeEndpointName(spec.preset.endpoint);
|
||||
const endpoint = customEndpoints.find(
|
||||
(customEndpoint) => normalizedName === normalizeEndpointName(customEndpoint.name),
|
||||
);
|
||||
|
||||
if (!endpoint) {
|
||||
logger.warn(`Model spec with endpoint "${spec.preset.endpoint}" was skipped: Endpoint not found in configuration. The \`endpoint\` value must exactly match either a system-defined endpoint or a custom endpoint defined by the user.
|
||||
|
||||
For more information, see the documentation at https://www.librechat.ai/docs/configuration/librechat_yaml/object_structure/model_specs#endpoint`);
|
||||
continue;
|
||||
}
|
||||
|
||||
modelSpecs.push({
|
||||
...spec,
|
||||
preset: {
|
||||
...spec.preset,
|
||||
endpoint: normalizedName,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
..._modelSpecs,
|
||||
list: modelSpecs,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = { processModelSpecs };
|
||||
@@ -1,6 +1,7 @@
|
||||
const Redis = require('ioredis');
|
||||
const passport = require('passport');
|
||||
const session = require('express-session');
|
||||
const MemoryStore = require('memorystore')(session);
|
||||
const RedisStore = require('connect-redis').default;
|
||||
const {
|
||||
setupOpenId,
|
||||
@@ -48,6 +49,10 @@ const configureSocialLogins = (app) => {
|
||||
.on('ready', () => logger.info('ioredis successfully initialized.'))
|
||||
.on('reconnecting', () => logger.info('ioredis reconnecting...'));
|
||||
sessionOptions.store = new RedisStore({ client, prefix: 'librechat' });
|
||||
} else {
|
||||
sessionOptions.store = new MemoryStore({
|
||||
checkPeriod: 86400000, // prune expired entries every 24h
|
||||
});
|
||||
}
|
||||
app.use(session(sessionOptions));
|
||||
app.use(passport.session());
|
||||
|
||||
@@ -113,7 +113,7 @@ async function importLibreChatConvo(
|
||||
*/
|
||||
const traverseMessages = async (messages, parentMessageId = null) => {
|
||||
for (const message of messages) {
|
||||
if (!message.text) {
|
||||
if (!message.text && !message.content) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -121,6 +121,7 @@ async function importLibreChatConvo(
|
||||
if (message.sender?.toLowerCase() === 'user' || message.isCreatedByUser) {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
content: message.content,
|
||||
sender: 'user',
|
||||
isCreatedByUser: true,
|
||||
parentMessageId: parentMessageId,
|
||||
@@ -128,6 +129,7 @@ async function importLibreChatConvo(
|
||||
} else {
|
||||
savedMessage = await importBatchBuilder.saveMessage({
|
||||
text: message.text,
|
||||
content: message.content,
|
||||
sender: message.sender,
|
||||
isCreatedByUser: false,
|
||||
model: options.model,
|
||||
|
||||
@@ -5,3 +5,4 @@ process.env.MONGO_URI = 'mongodb://127.0.0.1:27017/dummy-uri';
|
||||
process.env.BAN_VIOLATIONS = 'true';
|
||||
process.env.BAN_DURATION = '7200000';
|
||||
process.env.BAN_INTERVAL = '20';
|
||||
process.env.CI = 'true';
|
||||
|
||||
@@ -38,12 +38,36 @@
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports OpenAIClientOptions
|
||||
* @typedef {import('@librechat/agents').OpenAIClientOptions} OpenAIClientOptions
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports AnthropicClientOptions
|
||||
* @typedef {import('@librechat/agents').AnthropicClientOptions} AnthropicClientOptions
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports BedrockClientOptions
|
||||
* @typedef {import('@librechat/agents').BedrockConverseClientOptions} BedrockClientOptions
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports VertexAIClientOptions
|
||||
* @typedef {import('@librechat/agents').VertexAIClientOptions} VertexAIClientOptions
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports GoogleClientOptions
|
||||
* @typedef {import('@librechat/agents').GoogleClientOptions} GoogleClientOptions
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports StreamEventData
|
||||
* @typedef {import('@librechat/agents').StreamEventData} StreamEventData
|
||||
|
||||
@@ -139,6 +139,6 @@
|
||||
"typescript": "^5.0.4",
|
||||
"vite": "^5.1.1",
|
||||
"vite-plugin-node-polyfills": "^0.17.0",
|
||||
"vite-plugin-pwa": "^0.20.5"
|
||||
"vite-plugin-pwa": "^0.21.1"
|
||||
}
|
||||
}
|
||||
|
||||
BIN
client/public/assets/openweather.png
Normal file
BIN
client/public/assets/openweather.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.6 KiB |
1
client/public/assets/r.svg
Normal file
1
client/public/assets/r.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg role="img" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><title>R</title><path d="M12 2.746c-6.627 0-12 3.599-12 8.037 0 3.897 4.144 7.144 9.64 7.88V16.26c-2.924-.915-4.925-2.755-4.925-4.877 0-3.035 4.084-5.494 9.12-5.494 5.038 0 8.757 1.683 8.757 5.494 0 1.976-.999 3.379-2.662 4.272.09.066.174.128.258.216.169.149.25.363.372.544 2.128-1.45 3.44-3.437 3.44-5.631 0-4.44-5.373-8.038-12-8.038zm-2.111 4.99v13.516l4.093-.002-.002-5.291h1.1c.225 0 .321.066.549.25.272.22.715.982.715.982l2.164 4.063 4.627-.002-2.864-4.826s-.086-.193-.265-.383a2.22 2.22 0 00-.582-.416c-.422-.214-1.149-.434-1.149-.434s3.578-.264 3.578-3.826c0-3.562-3.744-3.63-3.744-3.63zm4.127 2.93l2.478.002s1.149-.062 1.149 1.127c0 1.165-1.149 1.17-1.149 1.17h-2.478zm1.754 6.119c-.494.049-1.012.079-1.54.088v1.807a16.622 16.622 0 002.37-.473l-.471-.891s-.108-.183-.248-.394c-.039-.054-.08-.098-.111-.137z"/></svg>
|
||||
|
After Width: | Height: | Size: 894 B |
14
client/src/Providers/SetConvoContext.tsx
Normal file
14
client/src/Providers/SetConvoContext.tsx
Normal file
@@ -0,0 +1,14 @@
|
||||
import { createContext, useContext, useRef } from 'react';
|
||||
import type { MutableRefObject } from 'react';
|
||||
|
||||
type SetConvoContext = MutableRefObject<boolean>;
|
||||
|
||||
export const SetConvoContext = createContext<SetConvoContext>({} as SetConvoContext);
|
||||
|
||||
export const SetConvoProvider = ({ children }: { children: React.ReactNode }) => {
|
||||
const hasSetConversation = useRef<boolean>(false);
|
||||
|
||||
return <SetConvoContext.Provider value={hasSetConversation}>{children}</SetConvoContext.Provider>;
|
||||
};
|
||||
|
||||
export const useSetConvoContext = () => useContext(SetConvoContext);
|
||||
@@ -18,3 +18,4 @@ export * from './AnnouncerContext';
|
||||
export * from './AgentsMapContext';
|
||||
export * from './CodeBlockContext';
|
||||
export * from './ToolCallsMapContext';
|
||||
export * from './SetConvoContext';
|
||||
|
||||
@@ -23,6 +23,7 @@ export type AgentForm = {
|
||||
instructions: string | null;
|
||||
model: string | null;
|
||||
model_parameters: AgentModelParameters;
|
||||
conversation_starters: string[];
|
||||
tools?: string[];
|
||||
provider?: AgentProvider | OptionWithIcon;
|
||||
agent_ids?: string[];
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export * from './a11y';
|
||||
export * from './artifacts';
|
||||
export * from './types';
|
||||
export * from './menus';
|
||||
export * from './tools';
|
||||
export * from './assistants-types';
|
||||
export * from './agents-types';
|
||||
|
||||
24
client/src/common/menus.ts
Normal file
24
client/src/common/menus.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
export type RenderProp<
|
||||
P = React.HTMLAttributes<any> & {
|
||||
ref?: React.Ref<any>;
|
||||
},
|
||||
> = (props: P) => React.ReactNode;
|
||||
|
||||
export interface MenuItemProps {
|
||||
id?: string;
|
||||
label?: string;
|
||||
onClick?: (e: React.MouseEvent<HTMLButtonElement | HTMLDivElement>) => void;
|
||||
icon?: React.ReactNode;
|
||||
kbd?: string;
|
||||
show?: boolean;
|
||||
disabled?: boolean;
|
||||
separate?: boolean;
|
||||
hideOnClick?: boolean;
|
||||
dialog?: React.ReactElement;
|
||||
ref?: React.Ref<any>;
|
||||
render?:
|
||||
| RenderProp<React.HTMLAttributes<any> & { ref?: React.Ref<any> | undefined }>
|
||||
| React.ReactElement<any, string | React.JSXElementConstructor<any>>
|
||||
| undefined;
|
||||
}
|
||||
@@ -91,7 +91,14 @@ export type IconMapProps = {
|
||||
size?: number;
|
||||
};
|
||||
|
||||
export type AgentIconMapProps = IconMapProps & { agentName: string };
|
||||
export type IconComponent = React.ComponentType<IconMapProps>;
|
||||
export type AgentIconComponent = React.ComponentType<AgentIconMapProps>;
|
||||
export type IconComponentTypes = IconComponent | AgentIconComponent;
|
||||
export type IconsRecord = {
|
||||
[key in t.EModelEndpoint | 'unknown' | string]: IconComponentTypes | null | undefined;
|
||||
};
|
||||
|
||||
export type AgentIconMapProps = IconMapProps & { agentName?: string };
|
||||
|
||||
export type NavLink = {
|
||||
title: string;
|
||||
@@ -307,6 +314,12 @@ export type TMessageProps = {
|
||||
setSiblingIdx?: ((value: number) => void | React.Dispatch<React.SetStateAction<number>>) | null;
|
||||
};
|
||||
|
||||
export type TMessageIcon = { endpoint?: string | null; isCreatedByUser?: boolean } & Pick<
|
||||
t.TConversation,
|
||||
'modelLabel'
|
||||
> &
|
||||
Pick<t.TMessage, 'model' | 'iconURL'>;
|
||||
|
||||
export type TInitialProps = {
|
||||
text: string;
|
||||
edit: boolean;
|
||||
|
||||
@@ -47,7 +47,6 @@ const MermaidDiagram: React.FC<MermaidDiagramProps> = ({ content }) => {
|
||||
diagramPadding: 8,
|
||||
htmlLabels: true,
|
||||
useMaxWidth: true,
|
||||
defaultRenderer: 'dagre-d3',
|
||||
padding: 15,
|
||||
wrappingWidth: 200,
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@ export const ErrorMessage = ({ children }: { children: React.ReactNode }) => (
|
||||
<div
|
||||
role="alert"
|
||||
aria-live="assertive"
|
||||
className="rounded-md border border-red-500 bg-red-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-200"
|
||||
className="relative mt-6 rounded-lg border border-red-500/20 bg-red-50/50 px-6 py-4 text-red-700 shadow-sm transition-all dark:bg-red-950/30 dark:text-red-100"
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -26,7 +26,10 @@ function Login() {
|
||||
<p className="my-4 text-center text-sm font-light text-gray-700 dark:text-white">
|
||||
{' '}
|
||||
{localize('com_auth_no_account')}{' '}
|
||||
<a href="/register" className="p-1 text-green-500">
|
||||
<a
|
||||
href="/register"
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
>
|
||||
{localize('com_auth_sign_up')}
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -153,16 +153,24 @@ const LoginForm: React.FC<TLoginFormProps> = ({ onSubmit, startupConfig, error,
|
||||
{renderError('password')}
|
||||
</div>
|
||||
{startupConfig.passwordResetEnabled && (
|
||||
<a href="/forgot-password" className="text-sm text-green-500">
|
||||
<a
|
||||
href="/forgot-password"
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
>
|
||||
{localize('com_auth_password_forgot')}
|
||||
</a>
|
||||
)}
|
||||
<div className="mt-6">
|
||||
<button
|
||||
aria-label="Sign in"
|
||||
aria-label={localize('com_auth_continue')}
|
||||
data-testid="login-button"
|
||||
type="submit"
|
||||
className="btn-primary w-full transform rounded-2xl px-4 py-3 tracking-wide transition-colors duration-200"
|
||||
className="
|
||||
w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white
|
||||
transition-colors hover:bg-green-700 focus:outline-none focus:ring-2
|
||||
focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50
|
||||
disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700
|
||||
"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
|
||||
@@ -183,7 +183,12 @@ const Registration: React.FC = () => {
|
||||
disabled={Object.keys(errors).length > 0}
|
||||
type="submit"
|
||||
aria-label="Submit registration"
|
||||
className="btn-primary w-full transform rounded-2xl px-4 py-3 tracking-wide transition-colors duration-200"
|
||||
className="
|
||||
w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white
|
||||
transition-colors hover:bg-green-700 focus:outline-none focus:ring-2
|
||||
focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50
|
||||
disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700
|
||||
"
|
||||
>
|
||||
{isSubmitting ? <Spinner /> : localize('com_auth_continue')}
|
||||
</button>
|
||||
@@ -192,7 +197,11 @@ const Registration: React.FC = () => {
|
||||
|
||||
<p className="my-4 text-center text-sm font-light text-gray-700 dark:text-white">
|
||||
{localize('com_auth_already_have_account')}{' '}
|
||||
<a href="/login" aria-label="Login" className="p-1 text-green-500">
|
||||
<a
|
||||
href="/login"
|
||||
aria-label="Login"
|
||||
className="inline-flex p-1 text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
>
|
||||
{localize('com_auth_login')}
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -10,7 +10,7 @@ import { useLocalize } from '~/hooks';
|
||||
const BodyTextWrapper: FC<{ children: ReactNode }> = ({ children }) => {
|
||||
return (
|
||||
<div
|
||||
className="relative mt-4 rounded border border-green-400 bg-green-100 px-4 py-3 text-green-700 dark:bg-green-900 dark:text-white"
|
||||
className="relative mt-6 rounded-lg border border-green-500/20 bg-green-50/50 px-6 py-4 text-green-700 shadow-sm transition-all dark:bg-green-950/30 dark:text-green-100"
|
||||
role="alert"
|
||||
>
|
||||
{children}
|
||||
@@ -21,13 +21,14 @@ const BodyTextWrapper: FC<{ children: ReactNode }> = ({ children }) => {
|
||||
const ResetPasswordBodyText = () => {
|
||||
const localize = useLocalize();
|
||||
return (
|
||||
<div className="flex flex-col">
|
||||
{localize('com_auth_reset_password_if_email_exists')}
|
||||
<span>
|
||||
<a className="text-sm text-green-500 hover:underline" href="/login">
|
||||
{localize('com_auth_back_to_login')}
|
||||
</a>
|
||||
</span>
|
||||
<div className="flex flex-col space-y-4">
|
||||
<p>{localize('com_auth_reset_password_if_email_exists')}</p>
|
||||
<a
|
||||
className="inline-flex text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
href="/login"
|
||||
>
|
||||
{localize('com_auth_back_to_login')}
|
||||
</a>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -76,12 +77,12 @@ function RequestPasswordReset() {
|
||||
|
||||
return (
|
||||
<form
|
||||
className="mt-6"
|
||||
className="mt-8 space-y-6"
|
||||
aria-label="Password reset form"
|
||||
method="POST"
|
||||
onSubmit={handleSubmit(onSubmit)}
|
||||
>
|
||||
<div className="mb-2">
|
||||
<div className="space-y-2">
|
||||
<div className="relative">
|
||||
<input
|
||||
type="email"
|
||||
@@ -105,42 +106,51 @@ function RequestPasswordReset() {
|
||||
})}
|
||||
aria-invalid={!!errors.email}
|
||||
className="
|
||||
webkit-dark-styles transition-color peer w-full rounded-2xl border border-border-light
|
||||
bg-surface-primary px-3.5 pb-2.5 pt-3 text-text-primary duration-200 focus:border-green-500 focus:outline-none
|
||||
peer w-full rounded-lg border border-gray-300 bg-transparent px-4 py-3
|
||||
text-base text-gray-900 placeholder-transparent transition-all
|
||||
focus:border-green-500 focus:outline-none focus:ring-2 focus:ring-green-500/20
|
||||
dark:border-gray-700 dark:text-white dark:focus:border-green-500
|
||||
"
|
||||
placeholder=" "
|
||||
placeholder="email@example.com"
|
||||
/>
|
||||
<label
|
||||
htmlFor="email"
|
||||
className="
|
||||
absolute start-3 top-1.5 z-10 origin-[0] -translate-y-4 scale-75 transform bg-surface-primary px-2 text-sm text-text-secondary-alt duration-200
|
||||
peer-placeholder-shown:top-1/2 peer-placeholder-shown:-translate-y-1/2 peer-placeholder-shown:scale-100
|
||||
peer-focus:top-1.5 peer-focus:-translate-y-4 peer-focus:scale-75 peer-focus:px-2 peer-focus:text-green-500
|
||||
rtl:peer-focus:left-auto rtl:peer-focus:translate-x-1/4
|
||||
absolute -top-2 left-2 z-10 bg-white px-2 text-sm text-gray-600
|
||||
transition-all peer-placeholder-shown:top-3 peer-placeholder-shown:text-base
|
||||
peer-placeholder-shown:text-gray-500 peer-focus:-top-2 peer-focus:text-sm
|
||||
peer-focus:text-green-600 dark:bg-gray-900 dark:text-gray-400
|
||||
dark:peer-focus:text-green-500
|
||||
"
|
||||
>
|
||||
{localize('com_auth_email_address')}
|
||||
</label>
|
||||
</div>
|
||||
{errors.email && (
|
||||
<span role="alert" className="mt-1 text-sm text-red-500 dark:text-red-900">
|
||||
<p role="alert" className="text-sm font-medium text-red-600 dark:text-red-400">
|
||||
{errors.email.message}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-6">
|
||||
<div className="space-y-4">
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!!errors.email}
|
||||
className="btn-primary w-full transform rounded-2xl px-4 py-3 tracking-wide transition-colors duration-200"
|
||||
className="
|
||||
w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white
|
||||
transition-colors hover:bg-green-700 focus:outline-none focus:ring-2
|
||||
focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50
|
||||
disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700
|
||||
"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
<div className="mt-4 flex justify-center">
|
||||
<a href="/login" className="text-sm text-green-500">
|
||||
{localize('com_auth_back_to_login')}
|
||||
</a>
|
||||
</div>
|
||||
<a
|
||||
href="/login"
|
||||
className="block text-center text-sm font-medium text-green-600 transition-colors hover:text-green-700 dark:text-green-400 dark:hover:text-green-300"
|
||||
>
|
||||
{localize('com_auth_back_to_login')}
|
||||
</a>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
|
||||
@@ -35,7 +35,7 @@ function ResetPassword() {
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
className="relative mb-8 mt-4 rounded border border-green-400 bg-green-100 px-4 py-3 text-center text-green-700 dark:bg-gray-900 dark:text-white"
|
||||
className="relative mb-8 mt-4 rounded-2xl border border-green-400 bg-green-100 px-4 py-3 text-center text-green-700 dark:bg-gray-900 dark:text-white"
|
||||
role="alert"
|
||||
>
|
||||
{localize('com_auth_login_with_new_password')}
|
||||
@@ -43,7 +43,7 @@ function ResetPassword() {
|
||||
<button
|
||||
onClick={() => navigate('/login')}
|
||||
aria-label={localize('com_auth_sign_in')}
|
||||
className="w-full transform rounded-md bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-600 focus:bg-green-600 focus:outline-none"
|
||||
className="w-full transform rounded-2xl bg-green-500 px-4 py-3 tracking-wide text-white transition-colors duration-200 hover:bg-green-600 focus:bg-green-600 focus:outline-none"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
@@ -163,7 +163,12 @@ function ResetPassword() {
|
||||
disabled={!!errors.password || !!errors.confirm_password}
|
||||
type="submit"
|
||||
aria-label={localize('com_auth_submit_registration')}
|
||||
className="btn-primary w-full transform rounded-2xl px-4 py-3 tracking-wide transition-colors duration-200"
|
||||
className="
|
||||
w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white
|
||||
transition-colors hover:bg-green-700 focus:outline-none focus:ring-2
|
||||
focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50
|
||||
disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700
|
||||
"
|
||||
>
|
||||
{localize('com_auth_continue')}
|
||||
</button>
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import reactRouter from 'react-router-dom';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import { render, waitFor } from 'test/layout-test-utils';
|
||||
import { getByTestId, render, waitFor } from 'test/layout-test-utils';
|
||||
import * as mockDataProvider from 'librechat-data-provider/react-query';
|
||||
import type { TStartupConfig } from 'librechat-data-provider';
|
||||
import * as authDataProvider from '~/data-provider/Auth/mutations';
|
||||
import AuthLayout from '~/components/Auth/AuthLayout';
|
||||
import Login from '~/components/Auth/Login';
|
||||
|
||||
@@ -61,7 +62,7 @@ const setup = ({
|
||||
},
|
||||
} = {}) => {
|
||||
const mockUseLoginUser = jest
|
||||
.spyOn(mockDataProvider, 'useLoginUserMutation')
|
||||
.spyOn(authDataProvider, 'useLoginUserMutation')
|
||||
//@ts-ignore - we don't need all parameters of the QueryObserverSuccessResult
|
||||
.mockReturnValue(useLoginUserReturnValue);
|
||||
const mockUseGetUserQuery = jest
|
||||
@@ -117,7 +118,7 @@ test('renders login form', () => {
|
||||
const { getByLabelText, getByRole } = setup();
|
||||
expect(getByLabelText(/email/i)).toBeInTheDocument();
|
||||
expect(getByLabelText(/password/i)).toBeInTheDocument();
|
||||
expect(getByRole('button', { name: /Sign in/i })).toBeInTheDocument();
|
||||
expect(getByTestId(document.body, 'login-button')).toBeInTheDocument();
|
||||
expect(getByRole('link', { name: /Sign up/i })).toBeInTheDocument();
|
||||
expect(getByRole('link', { name: /Sign up/i })).toHaveAttribute('href', '/register');
|
||||
expect(getByRole('link', { name: /Continue with Google/i })).toBeInTheDocument();
|
||||
@@ -144,7 +145,7 @@ test('renders login form', () => {
|
||||
|
||||
test('calls loginUser.mutate on login', async () => {
|
||||
const mutate = jest.fn();
|
||||
const { getByLabelText, getByRole } = setup({
|
||||
const { getByLabelText } = setup({
|
||||
// @ts-ignore - we don't need all parameters of the QueryObserverResult
|
||||
useLoginUserReturnValue: {
|
||||
isLoading: false,
|
||||
@@ -155,7 +156,7 @@ test('calls loginUser.mutate on login', async () => {
|
||||
|
||||
const emailInput = getByLabelText(/email/i);
|
||||
const passwordInput = getByLabelText(/password/i);
|
||||
const submitButton = getByRole('button', { name: /Sign in/i });
|
||||
const submitButton = getByTestId(document.body, 'login-button');
|
||||
|
||||
await userEvent.type(emailInput, 'test@test.com');
|
||||
await userEvent.type(passwordInput, 'password');
|
||||
@@ -165,7 +166,7 @@ test('calls loginUser.mutate on login', async () => {
|
||||
});
|
||||
|
||||
test('Navigates to / on successful login', async () => {
|
||||
const { getByLabelText, getByRole, history } = setup({
|
||||
const { getByLabelText, history } = setup({
|
||||
// @ts-ignore - we don't need all parameters of the QueryObserverResult
|
||||
useLoginUserReturnValue: {
|
||||
isLoading: false,
|
||||
@@ -185,7 +186,7 @@ test('Navigates to / on successful login', async () => {
|
||||
|
||||
const emailInput = getByLabelText(/email/i);
|
||||
const passwordInput = getByLabelText(/password/i);
|
||||
const submitButton = getByRole('button', { name: /Sign in/i });
|
||||
const submitButton = getByTestId(document.body, 'login-button');
|
||||
|
||||
await userEvent.type(emailInput, 'test@test.com');
|
||||
await userEvent.type(passwordInput, 'password');
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { render } from 'test/layout-test-utils';
|
||||
import { render, getByTestId } from 'test/layout-test-utils';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import * as mockDataProvider from 'librechat-data-provider/react-query';
|
||||
import type { TStartupConfig } from 'librechat-data-provider';
|
||||
import * as authDataProvider from '~/data-provider/Auth/mutations';
|
||||
import Login from '../LoginForm';
|
||||
|
||||
jest.mock('librechat-data-provider/react-query');
|
||||
@@ -66,7 +67,7 @@ const setup = ({
|
||||
},
|
||||
} = {}) => {
|
||||
const mockUseLoginUser = jest
|
||||
.spyOn(mockDataProvider, 'useLoginUserMutation')
|
||||
.spyOn(authDataProvider, 'useLoginUserMutation')
|
||||
//@ts-ignore - we don't need all parameters of the QueryObserverSuccessResult
|
||||
.mockReturnValue(useLoginUserReturnValue);
|
||||
const mockUseGetUserQuery = jest
|
||||
@@ -112,7 +113,7 @@ test('submits login form', async () => {
|
||||
);
|
||||
const emailInput = getByLabelText(/email/i);
|
||||
const passwordInput = getByLabelText(/password/i);
|
||||
const submitButton = getByRole('button', { name: /Sign in/i });
|
||||
const submitButton = getByTestId(document.body, 'login-button');
|
||||
|
||||
await userEvent.type(emailInput, 'test@example.com');
|
||||
await userEvent.type(passwordInput, 'password');
|
||||
@@ -127,7 +128,7 @@ test('displays validation error messages', async () => {
|
||||
);
|
||||
const emailInput = getByLabelText(/email/i);
|
||||
const passwordInput = getByLabelText(/password/i);
|
||||
const submitButton = getByRole('button', { name: /Sign in/i });
|
||||
const submitButton = getByTestId(document.body, 'login-button');
|
||||
|
||||
await userEvent.type(emailInput, 'test');
|
||||
await userEvent.type(passwordInput, 'pass');
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useRef, Dispatch, SetStateAction } from 'react';
|
||||
import { TConversationTag, TConversation } from 'librechat-data-provider';
|
||||
import { TConversationTag } from 'librechat-data-provider';
|
||||
import OGDialogTemplate from '~/components/ui/OGDialogTemplate';
|
||||
import { useConversationTagMutation } from '~/data-provider';
|
||||
import { OGDialog, Button, Spinner } from '~/components';
|
||||
@@ -10,23 +10,27 @@ import { useLocalize } from '~/hooks';
|
||||
import { logger } from '~/utils';
|
||||
|
||||
type BookmarkEditDialogProps = {
|
||||
context: string;
|
||||
bookmark?: TConversationTag;
|
||||
conversation?: TConversation;
|
||||
tags?: string[];
|
||||
setTags?: (tags: string[]) => void;
|
||||
open: boolean;
|
||||
setOpen: Dispatch<SetStateAction<boolean>>;
|
||||
tags?: string[];
|
||||
setTags?: (tags: string[]) => void;
|
||||
context: string;
|
||||
bookmark?: TConversationTag;
|
||||
conversationId?: string;
|
||||
children?: React.ReactNode;
|
||||
triggerRef?: React.RefObject<HTMLButtonElement>;
|
||||
};
|
||||
|
||||
const BookmarkEditDialog = ({
|
||||
context,
|
||||
bookmark,
|
||||
conversation,
|
||||
tags,
|
||||
setTags,
|
||||
open,
|
||||
setOpen,
|
||||
tags,
|
||||
setTags,
|
||||
context,
|
||||
bookmark,
|
||||
children,
|
||||
triggerRef,
|
||||
conversationId,
|
||||
}: BookmarkEditDialogProps) => {
|
||||
const localize = useLocalize();
|
||||
const formRef = useRef<HTMLFormElement>(null);
|
||||
@@ -44,12 +48,26 @@ const BookmarkEditDialog = ({
|
||||
});
|
||||
setOpen(false);
|
||||
logger.log('tag_mutation', 'tags before setting', tags);
|
||||
|
||||
if (setTags && vars.addToConversation === true) {
|
||||
const newTags = [...(tags || []), vars.tag].filter(
|
||||
(tag) => tag !== undefined,
|
||||
) as string[];
|
||||
setTags(newTags);
|
||||
|
||||
logger.log('tag_mutation', 'tags after', newTags);
|
||||
if (vars.tag == null || vars.tag === '') {
|
||||
return;
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
const tagElement = document.getElementById(vars.tag ?? '');
|
||||
console.log('tagElement', tagElement);
|
||||
if (!tagElement) {
|
||||
return;
|
||||
}
|
||||
tagElement.focus();
|
||||
}, 5);
|
||||
}
|
||||
},
|
||||
onError: () => {
|
||||
@@ -70,7 +88,8 @@ const BookmarkEditDialog = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog open={open} onOpenChange={setOpen}>
|
||||
<OGDialog open={open} onOpenChange={setOpen} triggerRef={triggerRef}>
|
||||
{children}
|
||||
<OGDialogTemplate
|
||||
title="Bookmark"
|
||||
showCloseButton={false}
|
||||
@@ -80,7 +99,7 @@ const BookmarkEditDialog = ({
|
||||
tags={tags}
|
||||
setOpen={setOpen}
|
||||
mutation={mutation}
|
||||
conversation={conversation}
|
||||
conversationId={conversationId}
|
||||
bookmark={bookmark}
|
||||
formRef={formRef}
|
||||
/>
|
||||
@@ -91,6 +110,7 @@ const BookmarkEditDialog = ({
|
||||
type="submit"
|
||||
disabled={mutation.isLoading}
|
||||
onClick={handleSubmitForm}
|
||||
className="text-white"
|
||||
>
|
||||
{mutation.isLoading ? <Spinner /> : localize('com_ui_save')}
|
||||
</Button>
|
||||
|
||||
@@ -2,11 +2,7 @@ import React, { useEffect } from 'react';
|
||||
import { QueryKeys } from 'librechat-data-provider';
|
||||
import { Controller, useForm } from 'react-hook-form';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import type {
|
||||
TConversation,
|
||||
TConversationTag,
|
||||
TConversationTagRequest,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TConversationTag, TConversationTagRequest } from 'librechat-data-provider';
|
||||
import { Checkbox, Label, TextareaAutosize, Input } from '~/components';
|
||||
import { useBookmarkContext } from '~/Providers/BookmarkContext';
|
||||
import { useConversationTagMutation } from '~/data-provider';
|
||||
@@ -17,7 +13,7 @@ import { cn, logger } from '~/utils';
|
||||
type TBookmarkFormProps = {
|
||||
tags?: string[];
|
||||
bookmark?: TConversationTag;
|
||||
conversation?: TConversation;
|
||||
conversationId?: string;
|
||||
formRef: React.RefObject<HTMLFormElement>;
|
||||
setOpen: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
mutation: ReturnType<typeof useConversationTagMutation>;
|
||||
@@ -26,7 +22,7 @@ const BookmarkForm = ({
|
||||
tags,
|
||||
bookmark,
|
||||
mutation,
|
||||
conversation,
|
||||
conversationId,
|
||||
setOpen,
|
||||
formRef,
|
||||
}: TBookmarkFormProps) => {
|
||||
@@ -46,8 +42,8 @@ const BookmarkForm = ({
|
||||
defaultValues: {
|
||||
tag: bookmark?.tag ?? '',
|
||||
description: bookmark?.description ?? '',
|
||||
conversationId: conversation?.conversationId ?? '',
|
||||
addToConversation: conversation ? true : false,
|
||||
conversationId: conversationId ?? '',
|
||||
addToConversation: conversationId != null && conversationId ? true : false,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -142,7 +138,7 @@ const BookmarkForm = ({
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
{conversation && (
|
||||
{conversationId != null && conversationId && (
|
||||
<div className="mt-2 flex w-full items-center">
|
||||
<Controller
|
||||
name="addToConversation"
|
||||
|
||||
@@ -3,7 +3,6 @@ import { MenuItem } from '@headlessui/react';
|
||||
import { BookmarkFilledIcon, BookmarkIcon } from '@radix-ui/react-icons';
|
||||
import type { FC } from 'react';
|
||||
import { Spinner } from '~/components/svg';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
type MenuItemProps = {
|
||||
tag: string | React.ReactNode;
|
||||
@@ -47,10 +46,7 @@ const BookmarkItem: FC<MenuItemProps> = ({ tag, selected, handleSubmit, icon, ..
|
||||
return (
|
||||
<MenuItem
|
||||
aria-label={tag as string}
|
||||
className={cn(
|
||||
'group flex w-full gap-2 rounded-lg p-2.5 text-sm text-text-primary transition-colors duration-200',
|
||||
selected ? 'bg-surface-hover' : 'data-[focus]:bg-surface-hover',
|
||||
)}
|
||||
className="group flex w-full gap-2 rounded-lg p-2.5 text-sm text-text-primary transition-colors duration-200 focus:outline-none data-[focus]:bg-surface-secondary data-[focus]:ring-2 data-[focus]:ring-primary"
|
||||
{...rest}
|
||||
as="button"
|
||||
onClick={clickHandler}
|
||||
|
||||
@@ -37,7 +37,7 @@ const DeleteBookmarkButton: FC<{
|
||||
}, [bookmark, deleteBookmarkMutation]);
|
||||
|
||||
const handleKeyDown = (event: React.KeyboardEvent<HTMLDivElement>) => {
|
||||
if (event.key === 'Enter') {
|
||||
if (event.key === 'Enter' || event.key === ' ') {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
setOpen(!open);
|
||||
@@ -49,6 +49,8 @@ const DeleteBookmarkButton: FC<{
|
||||
<OGDialog open={open} onOpenChange={setOpen}>
|
||||
<OGDialogTrigger asChild>
|
||||
<TooltipAnchor
|
||||
role="button"
|
||||
aria-label={localize('com_ui_bookmarks_delete')}
|
||||
description={localize('com_ui_delete')}
|
||||
className="flex size-7 items-center justify-center rounded-lg transition-colors duration-200 hover:bg-surface-hover"
|
||||
tabIndex={tabIndex}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useState } from 'react';
|
||||
import type { FC } from 'react';
|
||||
import type { TConversationTag } from 'librechat-data-provider';
|
||||
import { TooltipAnchor, OGDialogTrigger } from '~/components/ui';
|
||||
import BookmarkEditDialog from './BookmarkEditDialog';
|
||||
import { TooltipAnchor } from '~/components/ui';
|
||||
import { EditIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
@@ -16,31 +16,34 @@ const EditBookmarkButton: FC<{
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const handleKeyDown = (event: React.KeyboardEvent<HTMLDivElement>) => {
|
||||
if (event.key === 'Enter') {
|
||||
if (event.key === 'Enter' || event.key === ' ') {
|
||||
setOpen(!open);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<BookmarkEditDialog
|
||||
context="EditBookmarkButton"
|
||||
bookmark={bookmark}
|
||||
open={open}
|
||||
setOpen={setOpen}
|
||||
/>
|
||||
<TooltipAnchor
|
||||
description={localize('com_ui_edit')}
|
||||
tabIndex={tabIndex}
|
||||
onFocus={onFocus}
|
||||
onBlur={onBlur}
|
||||
onClick={() => setOpen(!open)}
|
||||
className="flex size-7 items-center justify-center rounded-lg transition-colors duration-200 hover:bg-surface-hover"
|
||||
onKeyDown={handleKeyDown}
|
||||
>
|
||||
<EditIcon />
|
||||
</TooltipAnchor>
|
||||
</>
|
||||
<BookmarkEditDialog
|
||||
context="EditBookmarkButton"
|
||||
bookmark={bookmark}
|
||||
open={open}
|
||||
setOpen={setOpen}
|
||||
>
|
||||
<OGDialogTrigger asChild>
|
||||
<TooltipAnchor
|
||||
role="button"
|
||||
aria-label={localize('com_ui_bookmarks_edit')}
|
||||
description={localize('com_ui_edit')}
|
||||
tabIndex={tabIndex}
|
||||
onFocus={onFocus}
|
||||
onBlur={onBlur}
|
||||
onClick={() => setOpen(!open)}
|
||||
className="flex size-7 items-center justify-center rounded-lg transition-colors duration-200 hover:bg-surface-hover"
|
||||
onKeyDown={handleKeyDown}
|
||||
>
|
||||
<EditIcon />
|
||||
</TooltipAnchor>
|
||||
</OGDialogTrigger>
|
||||
</BookmarkEditDialog>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useParams } from 'react-router-dom';
|
||||
import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { ChatFormValues } from '~/common';
|
||||
import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers';
|
||||
import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks';
|
||||
@@ -24,10 +25,13 @@ function ChatView({ index = 0 }: { index?: number }) {
|
||||
const fileMap = useFileMapContext();
|
||||
|
||||
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
|
||||
select: (data) => {
|
||||
const dataTree = buildTree({ messages: data, fileMap });
|
||||
return dataTree?.length === 0 ? null : dataTree ?? null;
|
||||
},
|
||||
select: useCallback(
|
||||
(data: TMessage[]) => {
|
||||
const dataTree = buildTree({ messages: data, fileMap });
|
||||
return dataTree?.length === 0 ? null : dataTree ?? null;
|
||||
},
|
||||
[fileMap],
|
||||
),
|
||||
enabled: !!fileMap,
|
||||
});
|
||||
|
||||
|
||||
@@ -2,10 +2,11 @@ import { useState, useId, useRef } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { Upload, Share2 } from 'lucide-react';
|
||||
import { ShareButton } from '~/components/Conversations/ConvoOptions';
|
||||
import { useMediaQuery, useLocalize } from '~/hooks';
|
||||
import type * as t from '~/common';
|
||||
import ExportModal from '~/components/Nav/ExportConversation/ExportModal';
|
||||
import { DropdownPopup } from '~/components/ui';
|
||||
import { ShareButton } from '~/components/Conversations/ConvoOptions';
|
||||
import { DropdownPopup, TooltipAnchor } from '~/components/ui';
|
||||
import { useMediaQuery, useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
|
||||
export default function ExportAndShareMenu({
|
||||
@@ -19,6 +20,7 @@ export default function ExportAndShareMenu({
|
||||
const [showShareDialog, setShowShareDialog] = useState(false);
|
||||
|
||||
const menuId = useId();
|
||||
const shareButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const exportButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const isSmallScreen = useMediaQuery('(max-width: 768px)');
|
||||
const conversation = useRecoilValue(store.conversationByIndex(0));
|
||||
@@ -33,31 +35,33 @@ export default function ExportAndShareMenu({
|
||||
return null;
|
||||
}
|
||||
|
||||
const onOpenChange = (value: boolean) => {
|
||||
setShowExports(value);
|
||||
};
|
||||
|
||||
const shareHandler = () => {
|
||||
setIsPopoverActive(false);
|
||||
setShowShareDialog(true);
|
||||
};
|
||||
|
||||
const exportHandler = () => {
|
||||
setIsPopoverActive(false);
|
||||
setShowExports(true);
|
||||
};
|
||||
|
||||
const dropdownItems = [
|
||||
const dropdownItems: t.MenuItemProps[] = [
|
||||
{
|
||||
label: localize('com_endpoint_export'),
|
||||
onClick: exportHandler,
|
||||
icon: <Upload className="icon-md mr-2 text-text-secondary" />,
|
||||
/** NOTE: THE FOLLOWING PROPS ARE REQUIRED FOR MENU ITEMS THAT OPEN DIALOGS */
|
||||
hideOnClick: false,
|
||||
ref: exportButtonRef,
|
||||
render: (props) => <button {...props} />,
|
||||
},
|
||||
{
|
||||
label: localize('com_ui_share'),
|
||||
onClick: shareHandler,
|
||||
icon: <Share2 className="icon-md mr-2 text-text-secondary" />,
|
||||
show: isSharedButtonEnabled,
|
||||
/** NOTE: THE FOLLOWING PROPS ARE REQUIRED FOR MENU ITEMS THAT OPEN DIALOGS */
|
||||
hideOnClick: false,
|
||||
ref: shareButtonRef,
|
||||
render: (props) => <button {...props} />,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -65,38 +69,44 @@ export default function ExportAndShareMenu({
|
||||
<>
|
||||
<DropdownPopup
|
||||
menuId={menuId}
|
||||
focusLoop={true}
|
||||
isOpen={isPopoverActive}
|
||||
setIsOpen={setIsPopoverActive}
|
||||
trigger={
|
||||
<Ariakit.MenuButton
|
||||
ref={exportButtonRef}
|
||||
id="export-menu-button"
|
||||
aria-label="Export options"
|
||||
className="inline-flex size-10 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
>
|
||||
<Upload className="icon-md text-text-secondary" aria-hidden="true" focusable="false" />
|
||||
</Ariakit.MenuButton>
|
||||
<TooltipAnchor
|
||||
description={localize('com_endpoint_export_share')}
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
id="export-menu-button"
|
||||
aria-label="Export options"
|
||||
className="inline-flex size-10 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary"
|
||||
>
|
||||
<Upload
|
||||
className="icon-md text-text-secondary"
|
||||
aria-hidden="true"
|
||||
focusable="false"
|
||||
/>
|
||||
</Ariakit.MenuButton>
|
||||
}
|
||||
/>
|
||||
}
|
||||
items={dropdownItems}
|
||||
className={isSmallScreen ? '' : 'absolute right-0 top-0 mt-2'}
|
||||
/>
|
||||
{showShareDialog && conversation.conversationId != null && (
|
||||
<ShareButton
|
||||
conversationId={conversation.conversationId}
|
||||
title={conversation.title ?? ''}
|
||||
showShareDialog={showShareDialog}
|
||||
setShowShareDialog={setShowShareDialog}
|
||||
/>
|
||||
)}
|
||||
{showExports && (
|
||||
<ExportModal
|
||||
open={showExports}
|
||||
onOpenChange={onOpenChange}
|
||||
conversation={conversation}
|
||||
triggerRef={exportButtonRef}
|
||||
aria-label={localize('com_ui_export_convo_modal')}
|
||||
/>
|
||||
)}
|
||||
<ExportModal
|
||||
open={showExports}
|
||||
onOpenChange={setShowExports}
|
||||
conversation={conversation}
|
||||
triggerRef={exportButtonRef}
|
||||
aria-label={localize('com_ui_export_convo_modal')}
|
||||
/>
|
||||
<ShareButton
|
||||
triggerRef={shareButtonRef}
|
||||
conversationId={conversation.conversationId ?? ''}
|
||||
title={conversation.title ?? ''}
|
||||
open={showShareDialog}
|
||||
onOpenChange={setShowShareDialog}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,73 +1,79 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useChatFormContext, useToastContext } from '~/Providers';
|
||||
import { ListeningIcon, Spinner } from '~/components/svg';
|
||||
import { useLocalize, useSpeechToText } from '~/hooks';
|
||||
import { useChatFormContext } from '~/Providers';
|
||||
import { TooltipAnchor } from '~/components/ui';
|
||||
import { globalAudioId } from '~/common';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
export default function AudioRecorder({
|
||||
textAreaRef,
|
||||
methods,
|
||||
ask,
|
||||
isRTL,
|
||||
disabled,
|
||||
ask,
|
||||
methods,
|
||||
textAreaRef,
|
||||
isSubmitting,
|
||||
}: {
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
methods: ReturnType<typeof useChatFormContext>;
|
||||
ask: (data: { text: string }) => void;
|
||||
isRTL: boolean;
|
||||
disabled: boolean;
|
||||
ask: (data: { text: string }) => void;
|
||||
methods: ReturnType<typeof useChatFormContext>;
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
isSubmitting: boolean;
|
||||
}) {
|
||||
const { setValue, reset } = methods;
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
|
||||
const handleTranscriptionComplete = (text: string) => {
|
||||
if (text) {
|
||||
const globalAudio = document.getElementById(globalAudioId) as HTMLAudioElement;
|
||||
if (globalAudio) {
|
||||
console.log('Unmuting global audio');
|
||||
globalAudio.muted = false;
|
||||
const onTranscriptionComplete = useCallback(
|
||||
(text: string) => {
|
||||
if (isSubmitting) {
|
||||
showToast({
|
||||
message: localize('com_ui_speech_while_submitting'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
ask({ text });
|
||||
methods.reset({ text: '' });
|
||||
clearText();
|
||||
}
|
||||
};
|
||||
if (text) {
|
||||
const globalAudio = document.getElementById(globalAudioId) as HTMLAudioElement | null;
|
||||
if (globalAudio) {
|
||||
console.log('Unmuting global audio');
|
||||
globalAudio.muted = false;
|
||||
}
|
||||
ask({ text });
|
||||
reset({ text: '' });
|
||||
}
|
||||
},
|
||||
[ask, reset, showToast, localize, isSubmitting],
|
||||
);
|
||||
|
||||
const {
|
||||
isListening,
|
||||
isLoading,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
interimTranscript,
|
||||
speechText,
|
||||
clearText,
|
||||
} = useSpeechToText(handleTranscriptionComplete);
|
||||
|
||||
useEffect(() => {
|
||||
if (isListening && textAreaRef.current) {
|
||||
methods.setValue('text', interimTranscript, {
|
||||
const setText = useCallback(
|
||||
(text: string) => {
|
||||
setValue('text', text, {
|
||||
shouldValidate: true,
|
||||
});
|
||||
} else if (textAreaRef.current) {
|
||||
textAreaRef.current.value = speechText;
|
||||
methods.setValue('text', speechText, { shouldValidate: true });
|
||||
}
|
||||
}, [interimTranscript, speechText, methods, textAreaRef]);
|
||||
},
|
||||
[setValue],
|
||||
);
|
||||
|
||||
const handleStartRecording = async () => {
|
||||
await startRecording();
|
||||
};
|
||||
const { isListening, isLoading, startRecording, stopRecording } = useSpeechToText(
|
||||
setText,
|
||||
onTranscriptionComplete,
|
||||
);
|
||||
|
||||
const handleStopRecording = async () => {
|
||||
await stopRecording();
|
||||
};
|
||||
if (!textAreaRef.current) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handleStartRecording = async () => startRecording();
|
||||
|
||||
const handleStopRecording = async () => stopRecording();
|
||||
|
||||
const renderIcon = () => {
|
||||
if (isListening) {
|
||||
if (isListening === true) {
|
||||
return <ListeningIcon className="stroke-red-500" />;
|
||||
}
|
||||
if (isLoading) {
|
||||
if (isLoading === true) {
|
||||
return <Spinner className="stroke-gray-700 dark:stroke-gray-300" />;
|
||||
}
|
||||
return <ListeningIcon className="stroke-gray-700 dark:stroke-gray-300" />;
|
||||
@@ -77,7 +83,7 @@ export default function AudioRecorder({
|
||||
<TooltipAnchor
|
||||
id="audio-recorder"
|
||||
aria-label={localize('com_ui_use_micrphone')}
|
||||
onClick={isListening ? handleStopRecording : handleStartRecording}
|
||||
onClick={isListening === true ? handleStopRecording : handleStartRecording}
|
||||
disabled={disabled}
|
||||
className={cn(
|
||||
'absolute flex size-[35px] items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover',
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { memo, useRef, useMemo, useEffect } from 'react';
|
||||
import { memo, useRef, useMemo, useEffect, useState } from 'react';
|
||||
import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import {
|
||||
supportsFiles,
|
||||
@@ -20,14 +20,15 @@ import {
|
||||
useQueryParams,
|
||||
useSubmitMessage,
|
||||
} from '~/hooks';
|
||||
import { cn, removeFocusRings, checkIfScrollable } from '~/utils';
|
||||
import FileFormWrapper from './Files/FileFormWrapper';
|
||||
import { TextareaAutosize } from '~/components/ui';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import TextareaHeader from './TextareaHeader';
|
||||
import PromptsCommand from './PromptsCommand';
|
||||
import AudioRecorder from './AudioRecorder';
|
||||
import { mainTextareaId } from '~/common';
|
||||
import CollapseChat from './CollapseChat';
|
||||
import StreamAudio from './StreamAudio';
|
||||
import StopButton from './StopButton';
|
||||
import SendButton from './SendButton';
|
||||
@@ -39,9 +40,13 @@ const ChatForm = ({ index = 0 }) => {
|
||||
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
useQueryParams({ textAreaRef });
|
||||
|
||||
const [isCollapsed, setIsCollapsed] = useState(false);
|
||||
const [isScrollable, setIsScrollable] = useState(false);
|
||||
|
||||
const SpeechToText = useRecoilValue(store.speechToText);
|
||||
const TextToSpeech = useRecoilValue(store.textToSpeech);
|
||||
const automaticPlayback = useRecoilValue(store.automaticPlayback);
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
|
||||
const isSearching = useRecoilValue(store.isSearching);
|
||||
const [showStopButton, setShowStopButton] = useRecoilState(store.showStopButtonByIndex(index));
|
||||
@@ -63,6 +68,7 @@ const ChatForm = ({ index = 0 }) => {
|
||||
const { handlePaste, handleKeyDown, handleCompositionStart, handleCompositionEnd } = useTextarea({
|
||||
textAreaRef,
|
||||
submitButtonRef,
|
||||
setIsScrollable,
|
||||
disabled: !!(requiresKey ?? false),
|
||||
});
|
||||
|
||||
@@ -128,11 +134,19 @@ const ChatForm = ({ index = 0 }) => {
|
||||
}
|
||||
}, [isSearching, disableInputs]);
|
||||
|
||||
useEffect(() => {
|
||||
if (textAreaRef.current) {
|
||||
checkIfScrollable(textAreaRef.current);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled: boolean = endpointFileConfig?.disabled ?? false;
|
||||
|
||||
const baseClasses =
|
||||
'md:py-3.5 m-0 w-full resize-none bg-surface-tertiary py-[13px] placeholder-black/50 dark:placeholder-white/50 [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)] max-h-[65vh] md:max-h-[75vh]';
|
||||
const baseClasses = cn(
|
||||
'md:py-3.5 m-0 w-full resize-none bg-surface-tertiary py-[13px] placeholder-black/50 dark:placeholder-white/50 [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]',
|
||||
isCollapsed ? 'max-h-[52px]' : 'max-h-[65vh] md:max-h-[75vh]',
|
||||
);
|
||||
|
||||
const uploadActive = endpointSupportsFiles && !isUploadDisabled;
|
||||
const speechClass = isRTL
|
||||
@@ -142,7 +156,10 @@ const ChatForm = ({ index = 0 }) => {
|
||||
return (
|
||||
<form
|
||||
onSubmit={methods.handleSubmit((data) => submitMessage(data))}
|
||||
className="stretch mx-2 flex flex-row gap-3 last:mb-2 md:mx-4 md:last:mb-6 lg:mx-auto lg:max-w-2xl xl:max-w-3xl"
|
||||
className={cn(
|
||||
'mx-auto flex flex-row gap-3 pl-2 transition-all duration-200 last:mb-2',
|
||||
maximizeChatSpace ? 'w-full max-w-full' : 'md:max-w-2xl xl:max-w-3xl',
|
||||
)}
|
||||
>
|
||||
<div className="relative flex h-full flex-1 items-stretch md:flex-col">
|
||||
<div className="flex w-full items-center">
|
||||
@@ -168,34 +185,55 @@ const ChatForm = ({ index = 0 }) => {
|
||||
<TextareaHeader addedConvo={addedConvo} setAddedConvo={setAddedConvo} />
|
||||
<FileFormWrapper disableInputs={disableInputs}>
|
||||
{endpoint && (
|
||||
<TextareaAutosize
|
||||
{...registerProps}
|
||||
ref={(e) => {
|
||||
ref(e);
|
||||
textAreaRef.current = e;
|
||||
}}
|
||||
disabled={disableInputs}
|
||||
onPaste={handlePaste}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
onCompositionStart={handleCompositionStart}
|
||||
onCompositionEnd={handleCompositionEnd}
|
||||
id={mainTextareaId}
|
||||
tabIndex={0}
|
||||
data-testid="text-input"
|
||||
style={{ height: 44, overflowY: 'auto' }}
|
||||
rows={1}
|
||||
className={cn(baseClasses, speechClass, removeFocusRings)}
|
||||
/>
|
||||
<>
|
||||
<CollapseChat
|
||||
isCollapsed={isCollapsed}
|
||||
isScrollable={isScrollable}
|
||||
setIsCollapsed={setIsCollapsed}
|
||||
/>
|
||||
<TextareaAutosize
|
||||
{...registerProps}
|
||||
ref={(e) => {
|
||||
ref(e);
|
||||
textAreaRef.current = e;
|
||||
}}
|
||||
disabled={disableInputs}
|
||||
onPaste={handlePaste}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
onHeightChange={() => {
|
||||
if (textAreaRef.current) {
|
||||
const scrollable = checkIfScrollable(textAreaRef.current);
|
||||
setIsScrollable(scrollable);
|
||||
}
|
||||
}}
|
||||
onCompositionStart={handleCompositionStart}
|
||||
onCompositionEnd={handleCompositionEnd}
|
||||
id={mainTextareaId}
|
||||
tabIndex={0}
|
||||
data-testid="text-input"
|
||||
rows={1}
|
||||
onFocus={() => isCollapsed && setIsCollapsed(false)}
|
||||
onClick={() => isCollapsed && setIsCollapsed(false)}
|
||||
style={{ height: 44, overflowY: 'auto' }}
|
||||
className={cn(
|
||||
baseClasses,
|
||||
speechClass,
|
||||
removeFocusRings,
|
||||
'transition-[max-height] duration-200',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</FileFormWrapper>
|
||||
{SpeechToText && (
|
||||
<AudioRecorder
|
||||
disabled={!!disableInputs}
|
||||
textAreaRef={textAreaRef}
|
||||
ask={submitMessage}
|
||||
isRTL={isRTL}
|
||||
methods={methods}
|
||||
ask={submitMessage}
|
||||
textAreaRef={textAreaRef}
|
||||
disabled={!!disableInputs}
|
||||
isSubmitting={isSubmitting}
|
||||
/>
|
||||
)}
|
||||
{TextToSpeech && automaticPlayback && <StreamAudio index={index} />}
|
||||
|
||||
41
client/src/components/Chat/Input/CollapseChat.tsx
Normal file
41
client/src/components/Chat/Input/CollapseChat.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
import React from 'react';
|
||||
import { Minimize2 } from 'lucide-react';
|
||||
import { TooltipAnchor } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
const CollapseChat = ({
|
||||
isScrollable,
|
||||
isCollapsed,
|
||||
setIsCollapsed,
|
||||
}: {
|
||||
isScrollable: boolean;
|
||||
isCollapsed: boolean;
|
||||
setIsCollapsed: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
if (!isScrollable) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isCollapsed) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<TooltipAnchor
|
||||
role="button"
|
||||
description={localize('com_ui_collapse_chat')}
|
||||
aria-label={localize('com_ui_collapse_chat')}
|
||||
onClick={() => setIsCollapsed(true)}
|
||||
className={cn(
|
||||
'absolute right-2 top-2 z-10 size-[35px] rounded-full p-2 transition-colors',
|
||||
'hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
)}
|
||||
>
|
||||
<Minimize2 className="h-full w-full" />
|
||||
</TooltipAnchor>
|
||||
);
|
||||
};
|
||||
|
||||
export default CollapseChat;
|
||||
@@ -26,7 +26,7 @@ const AttachFile = ({
|
||||
disabled={isUploadDisabled}
|
||||
className={cn(
|
||||
'absolute flex size-[35px] items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
isRTL ? 'bottom-2 right-2' : 'bottom-2 left-1 md:left-2',
|
||||
isRTL ? 'bottom-2 right-2' : 'bottom-2 left-2',
|
||||
)}
|
||||
description={localize('com_sidepanel_attach_files')}
|
||||
onKeyDownCapture={(e) => {
|
||||
|
||||
@@ -11,15 +11,15 @@ import { cn } from '~/utils';
|
||||
interface AttachFileProps {
|
||||
isRTL: boolean;
|
||||
disabled?: boolean | null;
|
||||
handleFileChange: (event: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
setToolResource?: React.Dispatch<React.SetStateAction<string | undefined>>;
|
||||
handleFileChange: (event: React.ChangeEvent<HTMLInputElement>, toolResource?: string) => void;
|
||||
}
|
||||
|
||||
const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: AttachFileProps) => {
|
||||
const AttachFile = ({ isRTL, disabled, handleFileChange }: AttachFileProps) => {
|
||||
const localize = useLocalize();
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
|
||||
const capabilities = useMemo(
|
||||
@@ -42,7 +42,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
|
||||
{
|
||||
label: localize('com_ui_upload_image_input'),
|
||||
onClick: () => {
|
||||
setToolResource?.(undefined);
|
||||
setToolResource(undefined);
|
||||
handleUploadClick(true);
|
||||
},
|
||||
icon: <ImageUpIcon className="icon-md" />,
|
||||
@@ -53,7 +53,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
|
||||
items.push({
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
onClick: () => {
|
||||
setToolResource?.(EToolResources.file_search);
|
||||
setToolResource(EToolResources.file_search);
|
||||
handleUploadClick();
|
||||
},
|
||||
icon: <FileSearch className="icon-md" />,
|
||||
@@ -64,7 +64,7 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
|
||||
items.push({
|
||||
label: localize('com_ui_upload_code_files'),
|
||||
onClick: () => {
|
||||
setToolResource?.(EToolResources.execute_code);
|
||||
setToolResource(EToolResources.execute_code);
|
||||
handleUploadClick();
|
||||
},
|
||||
icon: <TerminalSquareIcon className="icon-md" />,
|
||||
@@ -98,7 +98,12 @@ const AttachFile = ({ isRTL, disabled, setToolResource, handleFileChange }: Atta
|
||||
);
|
||||
|
||||
return (
|
||||
<FileUpload ref={inputRef} handleFileChange={handleFileChange}>
|
||||
<FileUpload
|
||||
ref={inputRef}
|
||||
handleFileChange={(e) => {
|
||||
handleFileChange(e, toolResource);
|
||||
}}
|
||||
>
|
||||
<div className="relative select-none">
|
||||
<DropdownPopup
|
||||
menuId="attach-file-menu"
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
export default function DragDropOverlay() {
|
||||
return (
|
||||
<div className="absolute inset-0 flex flex-col items-center justify-center gap-2 bg-gray-200 opacity-80 dark:bg-gray-800 dark:text-gray-200">
|
||||
<div
|
||||
className="bg-surface-primary/85 fixed inset-0 z-[9999] flex flex-col items-center justify-center
|
||||
gap-2 text-text-primary
|
||||
backdrop-blur-[4px] transition-all duration-200
|
||||
ease-in-out animate-in fade-in
|
||||
zoom-in-95 hover:backdrop-blur-sm"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 132 108"
|
||||
@@ -50,7 +56,7 @@ export default function DragDropOverlay() {
|
||||
</defs>
|
||||
</svg>
|
||||
<h3>Add anything</h3>
|
||||
<h4 className="w-2/3">Drop any file here to add it to the conversation</h4>
|
||||
<h4>Drop any file here to add it to the conversation</h4>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -15,13 +15,17 @@ const FileContainer = ({
|
||||
|
||||
return (
|
||||
<div className="group relative inline-block text-sm text-text-primary">
|
||||
<div className="relative overflow-hidden rounded-xl border border-border-medium">
|
||||
<div className="w-60 bg-surface-active p-2">
|
||||
<div className="relative overflow-hidden rounded-2xl border border-border-light">
|
||||
<div className="w-56 bg-surface-hover-alt p-1.5">
|
||||
<div className="flex flex-row items-center gap-2">
|
||||
<FilePreview file={file} fileType={fileType} className="relative" />
|
||||
<div className="overflow-hidden">
|
||||
<div className="truncate font-medium">{file.filename}</div>
|
||||
<div className="truncate text-text-secondary">{fileType.title}</div>
|
||||
<div className="truncate font-medium" title={file.filename}>
|
||||
{file.filename}
|
||||
</div>
|
||||
<div className="truncate text-text-secondary" title={fileType.title}>
|
||||
{fileType.title}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -27,7 +27,7 @@ function FileFormWrapper({
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]);
|
||||
|
||||
const { handleFileChange, abortUpload, setToolResource } = useFileHandling();
|
||||
const { handleFileChange, abortUpload } = useFileHandling();
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
@@ -48,7 +48,6 @@ function FileFormWrapper({
|
||||
<AttachFileMenu
|
||||
isRTL={isRTL}
|
||||
disabled={disableInputs}
|
||||
setToolResource={setToolResource}
|
||||
handleFileChange={handleFileChange}
|
||||
/>
|
||||
);
|
||||
@@ -70,9 +69,7 @@ function FileFormWrapper({
|
||||
abortUpload={abortUpload}
|
||||
setFilesLoading={setFilesLoading}
|
||||
isRTL={isRTL}
|
||||
Wrapper={({ children }) => (
|
||||
<div className="mx-2 mt-2 flex flex-wrap gap-2 px-2.5 md:pl-0 md:pr-4">{children}</div>
|
||||
)}
|
||||
Wrapper={({ children }) => <div className="mx-2 mt-2 flex flex-wrap gap-2">{children}</div>}
|
||||
/>
|
||||
{children}
|
||||
{renderAttachFile()}
|
||||
|
||||
@@ -21,7 +21,11 @@ const FilePreview = ({
|
||||
}) => {
|
||||
const radius = 55; // Radius of the SVG circle
|
||||
const circumference = 2 * Math.PI * radius;
|
||||
const progress = useProgress(file?.['progress'] ?? 1, 0.001, (file as ExtendedFile)?.size ?? 1);
|
||||
const progress = useProgress(
|
||||
file?.['progress'] ?? 1,
|
||||
0.001,
|
||||
(file as ExtendedFile | undefined)?.size ?? 1,
|
||||
);
|
||||
|
||||
// Calculate the offset based on the loading progress
|
||||
const offset = circumference - progress * circumference;
|
||||
@@ -30,7 +34,7 @@ const FilePreview = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn('h-10 w-10 shrink-0 overflow-hidden rounded-md', className)}>
|
||||
<div className={cn('size-10 shrink-0 overflow-hidden rounded-xl', className)}>
|
||||
<FileIcon file={file} fileType={fileType} />
|
||||
<SourceIcon source={file?.source} />
|
||||
{progress < 1 && (
|
||||
|
||||
@@ -73,8 +73,22 @@ export default function FileRow({
|
||||
}
|
||||
|
||||
const renderFiles = () => {
|
||||
// Inline style for RTL
|
||||
const rowStyle = isRTL ? { display: 'flex', flexDirection: 'row-reverse' } : {};
|
||||
const rowStyle = isRTL
|
||||
? {
|
||||
display: 'flex',
|
||||
flexDirection: 'row-reverse',
|
||||
flexWrap: 'wrap',
|
||||
gap: '4px',
|
||||
width: '100%',
|
||||
maxWidth: '100%',
|
||||
}
|
||||
: {
|
||||
display: 'flex',
|
||||
flexWrap: 'wrap',
|
||||
gap: '4px',
|
||||
width: '100%',
|
||||
maxWidth: '100%',
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={rowStyle as React.CSSProperties}>
|
||||
@@ -97,18 +111,28 @@ export default function FileRow({
|
||||
deleteFile({ file, setFiles });
|
||||
};
|
||||
const isImage = file.type?.startsWith('image') ?? false;
|
||||
if (isImage) {
|
||||
return (
|
||||
<Image
|
||||
key={index}
|
||||
url={file.preview ?? file.filepath}
|
||||
onDelete={handleDelete}
|
||||
progress={file.progress}
|
||||
source={file.source}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return <FileContainer key={index} file={file} onDelete={handleDelete} />;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
flexBasis: '70px',
|
||||
flexGrow: 0,
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
{isImage ? (
|
||||
<Image
|
||||
url={file.preview ?? file.filepath}
|
||||
onDelete={handleDelete}
|
||||
progress={file.progress}
|
||||
source={file.source}
|
||||
/>
|
||||
) : (
|
||||
<FileContainer file={file} onDelete={handleDelete} />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -21,7 +21,7 @@ export default function Files({ open, onOpenChange }) {
|
||||
<OGDialog open={open} onOpenChange={onOpenChange}>
|
||||
<OGDialogContent
|
||||
title={localize('com_nav_my_files')}
|
||||
className="w-11/12 overflow-x-auto bg-background text-text-primary shadow-2xl"
|
||||
className="w-11/12 bg-background text-text-primary shadow-2xl"
|
||||
>
|
||||
<OGDialogHeader>
|
||||
<OGDialogTitle>{localize('com_nav_my_files')}</OGDialogTitle>
|
||||
|
||||
@@ -17,7 +17,7 @@ const Image = ({
|
||||
}) => {
|
||||
return (
|
||||
<div className="group relative inline-block text-sm text-black/70 dark:text-white/90">
|
||||
<div className="relative overflow-hidden rounded-xl border border-gray-200 dark:border-gray-600">
|
||||
<div className="relative overflow-hidden rounded-2xl border border-gray-200 dark:border-gray-600">
|
||||
<ImagePreview source={source} imageBase64={imageBase64} url={url} progress={progress} />
|
||||
</div>
|
||||
<RemoveFile onRemove={onDelete} />
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { Maximize2 } from 'lucide-react';
|
||||
import { OGDialog, OGDialogContent } from '~/components/ui';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import ProgressCircle from './ProgressCircle';
|
||||
import SourceIcon from './SourceIcon';
|
||||
@@ -10,67 +13,165 @@ type styleProps = {
|
||||
backgroundRepeat?: string;
|
||||
};
|
||||
|
||||
interface CloseModalEvent {
|
||||
stopPropagation: () => void;
|
||||
preventDefault: () => void;
|
||||
}
|
||||
|
||||
const ImagePreview = ({
|
||||
imageBase64,
|
||||
url,
|
||||
progress = 1,
|
||||
className = '',
|
||||
source,
|
||||
alt = 'Preview image',
|
||||
}: {
|
||||
imageBase64?: string;
|
||||
url?: string;
|
||||
progress?: number; // between 0 and 1
|
||||
progress?: number;
|
||||
className?: string;
|
||||
source?: FileSources;
|
||||
alt?: string;
|
||||
}) => {
|
||||
let style: styleProps = {
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const [previousActiveElement, setPreviousActiveElement] = useState<Element | null>(null);
|
||||
|
||||
const openModal = useCallback(() => {
|
||||
setPreviousActiveElement(document.activeElement);
|
||||
setIsModalOpen(true);
|
||||
}, []);
|
||||
|
||||
const closeModal = useCallback(
|
||||
(e: CloseModalEvent): void => {
|
||||
setIsModalOpen(false);
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
|
||||
if (
|
||||
previousActiveElement instanceof HTMLElement &&
|
||||
!previousActiveElement.closest('[data-skip-refocus="true"]')
|
||||
) {
|
||||
previousActiveElement.focus();
|
||||
}
|
||||
},
|
||||
[previousActiveElement],
|
||||
);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent) => {
|
||||
if (e.key === 'Escape') {
|
||||
closeModal(e);
|
||||
}
|
||||
},
|
||||
[closeModal],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (isModalOpen) {
|
||||
document.addEventListener('keydown', handleKeyDown);
|
||||
document.body.style.overflow = 'hidden';
|
||||
const closeButton = document.querySelector('[aria-label="Close full view"]') as HTMLElement;
|
||||
if (closeButton) {
|
||||
setTimeout(() => closeButton.focus(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown);
|
||||
document.body.style.overflow = 'unset';
|
||||
};
|
||||
}, [isModalOpen, handleKeyDown]);
|
||||
|
||||
const baseStyle: styleProps = {
|
||||
backgroundSize: 'cover',
|
||||
backgroundPosition: 'center',
|
||||
backgroundRepeat: 'no-repeat',
|
||||
};
|
||||
if (imageBase64) {
|
||||
style = {
|
||||
...style,
|
||||
backgroundImage: `url(${imageBase64})`,
|
||||
};
|
||||
} else if (url) {
|
||||
style = {
|
||||
...style,
|
||||
backgroundImage: `url(${url})`,
|
||||
};
|
||||
}
|
||||
|
||||
if (!style.backgroundImage) {
|
||||
const imageUrl = imageBase64 ?? url ?? '';
|
||||
|
||||
const style: styleProps = imageUrl
|
||||
? {
|
||||
...baseStyle,
|
||||
backgroundImage: `url(${imageUrl})`,
|
||||
}
|
||||
: baseStyle;
|
||||
|
||||
if (typeof style.backgroundImage !== 'string' || style.backgroundImage.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const radius = 55; // Radius of the SVG circle
|
||||
const radius = 55;
|
||||
const circumference = 2 * Math.PI * radius;
|
||||
|
||||
// Calculate the offset based on the loading progress
|
||||
const offset = circumference - progress * circumference;
|
||||
const circleCSSProperties = {
|
||||
transition: 'stroke-dashoffset 0.3s linear',
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={cn('h-14 w-14', className)}>
|
||||
<button
|
||||
type="button"
|
||||
aria-haspopup="dialog"
|
||||
aria-expanded="false"
|
||||
className="h-full w-full"
|
||||
style={style}
|
||||
/>
|
||||
{progress < 1 && (
|
||||
<ProgressCircle
|
||||
circumference={circumference}
|
||||
offset={offset}
|
||||
circleCSSProperties={circleCSSProperties}
|
||||
<>
|
||||
<div
|
||||
className={cn('relative size-14 rounded-xl', className)}
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
className="size-full overflow-hidden rounded-xl"
|
||||
style={style}
|
||||
aria-label={`View ${alt} in full size`}
|
||||
aria-haspopup="dialog"
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
openModal();
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<SourceIcon source={source} />
|
||||
</div>
|
||||
{progress < 1 ? (
|
||||
<ProgressCircle
|
||||
circumference={circumference}
|
||||
offset={offset}
|
||||
circleCSSProperties={circleCSSProperties}
|
||||
aria-label={`Loading progress: ${Math.round(progress * 100)}%`}
|
||||
/>
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
'absolute inset-0 flex transform-gpu cursor-pointer items-center justify-center rounded-xl transition-opacity duration-200 ease-in-out',
|
||||
isHovered ? 'bg-black/20 opacity-100' : 'opacity-0',
|
||||
)}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
openModal();
|
||||
}}
|
||||
aria-hidden="true"
|
||||
>
|
||||
<Maximize2
|
||||
className={cn(
|
||||
'size-5 transform-gpu text-white drop-shadow-lg transition-all duration-200',
|
||||
isHovered ? 'scale-110' : '',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<SourceIcon source={source} aria-label={source ? `Source: ${source}` : undefined} />
|
||||
</div>
|
||||
|
||||
<OGDialog open={isModalOpen} onOpenChange={setIsModalOpen}>
|
||||
<OGDialogContent
|
||||
showCloseButton={false}
|
||||
className={cn('w-11/12 overflow-x-auto bg-transparent p-0 sm:w-auto')}
|
||||
disableScroll={false}
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt={alt}
|
||||
className="max-w-screen h-full max-h-screen w-full object-contain"
|
||||
/>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ export default function RemoveFile({ onRemove }: { onRemove: () => void }) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className="absolute right-1 top-1 -translate-y-1/2 translate-x-1/2 rounded-full border border-gray-500 bg-gray-500 p-0.5 text-white transition-colors hover:bg-gray-700 hover:opacity-100 group-hover:opacity-100 md:opacity-0"
|
||||
className="absolute right-1 top-1 -translate-y-1/2 translate-x-1/2 rounded-full bg-surface-secondary p-0.5 transition-colors duration-200 hover:bg-surface-primary z-50"
|
||||
onClick={onRemove}
|
||||
>
|
||||
<span>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user