Compare commits
25 Commits
v0.8.1-rc2
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f55bd6f99b | ||
|
|
754b495fb8 | ||
|
|
2d536dd0fa | ||
|
|
711d21365d | ||
|
|
8bdc808074 | ||
|
|
b2387cc6fa | ||
|
|
28bdd0dfa6 | ||
|
|
1477da4987 | ||
|
|
ef5540f278 | ||
|
|
745c299563 | ||
|
|
3b35fa53d9 | ||
|
|
01413eea3d | ||
|
|
6fa94d3eb8 | ||
|
|
4202db1c99 | ||
|
|
026890cd27 | ||
|
|
6c0aad423f | ||
|
|
774ebd1eaa | ||
|
|
d5d362e52b | ||
|
|
d7ce19e15a | ||
|
|
2ccaf6be6d | ||
|
|
90f0bcde44 | ||
|
|
801c95a829 | ||
|
|
872dbb4151 | ||
|
|
cb2bee19b7 | ||
|
|
961d3b1d3b |
66
.github/workflows/dev-staging-images.yml
vendored
Normal file
66
.github/workflows/dev-staging-images.yml
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
name: Docker Dev Staging Images Build
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- target: api-build
|
||||
file: Dockerfile.multi
|
||||
image_name: lc-dev-staging-api
|
||||
- target: node
|
||||
file: Dockerfile
|
||||
image_name: lc-dev-staging
|
||||
|
||||
steps:
|
||||
# Check out the repository
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# Set up QEMU
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
# Set up Docker Buildx
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in to GitHub Container Registry
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Login to Docker Hub
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Prepare the environment
|
||||
- name: Prepare environment
|
||||
run: |
|
||||
cp .env.example .env
|
||||
|
||||
# Build and push Docker images for each target
|
||||
- name: Build and push Docker images
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ${{ matrix.file }}
|
||||
push: true
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.sha }}
|
||||
ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.sha }}
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
target: ${{ matrix.target }}
|
||||
|
||||
@@ -2,6 +2,7 @@ const crypto = require('crypto');
|
||||
const fetch = require('node-fetch');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
countTokens,
|
||||
getBalanceConfig,
|
||||
extractFileContext,
|
||||
encodeAndFormatAudios,
|
||||
@@ -23,7 +24,6 @@ const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { truncateToolCallOutputs } = require('./prompts');
|
||||
const countTokens = require('~/server/utils/countTokens');
|
||||
const { getFiles } = require('~/models/File');
|
||||
const TextStream = require('./TextStream');
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { z } = require('zod');
|
||||
const { ProxyAgent, fetch } = require('undici');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { getApiKey } = require('./credentials');
|
||||
|
||||
@@ -19,13 +20,19 @@ function createTavilySearchTool(fields = {}) {
|
||||
...kwargs,
|
||||
};
|
||||
|
||||
const response = await fetch('https://api.tavily.com/search', {
|
||||
const fetchOptions = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
fetchOptions.dispatcher = new ProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const response = await fetch('https://api.tavily.com/search', fetchOptions);
|
||||
|
||||
const json = await response.json();
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { z } = require('zod');
|
||||
const { ProxyAgent, fetch } = require('undici');
|
||||
const { Tool } = require('@langchain/core/tools');
|
||||
const { getEnvironmentVariable } = require('@langchain/core/utils/env');
|
||||
|
||||
@@ -102,13 +103,19 @@ class TavilySearchResults extends Tool {
|
||||
...this.kwargs,
|
||||
};
|
||||
|
||||
const response = await fetch('https://api.tavily.com/search', {
|
||||
const fetchOptions = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
fetchOptions.dispatcher = new ProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const response = await fetch('https://api.tavily.com/search', fetchOptions);
|
||||
|
||||
const json = await response.json();
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { fetch, ProxyAgent } = require('undici');
|
||||
const TavilySearchResults = require('../TavilySearchResults');
|
||||
|
||||
jest.mock('node-fetch');
|
||||
jest.mock('undici');
|
||||
jest.mock('@langchain/core/utils/env');
|
||||
|
||||
describe('TavilySearchResults', () => {
|
||||
@@ -13,6 +14,7 @@ describe('TavilySearchResults', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
jest.clearAllMocks();
|
||||
process.env = {
|
||||
...originalEnv,
|
||||
TAVILY_API_KEY: mockApiKey,
|
||||
@@ -20,7 +22,6 @@ describe('TavilySearchResults', () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
@@ -35,4 +36,49 @@ describe('TavilySearchResults', () => {
|
||||
});
|
||||
expect(instance.apiKey).toBe(mockApiKey);
|
||||
});
|
||||
|
||||
describe('proxy support', () => {
|
||||
const mockResponse = {
|
||||
ok: true,
|
||||
json: jest.fn().mockResolvedValue({ results: [] }),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
fetch.mockResolvedValue(mockResponse);
|
||||
});
|
||||
|
||||
it('should use ProxyAgent when PROXY env var is set', async () => {
|
||||
const proxyUrl = 'http://proxy.example.com:8080';
|
||||
process.env.PROXY = proxyUrl;
|
||||
|
||||
const mockProxyAgent = { type: 'proxy-agent' };
|
||||
ProxyAgent.mockImplementation(() => mockProxyAgent);
|
||||
|
||||
const instance = new TavilySearchResults({ TAVILY_API_KEY: mockApiKey });
|
||||
await instance._call({ query: 'test query' });
|
||||
|
||||
expect(ProxyAgent).toHaveBeenCalledWith(proxyUrl);
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://api.tavily.com/search',
|
||||
expect.objectContaining({
|
||||
dispatcher: mockProxyAgent,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not use ProxyAgent when PROXY env var is not set', async () => {
|
||||
delete process.env.PROXY;
|
||||
|
||||
const instance = new TavilySearchResults({ TAVILY_API_KEY: mockApiKey });
|
||||
await instance._call({ query: 'test query' });
|
||||
|
||||
expect(ProxyAgent).not.toHaveBeenCalled();
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://api.tavily.com/search',
|
||||
expect.not.objectContaining({
|
||||
dispatcher: expect.anything(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { escapeRegExp } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Constants,
|
||||
@@ -14,7 +15,6 @@ const {
|
||||
} = require('./Project');
|
||||
const { removeAllPermissions } = require('~/server/services/PermissionService');
|
||||
const { PromptGroup, Prompt, AclEntry } = require('~/db/models');
|
||||
const { escapeRegExp } = require('~/server/utils');
|
||||
|
||||
/**
|
||||
* Create a pipeline for the aggregation to get prompt groups
|
||||
|
||||
@@ -141,6 +141,7 @@ const tokenValues = Object.assign(
|
||||
'command-r': { prompt: 0.5, completion: 1.5 },
|
||||
'command-r-plus': { prompt: 3, completion: 15 },
|
||||
'command-text': { prompt: 1.5, completion: 2.0 },
|
||||
'deepseek-chat': { prompt: 0.28, completion: 0.42 },
|
||||
'deepseek-reasoner': { prompt: 0.28, completion: 0.42 },
|
||||
'deepseek-r1': { prompt: 0.4, completion: 2.0 },
|
||||
'deepseek-v3': { prompt: 0.2, completion: 0.8 },
|
||||
@@ -173,6 +174,9 @@ const tokenValues = Object.assign(
|
||||
'grok-3-mini': { prompt: 0.3, completion: 0.5 },
|
||||
'grok-3-mini-fast': { prompt: 0.6, completion: 4 },
|
||||
'grok-4': { prompt: 3.0, completion: 15.0 },
|
||||
'grok-4-fast': { prompt: 0.2, completion: 0.5 },
|
||||
'grok-4-1-fast': { prompt: 0.2, completion: 0.5 }, // covers reasoning & non-reasoning variants
|
||||
'grok-code-fast': { prompt: 0.2, completion: 1.5 },
|
||||
codestral: { prompt: 0.3, completion: 0.9 },
|
||||
'ministral-3b': { prompt: 0.04, completion: 0.04 },
|
||||
'ministral-8b': { prompt: 0.1, completion: 0.1 },
|
||||
@@ -243,6 +247,10 @@ const cacheTokenValues = {
|
||||
'claude-sonnet-4': { write: 3.75, read: 0.3 },
|
||||
'claude-opus-4': { write: 18.75, read: 1.5 },
|
||||
'claude-opus-4-5': { write: 6.25, read: 0.5 },
|
||||
// DeepSeek models - cache hit: $0.028/1M, cache miss: $0.28/1M
|
||||
deepseek: { write: 0.28, read: 0.028 },
|
||||
'deepseek-chat': { write: 0.28, read: 0.028 },
|
||||
'deepseek-reasoner': { write: 0.28, read: 0.028 },
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -766,6 +766,78 @@ describe('Deepseek Model Tests', () => {
|
||||
const result = tokenValues[valueKey].prompt && multiplier === tokenValues[valueKey].prompt;
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return correct pricing for deepseek-chat', () => {
|
||||
expect(getMultiplier({ model: 'deepseek-chat', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['deepseek-chat'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'deepseek-chat', tokenType: 'completion' })).toBe(
|
||||
tokenValues['deepseek-chat'].completion,
|
||||
);
|
||||
expect(tokenValues['deepseek-chat'].prompt).toBe(0.28);
|
||||
expect(tokenValues['deepseek-chat'].completion).toBe(0.42);
|
||||
});
|
||||
|
||||
it('should return correct pricing for deepseek-reasoner', () => {
|
||||
expect(getMultiplier({ model: 'deepseek-reasoner', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['deepseek-reasoner'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'deepseek-reasoner', tokenType: 'completion' })).toBe(
|
||||
tokenValues['deepseek-reasoner'].completion,
|
||||
);
|
||||
expect(tokenValues['deepseek-reasoner'].prompt).toBe(0.28);
|
||||
expect(tokenValues['deepseek-reasoner'].completion).toBe(0.42);
|
||||
});
|
||||
|
||||
it('should handle DeepSeek model name variations with provider prefixes', () => {
|
||||
const modelVariations = [
|
||||
'deepseek/deepseek-chat',
|
||||
'openrouter/deepseek-chat',
|
||||
'deepseek/deepseek-reasoner',
|
||||
];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' });
|
||||
const completionMultiplier = getMultiplier({ model, tokenType: 'completion' });
|
||||
expect(promptMultiplier).toBe(0.28);
|
||||
expect(completionMultiplier).toBe(0.42);
|
||||
});
|
||||
});
|
||||
|
||||
it('should return correct cache multipliers for DeepSeek models', () => {
|
||||
expect(getCacheMultiplier({ model: 'deepseek-chat', cacheType: 'write' })).toBe(
|
||||
cacheTokenValues['deepseek-chat'].write,
|
||||
);
|
||||
expect(getCacheMultiplier({ model: 'deepseek-chat', cacheType: 'read' })).toBe(
|
||||
cacheTokenValues['deepseek-chat'].read,
|
||||
);
|
||||
expect(getCacheMultiplier({ model: 'deepseek-reasoner', cacheType: 'write' })).toBe(
|
||||
cacheTokenValues['deepseek-reasoner'].write,
|
||||
);
|
||||
expect(getCacheMultiplier({ model: 'deepseek-reasoner', cacheType: 'read' })).toBe(
|
||||
cacheTokenValues['deepseek-reasoner'].read,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return correct cache pricing values for DeepSeek models', () => {
|
||||
expect(cacheTokenValues['deepseek-chat'].write).toBe(0.28);
|
||||
expect(cacheTokenValues['deepseek-chat'].read).toBe(0.028);
|
||||
expect(cacheTokenValues['deepseek-reasoner'].write).toBe(0.28);
|
||||
expect(cacheTokenValues['deepseek-reasoner'].read).toBe(0.028);
|
||||
expect(cacheTokenValues['deepseek'].write).toBe(0.28);
|
||||
expect(cacheTokenValues['deepseek'].read).toBe(0.028);
|
||||
});
|
||||
|
||||
it('should handle DeepSeek cache multipliers with model variations', () => {
|
||||
const modelVariations = ['deepseek/deepseek-chat', 'openrouter/deepseek-reasoner'];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' });
|
||||
const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' });
|
||||
expect(writeMultiplier).toBe(0.28);
|
||||
expect(readMultiplier).toBe(0.028);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Qwen3 Model Tests', () => {
|
||||
@@ -1205,6 +1277,39 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 Fast model', () => {
|
||||
expect(getMultiplier({ model: 'grok-4-fast', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-fast', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4.1 Fast models', () => {
|
||||
expect(getMultiplier({ model: 'grok-4-1-fast-reasoning', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-1-fast-reasoning', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-1-fast-non-reasoning', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-4-1-fast-non-reasoning', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok Code Fast model', () => {
|
||||
expect(getMultiplier({ model: 'grok-code-fast-1', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-code-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'grok-code-fast-1', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-code-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-3'].prompt,
|
||||
@@ -1240,6 +1345,39 @@ describe('Grok Model Tests - Pricing', () => {
|
||||
tokenValues['grok-4'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4 Fast model with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-4-fast', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-4-fast', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4-fast'].completion,
|
||||
);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok 4.1 Fast models with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-4-1-fast-reasoning', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-4-1-fast-reasoning', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].completion,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-4-1-fast-non-reasoning', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-4-1-fast'].prompt,
|
||||
);
|
||||
expect(
|
||||
getMultiplier({ model: 'xai/grok-4-1-fast-non-reasoning', tokenType: 'completion' }),
|
||||
).toBe(tokenValues['grok-4-1-fast'].completion);
|
||||
});
|
||||
|
||||
test('should return correct prompt and completion rates for Grok Code Fast model with prefixes', () => {
|
||||
expect(getMultiplier({ model: 'xai/grok-code-fast-1', tokenType: 'prompt' })).toBe(
|
||||
tokenValues['grok-code-fast'].prompt,
|
||||
);
|
||||
expect(getMultiplier({ model: 'xai/grok-code-fast-1', tokenType: 'completion' })).toBe(
|
||||
tokenValues['grok-code-fast'].completion,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
"@langchain/google-genai": "^0.2.13",
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^3.0.32",
|
||||
"@librechat/agents": "^3.0.36",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
@@ -92,7 +92,7 @@
|
||||
"multer": "^2.0.2",
|
||||
"nanoid": "^3.3.7",
|
||||
"node-fetch": "^2.7.0",
|
||||
"nodemailer": "^7.0.9",
|
||||
"nodemailer": "^7.0.11",
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "5.8.2",
|
||||
"openid-client": "^6.5.0",
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
sendEvent,
|
||||
sanitizeFileForTransmit,
|
||||
sanitizeMessageForTransmit,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
@@ -224,13 +228,13 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
// Process files if needed
|
||||
// Process files if needed (sanitize to remove large text fields before transmission)
|
||||
if (req.body.files && client.options?.attachments) {
|
||||
userMessage.files = [];
|
||||
const messageFiles = new Set(req.body.files.map((file) => file.file_id));
|
||||
for (let attachment of client.options.attachments) {
|
||||
for (const attachment of client.options.attachments) {
|
||||
if (messageFiles.has(attachment.file_id)) {
|
||||
userMessage.files.push({ ...attachment });
|
||||
userMessage.files.push(sanitizeFileForTransmit(attachment));
|
||||
}
|
||||
}
|
||||
delete userMessage.image_urls;
|
||||
@@ -245,7 +249,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: userMessage,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: finalResponse,
|
||||
});
|
||||
res.end();
|
||||
@@ -273,7 +277,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: userMessage,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: finalResponse,
|
||||
error: { message: 'Request was aborted during completion' },
|
||||
});
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens, countTokens } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -33,7 +33,6 @@ const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const { v4 } = require('uuid');
|
||||
const { sleep } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens } = require('@librechat/api');
|
||||
const { sendEvent, getBalanceConfig, getModelMaxTokens, countTokens } = require('@librechat/api');
|
||||
const {
|
||||
Time,
|
||||
Constants,
|
||||
@@ -30,7 +30,6 @@ const { getTransactions } = require('~/models/Transaction');
|
||||
const { checkBalance } = require('~/models/balanceMethods');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { getOpenAIClient } = require('./helpers');
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,6 +16,7 @@ const {
|
||||
isEnabled,
|
||||
ErrorController,
|
||||
performStartupChecks,
|
||||
handleJsonParseError,
|
||||
initializeFileStorage,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
@@ -245,6 +246,7 @@ if (cluster.isMaster) {
|
||||
app.use(noIndex);
|
||||
app.use(express.json({ limit: '3mb' }));
|
||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||
app.use(handleJsonParseError);
|
||||
app.use(mongoSanitize());
|
||||
app.use(cors());
|
||||
app.use(cookieParser());
|
||||
@@ -290,7 +292,6 @@ if (cluster.isMaster) {
|
||||
app.use('/api/presets', routes.presets);
|
||||
app.use('/api/prompts', routes.prompts);
|
||||
app.use('/api/categories', routes.categories);
|
||||
app.use('/api/tokenizer', routes.tokenizer);
|
||||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
app.use('/api/models', routes.models);
|
||||
|
||||
@@ -14,6 +14,7 @@ const {
|
||||
isEnabled,
|
||||
ErrorController,
|
||||
performStartupChecks,
|
||||
handleJsonParseError,
|
||||
initializeFileStorage,
|
||||
} = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
@@ -81,6 +82,7 @@ const startServer = async () => {
|
||||
app.use(noIndex);
|
||||
app.use(express.json({ limit: '3mb' }));
|
||||
app.use(express.urlencoded({ extended: true, limit: '3mb' }));
|
||||
app.use(handleJsonParseError);
|
||||
app.use(mongoSanitize());
|
||||
app.use(cors());
|
||||
app.use(cookieParser());
|
||||
@@ -126,7 +128,6 @@ const startServer = async () => {
|
||||
app.use('/api/presets', routes.presets);
|
||||
app.use('/api/prompts', routes.prompts);
|
||||
app.use('/api/categories', routes.categories);
|
||||
app.use('/api/tokenizer', routes.tokenizer);
|
||||
app.use('/api/endpoints', routes.endpoints);
|
||||
app.use('/api/balance', routes.balance);
|
||||
app.use('/api/models', routes.models);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
|
||||
const { countTokens, isEnabled, sendEvent, sanitizeMessageForTransmit } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
@@ -290,7 +290,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
responseMessage: responseMessage,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -61,18 +61,24 @@ async function buildEndpointOption(req, res, next) {
|
||||
|
||||
try {
|
||||
currentModelSpec.preset.spec = spec;
|
||||
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
|
||||
currentModelSpec.preset.iconURL = currentModelSpec.iconURL;
|
||||
}
|
||||
parsedBody = parseCompactConvo({
|
||||
endpoint,
|
||||
endpointType,
|
||||
conversation: currentModelSpec.preset,
|
||||
});
|
||||
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {
|
||||
parsedBody.iconURL = currentModelSpec.iconURL;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error parsing model spec for endpoint ${endpoint}`, error);
|
||||
return handleError(res, { text: 'Error parsing model spec' });
|
||||
}
|
||||
} else if (parsedBody.spec && appConfig.modelSpecs?.list) {
|
||||
// Non-enforced mode: if spec is selected, derive iconURL from model spec
|
||||
const modelSpec = appConfig.modelSpecs.list.find((s) => s.name === parsedBody.spec);
|
||||
if (modelSpec?.iconURL) {
|
||||
parsedBody.iconURL = modelSpec.iconURL;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const crypto = require('crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { parseConvo } = require('librechat-data-provider');
|
||||
const { sendEvent, handleError } = require('@librechat/api');
|
||||
const { sendEvent, handleError, sanitizeMessageForTransmit } = require('@librechat/api');
|
||||
const { saveMessage, getMessages } = require('~/models/Message');
|
||||
const { getConvo } = require('~/models/Conversation');
|
||||
|
||||
@@ -71,7 +71,7 @@ const sendError = async (req, res, options, callback) => {
|
||||
|
||||
return sendEvent(res, {
|
||||
final: true,
|
||||
requestMessage: query?.[0] ? query[0] : requestMessage,
|
||||
requestMessage: sanitizeMessageForTransmit(query?.[0] ?? requestMessage),
|
||||
responseMessage: errorMessage,
|
||||
conversation: convo,
|
||||
});
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
const accessPermissions = require('./accessPermissions');
|
||||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const tokenizer = require('./tokenizer');
|
||||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
const messages = require('./messages');
|
||||
@@ -53,7 +52,6 @@ module.exports = {
|
||||
messages,
|
||||
memories,
|
||||
endpoints,
|
||||
tokenizer,
|
||||
assistants,
|
||||
categories,
|
||||
staticRoute,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const express = require('express');
|
||||
const { unescapeLaTeX } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ContentTypes } = require('librechat-data-provider');
|
||||
const { unescapeLaTeX, countTokens } = require('@librechat/api');
|
||||
const {
|
||||
saveConvo,
|
||||
getMessage,
|
||||
@@ -14,7 +14,6 @@ const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/
|
||||
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
|
||||
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
|
||||
const { getConvosQueried } = require('~/models/Conversation');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
const { Message } = require('~/db/models');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
@@ -5,6 +5,7 @@ const {
|
||||
markPublicPromptGroups,
|
||||
buildPromptGroupFilter,
|
||||
formatPromptGroupsResponse,
|
||||
safeValidatePromptGroupUpdate,
|
||||
createEmptyPromptGroupsResponse,
|
||||
filterAccessibleIdsBySharedLogic,
|
||||
} = require('@librechat/api');
|
||||
@@ -344,7 +345,16 @@ const patchPromptGroup = async (req, res) => {
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete filter.author;
|
||||
}
|
||||
const promptGroup = await updatePromptGroup(filter, req.body);
|
||||
|
||||
const validationResult = safeValidatePromptGroupUpdate(req.body);
|
||||
if (!validationResult.success) {
|
||||
return res.status(400).send({
|
||||
error: 'Invalid request body',
|
||||
details: validationResult.error.errors,
|
||||
});
|
||||
}
|
||||
|
||||
const promptGroup = await updatePromptGroup(filter, validationResult.data);
|
||||
res.status(200).send(promptGroup);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
|
||||
@@ -544,6 +544,169 @@ describe('Prompt Routes - ACL Permissions', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('PATCH /api/prompts/groups/:groupId - Update Prompt Group Security', () => {
|
||||
let testGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group
|
||||
testGroup = await PromptGroup.create({
|
||||
name: 'Security Test Group',
|
||||
category: 'security-test',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Grant owner permissions
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should allow updating allowed fields (name, category, oneliner)', async () => {
|
||||
const updateData = {
|
||||
name: 'Updated Group Name',
|
||||
category: 'updated-category',
|
||||
oneliner: 'Updated description',
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(updateData)
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.name).toBe(updateData.name);
|
||||
expect(response.body.category).toBe(updateData.category);
|
||||
expect(response.body.oneliner).toBe(updateData.oneliner);
|
||||
});
|
||||
|
||||
it('should reject request with author field (400 Bad Request)', async () => {
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
author: testUsers.noAccess._id.toString(), // Try to change ownership
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
expect(response.body.details).toBeDefined();
|
||||
});
|
||||
|
||||
it('should reject request with authorName field (400 Bad Request)', async () => {
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
authorName: 'Malicious Author Name',
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
});
|
||||
|
||||
it('should reject request with _id field (400 Bad Request)', async () => {
|
||||
const newId = new ObjectId();
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
_id: newId.toString(),
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
});
|
||||
|
||||
it('should reject request with productionId field (400 Bad Request)', async () => {
|
||||
const newProductionId = new ObjectId();
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
productionId: newProductionId.toString(),
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
});
|
||||
|
||||
it('should reject request with createdAt field (400 Bad Request)', async () => {
|
||||
const maliciousDate = new Date('2020-01-01');
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
createdAt: maliciousDate.toISOString(),
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
});
|
||||
|
||||
it('should reject request with __v field (400 Bad Request)', async () => {
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
__v: 999,
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
});
|
||||
|
||||
it('should reject request with multiple sensitive fields (400 Bad Request)', async () => {
|
||||
const maliciousUpdate = {
|
||||
name: 'Legit Update',
|
||||
author: testUsers.noAccess._id.toString(),
|
||||
authorName: 'Hacker',
|
||||
_id: new ObjectId().toString(),
|
||||
productionId: new ObjectId().toString(),
|
||||
createdAt: new Date('2020-01-01').toISOString(),
|
||||
__v: 999,
|
||||
};
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/groups/${testGroup._id}`)
|
||||
.send(maliciousUpdate)
|
||||
.expect(400);
|
||||
|
||||
// Verify the request was rejected with validation errors
|
||||
expect(response.body.error).toBe('Invalid request body');
|
||||
expect(response.body.details).toBeDefined();
|
||||
expect(Array.isArray(response.body.details)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Pagination', () => {
|
||||
beforeEach(async () => {
|
||||
// Create multiple prompt groups for pagination testing
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post('/', requireJwtAuth, async (req, res) => {
|
||||
try {
|
||||
const { arg } = req.body;
|
||||
const count = await countTokens(arg?.text ?? arg);
|
||||
res.send({ count });
|
||||
} catch (e) {
|
||||
logger.error('[/tokenizer] Error counting tokens', e);
|
||||
res.status(500).json('Error counting tokens');
|
||||
}
|
||||
});
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,5 +1,6 @@
|
||||
const OpenAI = require('openai');
|
||||
const { ProxyAgent } = require('undici');
|
||||
const { isUserProvided } = require('@librechat/api');
|
||||
const { ErrorTypes, EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
getUserKeyValues,
|
||||
@@ -7,7 +8,6 @@ const {
|
||||
checkUserKeyExpiry,
|
||||
} = require('~/server/services/UserService');
|
||||
const OAIClient = require('~/app/clients/OpenAIClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption, version, initAppClient = false }) => {
|
||||
const { PROXY, OPENAI_ORGANIZATION, ASSISTANTS_API_KEY, ASSISTANTS_BASE_URL } = process.env;
|
||||
|
||||
@@ -12,14 +12,13 @@ const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
* @returns {boolean} - True if the provider is a known custom provider, false otherwise
|
||||
*/
|
||||
function isKnownCustomProvider(provider) {
|
||||
return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes(
|
||||
return [Providers.XAI, Providers.DEEPSEEK, Providers.OPENROUTER].includes(
|
||||
provider?.toLowerCase() || '',
|
||||
);
|
||||
}
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
[Providers.OLLAMA]: initCustom,
|
||||
[Providers.DEEPSEEK]: initCustom,
|
||||
[Providers.OPENROUTER]: initCustom,
|
||||
[EModelEndpoint.openAI]: initOpenAI,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
const axios = require('axios');
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { logAxiosError, inputSchema, processModelData } = require('@librechat/api');
|
||||
const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider');
|
||||
const { logAxiosError, inputSchema, processModelData, isUserProvided } = require('@librechat/api');
|
||||
const {
|
||||
CacheKeys,
|
||||
defaultModels,
|
||||
KnownEndpoints,
|
||||
EModelEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { OllamaClient } = require('~/app/clients/OllamaClient');
|
||||
const { isUserProvided } = require('~/server/utils');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { extractBaseURL } = require('~/utils');
|
||||
|
||||
@@ -68,7 +71,7 @@ const fetchModels = async ({
|
||||
return models;
|
||||
}
|
||||
|
||||
if (name && name.toLowerCase().startsWith(Providers.OLLAMA)) {
|
||||
if (name && name.toLowerCase().startsWith(KnownEndpoints.ollama)) {
|
||||
try {
|
||||
return await OllamaClient.fetchModels(baseURL, { headers, user: userObject });
|
||||
} catch (ollamaError) {
|
||||
@@ -103,7 +106,7 @@ const fetchModels = async ({
|
||||
options.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
const url = new URL(`${baseURL}${azure ? '' : '/models'}`);
|
||||
const url = new URL(`${baseURL.replace(/\/+$/, '')}${azure ? '' : '/models'}`);
|
||||
if (user && userIdQuery) {
|
||||
url.searchParams.append('user', user);
|
||||
}
|
||||
|
||||
@@ -436,6 +436,68 @@ describe('fetchModels with Ollama specific logic', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('fetchModels URL construction with trailing slashes', () => {
|
||||
beforeEach(() => {
|
||||
axios.get.mockResolvedValue({
|
||||
data: {
|
||||
data: [{ id: 'model-1' }, { id: 'model-2' }],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should not create double slashes when baseURL has a trailing slash', async () => {
|
||||
await fetchModels({
|
||||
user: 'user123',
|
||||
apiKey: 'testApiKey',
|
||||
baseURL: 'https://api.test.com/v1/',
|
||||
name: 'TestAPI',
|
||||
});
|
||||
|
||||
expect(axios.get).toHaveBeenCalledWith('https://api.test.com/v1/models', expect.any(Object));
|
||||
});
|
||||
|
||||
it('should handle baseURL without trailing slash normally', async () => {
|
||||
await fetchModels({
|
||||
user: 'user123',
|
||||
apiKey: 'testApiKey',
|
||||
baseURL: 'https://api.test.com/v1',
|
||||
name: 'TestAPI',
|
||||
});
|
||||
|
||||
expect(axios.get).toHaveBeenCalledWith('https://api.test.com/v1/models', expect.any(Object));
|
||||
});
|
||||
|
||||
it('should handle baseURL with multiple trailing slashes', async () => {
|
||||
await fetchModels({
|
||||
user: 'user123',
|
||||
apiKey: 'testApiKey',
|
||||
baseURL: 'https://api.test.com/v1///',
|
||||
name: 'TestAPI',
|
||||
});
|
||||
|
||||
expect(axios.get).toHaveBeenCalledWith('https://api.test.com/v1/models', expect.any(Object));
|
||||
});
|
||||
|
||||
it('should correctly append query params after stripping trailing slashes', async () => {
|
||||
await fetchModels({
|
||||
user: 'user123',
|
||||
apiKey: 'testApiKey',
|
||||
baseURL: 'https://api.test.com/v1/',
|
||||
name: 'TestAPI',
|
||||
userIdQuery: true,
|
||||
});
|
||||
|
||||
expect(axios.get).toHaveBeenCalledWith(
|
||||
'https://api.test.com/v1/models?user=user123',
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('splitAndTrim', () => {
|
||||
it('should split a string by commas and trim each value', () => {
|
||||
const input = ' model1, model2 , model3,model4 ';
|
||||
|
||||
@@ -292,7 +292,7 @@ const ensurePrincipalExists = async function (principal) {
|
||||
let existingUser = await findUser({ idOnTheSource: principal.idOnTheSource });
|
||||
|
||||
if (!existingUser) {
|
||||
existingUser = await findUser({ email: principal.email.toLowerCase() });
|
||||
existingUser = await findUser({ email: principal.email });
|
||||
}
|
||||
|
||||
if (existingUser) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const path = require('path');
|
||||
const { v4 } = require('uuid');
|
||||
const { countTokens, escapeRegExp } = require('@librechat/api');
|
||||
const {
|
||||
Constants,
|
||||
ContentTypes,
|
||||
@@ -8,7 +9,6 @@ const {
|
||||
} = require('librechat-data-provider');
|
||||
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
|
||||
const { recordMessage, getMessages } = require('~/models/Message');
|
||||
const { countTokens, escapeRegExp } = require('~/server/utils');
|
||||
const { spendTokens } = require('~/models/spendTokens');
|
||||
const { saveConvo } = require('~/models/Conversation');
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
const { Tiktoken } = require('tiktoken/lite');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const p50k_base = require('tiktoken/encoders/p50k_base.json');
|
||||
const cl100k_base = require('tiktoken/encoders/cl100k_base.json');
|
||||
|
||||
/**
|
||||
* Counts the number of tokens in a given text using a specified encoding model.
|
||||
*
|
||||
* This function utilizes the 'Tiktoken' library to encode text based on the selected model.
|
||||
* It supports two models, 'text-davinci-003' and 'gpt-3.5-turbo', each with its own encoding strategy.
|
||||
* For 'text-davinci-003', the 'p50k_base' encoder is used, whereas for other models, the 'cl100k_base' encoder is applied.
|
||||
* In case of an error during encoding, the error is logged, and the function returns 0.
|
||||
*
|
||||
* @async
|
||||
* @param {string} text - The text to be tokenized. Defaults to an empty string if not provided.
|
||||
* @param {string} modelName - The name of the model used for tokenizing. Defaults to 'gpt-3.5-turbo'.
|
||||
* @returns {Promise<number>} The number of tokens in the provided text. Returns 0 if an error occurs.
|
||||
* @throws Logs the error to a logger and rethrows if any error occurs during tokenization.
|
||||
*/
|
||||
const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => {
|
||||
let encoder = null;
|
||||
try {
|
||||
const model = modelName.includes('text-davinci-003') ? p50k_base : cl100k_base;
|
||||
encoder = new Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str);
|
||||
const tokens = encoder.encode(text);
|
||||
encoder.free();
|
||||
return tokens.length;
|
||||
} catch (e) {
|
||||
logger.error('[countTokens]', e);
|
||||
if (encoder) {
|
||||
encoder.free();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = countTokens;
|
||||
@@ -10,14 +10,6 @@ const {
|
||||
const { sendEvent } = require('@librechat/api');
|
||||
const partialRight = require('lodash/partialRight');
|
||||
|
||||
/** Helper function to escape special characters in regex
|
||||
* @param {string} string - The string to escape.
|
||||
* @returns {string} The escaped string.
|
||||
*/
|
||||
function escapeRegExp(string) {
|
||||
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
|
||||
|
||||
const base = { message: true, initial: true };
|
||||
@@ -181,7 +173,6 @@ function generateConfig(key, baseURL, endpoint) {
|
||||
module.exports = {
|
||||
handleText,
|
||||
formatSteps,
|
||||
escapeRegExp,
|
||||
formatAction,
|
||||
isUserProvided,
|
||||
generateConfig,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const removePorts = require('./removePorts');
|
||||
const countTokens = require('./countTokens');
|
||||
const handleText = require('./handleText');
|
||||
const sendEmail = require('./sendEmail');
|
||||
const queue = require('./queue');
|
||||
@@ -7,7 +6,6 @@ const files = require('./files');
|
||||
|
||||
module.exports = {
|
||||
...handleText,
|
||||
countTokens,
|
||||
removePorts,
|
||||
sendEmail,
|
||||
...files,
|
||||
|
||||
@@ -172,6 +172,7 @@ describe('socialLogin', () => {
|
||||
|
||||
/** Verify both searches happened */
|
||||
expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId });
|
||||
/** Email passed as-is; findUser implementation handles case normalization */
|
||||
expect(findUser).toHaveBeenNthCalledWith(2, { email: email });
|
||||
expect(findUser).toHaveBeenCalledTimes(2);
|
||||
|
||||
|
||||
@@ -665,7 +665,7 @@ describe('Meta Models Tests', () => {
|
||||
|
||||
test('should match Deepseek model variations', () => {
|
||||
expect(getModelMaxTokens('deepseek-chat')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['deepseek'],
|
||||
maxTokensMap[EModelEndpoint.openAI]['deepseek-chat'],
|
||||
);
|
||||
expect(getModelMaxTokens('deepseek-coder')).toBe(
|
||||
maxTokensMap[EModelEndpoint.openAI]['deepseek'],
|
||||
@@ -677,6 +677,20 @@ describe('Meta Models Tests', () => {
|
||||
maxTokensMap[EModelEndpoint.openAI]['deepseek.r1'],
|
||||
);
|
||||
});
|
||||
|
||||
test('should return 128000 context tokens for all DeepSeek models', () => {
|
||||
expect(getModelMaxTokens('deepseek-chat')).toBe(128000);
|
||||
expect(getModelMaxTokens('deepseek-reasoner')).toBe(128000);
|
||||
expect(getModelMaxTokens('deepseek-r1')).toBe(128000);
|
||||
expect(getModelMaxTokens('deepseek-v3')).toBe(128000);
|
||||
expect(getModelMaxTokens('deepseek.r1')).toBe(128000);
|
||||
});
|
||||
|
||||
test('should handle DeepSeek models with provider prefixes', () => {
|
||||
expect(getModelMaxTokens('deepseek/deepseek-chat')).toBe(128000);
|
||||
expect(getModelMaxTokens('openrouter/deepseek-reasoner')).toBe(128000);
|
||||
expect(getModelMaxTokens('openai/deepseek-v3')).toBe(128000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('matchModelName', () => {
|
||||
@@ -705,11 +719,42 @@ describe('Meta Models Tests', () => {
|
||||
});
|
||||
|
||||
test('should match Deepseek model variations', () => {
|
||||
expect(matchModelName('deepseek-chat')).toBe('deepseek');
|
||||
expect(matchModelName('deepseek-chat')).toBe('deepseek-chat');
|
||||
expect(matchModelName('deepseek-coder')).toBe('deepseek');
|
||||
});
|
||||
});
|
||||
|
||||
describe('DeepSeek Max Output Tokens', () => {
|
||||
const { getModelMaxOutputTokens } = require('@librechat/api');
|
||||
|
||||
test('should return correct max output tokens for deepseek-chat', () => {
|
||||
expect(getModelMaxOutputTokens('deepseek-chat')).toBe(8000);
|
||||
expect(getModelMaxOutputTokens('deepseek-chat', EModelEndpoint.openAI)).toBe(8000);
|
||||
expect(getModelMaxOutputTokens('deepseek-chat', EModelEndpoint.custom)).toBe(8000);
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for deepseek-reasoner', () => {
|
||||
expect(getModelMaxOutputTokens('deepseek-reasoner')).toBe(64000);
|
||||
expect(getModelMaxOutputTokens('deepseek-reasoner', EModelEndpoint.openAI)).toBe(64000);
|
||||
expect(getModelMaxOutputTokens('deepseek-reasoner', EModelEndpoint.custom)).toBe(64000);
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for deepseek-r1', () => {
|
||||
expect(getModelMaxOutputTokens('deepseek-r1')).toBe(64000);
|
||||
expect(getModelMaxOutputTokens('deepseek-r1', EModelEndpoint.openAI)).toBe(64000);
|
||||
});
|
||||
|
||||
test('should return correct max output tokens for deepseek base pattern', () => {
|
||||
expect(getModelMaxOutputTokens('deepseek')).toBe(8000);
|
||||
expect(getModelMaxOutputTokens('deepseek-v3')).toBe(8000);
|
||||
});
|
||||
|
||||
test('should handle DeepSeek models with provider prefixes for max output tokens', () => {
|
||||
expect(getModelMaxOutputTokens('deepseek/deepseek-chat')).toBe(8000);
|
||||
expect(getModelMaxOutputTokens('openrouter/deepseek-reasoner')).toBe(64000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processModelData with Meta models', () => {
|
||||
test('should process Meta model data correctly', () => {
|
||||
const input = {
|
||||
@@ -778,6 +823,16 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(getModelMaxTokens('grok-4-0709')).toBe(256000);
|
||||
});
|
||||
|
||||
test('should return correct tokens for Grok 4 Fast and Grok 4.1 Fast models', () => {
|
||||
expect(getModelMaxTokens('grok-4-fast')).toBe(2000000);
|
||||
expect(getModelMaxTokens('grok-4-1-fast-reasoning')).toBe(2000000);
|
||||
expect(getModelMaxTokens('grok-4-1-fast-non-reasoning')).toBe(2000000);
|
||||
});
|
||||
|
||||
test('should return correct tokens for Grok Code Fast model', () => {
|
||||
expect(getModelMaxTokens('grok-code-fast-1')).toBe(256000);
|
||||
});
|
||||
|
||||
test('should handle partial matches for Grok models with prefixes', () => {
|
||||
// Vision models should match before general models
|
||||
expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(32768);
|
||||
@@ -797,6 +852,12 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe(131072);
|
||||
// Grok 4 model
|
||||
expect(getModelMaxTokens('xai/grok-4-0709')).toBe(256000);
|
||||
// Grok 4 Fast and 4.1 Fast models
|
||||
expect(getModelMaxTokens('xai/grok-4-fast')).toBe(2000000);
|
||||
expect(getModelMaxTokens('xai/grok-4-1-fast-reasoning')).toBe(2000000);
|
||||
expect(getModelMaxTokens('xai/grok-4-1-fast-non-reasoning')).toBe(2000000);
|
||||
// Grok Code Fast model
|
||||
expect(getModelMaxTokens('xai/grok-code-fast-1')).toBe(256000);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -820,6 +881,12 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(matchModelName('grok-3-mini-fast')).toBe('grok-3-mini-fast');
|
||||
// Grok 4 model
|
||||
expect(matchModelName('grok-4-0709')).toBe('grok-4');
|
||||
// Grok 4 Fast and 4.1 Fast models
|
||||
expect(matchModelName('grok-4-fast')).toBe('grok-4-fast');
|
||||
expect(matchModelName('grok-4-1-fast-reasoning')).toBe('grok-4-1-fast');
|
||||
expect(matchModelName('grok-4-1-fast-non-reasoning')).toBe('grok-4-1-fast');
|
||||
// Grok Code Fast model
|
||||
expect(matchModelName('grok-code-fast-1')).toBe('grok-code-fast');
|
||||
});
|
||||
|
||||
test('should match Grok model variations with prefixes', () => {
|
||||
@@ -841,6 +908,12 @@ describe('Grok Model Tests - Tokens', () => {
|
||||
expect(matchModelName('xai/grok-3-mini-fast')).toBe('grok-3-mini-fast');
|
||||
// Grok 4 model
|
||||
expect(matchModelName('xai/grok-4-0709')).toBe('grok-4');
|
||||
// Grok 4 Fast and 4.1 Fast models
|
||||
expect(matchModelName('xai/grok-4-fast')).toBe('grok-4-fast');
|
||||
expect(matchModelName('xai/grok-4-1-fast-reasoning')).toBe('grok-4-1-fast');
|
||||
expect(matchModelName('xai/grok-4-1-fast-non-reasoning')).toBe('grok-4-1-fast');
|
||||
// Grok Code Fast model
|
||||
expect(matchModelName('xai/grok-code-fast-1')).toBe('grok-code-fast');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -64,6 +64,7 @@
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"cross-env": "^7.0.3",
|
||||
"date-fns": "^3.3.1",
|
||||
"dompurify": "^3.3.0",
|
||||
"downloadjs": "^1.4.7",
|
||||
"export-from-json": "^1.7.2",
|
||||
"filenamify": "^6.0.0",
|
||||
|
||||
@@ -2,6 +2,7 @@ import React, { useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { OGDialog, OGDialogTemplate } from '@librechat/client';
|
||||
import {
|
||||
inferMimeType,
|
||||
EToolResources,
|
||||
EModelEndpoint,
|
||||
defaultAgentCapabilities,
|
||||
@@ -56,18 +57,26 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
||||
const _options: FileOption[] = [];
|
||||
const currentProvider = provider || endpoint;
|
||||
|
||||
/** Helper to get inferred MIME type for a file */
|
||||
const getFileType = (file: File) => inferMimeType(file.name, file.type);
|
||||
|
||||
// Check if provider supports document upload
|
||||
if (isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider)) {
|
||||
const isGoogleProvider = currentProvider === EModelEndpoint.google;
|
||||
const validFileTypes = isGoogleProvider
|
||||
? files.every(
|
||||
(file) =>
|
||||
file.type?.startsWith('image/') ||
|
||||
file.type?.startsWith('video/') ||
|
||||
file.type?.startsWith('audio/') ||
|
||||
file.type === 'application/pdf',
|
||||
)
|
||||
: files.every((file) => file.type?.startsWith('image/') || file.type === 'application/pdf');
|
||||
? files.every((file) => {
|
||||
const type = getFileType(file);
|
||||
return (
|
||||
type?.startsWith('image/') ||
|
||||
type?.startsWith('video/') ||
|
||||
type?.startsWith('audio/') ||
|
||||
type === 'application/pdf'
|
||||
);
|
||||
})
|
||||
: files.every((file) => {
|
||||
const type = getFileType(file);
|
||||
return type?.startsWith('image/') || type === 'application/pdf';
|
||||
});
|
||||
|
||||
_options.push({
|
||||
label: localize('com_ui_upload_provider'),
|
||||
@@ -81,7 +90,7 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD
|
||||
label: localize('com_ui_upload_image_input'),
|
||||
value: undefined,
|
||||
icon: <ImageUpIcon className="icon-md" />,
|
||||
condition: files.every((file) => file.type?.startsWith('image/')),
|
||||
condition: files.every((file) => getFileType(file)?.startsWith('image/')),
|
||||
});
|
||||
}
|
||||
if (capabilities.fileSearchEnabled && fileSearchAllowedByAgent) {
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { EModelEndpoint, isDocumentSupportedProvider } from 'librechat-data-provider';
|
||||
import {
|
||||
EModelEndpoint,
|
||||
isDocumentSupportedProvider,
|
||||
inferMimeType,
|
||||
} from 'librechat-data-provider';
|
||||
|
||||
describe('DragDropModal - Provider Detection', () => {
|
||||
describe('endpointType priority over currentProvider', () => {
|
||||
@@ -118,4 +122,59 @@ describe('DragDropModal - Provider Detection', () => {
|
||||
).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('HEIC/HEIF file type inference', () => {
|
||||
it('should infer image/heic for .heic files when browser returns empty type', () => {
|
||||
const fileName = 'photo.heic';
|
||||
const browserType = '';
|
||||
|
||||
const inferredType = inferMimeType(fileName, browserType);
|
||||
expect(inferredType).toBe('image/heic');
|
||||
});
|
||||
|
||||
it('should infer image/heif for .heif files when browser returns empty type', () => {
|
||||
const fileName = 'photo.heif';
|
||||
const browserType = '';
|
||||
|
||||
const inferredType = inferMimeType(fileName, browserType);
|
||||
expect(inferredType).toBe('image/heif');
|
||||
});
|
||||
|
||||
it('should handle uppercase .HEIC extension', () => {
|
||||
const fileName = 'IMG_1234.HEIC';
|
||||
const browserType = '';
|
||||
|
||||
const inferredType = inferMimeType(fileName, browserType);
|
||||
expect(inferredType).toBe('image/heic');
|
||||
});
|
||||
|
||||
it('should preserve browser-provided type when available', () => {
|
||||
const fileName = 'photo.jpg';
|
||||
const browserType = 'image/jpeg';
|
||||
|
||||
const inferredType = inferMimeType(fileName, browserType);
|
||||
expect(inferredType).toBe('image/jpeg');
|
||||
});
|
||||
|
||||
it('should not override browser type even if extension differs', () => {
|
||||
const fileName = 'renamed.heic';
|
||||
const browserType = 'image/png';
|
||||
|
||||
const inferredType = inferMimeType(fileName, browserType);
|
||||
expect(inferredType).toBe('image/png');
|
||||
});
|
||||
|
||||
it('should correctly identify HEIC as image type for upload options', () => {
|
||||
const heicType = inferMimeType('photo.heic', '');
|
||||
expect(heicType.startsWith('image/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return empty string for unknown extension with no browser type', () => {
|
||||
const fileName = 'file.xyz';
|
||||
const browserType = '';
|
||||
|
||||
const inferredType = inferMimeType(fileName, browserType);
|
||||
expect(inferredType).toBe('');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -145,8 +145,7 @@ export default function OpenAIImageGen({
|
||||
clearInterval(intervalRef.current);
|
||||
}
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [initialProgress, quality]);
|
||||
}, [isSubmitting, initialProgress, quality]);
|
||||
|
||||
useEffect(() => {
|
||||
if (initialProgress >= 1 || cancelled) {
|
||||
|
||||
@@ -45,6 +45,9 @@ const extractMessageContent = (message: TMessage): string => {
|
||||
if (Array.isArray(message.content)) {
|
||||
return message.content
|
||||
.map((part) => {
|
||||
if (part == null) {
|
||||
return '';
|
||||
}
|
||||
if (typeof part === 'string') {
|
||||
return part;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import DOMPurify from 'dompurify';
|
||||
import { useForm, Controller } from 'react-hook-form';
|
||||
import { Input, Label, Button, TooltipAnchor, CircleHelpIcon } from '@librechat/client';
|
||||
import { Input, Label, Button } from '@librechat/client';
|
||||
import { useMCPAuthValuesQuery } from '~/data-provider/Tools/queries';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
@@ -27,21 +28,40 @@ interface AuthFieldProps {
|
||||
function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const sanitizer = useMemo(() => {
|
||||
const instance = DOMPurify();
|
||||
instance.addHook('afterSanitizeAttributes', (node) => {
|
||||
if (node.tagName && node.tagName === 'A') {
|
||||
node.setAttribute('target', '_blank');
|
||||
node.setAttribute('rel', 'noopener noreferrer');
|
||||
}
|
||||
});
|
||||
return instance;
|
||||
}, []);
|
||||
|
||||
const sanitizedDescription = useMemo(() => {
|
||||
if (!config.description) {
|
||||
return '';
|
||||
}
|
||||
try {
|
||||
return sanitizer.sanitize(config.description, {
|
||||
ALLOWED_TAGS: ['a', 'strong', 'b', 'em', 'i', 'br', 'code'],
|
||||
ALLOWED_ATTR: ['href', 'class', 'target', 'rel'],
|
||||
ALLOW_DATA_ATTR: false,
|
||||
ALLOW_ARIA_ATTR: false,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Sanitization failed', error);
|
||||
return config.description;
|
||||
}
|
||||
}, [config.description, sanitizer]);
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<TooltipAnchor
|
||||
enableHTML={true}
|
||||
description={config.description || ''}
|
||||
render={
|
||||
<div className="flex items-center gap-2">
|
||||
<Label htmlFor={name} className="text-sm font-medium">
|
||||
{config.title}
|
||||
</Label>
|
||||
<CircleHelpIcon className="h-6 w-6 cursor-help text-text-secondary transition-colors hover:text-text-primary" />
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
<Label htmlFor={name} className="text-sm font-medium">
|
||||
{config.title}
|
||||
</Label>
|
||||
{hasValue ? (
|
||||
<div className="flex min-w-fit items-center gap-2 whitespace-nowrap rounded-full border border-border-light px-2 py-0.5 text-xs font-medium text-text-secondary">
|
||||
<div className="h-1.5 w-1.5 rounded-full bg-green-500" />
|
||||
@@ -66,12 +86,18 @@ function AuthField({ name, config, hasValue, control, errors }: AuthFieldProps)
|
||||
placeholder={
|
||||
hasValue
|
||||
? localize('com_ui_mcp_update_var', { 0: config.title })
|
||||
: `${localize('com_ui_mcp_enter_var', { 0: config.title })} ${localize('com_ui_optional')}`
|
||||
: localize('com_ui_mcp_enter_var', { 0: config.title })
|
||||
}
|
||||
className="w-full rounded border border-border-medium bg-transparent px-2 py-1 text-text-primary placeholder:text-text-secondary focus:outline-none sm:text-sm"
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
{sanitizedDescription && (
|
||||
<p
|
||||
className="text-xs text-text-secondary [&_a]:text-blue-500 [&_a]:hover:underline"
|
||||
dangerouslySetInnerHTML={{ __html: sanitizedDescription }}
|
||||
/>
|
||||
)}
|
||||
{errors[name] && <p className="text-xs text-red-500">{errors[name]?.message}</p>}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -41,9 +41,11 @@ const toggleSwitchConfigs = [
|
||||
export const ThemeSelector = ({
|
||||
theme,
|
||||
onChange,
|
||||
portal = true,
|
||||
}: {
|
||||
theme: string;
|
||||
onChange: (value: string) => void;
|
||||
portal?: boolean;
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
|
||||
@@ -67,6 +69,7 @@ export const ThemeSelector = ({
|
||||
testId="theme-selector"
|
||||
className="z-50"
|
||||
aria-labelledby={labelId}
|
||||
portal={portal}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -227,9 +227,13 @@ function ShareHeader({
|
||||
<OGDialogTitle>{settingsLabel}</OGDialogTitle>
|
||||
</OGDialogHeader>
|
||||
<div className="flex flex-col gap-4 pt-2 text-sm">
|
||||
<ThemeSelector theme={theme} onChange={onThemeChange} />
|
||||
<div className="relative focus-within:z-[100]">
|
||||
<ThemeSelector theme={theme} onChange={onThemeChange} portal={false} />
|
||||
</div>
|
||||
<div className="bg-border-medium/60 h-px w-full" />
|
||||
<LangSelector langcode={langcode} onChange={onLangChange} portal={false} />
|
||||
<div className="relative focus-within:z-[100]">
|
||||
<LangSelector langcode={langcode} onChange={onLangChange} portal={false} />
|
||||
</div>
|
||||
</div>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
|
||||
@@ -168,6 +168,7 @@ export default function useChatFunctions({
|
||||
|
||||
const endpointsConfig = queryClient.getQueryData<TEndpointsConfig>([QueryKeys.endpoints]);
|
||||
const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');
|
||||
const iconURL = conversation?.iconURL;
|
||||
|
||||
/** This becomes part of the `endpointOption` */
|
||||
const convo = parseCompactConvo({
|
||||
@@ -248,9 +249,9 @@ export default function useChatFunctions({
|
||||
conversationId,
|
||||
unfinished: false,
|
||||
isCreatedByUser: false,
|
||||
iconURL: convo?.iconURL,
|
||||
model: convo?.model,
|
||||
error: false,
|
||||
iconURL,
|
||||
};
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
|
||||
@@ -73,7 +73,9 @@ export default function useExportConversation({
|
||||
}
|
||||
|
||||
return message.content
|
||||
.filter((content) => content != null)
|
||||
.map((content) => getMessageContent(message.sender || '', content))
|
||||
.filter((text) => text.length > 0)
|
||||
.map((text) => {
|
||||
return formatText(text[0], text[1]);
|
||||
})
|
||||
@@ -103,7 +105,7 @@ export default function useExportConversation({
|
||||
if (content.type === ContentTypes.TEXT) {
|
||||
// TEXT
|
||||
const textPart = content[ContentTypes.TEXT];
|
||||
const text = typeof textPart === 'string' ? textPart : textPart.value;
|
||||
const text = typeof textPart === 'string' ? textPart : (textPart?.value ?? '');
|
||||
return [sender, text];
|
||||
}
|
||||
|
||||
@@ -365,12 +367,10 @@ export default function useExportConversation({
|
||||
data['messages'] = messages;
|
||||
}
|
||||
|
||||
exportFromJSON({
|
||||
data: data,
|
||||
fileName: filename,
|
||||
extension: 'json',
|
||||
exportType: exportFromJSON.types.json,
|
||||
});
|
||||
/** Use JSON.stringify without indentation to minimize file size for deeply nested recursive exports */
|
||||
const jsonString = JSON.stringify(data);
|
||||
const blob = new Blob([jsonString], { type: 'application/json;charset=utf-8' });
|
||||
download(blob, `${filename}.json`, 'application/json');
|
||||
};
|
||||
|
||||
const exportConversation = () => {
|
||||
|
||||
@@ -33,9 +33,8 @@ export default function useContentHandler({ setMessages, getMessages }: TUseCont
|
||||
|
||||
const _messages = getMessages();
|
||||
const messages =
|
||||
_messages
|
||||
?.filter((m) => m.messageId !== messageId)
|
||||
.map((msg) => ({ ...msg, thread_id })) ?? [];
|
||||
_messages?.filter((m) => m.messageId !== messageId).map((msg) => ({ ...msg, thread_id })) ??
|
||||
[];
|
||||
const userMessage = messages[messages.length - 1] as TMessage | undefined;
|
||||
|
||||
const { initialResponse } = submission;
|
||||
@@ -66,14 +65,17 @@ export default function useContentHandler({ setMessages, getMessages }: TUseCont
|
||||
|
||||
response.content[index] = { type, [type]: part } as TMessageContentParts;
|
||||
|
||||
const lastContentPart = response.content[response.content.length - 1];
|
||||
const initialContentPart = initialResponse.content?.[0];
|
||||
if (
|
||||
type !== ContentTypes.TEXT &&
|
||||
initialResponse.content &&
|
||||
((response.content[response.content.length - 1].type === ContentTypes.TOOL_CALL &&
|
||||
response.content[response.content.length - 1][ContentTypes.TOOL_CALL].progress === 1) ||
|
||||
response.content[response.content.length - 1].type === ContentTypes.IMAGE_FILE)
|
||||
initialContentPart != null &&
|
||||
lastContentPart != null &&
|
||||
((lastContentPart.type === ContentTypes.TOOL_CALL &&
|
||||
lastContentPart[ContentTypes.TOOL_CALL]?.progress === 1) ||
|
||||
lastContentPart.type === ContentTypes.IMAGE_FILE)
|
||||
) {
|
||||
response.content.push(initialResponse.content[0]);
|
||||
response.content.push(initialContentPart);
|
||||
}
|
||||
|
||||
setMessages([...messages, response]);
|
||||
|
||||
@@ -87,12 +87,14 @@ const createErrorMessage = ({
|
||||
let isValidContentPart = false;
|
||||
if (latestContent.length > 0) {
|
||||
const latestContentPart = latestContent[latestContent.length - 1];
|
||||
const latestPartValue = latestContentPart?.[latestContentPart.type ?? ''];
|
||||
isValidContentPart =
|
||||
latestContentPart.type !== ContentTypes.TEXT ||
|
||||
(latestContentPart.type === ContentTypes.TEXT && typeof latestPartValue === 'string')
|
||||
? true
|
||||
: latestPartValue?.value !== '';
|
||||
if (latestContentPart != null) {
|
||||
const latestPartValue = latestContentPart[latestContentPart.type ?? ''];
|
||||
isValidContentPart =
|
||||
latestContentPart.type !== ContentTypes.TEXT ||
|
||||
(latestContentPart.type === ContentTypes.TEXT && typeof latestPartValue === 'string')
|
||||
? true
|
||||
: latestPartValue?.value !== '';
|
||||
}
|
||||
}
|
||||
if (
|
||||
latestMessage?.conversationId &&
|
||||
@@ -455,141 +457,145 @@ export default function useEventHandlers({
|
||||
isTemporary = false,
|
||||
} = submission;
|
||||
|
||||
if (responseMessage?.attachments && responseMessage.attachments.length > 0) {
|
||||
// Process each attachment through the attachmentHandler
|
||||
responseMessage.attachments.forEach((attachment) => {
|
||||
const attachmentData = {
|
||||
...attachment,
|
||||
messageId: responseMessage.messageId,
|
||||
};
|
||||
try {
|
||||
if (responseMessage?.attachments && responseMessage.attachments.length > 0) {
|
||||
// Process each attachment through the attachmentHandler
|
||||
responseMessage.attachments.forEach((attachment) => {
|
||||
const attachmentData = {
|
||||
...attachment,
|
||||
messageId: responseMessage.messageId,
|
||||
};
|
||||
|
||||
attachmentHandler({
|
||||
data: attachmentData,
|
||||
submission: submission as EventSubmission,
|
||||
attachmentHandler({
|
||||
data: attachmentData,
|
||||
submission: submission as EventSubmission,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
setShowStopButton(false);
|
||||
setCompleted((prev) => new Set(prev.add(submission.initialResponse.messageId)));
|
||||
setCompleted((prev) => new Set(prev.add(submission.initialResponse.messageId)));
|
||||
|
||||
const currentMessages = getMessages();
|
||||
/* Early return if messages are empty; i.e., the user navigated away */
|
||||
if (!currentMessages || currentMessages.length === 0) {
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
const currentMessages = getMessages();
|
||||
/* Early return if messages are empty; i.e., the user navigated away */
|
||||
if (!currentMessages || currentMessages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* a11y announcements */
|
||||
announcePolite({ message: 'end', isStatus: true });
|
||||
announcePolite({ message: getAllContentText(responseMessage) });
|
||||
/* a11y announcements */
|
||||
announcePolite({ message: 'end', isStatus: true });
|
||||
announcePolite({ message: getAllContentText(responseMessage) });
|
||||
|
||||
const isNewConvo = conversation.conversationId !== submissionConvo.conversationId;
|
||||
const isNewConvo = conversation.conversationId !== submissionConvo.conversationId;
|
||||
|
||||
const setFinalMessages = (id: string | null, _messages: TMessage[]) => {
|
||||
setMessages(_messages);
|
||||
queryClient.setQueryData<TMessage[]>([QueryKeys.messages, id], _messages);
|
||||
};
|
||||
const setFinalMessages = (id: string | null, _messages: TMessage[]) => {
|
||||
setMessages(_messages);
|
||||
queryClient.setQueryData<TMessage[]>([QueryKeys.messages, id], _messages);
|
||||
};
|
||||
|
||||
const hasNoResponse =
|
||||
responseMessage?.content?.[0]?.['text']?.value ===
|
||||
submission.initialResponse?.content?.[0]?.['text']?.value ||
|
||||
!!responseMessage?.content?.[0]?.['tool_call']?.auth;
|
||||
const hasNoResponse =
|
||||
responseMessage?.content?.[0]?.['text']?.value ===
|
||||
submission.initialResponse?.content?.[0]?.['text']?.value ||
|
||||
!!responseMessage?.content?.[0]?.['tool_call']?.auth;
|
||||
|
||||
/** Handle edge case where stream is cancelled before any response, which creates a blank page */
|
||||
if (!conversation.conversationId && hasNoResponse) {
|
||||
const currentConvoId =
|
||||
(submissionConvo.conversationId ?? conversation.conversationId) || Constants.NEW_CONVO;
|
||||
if (isNewConvo && submissionConvo.conversationId) {
|
||||
removeConvoFromAllQueries(queryClient, submissionConvo.conversationId);
|
||||
}
|
||||
|
||||
const isNewChat =
|
||||
location.pathname === `/c/${Constants.NEW_CONVO}` &&
|
||||
currentConvoId === Constants.NEW_CONVO;
|
||||
|
||||
setFinalMessages(currentConvoId, isNewChat ? [] : [...messages]);
|
||||
setDraft({ id: currentConvoId, value: requestMessage?.text });
|
||||
if (isNewChat) {
|
||||
navigate(`/c/${Constants.NEW_CONVO}`, { replace: true, state: { focusChat: true } });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* Update messages; if assistants endpoint, client doesn't receive responseMessage */
|
||||
let finalMessages: TMessage[] = [];
|
||||
if (runMessages) {
|
||||
finalMessages = [...runMessages];
|
||||
} else if (isRegenerate && responseMessage) {
|
||||
finalMessages = [...messages, responseMessage];
|
||||
} else if (requestMessage != null && responseMessage != null) {
|
||||
finalMessages = [...messages, requestMessage, responseMessage];
|
||||
}
|
||||
if (finalMessages.length > 0) {
|
||||
setFinalMessages(conversation.conversationId, finalMessages);
|
||||
} else if (
|
||||
isAssistantsEndpoint(submissionConvo.endpoint) &&
|
||||
(!submissionConvo.conversationId ||
|
||||
submissionConvo.conversationId === Constants.NEW_CONVO)
|
||||
) {
|
||||
queryClient.setQueryData<TMessage[]>(
|
||||
[QueryKeys.messages, conversation.conversationId],
|
||||
[...currentMessages],
|
||||
);
|
||||
}
|
||||
|
||||
/** Handle edge case where stream is cancelled before any response, which creates a blank page */
|
||||
if (!conversation.conversationId && hasNoResponse) {
|
||||
const currentConvoId =
|
||||
(submissionConvo.conversationId ?? conversation.conversationId) || Constants.NEW_CONVO;
|
||||
if (isNewConvo && submissionConvo.conversationId) {
|
||||
removeConvoFromAllQueries(queryClient, submissionConvo.conversationId);
|
||||
}
|
||||
|
||||
const isNewChat =
|
||||
location.pathname === `/c/${Constants.NEW_CONVO}` &&
|
||||
currentConvoId === Constants.NEW_CONVO;
|
||||
|
||||
setFinalMessages(currentConvoId, isNewChat ? [] : [...messages]);
|
||||
setDraft({ id: currentConvoId, value: requestMessage?.text });
|
||||
setIsSubmitting(false);
|
||||
if (isNewChat) {
|
||||
navigate(`/c/${Constants.NEW_CONVO}`, { replace: true, state: { focusChat: true } });
|
||||
/* Refresh title */
|
||||
if (
|
||||
genTitle &&
|
||||
isNewConvo &&
|
||||
!isTemporary &&
|
||||
requestMessage &&
|
||||
requestMessage.parentMessageId === Constants.NO_PARENT
|
||||
) {
|
||||
setTimeout(() => {
|
||||
genTitle.mutate({ conversationId: conversation.conversationId as string });
|
||||
}, 2500);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* Update messages; if assistants endpoint, client doesn't receive responseMessage */
|
||||
let finalMessages: TMessage[] = [];
|
||||
if (runMessages) {
|
||||
finalMessages = [...runMessages];
|
||||
} else if (isRegenerate && responseMessage) {
|
||||
finalMessages = [...messages, responseMessage];
|
||||
} else if (requestMessage != null && responseMessage != null) {
|
||||
finalMessages = [...messages, requestMessage, responseMessage];
|
||||
}
|
||||
if (finalMessages.length > 0) {
|
||||
setFinalMessages(conversation.conversationId, finalMessages);
|
||||
} else if (
|
||||
isAssistantsEndpoint(submissionConvo.endpoint) &&
|
||||
(!submissionConvo.conversationId || submissionConvo.conversationId === Constants.NEW_CONVO)
|
||||
) {
|
||||
queryClient.setQueryData<TMessage[]>(
|
||||
[QueryKeys.messages, conversation.conversationId],
|
||||
[...currentMessages],
|
||||
);
|
||||
}
|
||||
|
||||
if (isNewConvo && submissionConvo.conversationId) {
|
||||
removeConvoFromAllQueries(queryClient, submissionConvo.conversationId);
|
||||
}
|
||||
|
||||
/* Refresh title */
|
||||
if (
|
||||
genTitle &&
|
||||
isNewConvo &&
|
||||
!isTemporary &&
|
||||
requestMessage &&
|
||||
requestMessage.parentMessageId === Constants.NO_PARENT
|
||||
) {
|
||||
setTimeout(() => {
|
||||
genTitle.mutate({ conversationId: conversation.conversationId as string });
|
||||
}, 2500);
|
||||
}
|
||||
|
||||
if (setConversation && isAddedRequest !== true) {
|
||||
setConversation((prevState) => {
|
||||
const update = {
|
||||
...prevState,
|
||||
...(conversation as TConversation),
|
||||
};
|
||||
if (prevState?.model != null && prevState.model !== submissionConvo.model) {
|
||||
update.model = prevState.model;
|
||||
}
|
||||
const cachedConvo = queryClient.getQueryData<TConversation>([
|
||||
QueryKeys.conversation,
|
||||
conversation.conversationId,
|
||||
]);
|
||||
if (!cachedConvo) {
|
||||
queryClient.setQueryData([QueryKeys.conversation, conversation.conversationId], update);
|
||||
}
|
||||
return update;
|
||||
});
|
||||
|
||||
if (conversation.conversationId && submission.ephemeralAgent) {
|
||||
applyAgentTemplate({
|
||||
targetId: conversation.conversationId,
|
||||
sourceId: submissionConvo.conversationId,
|
||||
ephemeralAgent: submission.ephemeralAgent,
|
||||
specName: submission.conversation?.spec,
|
||||
startupConfig: queryClient.getQueryData<TStartupConfig>([QueryKeys.startupConfig]),
|
||||
if (setConversation && isAddedRequest !== true) {
|
||||
setConversation((prevState) => {
|
||||
const update = {
|
||||
...prevState,
|
||||
...(conversation as TConversation),
|
||||
};
|
||||
if (prevState?.model != null && prevState.model !== submissionConvo.model) {
|
||||
update.model = prevState.model;
|
||||
}
|
||||
const cachedConvo = queryClient.getQueryData<TConversation>([
|
||||
QueryKeys.conversation,
|
||||
conversation.conversationId,
|
||||
]);
|
||||
if (!cachedConvo) {
|
||||
queryClient.setQueryData(
|
||||
[QueryKeys.conversation, conversation.conversationId],
|
||||
update,
|
||||
);
|
||||
}
|
||||
return update;
|
||||
});
|
||||
}
|
||||
|
||||
if (location.pathname === `/c/${Constants.NEW_CONVO}`) {
|
||||
navigate(`/c/${conversation.conversationId}`, { replace: true });
|
||||
if (conversation.conversationId && submission.ephemeralAgent) {
|
||||
applyAgentTemplate({
|
||||
targetId: conversation.conversationId,
|
||||
sourceId: submissionConvo.conversationId,
|
||||
ephemeralAgent: submission.ephemeralAgent,
|
||||
specName: submission.conversation?.spec,
|
||||
startupConfig: queryClient.getQueryData<TStartupConfig>([QueryKeys.startupConfig]),
|
||||
});
|
||||
}
|
||||
|
||||
if (location.pathname === `/c/${Constants.NEW_CONVO}`) {
|
||||
navigate(`/c/${conversation.conversationId}`, { replace: true });
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
setShowStopButton(false);
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
|
||||
setIsSubmitting(false);
|
||||
},
|
||||
[
|
||||
navigate,
|
||||
@@ -722,26 +728,37 @@ export default function useEventHandlers({
|
||||
messages[messages.length - 2] != null
|
||||
) {
|
||||
let requestMessage = messages[messages.length - 2];
|
||||
const responseMessage = messages[messages.length - 1];
|
||||
if (requestMessage.messageId !== responseMessage.parentMessageId) {
|
||||
const _responseMessage = messages[messages.length - 1];
|
||||
if (requestMessage.messageId !== _responseMessage.parentMessageId) {
|
||||
// the request message is the parent of response, which we search for backwards
|
||||
for (let i = messages.length - 3; i >= 0; i--) {
|
||||
if (messages[i].messageId === responseMessage.parentMessageId) {
|
||||
if (messages[i].messageId === _responseMessage.parentMessageId) {
|
||||
requestMessage = messages[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
finalHandler(
|
||||
{
|
||||
conversation: {
|
||||
conversationId,
|
||||
/** Sanitize content array to remove undefined parts from interrupted streaming */
|
||||
const responseMessage = {
|
||||
..._responseMessage,
|
||||
content: _responseMessage.content?.filter((part) => part != null),
|
||||
};
|
||||
try {
|
||||
finalHandler(
|
||||
{
|
||||
conversation: {
|
||||
conversationId,
|
||||
},
|
||||
requestMessage,
|
||||
responseMessage,
|
||||
},
|
||||
requestMessage,
|
||||
responseMessage,
|
||||
},
|
||||
submission,
|
||||
);
|
||||
submission,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Error in finalHandler during abort:', error);
|
||||
setShowStopButton(false);
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
return;
|
||||
} else if (!isAssistantsEndpoint(endpoint)) {
|
||||
const convoId = conversationId || `_${v4()}`;
|
||||
@@ -809,13 +826,14 @@ export default function useEventHandlers({
|
||||
}
|
||||
},
|
||||
[
|
||||
finalHandler,
|
||||
newConversation,
|
||||
setIsSubmitting,
|
||||
token,
|
||||
cancelHandler,
|
||||
getMessages,
|
||||
setMessages,
|
||||
finalHandler,
|
||||
cancelHandler,
|
||||
newConversation,
|
||||
setIsSubmitting,
|
||||
setShowStopButton,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
@@ -124,7 +124,13 @@ export default function useSSE(
|
||||
if (data.final != null) {
|
||||
clearDraft(submission.conversation?.conversationId);
|
||||
const { plugins } = data;
|
||||
finalHandler(data, { ...submission, plugins } as EventSubmission);
|
||||
try {
|
||||
finalHandler(data, { ...submission, plugins } as EventSubmission);
|
||||
} catch (error) {
|
||||
console.error('Error in finalHandler:', error);
|
||||
setIsSubmitting(false);
|
||||
setShowStopButton(false);
|
||||
}
|
||||
(startupConfig?.balance?.enabled ?? false) && balanceQuery.refetch();
|
||||
console.log('final', data);
|
||||
return;
|
||||
@@ -187,14 +193,20 @@ export default function useSSE(
|
||||
setCompleted((prev) => new Set(prev.add(streamKey)));
|
||||
const latestMessages = getMessages();
|
||||
const conversationId = latestMessages?.[latestMessages.length - 1]?.conversationId;
|
||||
return await abortConversation(
|
||||
conversationId ??
|
||||
userMessage.conversationId ??
|
||||
submission.conversation?.conversationId ??
|
||||
'',
|
||||
submission as EventSubmission,
|
||||
latestMessages,
|
||||
);
|
||||
try {
|
||||
await abortConversation(
|
||||
conversationId ??
|
||||
userMessage.conversationId ??
|
||||
submission.conversation?.conversationId ??
|
||||
'',
|
||||
submission as EventSubmission,
|
||||
latestMessages,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Error during abort:', error);
|
||||
setIsSubmitting(false);
|
||||
setShowStopButton(false);
|
||||
}
|
||||
});
|
||||
|
||||
sse.addEventListener('error', async (e: MessageEvent) => {
|
||||
|
||||
@@ -313,6 +313,10 @@ export default function useStepHandler({
|
||||
? messageDelta.delta.content[0]
|
||||
: messageDelta.delta.content;
|
||||
|
||||
if (contentPart == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = calculateContentIndex(
|
||||
runStep.index,
|
||||
initialContent,
|
||||
@@ -345,6 +349,10 @@ export default function useStepHandler({
|
||||
? reasoningDelta.delta.content[0]
|
||||
: reasoningDelta.delta.content;
|
||||
|
||||
if (contentPart == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = calculateContentIndex(
|
||||
runStep.index,
|
||||
initialContent,
|
||||
|
||||
@@ -9,9 +9,9 @@ import {
|
||||
import {
|
||||
megabyte,
|
||||
QueryKeys,
|
||||
inferMimeType,
|
||||
excelMimeTypes,
|
||||
EToolResources,
|
||||
codeTypeMapping,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TFile, EndpointFileConfig, FileConfig } from 'librechat-data-provider';
|
||||
@@ -257,14 +257,7 @@ export const validateFiles = ({
|
||||
|
||||
for (let i = 0; i < fileList.length; i++) {
|
||||
let originalFile = fileList[i];
|
||||
let fileType = originalFile.type;
|
||||
const extension = originalFile.name.split('.').pop() ?? '';
|
||||
const knownCodeType = codeTypeMapping[extension];
|
||||
|
||||
// Infer MIME type for Known Code files when the type is empty or a mismatch
|
||||
if (knownCodeType && (!fileType || fileType !== knownCodeType)) {
|
||||
fileType = knownCodeType;
|
||||
}
|
||||
const fileType = inferMimeType(originalFile.name, originalFile.type);
|
||||
|
||||
// Check if the file type is still empty after the extension check
|
||||
if (!fileType) {
|
||||
|
||||
@@ -44,7 +44,7 @@ export const getAllContentText = (message?: TMessage | null): string => {
|
||||
|
||||
if (message.content && message.content.length > 0) {
|
||||
return message.content
|
||||
.filter((part) => part.type === ContentTypes.TEXT)
|
||||
.filter((part) => part != null && part.type === ContentTypes.TEXT)
|
||||
.map((part) => {
|
||||
if (!('text' in part)) return '';
|
||||
const text = part.text;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const path = require('path');
|
||||
const mongoose = require('mongoose');
|
||||
const { isEnabled, getBalanceConfig } = require('@librechat/api');
|
||||
const { getBalanceConfig } = require('@librechat/api');
|
||||
const { User } = require('@librechat/data-schemas').createModels(mongoose);
|
||||
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
||||
const { createTransaction } = require('~/models/Transaction');
|
||||
@@ -33,15 +33,12 @@ const connect = require('./connect');
|
||||
// console.purple(`[DEBUG] Args Length: ${process.argv.length}`);
|
||||
}
|
||||
|
||||
if (!process.env.CHECK_BALANCE) {
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
|
||||
if (!balanceConfig?.enabled) {
|
||||
console.red(
|
||||
'Error: CHECK_BALANCE environment variable is not set! Configure it to use it: `CHECK_BALANCE=true`',
|
||||
);
|
||||
silentExit(1);
|
||||
}
|
||||
if (isEnabled(process.env.CHECK_BALANCE) === false) {
|
||||
console.red(
|
||||
'Error: CHECK_BALANCE environment variable is set to `false`! Please configure: `CHECK_BALANCE=true`',
|
||||
'Error: Balance is not enabled. Use librechat.yaml to enable it',
|
||||
);
|
||||
silentExit(1);
|
||||
}
|
||||
@@ -80,8 +77,6 @@ const connect = require('./connect');
|
||||
*/
|
||||
let result;
|
||||
try {
|
||||
const appConfig = await getAppConfig();
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
result = await createTransaction({
|
||||
user: user._id,
|
||||
tokenType: 'credits',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const path = require('path');
|
||||
const mongoose = require('mongoose');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { getBalanceConfig } = require('@librechat/api');
|
||||
const { User, Balance } = require('@librechat/data-schemas').createModels(mongoose);
|
||||
require('module-alias')({ base: path.resolve(__dirname, '..', 'api') });
|
||||
const { askQuestion, silentExit } = require('./helpers');
|
||||
@@ -31,15 +31,10 @@ const connect = require('./connect');
|
||||
// console.purple(`[DEBUG] Args Length: ${process.argv.length}`);
|
||||
}
|
||||
|
||||
if (!process.env.CHECK_BALANCE) {
|
||||
const balanceConfig = getBalanceConfig();
|
||||
if (!balanceConfig?.enabled) {
|
||||
console.red(
|
||||
'Error: CHECK_BALANCE environment variable is not set! Configure it to use it: `CHECK_BALANCE=true`',
|
||||
);
|
||||
silentExit(1);
|
||||
}
|
||||
if (isEnabled(process.env.CHECK_BALANCE) === false) {
|
||||
console.red(
|
||||
'Error: CHECK_BALANCE environment variable is set to `false`! Please configure: `CHECK_BALANCE=true`',
|
||||
'Error: Balance is not enabled. Use librechat.yaml to enable it',
|
||||
);
|
||||
silentExit(1);
|
||||
}
|
||||
|
||||
@@ -269,6 +269,16 @@ export default [
|
||||
project: './packages/data-provider/tsconfig.json',
|
||||
},
|
||||
},
|
||||
rules: {
|
||||
'@typescript-eslint/no-unused-vars': [
|
||||
'warn',
|
||||
{
|
||||
argsIgnorePattern: '^_',
|
||||
varsIgnorePattern: '^_',
|
||||
caughtErrorsIgnorePattern: '^_',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
files: ['./api/demo/**/*.ts'],
|
||||
|
||||
423
package-lock.json
generated
423
package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -84,7 +84,7 @@
|
||||
"@azure/storage-blob": "^12.27.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/core": "^0.3.79",
|
||||
"@librechat/agents": "^3.0.32",
|
||||
"@librechat/agents": "^3.0.36",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.21.0",
|
||||
"axios": "^1.12.1",
|
||||
|
||||
@@ -16,9 +16,9 @@ import { resolveHeaders, createSafeUser } from '~/utils/env';
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
KnownEndpoints.ollama,
|
||||
]);
|
||||
|
||||
export function getReasoningKey(
|
||||
|
||||
@@ -394,6 +394,34 @@ describe('findOpenIDUser', () => {
|
||||
expect(mockFindUser).toHaveBeenCalledWith({ email: 'user@example.com' });
|
||||
});
|
||||
|
||||
it('should pass email to findUser for case-insensitive lookup (findUser handles normalization)', async () => {
|
||||
const mockUser: IUser = {
|
||||
_id: 'user123',
|
||||
provider: 'openid',
|
||||
openidId: 'openid_456',
|
||||
email: 'user@example.com',
|
||||
username: 'testuser',
|
||||
} as IUser;
|
||||
|
||||
mockFindUser
|
||||
.mockResolvedValueOnce(null) // Primary condition fails
|
||||
.mockResolvedValueOnce(mockUser); // Email search succeeds
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
openidId: 'openid_123',
|
||||
findUser: mockFindUser,
|
||||
email: 'User@Example.COM',
|
||||
});
|
||||
|
||||
/** Email is passed as-is; findUser implementation handles normalization */
|
||||
expect(mockFindUser).toHaveBeenNthCalledWith(2, { email: 'User@Example.COM' });
|
||||
expect(result).toEqual({
|
||||
user: mockUser,
|
||||
error: null,
|
||||
migration: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle findUser throwing an error', async () => {
|
||||
mockFindUser.mockRejectedValueOnce(new Error('Database error'));
|
||||
|
||||
|
||||
@@ -121,9 +121,12 @@ export function getSafetySettings(
|
||||
export function getGoogleConfig(
|
||||
credentials: string | t.GoogleCredentials | undefined,
|
||||
options: t.GoogleConfigOptions = {},
|
||||
acceptRawApiKey = false,
|
||||
) {
|
||||
let creds: t.GoogleCredentials = {};
|
||||
if (typeof credentials === 'string') {
|
||||
if (acceptRawApiKey && typeof credentials === 'string') {
|
||||
creds[AuthKeys.GOOGLE_API_KEY] = credentials;
|
||||
} else if (typeof credentials === 'string') {
|
||||
try {
|
||||
creds = JSON.parse(credentials);
|
||||
} catch (err: unknown) {
|
||||
|
||||
@@ -69,6 +69,26 @@ describe('getOpenAIConfig - Google Compatibility', () => {
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should filter out googleSearch when web_search is only in modelOptions (not explicitly in addParams/defaultParams)', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
web_search: true,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
/** googleSearch should be filtered out since web_search was not explicitly added via addParams or defaultParams */
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle web_search with mixed Google and OpenAI params in addParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
|
||||
@@ -26,7 +26,7 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
it('should apply model options', () => {
|
||||
const modelOptions = {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
max_tokens: 1000,
|
||||
};
|
||||
@@ -34,14 +34,11 @@ describe('getOpenAIConfig', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, { modelOptions });
|
||||
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
modelKwargs: {
|
||||
max_completion_tokens: 1000,
|
||||
},
|
||||
maxTokens: 1000,
|
||||
});
|
||||
expect((result.llmConfig as Record<string, unknown>).max_tokens).toBeUndefined();
|
||||
expect((result.llmConfig as Record<string, unknown>).maxTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should separate known and unknown params from addParams', () => {
|
||||
@@ -286,7 +283,7 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
it('should ignore non-boolean web_search values in addParams', () => {
|
||||
const modelOptions = {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
web_search: true,
|
||||
};
|
||||
|
||||
@@ -399,7 +396,7 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
it('should handle verbosity parameter in modelKwargs', () => {
|
||||
const modelOptions = {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
verbosity: Verbosity.high,
|
||||
};
|
||||
@@ -407,7 +404,7 @@ describe('getOpenAIConfig', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, { modelOptions });
|
||||
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
});
|
||||
expect(result.llmConfig.modelKwargs).toEqual({
|
||||
@@ -417,7 +414,7 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
it('should allow addParams to override verbosity in modelKwargs', () => {
|
||||
const modelOptions = {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
verbosity: Verbosity.low,
|
||||
};
|
||||
|
||||
@@ -451,7 +448,7 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
it('should nest verbosity under text when useResponsesApi is enabled', () => {
|
||||
const modelOptions = {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
verbosity: Verbosity.low,
|
||||
useResponsesApi: true,
|
||||
@@ -460,7 +457,7 @@ describe('getOpenAIConfig', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, { modelOptions });
|
||||
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
useResponsesApi: true,
|
||||
});
|
||||
@@ -496,7 +493,6 @@ describe('getOpenAIConfig', () => {
|
||||
it('should move maxTokens to modelKwargs.max_completion_tokens for GPT-5+ models', () => {
|
||||
const modelOptions = {
|
||||
model: 'gpt-5',
|
||||
temperature: 0.7,
|
||||
max_tokens: 2048,
|
||||
};
|
||||
|
||||
@@ -504,7 +500,6 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gpt-5',
|
||||
temperature: 0.7,
|
||||
});
|
||||
expect(result.llmConfig.maxTokens).toBeUndefined();
|
||||
expect(result.llmConfig.modelKwargs).toEqual({
|
||||
@@ -1684,7 +1679,7 @@ describe('getOpenAIConfig', () => {
|
||||
it('should not override existing modelOptions with defaultParams', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
@@ -1697,7 +1692,7 @@ describe('getOpenAIConfig', () => {
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.modelKwargs?.max_completion_tokens).toBe(1000);
|
||||
expect(result.llmConfig.maxTokens).toBe(1000);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams', () => {
|
||||
@@ -1845,7 +1840,7 @@ describe('getOpenAIConfig', () => {
|
||||
it('should preserve order: defaultParams < addParams < modelOptions', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
model: 'gpt-4',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
@@ -1863,7 +1858,7 @@ describe('getOpenAIConfig', () => {
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.topP).toBe(0.8);
|
||||
expect(result.llmConfig.modelKwargs?.max_completion_tokens).toBe(500);
|
||||
expect(result.llmConfig.maxTokens).toBe(500);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -77,23 +77,29 @@ export function getOpenAIConfig(
|
||||
headers = Object.assign(headers ?? {}, transformed.configOptions?.defaultHeaders);
|
||||
}
|
||||
} else if (isGoogle) {
|
||||
const googleResult = getGoogleConfig(apiKey, {
|
||||
modelOptions,
|
||||
reverseProxyUrl: baseURL ?? undefined,
|
||||
authHeader: true,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
});
|
||||
const googleResult = getGoogleConfig(
|
||||
apiKey,
|
||||
{
|
||||
modelOptions,
|
||||
reverseProxyUrl: baseURL ?? undefined,
|
||||
authHeader: true,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
},
|
||||
true,
|
||||
);
|
||||
/** Transform handles addParams/dropParams - it knows about OpenAI params */
|
||||
const transformed = transformToOpenAIConfig({
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
tools: googleResult.tools,
|
||||
llmConfig: googleResult.llmConfig,
|
||||
fromEndpoint: EModelEndpoint.google,
|
||||
});
|
||||
llmConfig = transformed.llmConfig;
|
||||
tools = googleResult.tools;
|
||||
tools = transformed.tools;
|
||||
} else {
|
||||
const openaiResult = getOpenAILLMConfig({
|
||||
azure,
|
||||
|
||||
602
packages/api/src/endpoints/openai/llm.spec.ts
Normal file
602
packages/api/src/endpoints/openai/llm.spec.ts
Normal file
@@ -0,0 +1,602 @@
|
||||
import {
|
||||
Verbosity,
|
||||
EModelEndpoint,
|
||||
ReasoningEffort,
|
||||
ReasoningSummary,
|
||||
} from 'librechat-data-provider';
|
||||
import { getOpenAILLMConfig, extractDefaultParams, applyDefaultParams } from './llm';
|
||||
import type * as t from '~/types';
|
||||
|
||||
describe('getOpenAILLMConfig', () => {
|
||||
describe('Basic Configuration', () => {
|
||||
it('should create a basic configuration with required fields', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
|
||||
expect(result.llmConfig).toHaveProperty('model', 'gpt-4');
|
||||
expect(result.llmConfig).toHaveProperty('streaming', true);
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle model options including temperature and penalties', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
frequency_penalty: 0.5,
|
||||
presence_penalty: 0.3,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
|
||||
expect(result.llmConfig).toHaveProperty('frequencyPenalty', 0.5);
|
||||
expect(result.llmConfig).toHaveProperty('presencePenalty', 0.3);
|
||||
});
|
||||
|
||||
it('should handle max_tokens conversion to maxTokens', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
max_tokens: 4096,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
|
||||
expect(result.llmConfig).not.toHaveProperty('max_tokens');
|
||||
});
|
||||
});
|
||||
|
||||
describe('OpenAI Reasoning Models (o1/o3/gpt-5)', () => {
|
||||
const reasoningModels = [
|
||||
'o1',
|
||||
'o1-mini',
|
||||
'o1-preview',
|
||||
'o1-pro',
|
||||
'o3',
|
||||
'o3-mini',
|
||||
'gpt-5',
|
||||
'gpt-5-pro',
|
||||
'gpt-5-turbo',
|
||||
];
|
||||
|
||||
const excludedParams = [
|
||||
'frequencyPenalty',
|
||||
'presencePenalty',
|
||||
'temperature',
|
||||
'topP',
|
||||
'logitBias',
|
||||
'n',
|
||||
'logprobs',
|
||||
];
|
||||
|
||||
it.each(reasoningModels)(
|
||||
'should exclude unsupported parameters for reasoning model: %s',
|
||||
(model) => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model,
|
||||
temperature: 0.7,
|
||||
frequency_penalty: 0.5,
|
||||
presence_penalty: 0.3,
|
||||
topP: 0.9,
|
||||
logitBias: { '50256': -100 },
|
||||
n: 2,
|
||||
logprobs: true,
|
||||
} as Partial<t.OpenAIParameters>,
|
||||
});
|
||||
|
||||
excludedParams.forEach((param) => {
|
||||
expect(result.llmConfig).not.toHaveProperty(param);
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('model', model);
|
||||
expect(result.llmConfig).toHaveProperty('streaming', true);
|
||||
},
|
||||
);
|
||||
|
||||
it('should preserve maxTokens for reasoning models', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'o1',
|
||||
max_tokens: 4096,
|
||||
temperature: 0.7,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
});
|
||||
|
||||
it('should preserve other valid parameters for reasoning models', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'o1',
|
||||
max_tokens: 8192,
|
||||
stop: ['END'],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 8192);
|
||||
expect(result.llmConfig).toHaveProperty('stop', ['END']);
|
||||
});
|
||||
|
||||
it('should handle GPT-5 max_tokens conversion to max_completion_tokens', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
max_tokens: 8192,
|
||||
stop: ['END'],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('max_completion_tokens', 8192);
|
||||
expect(result.llmConfig).not.toHaveProperty('maxTokens');
|
||||
expect(result.llmConfig).toHaveProperty('stop', ['END']);
|
||||
});
|
||||
|
||||
it('should combine user dropParams with reasoning exclusion params', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'o3-mini',
|
||||
temperature: 0.7,
|
||||
stop: ['END'],
|
||||
},
|
||||
dropParams: ['stop'],
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
expect(result.llmConfig).not.toHaveProperty('stop');
|
||||
});
|
||||
|
||||
it('should NOT exclude parameters for non-reasoning models', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4-turbo',
|
||||
temperature: 0.7,
|
||||
frequency_penalty: 0.5,
|
||||
presence_penalty: 0.3,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
|
||||
expect(result.llmConfig).toHaveProperty('frequencyPenalty', 0.5);
|
||||
expect(result.llmConfig).toHaveProperty('presencePenalty', 0.3);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should NOT exclude parameters for gpt-5.x versioned models (they support sampling params)', () => {
|
||||
const versionedModels = ['gpt-5.1', 'gpt-5.1-turbo', 'gpt-5.2', 'gpt-5.5-preview'];
|
||||
|
||||
versionedModels.forEach((model) => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model,
|
||||
temperature: 0.7,
|
||||
frequency_penalty: 0.5,
|
||||
presence_penalty: 0.3,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
|
||||
expect(result.llmConfig).toHaveProperty('frequencyPenalty', 0.5);
|
||||
expect(result.llmConfig).toHaveProperty('presencePenalty', 0.3);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
});
|
||||
|
||||
it('should NOT exclude parameters for gpt-5-chat (it supports sampling params)', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-5-chat',
|
||||
temperature: 0.7,
|
||||
frequency_penalty: 0.5,
|
||||
presence_penalty: 0.3,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
|
||||
expect(result.llmConfig).toHaveProperty('frequencyPenalty', 0.5);
|
||||
expect(result.llmConfig).toHaveProperty('presencePenalty', 0.3);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should handle reasoning models with reasoning_effort parameter', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
modelOptions: {
|
||||
model: 'o1',
|
||||
reasoning_effort: ReasoningEffort.high,
|
||||
temperature: 0.7,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('reasoning_effort', ReasoningEffort.high);
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
});
|
||||
});
|
||||
|
||||
describe('OpenAI Web Search Models', () => {
|
||||
it('should exclude parameters for gpt-4o search models', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4o-search-preview',
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
seed: 42,
|
||||
} as Partial<t.OpenAIParameters>,
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
expect(result.llmConfig).not.toHaveProperty('top_p');
|
||||
expect(result.llmConfig).not.toHaveProperty('seed');
|
||||
});
|
||||
|
||||
it('should preserve max_tokens for search models', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4o-search',
|
||||
max_tokens: 4096,
|
||||
temperature: 0.7,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Web Search Functionality', () => {
|
||||
it('should enable web search with Responses API', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
web_search: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('useResponsesApi', true);
|
||||
expect(result.tools).toContainEqual({ type: 'web_search' });
|
||||
});
|
||||
|
||||
it('should handle web search with OpenRouter', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
useOpenRouter: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
web_search: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('plugins', [{ id: 'web' }]);
|
||||
expect(result.llmConfig).toHaveProperty('include_reasoning', true);
|
||||
});
|
||||
|
||||
it('should disable web search via dropParams', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
web_search: true,
|
||||
},
|
||||
dropParams: ['web_search'],
|
||||
});
|
||||
|
||||
expect(result.tools).not.toContainEqual({ type: 'web_search' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('GPT-5 max_tokens Handling', () => {
|
||||
it('should convert maxTokens to max_completion_tokens for GPT-5 models', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
max_tokens: 8192,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('max_completion_tokens', 8192);
|
||||
expect(result.llmConfig).not.toHaveProperty('maxTokens');
|
||||
});
|
||||
|
||||
it('should convert maxTokens to max_output_tokens for GPT-5 with Responses API', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
max_tokens: 8192,
|
||||
},
|
||||
addParams: {
|
||||
useResponsesApi: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('max_output_tokens', 8192);
|
||||
expect(result.llmConfig).not.toHaveProperty('maxTokens');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Reasoning Parameters', () => {
|
||||
it('should handle reasoning_effort for OpenAI endpoint', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
modelOptions: {
|
||||
model: 'o1',
|
||||
reasoning_effort: ReasoningEffort.high,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('reasoning_effort', ReasoningEffort.high);
|
||||
});
|
||||
|
||||
it('should use reasoning object for non-OpenAI endpoints', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
endpoint: 'custom',
|
||||
modelOptions: {
|
||||
model: 'o1',
|
||||
reasoning_effort: ReasoningEffort.high,
|
||||
reasoning_summary: ReasoningSummary.concise,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('reasoning');
|
||||
expect(result.llmConfig.reasoning).toEqual({
|
||||
effort: ReasoningEffort.high,
|
||||
summary: ReasoningSummary.concise,
|
||||
});
|
||||
});
|
||||
|
||||
it('should use reasoning object when useResponsesApi is true', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
modelOptions: {
|
||||
model: 'o1',
|
||||
reasoning_effort: ReasoningEffort.medium,
|
||||
reasoning_summary: ReasoningSummary.detailed,
|
||||
},
|
||||
addParams: {
|
||||
useResponsesApi: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('reasoning');
|
||||
expect(result.llmConfig.reasoning).toEqual({
|
||||
effort: ReasoningEffort.medium,
|
||||
summary: ReasoningSummary.detailed,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Default and Add Parameters', () => {
|
||||
it('should apply default parameters when fields are undefined', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
defaultParams: {
|
||||
temperature: 0.5,
|
||||
topP: 0.9,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.5);
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should NOT override existing values with default parameters', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
temperature: 0.8,
|
||||
},
|
||||
defaultParams: {
|
||||
temperature: 0.5,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.8);
|
||||
});
|
||||
|
||||
it('should apply addParams and override defaults', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
defaultParams: {
|
||||
temperature: 0.5,
|
||||
},
|
||||
addParams: {
|
||||
temperature: 0.9,
|
||||
seed: 42,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('temperature', 0.9);
|
||||
expect(result.llmConfig).toHaveProperty('seed', 42);
|
||||
});
|
||||
|
||||
it('should handle unknown params via modelKwargs', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
addParams: {
|
||||
custom_param: 'custom_value',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('custom_param', 'custom_value');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Drop Parameters', () => {
|
||||
it('should drop specified parameters', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
temperature: 0.7,
|
||||
topP: 0.9,
|
||||
},
|
||||
dropParams: ['temperature'],
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('temperature');
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
});
|
||||
|
||||
describe('OpenRouter Configuration', () => {
|
||||
it('should include include_reasoning for OpenRouter', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
useOpenRouter: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('include_reasoning', true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Verbosity Handling', () => {
|
||||
it('should add verbosity to modelKwargs', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
verbosity: Verbosity.high,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('verbosity', Verbosity.high);
|
||||
});
|
||||
|
||||
it('should convert verbosity to text object with Responses API', () => {
|
||||
const result = getOpenAILLMConfig({
|
||||
apiKey: 'test-api-key',
|
||||
streaming: true,
|
||||
modelOptions: {
|
||||
model: 'gpt-4',
|
||||
verbosity: Verbosity.low,
|
||||
},
|
||||
addParams: {
|
||||
useResponsesApi: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toHaveProperty('text', { verbosity: Verbosity.low });
|
||||
expect(result.llmConfig.modelKwargs).not.toHaveProperty('verbosity');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractDefaultParams', () => {
|
||||
it('should extract default values from param definitions', () => {
|
||||
const paramDefinitions = [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'maxTokens', default: 4096 },
|
||||
{ key: 'noDefault' },
|
||||
];
|
||||
|
||||
const result = extractDefaultParams(paramDefinitions);
|
||||
|
||||
expect(result).toEqual({
|
||||
temperature: 0.7,
|
||||
maxTokens: 4096,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return undefined for undefined or non-array input', () => {
|
||||
expect(extractDefaultParams(undefined)).toBeUndefined();
|
||||
expect(extractDefaultParams(null as unknown as undefined)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle empty array', () => {
|
||||
const result = extractDefaultParams([]);
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('applyDefaultParams', () => {
|
||||
it('should apply defaults only when field is undefined', () => {
|
||||
const target: Record<string, unknown> = {
|
||||
temperature: 0.8,
|
||||
maxTokens: undefined,
|
||||
};
|
||||
|
||||
const defaults = {
|
||||
temperature: 0.5,
|
||||
maxTokens: 4096,
|
||||
topP: 0.9,
|
||||
};
|
||||
|
||||
applyDefaultParams(target, defaults);
|
||||
|
||||
expect(target).toEqual({
|
||||
temperature: 0.8,
|
||||
maxTokens: 4096,
|
||||
topP: 0.9,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -259,9 +259,35 @@ export function getOpenAILLMConfig({
|
||||
}
|
||||
|
||||
/**
|
||||
* Note: OpenAI Web Search models do not support any known parameters besides `max_tokens`
|
||||
* Note: OpenAI reasoning models (o1/o3/gpt-5) do not support temperature and other sampling parameters
|
||||
* Exception: gpt-5-chat and versioned models like gpt-5.1 DO support these parameters
|
||||
*/
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model as string)) {
|
||||
if (
|
||||
modelOptions.model &&
|
||||
/\b(o[13]|gpt-5)(?!\.|-chat)(?:-|$)/.test(modelOptions.model as string)
|
||||
) {
|
||||
const reasoningExcludeParams = [
|
||||
'frequencyPenalty',
|
||||
'presencePenalty',
|
||||
'temperature',
|
||||
'topP',
|
||||
'logitBias',
|
||||
'n',
|
||||
'logprobs',
|
||||
];
|
||||
|
||||
const updatedDropParams = dropParams || [];
|
||||
const combinedDropParams = [...new Set([...updatedDropParams, ...reasoningExcludeParams])];
|
||||
|
||||
combinedDropParams.forEach((param) => {
|
||||
if (param in llmConfig) {
|
||||
delete llmConfig[param as keyof t.OAIClientOptions];
|
||||
}
|
||||
});
|
||||
} else if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model as string)) {
|
||||
/**
|
||||
* Note: OpenAI Web Search models do not support any known parameters besides `max_tokens`
|
||||
*/
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
|
||||
@@ -1,28 +1,48 @@
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { GoogleAIToolType } from '@langchain/google-common';
|
||||
import type { ClientOptions } from '@librechat/agents';
|
||||
import type * as t from '~/types';
|
||||
import { knownOpenAIParams } from './llm';
|
||||
|
||||
const anthropicExcludeParams = new Set(['anthropicApiUrl']);
|
||||
const googleExcludeParams = new Set(['safetySettings', 'location', 'baseUrl', 'customHeaders']);
|
||||
const googleExcludeParams = new Set([
|
||||
'safetySettings',
|
||||
'location',
|
||||
'baseUrl',
|
||||
'customHeaders',
|
||||
'thinkingConfig',
|
||||
'thinkingBudget',
|
||||
'includeThoughts',
|
||||
]);
|
||||
|
||||
/** Google-specific tool types that have no OpenAI-compatible equivalent */
|
||||
const googleToolsToFilter = new Set(['googleSearch']);
|
||||
|
||||
export type ConfigTools = Array<Record<string, unknown>> | Array<GoogleAIToolType>;
|
||||
|
||||
/**
|
||||
* Transforms a Non-OpenAI LLM config to an OpenAI-conformant config.
|
||||
* Non-OpenAI parameters are moved to modelKwargs.
|
||||
* Also extracts configuration options that belong in configOptions.
|
||||
* Handles addParams and dropParams for parameter customization.
|
||||
* Filters out provider-specific tools that have no OpenAI equivalent.
|
||||
*/
|
||||
export function transformToOpenAIConfig({
|
||||
tools,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
llmConfig,
|
||||
fromEndpoint,
|
||||
}: {
|
||||
tools?: ConfigTools;
|
||||
addParams?: Record<string, unknown>;
|
||||
dropParams?: string[];
|
||||
defaultParams?: Record<string, unknown>;
|
||||
llmConfig: ClientOptions;
|
||||
fromEndpoint: string;
|
||||
}): {
|
||||
tools: ConfigTools;
|
||||
llmConfig: t.OAIClientOptions;
|
||||
configOptions: Partial<t.OpenAIConfiguration>;
|
||||
} {
|
||||
@@ -58,18 +78,9 @@ export function transformToOpenAIConfig({
|
||||
hasModelKwargs = true;
|
||||
continue;
|
||||
} else if (isGoogle && key === 'authOptions') {
|
||||
// Handle Google authOptions
|
||||
modelKwargs = Object.assign({}, modelKwargs, value as Record<string, unknown>);
|
||||
hasModelKwargs = true;
|
||||
continue;
|
||||
} else if (
|
||||
isGoogle &&
|
||||
(key === 'thinkingConfig' || key === 'thinkingBudget' || key === 'includeThoughts')
|
||||
) {
|
||||
// Handle Google thinking configuration
|
||||
modelKwargs = Object.assign({}, modelKwargs, { [key]: value });
|
||||
hasModelKwargs = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownOpenAIParams.has(key)) {
|
||||
@@ -121,7 +132,34 @@ export function transformToOpenAIConfig({
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter out provider-specific tools that have no OpenAI equivalent.
|
||||
* Exception: If web_search was explicitly enabled via addParams or defaultParams,
|
||||
* preserve googleSearch tools (pass through in Google-native format).
|
||||
*/
|
||||
const webSearchExplicitlyEnabled =
|
||||
addParams?.web_search === true || defaultParams?.web_search === true;
|
||||
|
||||
const filterGoogleTool = (tool: unknown): boolean => {
|
||||
if (!isGoogle) {
|
||||
return true;
|
||||
}
|
||||
if (typeof tool !== 'object' || tool === null) {
|
||||
return false;
|
||||
}
|
||||
const toolKeys = Object.keys(tool as Record<string, unknown>);
|
||||
const isGoogleSpecificTool = toolKeys.some((key) => googleToolsToFilter.has(key));
|
||||
/** Preserve googleSearch if web_search was explicitly enabled */
|
||||
if (isGoogleSpecificTool && webSearchExplicitlyEnabled) {
|
||||
return true;
|
||||
}
|
||||
return !isGoogleSpecificTool;
|
||||
};
|
||||
|
||||
const filteredTools = Array.isArray(tools) ? tools.filter(filterGoogleTool) : [];
|
||||
|
||||
return {
|
||||
tools: filteredTools,
|
||||
llmConfig: openAIConfig as t.OAIClientOptions,
|
||||
configOptions,
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
export * from './access';
|
||||
export * from './error';
|
||||
export * from './balance';
|
||||
export * from './json';
|
||||
|
||||
158
packages/api/src/middleware/json.spec.ts
Normal file
158
packages/api/src/middleware/json.spec.ts
Normal file
@@ -0,0 +1,158 @@
|
||||
import { handleJsonParseError } from './json';
|
||||
import type { Request, Response, NextFunction } from 'express';
|
||||
|
||||
describe('handleJsonParseError', () => {
|
||||
let req: Partial<Request>;
|
||||
let res: Partial<Response>;
|
||||
let next: NextFunction;
|
||||
let jsonSpy: jest.Mock;
|
||||
let statusSpy: jest.Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
req = {
|
||||
path: '/api/test',
|
||||
method: 'POST',
|
||||
ip: '127.0.0.1',
|
||||
};
|
||||
|
||||
jsonSpy = jest.fn();
|
||||
statusSpy = jest.fn().mockReturnValue({ json: jsonSpy });
|
||||
|
||||
res = {
|
||||
status: statusSpy,
|
||||
json: jsonSpy,
|
||||
};
|
||||
|
||||
next = jest.fn();
|
||||
});
|
||||
|
||||
describe('JSON parse errors', () => {
|
||||
it('should handle JSON SyntaxError with 400 status', () => {
|
||||
const err = new SyntaxError('Unexpected token < in JSON at position 0') as SyntaxError & {
|
||||
status?: number;
|
||||
body?: unknown;
|
||||
};
|
||||
err.status = 400;
|
||||
err.body = {};
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(statusSpy).toHaveBeenCalledWith(400);
|
||||
expect(jsonSpy).toHaveBeenCalledWith({
|
||||
error: 'Invalid JSON format',
|
||||
message: 'The request body contains malformed JSON',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not reflect user input in error message', () => {
|
||||
const maliciousInput = '<script>alert("xss")</script>';
|
||||
const err = new SyntaxError(
|
||||
`Unexpected token < in JSON at position 0: ${maliciousInput}`,
|
||||
) as SyntaxError & {
|
||||
status?: number;
|
||||
body?: unknown;
|
||||
};
|
||||
err.status = 400;
|
||||
err.body = maliciousInput;
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(statusSpy).toHaveBeenCalledWith(400);
|
||||
const errorResponse = jsonSpy.mock.calls[0][0];
|
||||
expect(errorResponse.message).not.toContain(maliciousInput);
|
||||
expect(errorResponse.message).toBe('The request body contains malformed JSON');
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle JSON parse error with HTML tags in body', () => {
|
||||
const err = new SyntaxError('Invalid JSON') as SyntaxError & {
|
||||
status?: number;
|
||||
body?: unknown;
|
||||
};
|
||||
err.status = 400;
|
||||
err.body = '<html><body><h1>XSS</h1></body></html>';
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(statusSpy).toHaveBeenCalledWith(400);
|
||||
const errorResponse = jsonSpy.mock.calls[0][0];
|
||||
expect(errorResponse.message).not.toContain('<html>');
|
||||
expect(errorResponse.message).not.toContain('<script>');
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('non-JSON errors', () => {
|
||||
it('should pass through non-SyntaxError errors', () => {
|
||||
const err = new Error('Some other error');
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(next).toHaveBeenCalledWith(err);
|
||||
expect(statusSpy).not.toHaveBeenCalled();
|
||||
expect(jsonSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass through SyntaxError without status 400', () => {
|
||||
const err = new SyntaxError('Some syntax error') as SyntaxError & { status?: number };
|
||||
err.status = 500;
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(next).toHaveBeenCalledWith(err);
|
||||
expect(statusSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass through SyntaxError without body property', () => {
|
||||
const err = new SyntaxError('Some syntax error') as SyntaxError & { status?: number };
|
||||
err.status = 400;
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(next).toHaveBeenCalledWith(err);
|
||||
expect(statusSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass through TypeError', () => {
|
||||
const err = new TypeError('Type error');
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
expect(next).toHaveBeenCalledWith(err);
|
||||
expect(statusSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('security verification', () => {
|
||||
it('should return generic error message for all JSON parse errors', () => {
|
||||
const testCases = [
|
||||
'Unexpected token < in JSON',
|
||||
'Unexpected end of JSON input',
|
||||
'Invalid or unexpected token',
|
||||
'<script>alert(1)</script>',
|
||||
'"><img src=x onerror=alert(1)>',
|
||||
];
|
||||
|
||||
testCases.forEach((errorMsg) => {
|
||||
const err = new SyntaxError(errorMsg) as SyntaxError & {
|
||||
status?: number;
|
||||
body?: unknown;
|
||||
};
|
||||
err.status = 400;
|
||||
err.body = errorMsg;
|
||||
|
||||
jsonSpy.mockClear();
|
||||
statusSpy.mockClear();
|
||||
(next as jest.Mock).mockClear();
|
||||
|
||||
handleJsonParseError(err, req as Request, res as Response, next);
|
||||
|
||||
const errorResponse = jsonSpy.mock.calls[0][0];
|
||||
// Verify the generic message is always returned, not the user input
|
||||
expect(errorResponse.message).toBe('The request body contains malformed JSON');
|
||||
expect(errorResponse.error).toBe('Invalid JSON format');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
40
packages/api/src/middleware/json.ts
Normal file
40
packages/api/src/middleware/json.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { Request, Response, NextFunction } from 'express';
|
||||
|
||||
/**
|
||||
* Middleware to handle JSON parsing errors from express.json()
|
||||
* Prevents user input from being reflected in error messages (XSS prevention)
|
||||
*
|
||||
* This middleware should be placed immediately after express.json() middleware.
|
||||
*
|
||||
* @param err - Error object from express.json()
|
||||
* @param req - Express request object
|
||||
* @param res - Express response object
|
||||
* @param next - Express next function
|
||||
*
|
||||
* @example
|
||||
* app.use(express.json({ limit: '3mb' }));
|
||||
* app.use(handleJsonParseError);
|
||||
*/
|
||||
export function handleJsonParseError(
|
||||
err: Error & { status?: number; body?: unknown },
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction,
|
||||
): void {
|
||||
if (err instanceof SyntaxError && err.status === 400 && 'body' in err) {
|
||||
logger.warn('[JSON Parse Error] Invalid JSON received', {
|
||||
path: req.path,
|
||||
method: req.method,
|
||||
ip: req.ip,
|
||||
});
|
||||
|
||||
res.status(400).json({
|
||||
error: 'Invalid JSON format',
|
||||
message: 'The request body contains malformed JSON',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
next(err);
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import { SystemCategories } from 'librechat-data-provider';
|
||||
import type { IPromptGroupDocument as IPromptGroup } from '@librechat/data-schemas';
|
||||
import type { Types } from 'mongoose';
|
||||
import type { PromptGroupsListResponse } from '~/types';
|
||||
import { escapeRegExp } from '~/utils/common';
|
||||
|
||||
/**
|
||||
* Formats prompt groups for the paginated /groups endpoint response
|
||||
@@ -101,7 +102,6 @@ export function buildPromptGroupFilter({
|
||||
|
||||
// Handle name filter - convert to regex for case-insensitive search
|
||||
if (name) {
|
||||
const escapeRegExp = (str: string) => str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
filter.name = new RegExp(escapeRegExp(name), 'i');
|
||||
}
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
export * from './format';
|
||||
export * from './migration';
|
||||
export * from './schemas';
|
||||
|
||||
222
packages/api/src/prompts/schemas.spec.ts
Normal file
222
packages/api/src/prompts/schemas.spec.ts
Normal file
@@ -0,0 +1,222 @@
|
||||
import {
|
||||
updatePromptGroupSchema,
|
||||
validatePromptGroupUpdate,
|
||||
safeValidatePromptGroupUpdate,
|
||||
} from './schemas';
|
||||
|
||||
describe('updatePromptGroupSchema', () => {
|
||||
describe('allowed fields', () => {
|
||||
it('should accept valid name field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ name: 'Test Group' });
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.name).toBe('Test Group');
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept valid oneliner field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ oneliner: 'A short description' });
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.oneliner).toBe('A short description');
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept valid category field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ category: 'testing' });
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.category).toBe('testing');
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept valid projectIds array', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
projectIds: ['proj1', 'proj2'],
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.projectIds).toEqual(['proj1', 'proj2']);
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept valid removeProjectIds array', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
removeProjectIds: ['proj1'],
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.removeProjectIds).toEqual(['proj1']);
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept valid command field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ command: 'my-command-123' });
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.command).toBe('my-command-123');
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept null command field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ command: null });
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data.command).toBeNull();
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept multiple valid fields', () => {
|
||||
const input = {
|
||||
name: 'Updated Name',
|
||||
category: 'new-category',
|
||||
oneliner: 'New description',
|
||||
};
|
||||
const result = updatePromptGroupSchema.safeParse(input);
|
||||
expect(result.success).toBe(true);
|
||||
if (result.success) {
|
||||
expect(result.data).toEqual(input);
|
||||
}
|
||||
});
|
||||
|
||||
it('should accept empty object', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('security - strips sensitive fields', () => {
|
||||
it('should reject author field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
author: '507f1f77bcf86cd799439011',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject authorName field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
authorName: 'Malicious Author',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject _id field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
_id: '507f1f77bcf86cd799439011',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject productionId field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
productionId: '507f1f77bcf86cd799439011',
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject createdAt field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
createdAt: new Date().toISOString(),
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject updatedAt field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
updatedAt: new Date().toISOString(),
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject __v field', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Test',
|
||||
__v: 999,
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject multiple sensitive fields in a single request', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({
|
||||
name: 'Legit Name',
|
||||
author: '507f1f77bcf86cd799439011',
|
||||
authorName: 'Hacker',
|
||||
_id: 'newid123',
|
||||
productionId: 'prodid456',
|
||||
createdAt: '2020-01-01T00:00:00.000Z',
|
||||
__v: 999,
|
||||
});
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validation rules', () => {
|
||||
it('should reject empty name', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ name: '' });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject name exceeding max length', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ name: 'a'.repeat(256) });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject oneliner exceeding max length', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ oneliner: 'a'.repeat(501) });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject category exceeding max length', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ category: 'a'.repeat(101) });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject command with invalid characters (uppercase)', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ command: 'MyCommand' });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject command with invalid characters (spaces)', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ command: 'my command' });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should reject command with invalid characters (special)', () => {
|
||||
const result = updatePromptGroupSchema.safeParse({ command: 'my_command!' });
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('validatePromptGroupUpdate', () => {
|
||||
it('should return validated data for valid input', () => {
|
||||
const input = { name: 'Test', category: 'testing' };
|
||||
const result = validatePromptGroupUpdate(input);
|
||||
expect(result).toEqual(input);
|
||||
});
|
||||
|
||||
it('should throw ZodError for invalid input', () => {
|
||||
expect(() => validatePromptGroupUpdate({ author: 'malicious-id' })).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeValidatePromptGroupUpdate', () => {
|
||||
it('should return success true for valid input', () => {
|
||||
const result = safeValidatePromptGroupUpdate({ name: 'Test' });
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should return success false for invalid input with errors', () => {
|
||||
const result = safeValidatePromptGroupUpdate({ author: 'malicious-id' });
|
||||
expect(result.success).toBe(false);
|
||||
if (!result.success) {
|
||||
expect(result.error.errors.length).toBeGreaterThan(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
53
packages/api/src/prompts/schemas.ts
Normal file
53
packages/api/src/prompts/schemas.ts
Normal file
@@ -0,0 +1,53 @@
|
||||
import { z } from 'zod';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
|
||||
/**
|
||||
* Schema for validating prompt group update payloads.
|
||||
* Only allows fields that users should be able to modify.
|
||||
* Sensitive fields like author, authorName, _id, productionId, etc. are excluded.
|
||||
*/
|
||||
export const updatePromptGroupSchema = z
|
||||
.object({
|
||||
/** The name of the prompt group */
|
||||
name: z.string().min(1).max(255).optional(),
|
||||
/** Short description/oneliner for the prompt group */
|
||||
oneliner: z.string().max(500).optional(),
|
||||
/** Category for organizing prompt groups */
|
||||
category: z.string().max(100).optional(),
|
||||
/** Project IDs to add for sharing */
|
||||
projectIds: z.array(z.string()).optional(),
|
||||
/** Project IDs to remove from sharing */
|
||||
removeProjectIds: z.array(z.string()).optional(),
|
||||
/** Command shortcut for the prompt group */
|
||||
command: z
|
||||
.string()
|
||||
.max(Constants.COMMANDS_MAX_LENGTH as number)
|
||||
.regex(/^[a-z0-9-]*$/, {
|
||||
message: 'Command must only contain lowercase alphanumeric characters and hyphens',
|
||||
})
|
||||
.optional()
|
||||
.nullable(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
export type TUpdatePromptGroupSchema = z.infer<typeof updatePromptGroupSchema>;
|
||||
|
||||
/**
|
||||
* Validates and sanitizes a prompt group update payload.
|
||||
* Returns only the allowed fields, stripping any sensitive fields.
|
||||
* @param data - The raw request body to validate
|
||||
* @returns The validated and sanitized payload
|
||||
* @throws ZodError if validation fails
|
||||
*/
|
||||
export function validatePromptGroupUpdate(data: unknown): TUpdatePromptGroupSchema {
|
||||
return updatePromptGroupSchema.parse(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely validates a prompt group update payload without throwing.
|
||||
* @param data - The raw request body to validate
|
||||
* @returns A SafeParseResult with either the validated data or validation errors
|
||||
*/
|
||||
export function safeValidatePromptGroupUpdate(data: unknown) {
|
||||
return updatePromptGroupSchema.safeParse(data);
|
||||
}
|
||||
@@ -48,3 +48,12 @@ export function optionalChainWithEmptyCheck(
|
||||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes special characters in a string for use in a regular expression.
|
||||
* @param str - The string to escape.
|
||||
* @returns The escaped string safe for use in RegExp.
|
||||
*/
|
||||
export function escapeRegExp(str: string): string {
|
||||
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
@@ -17,7 +17,8 @@ export * from './promise';
|
||||
export * from './sanitizeTitle';
|
||||
export * from './tempChatRetention';
|
||||
export * from './text';
|
||||
export { default as Tokenizer } from './tokenizer';
|
||||
export { default as Tokenizer, countTokens } from './tokenizer';
|
||||
export * from './yaml';
|
||||
export * from './http';
|
||||
export * from './tokens';
|
||||
export * from './message';
|
||||
|
||||
122
packages/api/src/utils/message.spec.ts
Normal file
122
packages/api/src/utils/message.spec.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import { sanitizeFileForTransmit, sanitizeMessageForTransmit } from './message';
|
||||
|
||||
describe('sanitizeFileForTransmit', () => {
|
||||
it('should remove text field from file', () => {
|
||||
const file = {
|
||||
file_id: 'test-123',
|
||||
filename: 'test.txt',
|
||||
text: 'This is a very long text content that should be stripped',
|
||||
bytes: 1000,
|
||||
};
|
||||
|
||||
const result = sanitizeFileForTransmit(file);
|
||||
|
||||
expect(result.file_id).toBe('test-123');
|
||||
expect(result.filename).toBe('test.txt');
|
||||
expect(result.bytes).toBe(1000);
|
||||
expect(result).not.toHaveProperty('text');
|
||||
});
|
||||
|
||||
it('should remove _id and __v fields', () => {
|
||||
const file = {
|
||||
file_id: 'test-123',
|
||||
_id: 'mongo-id',
|
||||
__v: 0,
|
||||
filename: 'test.txt',
|
||||
};
|
||||
|
||||
const result = sanitizeFileForTransmit(file);
|
||||
|
||||
expect(result.file_id).toBe('test-123');
|
||||
expect(result).not.toHaveProperty('_id');
|
||||
expect(result).not.toHaveProperty('__v');
|
||||
});
|
||||
|
||||
it('should not modify original file object', () => {
|
||||
const file = {
|
||||
file_id: 'test-123',
|
||||
text: 'original text',
|
||||
};
|
||||
|
||||
sanitizeFileForTransmit(file);
|
||||
|
||||
expect(file.text).toBe('original text');
|
||||
});
|
||||
});
|
||||
|
||||
describe('sanitizeMessageForTransmit', () => {
|
||||
it('should remove fileContext from message', () => {
|
||||
const message = {
|
||||
messageId: 'msg-123',
|
||||
text: 'Hello world',
|
||||
fileContext: 'This is a very long context that should be stripped',
|
||||
};
|
||||
|
||||
const result = sanitizeMessageForTransmit(message);
|
||||
|
||||
expect(result.messageId).toBe('msg-123');
|
||||
expect(result.text).toBe('Hello world');
|
||||
expect(result).not.toHaveProperty('fileContext');
|
||||
});
|
||||
|
||||
it('should sanitize files array', () => {
|
||||
const message = {
|
||||
messageId: 'msg-123',
|
||||
files: [
|
||||
{ file_id: 'file-1', text: 'long text 1', filename: 'a.txt' },
|
||||
{ file_id: 'file-2', text: 'long text 2', filename: 'b.txt' },
|
||||
],
|
||||
};
|
||||
|
||||
const result = sanitizeMessageForTransmit(message);
|
||||
|
||||
expect(result.files).toHaveLength(2);
|
||||
expect(result.files?.[0].file_id).toBe('file-1');
|
||||
expect(result.files?.[0].filename).toBe('a.txt');
|
||||
expect(result.files?.[0]).not.toHaveProperty('text');
|
||||
expect(result.files?.[1]).not.toHaveProperty('text');
|
||||
});
|
||||
|
||||
it('should handle null/undefined message', () => {
|
||||
expect(sanitizeMessageForTransmit(null as unknown as object)).toBeNull();
|
||||
expect(sanitizeMessageForTransmit(undefined as unknown as object)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle message without files', () => {
|
||||
const message = {
|
||||
messageId: 'msg-123',
|
||||
text: 'Hello',
|
||||
};
|
||||
|
||||
const result = sanitizeMessageForTransmit(message);
|
||||
|
||||
expect(result.messageId).toBe('msg-123');
|
||||
expect(result.text).toBe('Hello');
|
||||
});
|
||||
|
||||
it('should create new array reference for empty files array (immutability)', () => {
|
||||
const message = {
|
||||
messageId: 'msg-123',
|
||||
files: [] as { file_id: string }[],
|
||||
};
|
||||
|
||||
const result = sanitizeMessageForTransmit(message);
|
||||
|
||||
expect(result.files).toEqual([]);
|
||||
// New array reference ensures full immutability even for empty arrays
|
||||
expect(result.files).not.toBe(message.files);
|
||||
});
|
||||
|
||||
it('should not modify original message object', () => {
|
||||
const message = {
|
||||
messageId: 'msg-123',
|
||||
fileContext: 'original context',
|
||||
files: [{ file_id: 'file-1', text: 'original text' }],
|
||||
};
|
||||
|
||||
sanitizeMessageForTransmit(message);
|
||||
|
||||
expect(message.fileContext).toBe('original context');
|
||||
expect(message.files[0].text).toBe('original text');
|
||||
});
|
||||
});
|
||||
68
packages/api/src/utils/message.ts
Normal file
68
packages/api/src/utils/message.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import type { TFile, TMessage } from 'librechat-data-provider';
|
||||
|
||||
/** Fields to strip from files before client transmission */
|
||||
const FILE_STRIP_FIELDS = ['text', '_id', '__v'] as const;
|
||||
|
||||
/** Fields to strip from messages before client transmission */
|
||||
const MESSAGE_STRIP_FIELDS = ['fileContext'] as const;
|
||||
|
||||
/**
|
||||
* Strips large/unnecessary fields from a file object before transmitting to client.
|
||||
* Use this within existing loops when building file arrays to avoid extra iterations.
|
||||
*
|
||||
* @param file - The file object to sanitize
|
||||
* @returns A new file object without the stripped fields
|
||||
*
|
||||
* @example
|
||||
* // Use in existing file processing loop:
|
||||
* for (const attachment of client.options.attachments) {
|
||||
* if (messageFiles.has(attachment.file_id)) {
|
||||
* userMessage.files.push(sanitizeFileForTransmit(attachment));
|
||||
* }
|
||||
* }
|
||||
*/
|
||||
export function sanitizeFileForTransmit<T extends Partial<TFile>>(
|
||||
file: T,
|
||||
): Omit<T, (typeof FILE_STRIP_FIELDS)[number]> {
|
||||
const sanitized = { ...file };
|
||||
for (const field of FILE_STRIP_FIELDS) {
|
||||
delete sanitized[field as keyof typeof sanitized];
|
||||
}
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a message object before transmitting to client.
|
||||
* Removes large fields like `fileContext` and strips `text` from embedded files.
|
||||
*
|
||||
* @param message - The message object to sanitize
|
||||
* @returns A new message object safe for client transmission
|
||||
*
|
||||
* @example
|
||||
* sendEvent(res, {
|
||||
* final: true,
|
||||
* requestMessage: sanitizeMessageForTransmit(userMessage),
|
||||
* responseMessage: response,
|
||||
* });
|
||||
*/
|
||||
export function sanitizeMessageForTransmit<T extends Partial<TMessage>>(
|
||||
message: T,
|
||||
): Omit<T, (typeof MESSAGE_STRIP_FIELDS)[number]> {
|
||||
if (!message) {
|
||||
return message as Omit<T, (typeof MESSAGE_STRIP_FIELDS)[number]>;
|
||||
}
|
||||
|
||||
const sanitized = { ...message };
|
||||
|
||||
// Remove message-level fields
|
||||
for (const field of MESSAGE_STRIP_FIELDS) {
|
||||
delete sanitized[field as keyof typeof sanitized];
|
||||
}
|
||||
|
||||
// Always create a new array when files exist to maintain full immutability
|
||||
if (Array.isArray(sanitized.files)) {
|
||||
sanitized.files = sanitized.files.map((file) => sanitizeFileForTransmit(file));
|
||||
}
|
||||
|
||||
return sanitized;
|
||||
}
|
||||
851
packages/api/src/utils/text.spec.ts
Normal file
851
packages/api/src/utils/text.spec.ts
Normal file
@@ -0,0 +1,851 @@
|
||||
import { processTextWithTokenLimit, TokenCountFn } from './text';
|
||||
import Tokenizer, { countTokens } from './tokenizer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
/**
|
||||
* OLD IMPLEMENTATION (Binary Search) - kept for comparison testing
|
||||
* This is the original algorithm that caused CPU spikes
|
||||
*/
|
||||
async function processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
}: {
|
||||
text: string;
|
||||
tokenLimit: number;
|
||||
tokenCountFn: TokenCountFn;
|
||||
}): Promise<{ text: string; tokenCount: number; wasTruncated: boolean }> {
|
||||
const originalTokenCount = await tokenCountFn(text);
|
||||
|
||||
if (originalTokenCount <= tokenLimit) {
|
||||
return {
|
||||
text,
|
||||
tokenCount: originalTokenCount,
|
||||
wasTruncated: false,
|
||||
};
|
||||
}
|
||||
|
||||
let low = 0;
|
||||
let high = text.length;
|
||||
let bestText = '';
|
||||
|
||||
while (low <= high) {
|
||||
const mid = Math.floor((low + high) / 2);
|
||||
const truncatedText = text.substring(0, mid);
|
||||
const tokenCount = await tokenCountFn(truncatedText);
|
||||
|
||||
if (tokenCount <= tokenLimit) {
|
||||
bestText = truncatedText;
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid - 1;
|
||||
}
|
||||
}
|
||||
|
||||
const finalTokenCount = await tokenCountFn(bestText);
|
||||
|
||||
return {
|
||||
text: bestText,
|
||||
tokenCount: finalTokenCount,
|
||||
wasTruncated: true,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a wrapper around Tokenizer.getTokenCount that tracks call count
|
||||
*/
|
||||
const createRealTokenCounter = () => {
|
||||
let callCount = 0;
|
||||
const tokenCountFn = (text: string): number => {
|
||||
callCount++;
|
||||
return Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
};
|
||||
return {
|
||||
tokenCountFn,
|
||||
getCallCount: () => callCount,
|
||||
resetCallCount: () => {
|
||||
callCount = 0;
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a wrapper around the async countTokens function that tracks call count
|
||||
*/
|
||||
const createCountTokensCounter = () => {
|
||||
let callCount = 0;
|
||||
const tokenCountFn = async (text: string): Promise<number> => {
|
||||
callCount++;
|
||||
return countTokens(text);
|
||||
};
|
||||
return {
|
||||
tokenCountFn,
|
||||
getCallCount: () => callCount,
|
||||
resetCallCount: () => {
|
||||
callCount = 0;
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
describe('processTextWithTokenLimit', () => {
|
||||
/**
|
||||
* Creates a mock token count function that simulates realistic token counting.
|
||||
* Roughly 4 characters per token (common for English text).
|
||||
* Tracks call count to verify efficiency.
|
||||
*/
|
||||
const createMockTokenCounter = () => {
|
||||
let callCount = 0;
|
||||
const tokenCountFn = (text: string): number => {
|
||||
callCount++;
|
||||
return Math.ceil(text.length / 4);
|
||||
};
|
||||
return {
|
||||
tokenCountFn,
|
||||
getCallCount: () => callCount,
|
||||
resetCallCount: () => {
|
||||
callCount = 0;
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
/** Creates a string of specified character length */
|
||||
const createTextOfLength = (charLength: number): string => {
|
||||
return 'a'.repeat(charLength);
|
||||
};
|
||||
|
||||
/** Creates realistic text content with varied token density */
|
||||
const createRealisticText = (approximateTokens: number): string => {
|
||||
const words = [
|
||||
'the',
|
||||
'quick',
|
||||
'brown',
|
||||
'fox',
|
||||
'jumps',
|
||||
'over',
|
||||
'lazy',
|
||||
'dog',
|
||||
'lorem',
|
||||
'ipsum',
|
||||
'dolor',
|
||||
'sit',
|
||||
'amet',
|
||||
'consectetur',
|
||||
'adipiscing',
|
||||
'elit',
|
||||
'sed',
|
||||
'do',
|
||||
'eiusmod',
|
||||
'tempor',
|
||||
'incididunt',
|
||||
'ut',
|
||||
'labore',
|
||||
'et',
|
||||
'dolore',
|
||||
'magna',
|
||||
'aliqua',
|
||||
'enim',
|
||||
'ad',
|
||||
'minim',
|
||||
'veniam',
|
||||
'authentication',
|
||||
'implementation',
|
||||
'configuration',
|
||||
'documentation',
|
||||
];
|
||||
const result: string[] = [];
|
||||
for (let i = 0; i < approximateTokens; i++) {
|
||||
result.push(words[i % words.length]);
|
||||
}
|
||||
return result.join(' ');
|
||||
};
|
||||
|
||||
describe('tokenCountFn flexibility (sync and async)', () => {
|
||||
it('should work with synchronous tokenCountFn', async () => {
|
||||
const syncTokenCountFn = (text: string): number => Math.ceil(text.length / 4);
|
||||
const text = 'Hello, world! This is a test message.';
|
||||
const tokenLimit = 5;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: syncTokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
});
|
||||
|
||||
it('should work with asynchronous tokenCountFn', async () => {
|
||||
const asyncTokenCountFn = async (text: string): Promise<number> => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1));
|
||||
return Math.ceil(text.length / 4);
|
||||
};
|
||||
const text = 'Hello, world! This is a test message.';
|
||||
const tokenLimit = 5;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: asyncTokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
});
|
||||
|
||||
it('should produce equivalent results with sync and async tokenCountFn', async () => {
|
||||
const syncTokenCountFn = (text: string): number => Math.ceil(text.length / 4);
|
||||
const asyncTokenCountFn = async (text: string): Promise<number> => Math.ceil(text.length / 4);
|
||||
const text = 'a'.repeat(8000);
|
||||
const tokenLimit = 1000;
|
||||
|
||||
const syncResult = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: syncTokenCountFn,
|
||||
});
|
||||
|
||||
const asyncResult = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: asyncTokenCountFn,
|
||||
});
|
||||
|
||||
expect(syncResult.tokenCount).toBe(asyncResult.tokenCount);
|
||||
expect(syncResult.wasTruncated).toBe(asyncResult.wasTruncated);
|
||||
expect(syncResult.text.length).toBe(asyncResult.text.length);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when text is under the token limit', () => {
|
||||
it('should return original text unchanged', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = 'Hello, world!';
|
||||
const tokenLimit = 100;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.text).toBe(text);
|
||||
expect(result.wasTruncated).toBe(false);
|
||||
});
|
||||
|
||||
it('should return correct token count', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = 'Hello, world!';
|
||||
const tokenLimit = 100;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.tokenCount).toBe(Math.ceil(text.length / 4));
|
||||
});
|
||||
|
||||
it('should only call tokenCountFn once when under limit', async () => {
|
||||
const { tokenCountFn, getCallCount } = createMockTokenCounter();
|
||||
const text = 'Hello, world!';
|
||||
const tokenLimit = 100;
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(getCallCount()).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when text is exactly at the token limit', () => {
|
||||
it('should return original text unchanged', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = createTextOfLength(400);
|
||||
const tokenLimit = 100;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.text).toBe(text);
|
||||
expect(result.wasTruncated).toBe(false);
|
||||
expect(result.tokenCount).toBe(tokenLimit);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when text exceeds the token limit', () => {
|
||||
it('should truncate text to fit within limit', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = createTextOfLength(8000);
|
||||
const tokenLimit = 1000;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(result.text.length).toBeLessThan(text.length);
|
||||
});
|
||||
|
||||
it('should truncate text to be close to but not exceed the limit', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = createTextOfLength(8000);
|
||||
const tokenLimit = 1000;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(result.tokenCount).toBeGreaterThan(tokenLimit * 0.9);
|
||||
});
|
||||
});
|
||||
|
||||
describe('efficiency - tokenCountFn call count', () => {
|
||||
it('should call tokenCountFn at most 7 times for large text (vs ~17 for binary search)', async () => {
|
||||
const { tokenCountFn, getCallCount } = createMockTokenCounter();
|
||||
const text = createTextOfLength(400000);
|
||||
const tokenLimit = 50000;
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(getCallCount()).toBeLessThanOrEqual(7);
|
||||
});
|
||||
|
||||
it('should typically call tokenCountFn only 2-3 times for standard truncation', async () => {
|
||||
const { tokenCountFn, getCallCount } = createMockTokenCounter();
|
||||
const text = createTextOfLength(40000);
|
||||
const tokenLimit = 5000;
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(getCallCount()).toBeLessThanOrEqual(3);
|
||||
});
|
||||
|
||||
it('should call tokenCountFn only once when text is under limit', async () => {
|
||||
const { tokenCountFn, getCallCount } = createMockTokenCounter();
|
||||
const text = createTextOfLength(1000);
|
||||
const tokenLimit = 10000;
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(getCallCount()).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle very large text (100k+ tokens) efficiently', async () => {
|
||||
const { tokenCountFn, getCallCount } = createMockTokenCounter();
|
||||
const text = createTextOfLength(500000);
|
||||
const tokenLimit = 100000;
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(getCallCount()).toBeLessThanOrEqual(7);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty text', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = '';
|
||||
const tokenLimit = 100;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.text).toBe('');
|
||||
expect(result.tokenCount).toBe(0);
|
||||
expect(result.wasTruncated).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle token limit of 1', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = createTextOfLength(1000);
|
||||
const tokenLimit = 1;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
});
|
||||
|
||||
it('should handle text that is just slightly over the limit', async () => {
|
||||
const { tokenCountFn } = createMockTokenCounter();
|
||||
const text = createTextOfLength(404);
|
||||
const tokenLimit = 100;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
});
|
||||
});
|
||||
|
||||
describe('correctness with variable token density', () => {
|
||||
it('should handle text with varying token density', async () => {
|
||||
const variableDensityTokenCounter = (text: string): number => {
|
||||
const shortWords = (text.match(/\s+/g) || []).length;
|
||||
return Math.ceil(text.length / 4) + shortWords;
|
||||
};
|
||||
|
||||
const text = 'This is a test with many short words and some longer concatenated words too';
|
||||
const tokenLimit = 10;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: variableDensityTokenCounter,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
});
|
||||
});
|
||||
|
||||
describe('direct comparison with OLD binary search implementation', () => {
|
||||
it('should produce equivalent results to the old implementation', async () => {
|
||||
const oldCounter = createMockTokenCounter();
|
||||
const newCounter = createMockTokenCounter();
|
||||
const text = createTextOfLength(8000);
|
||||
const tokenLimit = 1000;
|
||||
|
||||
const oldResult = await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const newResult = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
expect(newResult.wasTruncated).toBe(oldResult.wasTruncated);
|
||||
expect(newResult.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(oldResult.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
});
|
||||
|
||||
it('should use significantly fewer tokenCountFn calls than old implementation (400k chars)', async () => {
|
||||
const oldCounter = createMockTokenCounter();
|
||||
const newCounter = createMockTokenCounter();
|
||||
const text = createTextOfLength(400000);
|
||||
const tokenLimit = 50000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(
|
||||
`[400k chars] OLD implementation: ${oldCalls} calls, NEW implementation: ${newCalls} calls`,
|
||||
);
|
||||
console.log(`[400k chars] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
expect(newCalls).toBeLessThanOrEqual(7);
|
||||
});
|
||||
|
||||
it('should use significantly fewer tokenCountFn calls than old implementation (500k chars, 100k token limit)', async () => {
|
||||
const oldCounter = createMockTokenCounter();
|
||||
const newCounter = createMockTokenCounter();
|
||||
const text = createTextOfLength(500000);
|
||||
const tokenLimit = 100000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(
|
||||
`[500k chars] OLD implementation: ${oldCalls} calls, NEW implementation: ${newCalls} calls`,
|
||||
);
|
||||
console.log(`[500k chars] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
});
|
||||
|
||||
it('should achieve at least 70% reduction in tokenCountFn calls', async () => {
|
||||
const oldCounter = createMockTokenCounter();
|
||||
const newCounter = createMockTokenCounter();
|
||||
const text = createTextOfLength(500000);
|
||||
const tokenLimit = 100000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
const reduction = 1 - newCalls / oldCalls;
|
||||
|
||||
console.log(
|
||||
`Efficiency improvement: ${(reduction * 100).toFixed(1)}% fewer tokenCountFn calls`,
|
||||
);
|
||||
|
||||
expect(reduction).toBeGreaterThanOrEqual(0.7);
|
||||
});
|
||||
|
||||
it('should simulate the reported scenario (122k tokens, 100k limit)', async () => {
|
||||
const oldCounter = createMockTokenCounter();
|
||||
const newCounter = createMockTokenCounter();
|
||||
const text = createTextOfLength(489564);
|
||||
const tokenLimit = 100000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`[User reported scenario: ~122k tokens]`);
|
||||
console.log(`OLD implementation: ${oldCalls} tokenCountFn calls`);
|
||||
console.log(`NEW implementation: ${newCalls} tokenCountFn calls`);
|
||||
console.log(`Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
expect(newCalls).toBeLessThanOrEqual(7);
|
||||
});
|
||||
});
|
||||
|
||||
describe('direct comparison with REAL tiktoken tokenizer', () => {
|
||||
beforeEach(() => {
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('should produce valid truncation with real tokenizer', async () => {
|
||||
const counter = createRealTokenCounter();
|
||||
const text = createRealisticText(5000);
|
||||
const tokenLimit = 1000;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: counter.tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(result.text.length).toBeLessThan(text.length);
|
||||
});
|
||||
|
||||
it('should use fewer tiktoken calls than old implementation (realistic text)', async () => {
|
||||
const oldCounter = createRealTokenCounter();
|
||||
const newCounter = createRealTokenCounter();
|
||||
const text = createRealisticText(15000);
|
||||
const tokenLimit = 5000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`[Real tiktoken ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
|
||||
console.log(`[Real tiktoken] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
});
|
||||
|
||||
it('should handle the reported user scenario with real tokenizer (~120k tokens)', async () => {
|
||||
const oldCounter = createRealTokenCounter();
|
||||
const newCounter = createRealTokenCounter();
|
||||
const text = createRealisticText(120000);
|
||||
const tokenLimit = 100000;
|
||||
|
||||
const startOld = performance.now();
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
const timeOld = performance.now() - startOld;
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
const startNew = performance.now();
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
const timeNew = performance.now() - startNew;
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`\n[REAL TIKTOKEN - User reported scenario: ~120k tokens]`);
|
||||
console.log(`OLD implementation: ${oldCalls} tiktoken calls, ${timeOld.toFixed(0)}ms`);
|
||||
console.log(`NEW implementation: ${newCalls} tiktoken calls, ${timeNew.toFixed(0)}ms`);
|
||||
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`);
|
||||
console.log(
|
||||
`Result: truncated=${result.wasTruncated}, tokens=${result.tokenCount}/${tokenLimit}\n`,
|
||||
);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(newCalls).toBeLessThanOrEqual(7);
|
||||
});
|
||||
|
||||
it('should achieve at least 70% reduction with real tokenizer', async () => {
|
||||
const oldCounter = createRealTokenCounter();
|
||||
const newCounter = createRealTokenCounter();
|
||||
const text = createRealisticText(50000);
|
||||
const tokenLimit = 10000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
const reduction = 1 - newCalls / oldCalls;
|
||||
|
||||
console.log(
|
||||
`[Real tiktoken 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
|
||||
);
|
||||
|
||||
expect(reduction).toBeGreaterThanOrEqual(0.7);
|
||||
});
|
||||
});
|
||||
|
||||
describe('using countTokens async function from @librechat/api', () => {
|
||||
beforeEach(() => {
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('countTokens should return correct token count', async () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = await countTokens(text);
|
||||
|
||||
expect(count).toBeGreaterThan(0);
|
||||
expect(typeof count).toBe('number');
|
||||
});
|
||||
|
||||
it('countTokens should handle empty string', async () => {
|
||||
const count = await countTokens('');
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
it('should work with processTextWithTokenLimit using countTokens', async () => {
|
||||
const counter = createCountTokensCounter();
|
||||
const text = createRealisticText(5000);
|
||||
const tokenLimit = 1000;
|
||||
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: counter.tokenCountFn,
|
||||
});
|
||||
|
||||
expect(result.wasTruncated).toBe(true);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(result.text.length).toBeLessThan(text.length);
|
||||
});
|
||||
|
||||
it('should use fewer countTokens calls than old implementation', async () => {
|
||||
const oldCounter = createCountTokensCounter();
|
||||
const newCounter = createCountTokensCounter();
|
||||
const text = createRealisticText(15000);
|
||||
const tokenLimit = 5000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`[countTokens ~15k tokens] OLD: ${oldCalls} calls, NEW: ${newCalls} calls`);
|
||||
console.log(`[countTokens] Reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
});
|
||||
|
||||
it('should handle user reported scenario with countTokens (~120k tokens)', async () => {
|
||||
const oldCounter = createCountTokensCounter();
|
||||
const newCounter = createCountTokensCounter();
|
||||
const text = createRealisticText(120000);
|
||||
const tokenLimit = 100000;
|
||||
|
||||
const startOld = performance.now();
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
const timeOld = performance.now() - startOld;
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
const startNew = performance.now();
|
||||
const result = await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
const timeNew = performance.now() - startNew;
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
|
||||
console.log(`\n[countTokens - User reported scenario: ~120k tokens]`);
|
||||
console.log(`OLD implementation: ${oldCalls} countTokens calls, ${timeOld.toFixed(0)}ms`);
|
||||
console.log(`NEW implementation: ${newCalls} countTokens calls, ${timeNew.toFixed(0)}ms`);
|
||||
console.log(`Call reduction: ${((1 - newCalls / oldCalls) * 100).toFixed(1)}%`);
|
||||
console.log(`Time reduction: ${((1 - timeNew / timeOld) * 100).toFixed(1)}%`);
|
||||
console.log(
|
||||
`Result: truncated=${result.wasTruncated}, tokens=${result.tokenCount}/${tokenLimit}\n`,
|
||||
);
|
||||
|
||||
expect(newCalls).toBeLessThan(oldCalls);
|
||||
expect(result.tokenCount).toBeLessThanOrEqual(tokenLimit);
|
||||
expect(newCalls).toBeLessThanOrEqual(7);
|
||||
});
|
||||
|
||||
it('should achieve at least 70% reduction with countTokens', async () => {
|
||||
const oldCounter = createCountTokensCounter();
|
||||
const newCounter = createCountTokensCounter();
|
||||
const text = createRealisticText(50000);
|
||||
const tokenLimit = 10000;
|
||||
|
||||
await processTextWithTokenLimitOLD({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: oldCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
await processTextWithTokenLimit({
|
||||
text,
|
||||
tokenLimit,
|
||||
tokenCountFn: newCounter.tokenCountFn,
|
||||
});
|
||||
|
||||
const oldCalls = oldCounter.getCallCount();
|
||||
const newCalls = newCounter.getCallCount();
|
||||
const reduction = 1 - newCalls / oldCalls;
|
||||
|
||||
console.log(
|
||||
`[countTokens 50k tokens] OLD: ${oldCalls}, NEW: ${newCalls}, Reduction: ${(reduction * 100).toFixed(1)}%`,
|
||||
);
|
||||
|
||||
expect(reduction).toBeGreaterThanOrEqual(0.7);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,39 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
|
||||
/** Token count function that can be sync or async */
|
||||
export type TokenCountFn = (text: string) => number | Promise<number>;
|
||||
|
||||
/**
|
||||
* Safety buffer multiplier applied to character position estimates during truncation.
|
||||
*
|
||||
* We use 98% (0.98) rather than 100% to intentionally undershoot the target on the first attempt.
|
||||
* This is necessary because:
|
||||
* - Token density varies across text (some regions may have more tokens per character than the average)
|
||||
* - The ratio-based estimate assumes uniform token distribution, which is rarely true
|
||||
* - Undershooting is safer than overshooting: exceeding the limit requires another iteration,
|
||||
* while being slightly under is acceptable
|
||||
* - In practice, this buffer reduces refinement iterations from 2-3 down to 0-1 in most cases
|
||||
*
|
||||
* @example
|
||||
* // If text has 1000 chars and 250 tokens (4 chars/token average), targeting 100 tokens:
|
||||
* // Without buffer: estimate = 1000 * (100/250) = 400 chars → might yield 105 tokens (over!)
|
||||
* // With 0.98 buffer: estimate = 400 * 0.98 = 392 chars → likely yields 97-99 tokens (safe)
|
||||
*/
|
||||
const TRUNCATION_SAFETY_BUFFER = 0.98;
|
||||
|
||||
/**
|
||||
* Processes text content by counting tokens and truncating if it exceeds the specified limit.
|
||||
* Uses ratio-based estimation to minimize expensive tokenCountFn calls.
|
||||
*
|
||||
* @param text - The text content to process
|
||||
* @param tokenLimit - The maximum number of tokens allowed
|
||||
* @param tokenCountFn - Function to count tokens
|
||||
* @param tokenCountFn - Function to count tokens (can be sync or async)
|
||||
* @returns Promise resolving to object with processed text, token count, and truncation status
|
||||
*
|
||||
* @remarks
|
||||
* This function uses a ratio-based estimation algorithm instead of binary search.
|
||||
* Binary search would require O(log n) tokenCountFn calls (~17 for 100k chars),
|
||||
* while this approach typically requires only 2-3 calls for a 90%+ reduction in CPU usage.
|
||||
*/
|
||||
export async function processTextWithTokenLimit({
|
||||
text,
|
||||
@@ -14,7 +42,7 @@ export async function processTextWithTokenLimit({
|
||||
}: {
|
||||
text: string;
|
||||
tokenLimit: number;
|
||||
tokenCountFn: (text: string) => number;
|
||||
tokenCountFn: TokenCountFn;
|
||||
}): Promise<{ text: string; tokenCount: number; wasTruncated: boolean }> {
|
||||
const originalTokenCount = await tokenCountFn(text);
|
||||
|
||||
@@ -26,40 +54,34 @@ export async function processTextWithTokenLimit({
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Doing binary search here to find the truncation point efficiently
|
||||
* (May be a better way to go about this)
|
||||
*/
|
||||
let low = 0;
|
||||
let high = text.length;
|
||||
let bestText = '';
|
||||
|
||||
logger.debug(
|
||||
`[textTokenLimiter] Text content exceeds token limit: ${originalTokenCount} > ${tokenLimit}, truncating...`,
|
||||
);
|
||||
|
||||
while (low <= high) {
|
||||
const mid = Math.floor((low + high) / 2);
|
||||
const truncatedText = text.substring(0, mid);
|
||||
const tokenCount = await tokenCountFn(truncatedText);
|
||||
const ratio = tokenLimit / originalTokenCount;
|
||||
let charPosition = Math.floor(text.length * ratio * TRUNCATION_SAFETY_BUFFER);
|
||||
|
||||
if (tokenCount <= tokenLimit) {
|
||||
bestText = truncatedText;
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid - 1;
|
||||
}
|
||||
let truncatedText = text.substring(0, charPosition);
|
||||
let tokenCount = await tokenCountFn(truncatedText);
|
||||
|
||||
const maxIterations = 5;
|
||||
let iterations = 0;
|
||||
|
||||
while (tokenCount > tokenLimit && iterations < maxIterations && charPosition > 0) {
|
||||
const overageRatio = tokenLimit / tokenCount;
|
||||
charPosition = Math.floor(charPosition * overageRatio * TRUNCATION_SAFETY_BUFFER);
|
||||
truncatedText = text.substring(0, charPosition);
|
||||
tokenCount = await tokenCountFn(truncatedText);
|
||||
iterations++;
|
||||
}
|
||||
|
||||
const finalTokenCount = await tokenCountFn(bestText);
|
||||
|
||||
logger.warn(
|
||||
`[textTokenLimiter] Text truncated from ${originalTokenCount} to ${finalTokenCount} tokens (limit: ${tokenLimit})`,
|
||||
`[textTokenLimiter] Text truncated from ${originalTokenCount} to ${tokenCount} tokens (limit: ${tokenLimit})`,
|
||||
);
|
||||
|
||||
return {
|
||||
text: bestText,
|
||||
tokenCount: finalTokenCount,
|
||||
text: truncatedText,
|
||||
tokenCount,
|
||||
wasTruncated: true,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -75,4 +75,14 @@ class Tokenizer {
|
||||
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
/**
|
||||
* Counts the number of tokens in a given text using tiktoken.
|
||||
* This is an async wrapper around Tokenizer.getTokenCount for compatibility.
|
||||
* @param text - The text to be tokenized. Defaults to an empty string if not provided.
|
||||
* @returns The number of tokens in the provided text.
|
||||
*/
|
||||
export async function countTokens(text = ''): Promise<number> {
|
||||
return TokenizerSingleton.getTokenCount(text, 'cl100k_base');
|
||||
}
|
||||
|
||||
export default TokenizerSingleton;
|
||||
|
||||
@@ -140,6 +140,7 @@ const anthropicModels = {
|
||||
|
||||
const deepseekModels = {
|
||||
deepseek: 128000,
|
||||
'deepseek-chat': 128000,
|
||||
'deepseek-reasoner': 128000,
|
||||
'deepseek-r1': 128000,
|
||||
'deepseek-v3': 128000,
|
||||
@@ -280,6 +281,9 @@ const xAIModels = {
|
||||
'grok-3-mini': 131072,
|
||||
'grok-3-mini-fast': 131072,
|
||||
'grok-4': 256000, // 256K context
|
||||
'grok-4-fast': 2000000, // 2M context
|
||||
'grok-4-1-fast': 2000000, // 2M context (covers reasoning & non-reasoning variants)
|
||||
'grok-code-fast': 256000, // 256K context
|
||||
};
|
||||
|
||||
const aggregateModels = {
|
||||
@@ -344,11 +348,21 @@ const anthropicMaxOutputs = {
|
||||
'claude-3-7-sonnet': 128000,
|
||||
};
|
||||
|
||||
/** Outputs from https://api-docs.deepseek.com/quick_start/pricing */
|
||||
const deepseekMaxOutputs = {
|
||||
deepseek: 8000, // deepseek-chat default: 4K, max: 8K
|
||||
'deepseek-chat': 8000,
|
||||
'deepseek-reasoner': 64000, // default: 32K, max: 64K
|
||||
'deepseek-r1': 64000,
|
||||
'deepseek-v3': 8000,
|
||||
'deepseek.r1': 64000,
|
||||
};
|
||||
|
||||
export const maxOutputTokensMap = {
|
||||
[EModelEndpoint.anthropic]: anthropicMaxOutputs,
|
||||
[EModelEndpoint.azureOpenAI]: modelMaxOutputs,
|
||||
[EModelEndpoint.openAI]: modelMaxOutputs,
|
||||
[EModelEndpoint.custom]: modelMaxOutputs,
|
||||
[EModelEndpoint.openAI]: { ...modelMaxOutputs, ...deepseekMaxOutputs },
|
||||
[EModelEndpoint.custom]: { ...modelMaxOutputs, ...deepseekMaxOutputs },
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
"@tanstack/react-virtual": "^3.0.0",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"dompurify": "^3.2.6",
|
||||
"dompurify": "^3.3.0",
|
||||
"framer-motion": "^12.23.6",
|
||||
"i18next": "^24.2.2 || ^25.3.2",
|
||||
"i18next-browser-languagedetector": "^8.2.0",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { replaceSpecialVars } from '../src/parsers';
|
||||
import { replaceSpecialVars, parseCompactConvo } from '../src/parsers';
|
||||
import { specialVariables } from '../src/config';
|
||||
import type { TUser } from '../src/types';
|
||||
import { EModelEndpoint } from '../src/schemas';
|
||||
import type { TUser, TConversation } from '../src/types';
|
||||
|
||||
// Mock dayjs module with consistent date/time values regardless of environment
|
||||
jest.mock('dayjs', () => {
|
||||
@@ -123,3 +124,138 @@ describe('replaceSpecialVars', () => {
|
||||
expect(result).toContain('Test User'); // current_user
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseCompactConvo', () => {
|
||||
describe('iconURL security sanitization', () => {
|
||||
test('should strip iconURL from OpenAI endpoint conversation input', () => {
|
||||
const maliciousIconURL = 'https://evil-tracker.example.com/pixel.png?user=victim';
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4',
|
||||
iconURL: maliciousIconURL,
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.model).toBe('gpt-4');
|
||||
});
|
||||
|
||||
test('should strip iconURL from agents endpoint conversation input', () => {
|
||||
const maliciousIconURL = 'https://evil-tracker.example.com/pixel.png';
|
||||
const conversation: Partial<TConversation> = {
|
||||
agent_id: 'agent_123',
|
||||
iconURL: maliciousIconURL,
|
||||
endpoint: EModelEndpoint.agents,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.agents,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.agent_id).toBe('agent_123');
|
||||
});
|
||||
|
||||
test('should strip iconURL from anthropic endpoint conversation input', () => {
|
||||
const maliciousIconURL = 'https://tracker.malicious.com/beacon.gif';
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'claude-3-opus',
|
||||
iconURL: maliciousIconURL,
|
||||
endpoint: EModelEndpoint.anthropic,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.anthropic,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.model).toBe('claude-3-opus');
|
||||
});
|
||||
|
||||
test('should strip iconURL from google endpoint conversation input', () => {
|
||||
const maliciousIconURL = 'https://tracking.example.com/spy.png';
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gemini-pro',
|
||||
iconURL: maliciousIconURL,
|
||||
endpoint: EModelEndpoint.google,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.google,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.model).toBe('gemini-pro');
|
||||
});
|
||||
|
||||
test('should strip iconURL from assistants endpoint conversation input', () => {
|
||||
const maliciousIconURL = 'https://evil.com/track.png';
|
||||
const conversation: Partial<TConversation> = {
|
||||
assistant_id: 'asst_123',
|
||||
iconURL: maliciousIconURL,
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.assistants,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.assistant_id).toBe('asst_123');
|
||||
});
|
||||
|
||||
test('should preserve other conversation properties while stripping iconURL', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4',
|
||||
iconURL: 'https://malicious.com/track.png',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
promptPrefix: 'You are a helpful assistant.',
|
||||
maxContextTokens: 4000,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.model).toBe('gpt-4');
|
||||
expect(result?.temperature).toBe(0.7);
|
||||
expect(result?.top_p).toBe(0.9);
|
||||
expect(result?.promptPrefix).toBe('You are a helpful assistant.');
|
||||
expect(result?.maxContextTokens).toBe(4000);
|
||||
});
|
||||
|
||||
test('should handle conversation without iconURL (no error)', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.iconURL).toBeUndefined();
|
||||
expect(result?.model).toBe('gpt-4');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1133,6 +1133,7 @@ export const supportsBalanceCheck = {
|
||||
[EModelEndpoint.azureAssistants]: true,
|
||||
[EModelEndpoint.azureOpenAI]: true,
|
||||
[EModelEndpoint.bedrock]: true,
|
||||
[EModelEndpoint.google]: true,
|
||||
};
|
||||
|
||||
export const visionModels = [
|
||||
|
||||
@@ -200,6 +200,27 @@ export const codeTypeMapping: { [key: string]: string } = {
|
||||
tsv: 'text/tab-separated-values',
|
||||
};
|
||||
|
||||
/** Maps image extensions to MIME types for formats browsers may not recognize */
|
||||
export const imageTypeMapping: { [key: string]: string } = {
|
||||
heic: 'image/heic',
|
||||
heif: 'image/heif',
|
||||
};
|
||||
|
||||
/**
|
||||
* Infers the MIME type from a file's extension when the browser doesn't recognize it
|
||||
* @param fileName - The name of the file including extension
|
||||
* @param currentType - The current MIME type reported by the browser (may be empty)
|
||||
* @returns The inferred MIME type if browser didn't provide one, otherwise the original type
|
||||
*/
|
||||
export function inferMimeType(fileName: string, currentType: string): string {
|
||||
if (currentType) {
|
||||
return currentType;
|
||||
}
|
||||
|
||||
const extension = fileName.split('.').pop()?.toLowerCase() ?? '';
|
||||
return codeTypeMapping[extension] || imageTypeMapping[extension] || currentType;
|
||||
}
|
||||
|
||||
export const retrievalMimeTypes = [
|
||||
/^(text\/(x-c|x-c\+\+|x-h|html|x-java|markdown|x-php|x-python|x-script\.python|x-ruby|x-tex|plain|vtt|xml))$/,
|
||||
/^(application\/(json|pdf|vnd\.openxmlformats-officedocument\.(wordprocessingml\.document|presentationml\.presentation)))$/,
|
||||
|
||||
@@ -326,7 +326,7 @@ export const parseCompactConvo = ({
|
||||
possibleValues?: TPossibleValues;
|
||||
// TODO: POC for default schema
|
||||
// defaultSchema?: Partial<EndpointSchema>,
|
||||
}) => {
|
||||
}): Omit<s.TConversation, 'iconURL'> | null => {
|
||||
if (!endpoint) {
|
||||
throw new Error(`undefined endpoint: ${endpoint}`);
|
||||
}
|
||||
@@ -343,7 +343,11 @@ export const parseCompactConvo = ({
|
||||
throw new Error(`Unknown endpointType: ${endpointType}`);
|
||||
}
|
||||
|
||||
const convo = schema.parse(conversation) as s.TConversation | null;
|
||||
// Strip iconURL from input before parsing - it should only be derived server-side
|
||||
// from model spec configuration, not accepted from client requests
|
||||
const { iconURL: _clientIconURL, ...conversationWithoutIconURL } = conversation;
|
||||
|
||||
const convo = schema.parse(conversationWithoutIconURL) as s.TConversation | null;
|
||||
// const { models, secondaryModels } = possibleValues ?? {};
|
||||
const { models } = possibleValues ?? {};
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ export enum Providers {
|
||||
BEDROCK = 'bedrock',
|
||||
MISTRALAI = 'mistralai',
|
||||
MISTRAL = 'mistral',
|
||||
OLLAMA = 'ollama',
|
||||
DEEPSEEK = 'deepseek',
|
||||
OPENROUTER = 'openrouter',
|
||||
XAI = 'xai',
|
||||
@@ -59,7 +58,6 @@ export const documentSupportedProviders = new Set<string>([
|
||||
Providers.VERTEXAI,
|
||||
Providers.MISTRALAI,
|
||||
Providers.MISTRAL,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
Providers.XAI,
|
||||
@@ -71,7 +69,6 @@ const openAILikeProviders = new Set<string>([
|
||||
EModelEndpoint.custom,
|
||||
Providers.MISTRALAI,
|
||||
Providers.MISTRAL,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
Providers.XAI,
|
||||
|
||||
@@ -418,6 +418,41 @@ describe('Token Methods - Detailed Tests', () => {
|
||||
|
||||
expect(updated).toBeNull();
|
||||
});
|
||||
|
||||
test('should update expiresAt when expiresIn is provided', async () => {
|
||||
const beforeUpdate = Date.now();
|
||||
const newExpiresIn = 7200;
|
||||
|
||||
const updated = await methods.updateToken(
|
||||
{ token: 'update-token' },
|
||||
{ expiresIn: newExpiresIn },
|
||||
);
|
||||
|
||||
const afterUpdate = Date.now();
|
||||
|
||||
expect(updated).toBeDefined();
|
||||
expect(updated?.expiresAt).toBeDefined();
|
||||
|
||||
const expectedMinExpiry = beforeUpdate + newExpiresIn * 1000;
|
||||
const expectedMaxExpiry = afterUpdate + newExpiresIn * 1000;
|
||||
|
||||
expect(updated!.expiresAt.getTime()).toBeGreaterThanOrEqual(expectedMinExpiry);
|
||||
expect(updated!.expiresAt.getTime()).toBeLessThanOrEqual(expectedMaxExpiry);
|
||||
});
|
||||
|
||||
test('should not modify expiresAt when expiresIn is not provided', async () => {
|
||||
const original = await Token.findOne({ token: 'update-token' });
|
||||
const originalExpiresAt = original!.expiresAt.getTime();
|
||||
|
||||
const updated = await methods.updateToken(
|
||||
{ token: 'update-token' },
|
||||
{ email: 'changed@example.com' },
|
||||
);
|
||||
|
||||
expect(updated).toBeDefined();
|
||||
expect(updated?.email).toBe('changed@example.com');
|
||||
expect(updated!.expiresAt.getTime()).toBe(originalExpiresAt);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteTokens', () => {
|
||||
@@ -617,4 +652,171 @@ describe('Token Methods - Detailed Tests', () => {
|
||||
expect(remainingTokens.find((t) => t.token === 'email-verify-token-2')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Email Normalization', () => {
|
||||
let normUserId: mongoose.Types.ObjectId;
|
||||
|
||||
beforeEach(async () => {
|
||||
normUserId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create token with lowercase email (as stored in DB)
|
||||
await Token.create({
|
||||
token: 'norm-token-1',
|
||||
userId: normUserId,
|
||||
email: 'john.doe@example.com',
|
||||
createdAt: new Date(),
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
});
|
||||
});
|
||||
|
||||
describe('findToken email normalization', () => {
|
||||
test('should find token by email with different case (case-insensitive)', async () => {
|
||||
const foundUpper = await methods.findToken({ email: 'JOHN.DOE@EXAMPLE.COM' });
|
||||
const foundMixed = await methods.findToken({ email: 'John.Doe@Example.COM' });
|
||||
const foundLower = await methods.findToken({ email: 'john.doe@example.com' });
|
||||
|
||||
expect(foundUpper).toBeDefined();
|
||||
expect(foundUpper?.token).toBe('norm-token-1');
|
||||
|
||||
expect(foundMixed).toBeDefined();
|
||||
expect(foundMixed?.token).toBe('norm-token-1');
|
||||
|
||||
expect(foundLower).toBeDefined();
|
||||
expect(foundLower?.token).toBe('norm-token-1');
|
||||
});
|
||||
|
||||
test('should find token by email with leading/trailing whitespace', async () => {
|
||||
const foundWithSpaces = await methods.findToken({ email: ' john.doe@example.com ' });
|
||||
const foundWithTabs = await methods.findToken({ email: '\tjohn.doe@example.com\t' });
|
||||
|
||||
expect(foundWithSpaces).toBeDefined();
|
||||
expect(foundWithSpaces?.token).toBe('norm-token-1');
|
||||
|
||||
expect(foundWithTabs).toBeDefined();
|
||||
expect(foundWithTabs?.token).toBe('norm-token-1');
|
||||
});
|
||||
|
||||
test('should find token by email with both case difference and whitespace', async () => {
|
||||
const found = await methods.findToken({ email: ' JOHN.DOE@EXAMPLE.COM ' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.token).toBe('norm-token-1');
|
||||
});
|
||||
|
||||
test('should find token with combined email and other criteria', async () => {
|
||||
const found = await methods.findToken({
|
||||
userId: normUserId.toString(),
|
||||
email: 'John.Doe@Example.COM',
|
||||
});
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.token).toBe('norm-token-1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteTokens email normalization', () => {
|
||||
test('should delete token by email with different case', async () => {
|
||||
const result = await methods.deleteTokens({ email: 'JOHN.DOE@EXAMPLE.COM' });
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
|
||||
const remaining = await Token.find({});
|
||||
expect(remaining).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should delete token by email with whitespace', async () => {
|
||||
const result = await methods.deleteTokens({ email: ' john.doe@example.com ' });
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
|
||||
const remaining = await Token.find({});
|
||||
expect(remaining).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should delete token by email with case and whitespace combined', async () => {
|
||||
const result = await methods.deleteTokens({ email: ' John.Doe@EXAMPLE.COM ' });
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
|
||||
const remaining = await Token.find({});
|
||||
expect(remaining).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should only delete matching token when using normalized email', async () => {
|
||||
// Create additional token with different email
|
||||
await Token.create({
|
||||
token: 'norm-token-2',
|
||||
userId: new mongoose.Types.ObjectId(),
|
||||
email: 'jane.doe@example.com',
|
||||
createdAt: new Date(),
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
});
|
||||
|
||||
const result = await methods.deleteTokens({ email: 'JOHN.DOE@EXAMPLE.COM' });
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
|
||||
const remaining = await Token.find({});
|
||||
expect(remaining).toHaveLength(1);
|
||||
expect(remaining[0].email).toBe('jane.doe@example.com');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Email verification flow with normalization', () => {
|
||||
test('should handle OpenID provider email case mismatch scenario', async () => {
|
||||
/**
|
||||
* Simulate the exact bug scenario:
|
||||
* 1. User registers with email stored as lowercase
|
||||
* 2. OpenID provider returns email with different casing
|
||||
* 3. System should still find and delete the correct token
|
||||
*/
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Token created during registration (email stored lowercase)
|
||||
await Token.create({
|
||||
token: 'verification-token',
|
||||
userId: userId,
|
||||
email: 'user@company.com',
|
||||
createdAt: new Date(),
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
});
|
||||
|
||||
// OpenID provider returns email with different case
|
||||
const emailFromProvider = 'User@Company.COM';
|
||||
|
||||
// Should find the token despite case mismatch
|
||||
const found = await methods.findToken({ email: emailFromProvider });
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.token).toBe('verification-token');
|
||||
|
||||
// Should delete the token despite case mismatch
|
||||
const deleted = await methods.deleteTokens({ email: emailFromProvider });
|
||||
expect(deleted.deletedCount).toBe(1);
|
||||
});
|
||||
|
||||
test('should handle resend verification email with case mismatch', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Old verification token
|
||||
await Token.create({
|
||||
token: 'old-verification',
|
||||
userId: userId,
|
||||
email: 'john.smith@enterprise.com',
|
||||
createdAt: new Date(Date.now() - 3600000),
|
||||
expiresAt: new Date(Date.now() + 82800000),
|
||||
});
|
||||
|
||||
// User requests resend with different email casing
|
||||
const userInputEmail = ' John.Smith@ENTERPRISE.COM ';
|
||||
|
||||
// Delete old tokens for this email
|
||||
const deleted = await methods.deleteTokens({ email: userInputEmail });
|
||||
expect(deleted.deletedCount).toBe(1);
|
||||
|
||||
// Verify token was actually deleted
|
||||
const remaining = await Token.find({ userId });
|
||||
expect(remaining).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -35,7 +35,13 @@ export function createTokenMethods(mongoose: typeof import('mongoose')) {
|
||||
): Promise<IToken | null> {
|
||||
try {
|
||||
const Token = mongoose.models.Token;
|
||||
return await Token.findOneAndUpdate(query, updateData, { new: true });
|
||||
|
||||
const dataToUpdate = { ...updateData };
|
||||
if (updateData?.expiresIn !== undefined) {
|
||||
dataToUpdate.expiresAt = new Date(Date.now() + updateData.expiresIn * 1000);
|
||||
}
|
||||
|
||||
return await Token.findOneAndUpdate(query, dataToUpdate, { new: true });
|
||||
} catch (error) {
|
||||
logger.debug('An error occurred while updating token:', error);
|
||||
throw error;
|
||||
@@ -44,6 +50,7 @@ export function createTokenMethods(mongoose: typeof import('mongoose')) {
|
||||
|
||||
/**
|
||||
* Deletes all Token documents that match the provided token, user ID, or email.
|
||||
* Email is automatically normalized to lowercase for case-insensitive matching.
|
||||
*/
|
||||
async function deleteTokens(query: TokenQuery): Promise<TokenDeleteResult> {
|
||||
try {
|
||||
@@ -57,7 +64,7 @@ export function createTokenMethods(mongoose: typeof import('mongoose')) {
|
||||
conditions.push({ token: query.token });
|
||||
}
|
||||
if (query.email !== undefined) {
|
||||
conditions.push({ email: query.email });
|
||||
conditions.push({ email: query.email.trim().toLowerCase() });
|
||||
}
|
||||
if (query.identifier !== undefined) {
|
||||
conditions.push({ identifier: query.identifier });
|
||||
@@ -81,6 +88,7 @@ export function createTokenMethods(mongoose: typeof import('mongoose')) {
|
||||
|
||||
/**
|
||||
* Finds a Token document that matches the provided query.
|
||||
* Email is automatically normalized to lowercase for case-insensitive matching.
|
||||
*/
|
||||
async function findToken(query: TokenQuery, options?: QueryOptions): Promise<IToken | null> {
|
||||
try {
|
||||
@@ -94,7 +102,7 @@ export function createTokenMethods(mongoose: typeof import('mongoose')) {
|
||||
conditions.push({ token: query.token });
|
||||
}
|
||||
if (query.email) {
|
||||
conditions.push({ email: query.email });
|
||||
conditions.push({ email: query.email.trim().toLowerCase() });
|
||||
}
|
||||
if (query.identifier) {
|
||||
conditions.push({ identifier: query.identifier });
|
||||
|
||||
623
packages/data-schemas/src/methods/user.methods.spec.ts
Normal file
623
packages/data-schemas/src/methods/user.methods.spec.ts
Normal file
@@ -0,0 +1,623 @@
|
||||
import mongoose from 'mongoose';
|
||||
import { MongoMemoryServer } from 'mongodb-memory-server';
|
||||
import type * as t from '~/types';
|
||||
import { createUserMethods } from './user';
|
||||
import userSchema from '~/schema/user';
|
||||
import balanceSchema from '~/schema/balance';
|
||||
|
||||
/** Mocking crypto for generateToken */
|
||||
jest.mock('~/crypto', () => ({
|
||||
signPayload: jest.fn().mockResolvedValue('mocked-token'),
|
||||
}));
|
||||
|
||||
let mongoServer: MongoMemoryServer;
|
||||
let User: mongoose.Model<t.IUser>;
|
||||
let Balance: mongoose.Model<t.IBalance>;
|
||||
let methods: ReturnType<typeof createUserMethods>;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
/** Register models */
|
||||
User = mongoose.models.User || mongoose.model<t.IUser>('User', userSchema);
|
||||
Balance = mongoose.models.Balance || mongoose.model<t.IBalance>('Balance', balanceSchema);
|
||||
|
||||
/** Initialize methods */
|
||||
methods = createUserMethods(mongoose);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
});
|
||||
|
||||
describe('User Methods - Database Tests', () => {
|
||||
describe('findUser', () => {
|
||||
test('should find user by exact email', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({ email: 'test@example.com' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should find user by email with different case (case-insensitive)', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com', // stored lowercase by schema
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
/** Test various case combinations - all should find the same user */
|
||||
const foundUpper = await methods.findUser({ email: 'TEST@EXAMPLE.COM' });
|
||||
const foundMixed = await methods.findUser({ email: 'Test@Example.COM' });
|
||||
const foundLower = await methods.findUser({ email: 'test@example.com' });
|
||||
|
||||
expect(foundUpper).toBeDefined();
|
||||
expect(foundUpper?.email).toBe('test@example.com');
|
||||
|
||||
expect(foundMixed).toBeDefined();
|
||||
expect(foundMixed?.email).toBe('test@example.com');
|
||||
|
||||
expect(foundLower).toBeDefined();
|
||||
expect(foundLower?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should find user by email with leading/trailing whitespace (trimmed)', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const foundWithSpaces = await methods.findUser({ email: ' test@example.com ' });
|
||||
const foundWithTabs = await methods.findUser({ email: '\ttest@example.com\t' });
|
||||
|
||||
expect(foundWithSpaces).toBeDefined();
|
||||
expect(foundWithSpaces?.email).toBe('test@example.com');
|
||||
|
||||
expect(foundWithTabs).toBeDefined();
|
||||
expect(foundWithTabs?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should find user by email with both case difference and whitespace', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'john.doe@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({ email: ' John.Doe@EXAMPLE.COM ' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('john.doe@example.com');
|
||||
});
|
||||
|
||||
test('should normalize email in $or conditions', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'openid',
|
||||
openidId: 'openid-123',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({
|
||||
$or: [{ openidId: 'different-id' }, { email: 'TEST@EXAMPLE.COM' }],
|
||||
});
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should find user by non-email criteria without affecting them', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'openid',
|
||||
openidId: 'openid-123',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({ openidId: 'openid-123' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.openidId).toBe('openid-123');
|
||||
});
|
||||
|
||||
test('should apply field selection correctly', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
username: 'testuser',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({ email: 'test@example.com' }, 'email name');
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('test@example.com');
|
||||
expect(found?.name).toBe('Test User');
|
||||
expect(found?.username).toBeUndefined();
|
||||
expect(found?.provider).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return null for non-existent user', async () => {
|
||||
const found = await methods.findUser({ email: 'nonexistent@example.com' });
|
||||
|
||||
expect(found).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createUser', () => {
|
||||
test('should create a user and return ObjectId by default', async () => {
|
||||
const result = await methods.createUser({
|
||||
name: 'New User',
|
||||
email: 'new@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
expect(result).toBeInstanceOf(mongoose.Types.ObjectId);
|
||||
|
||||
const user = await User.findById(result);
|
||||
expect(user).toBeDefined();
|
||||
expect(user?.name).toBe('New User');
|
||||
expect(user?.email).toBe('new@example.com');
|
||||
});
|
||||
|
||||
test('should create a user and return user object when returnUser is true', async () => {
|
||||
const result = await methods.createUser(
|
||||
{
|
||||
name: 'New User',
|
||||
email: 'new@example.com',
|
||||
provider: 'local',
|
||||
},
|
||||
undefined,
|
||||
true,
|
||||
true,
|
||||
);
|
||||
|
||||
expect(result).toHaveProperty('_id');
|
||||
expect(result).toHaveProperty('name', 'New User');
|
||||
expect(result).toHaveProperty('email', 'new@example.com');
|
||||
});
|
||||
|
||||
test('should store email as lowercase regardless of input case', async () => {
|
||||
await methods.createUser({
|
||||
name: 'New User',
|
||||
email: 'NEW@EXAMPLE.COM',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const user = await User.findOne({ email: 'new@example.com' });
|
||||
expect(user).toBeDefined();
|
||||
expect(user?.email).toBe('new@example.com');
|
||||
});
|
||||
|
||||
test('should create user with TTL when disableTTL is false', async () => {
|
||||
const result = await methods.createUser(
|
||||
{
|
||||
name: 'TTL User',
|
||||
email: 'ttl@example.com',
|
||||
provider: 'local',
|
||||
},
|
||||
undefined,
|
||||
false,
|
||||
true,
|
||||
);
|
||||
|
||||
expect(result).toHaveProperty('expiresAt');
|
||||
const expiresAt = (result as t.IUser).expiresAt;
|
||||
expect(expiresAt).toBeInstanceOf(Date);
|
||||
|
||||
/** Should expire in approximately 1 week */
|
||||
const oneWeekMs = 604800 * 1000;
|
||||
const expectedExpiry = Date.now() + oneWeekMs;
|
||||
expect(expiresAt!.getTime()).toBeGreaterThan(expectedExpiry - 10000);
|
||||
expect(expiresAt!.getTime()).toBeLessThan(expectedExpiry + 10000);
|
||||
});
|
||||
|
||||
test('should create balance record when balanceConfig is provided', async () => {
|
||||
const userId = await methods.createUser(
|
||||
{
|
||||
name: 'Balance User',
|
||||
email: 'balance@example.com',
|
||||
provider: 'local',
|
||||
},
|
||||
{
|
||||
enabled: true,
|
||||
startBalance: 1000,
|
||||
},
|
||||
);
|
||||
|
||||
const balance = await Balance.findOne({ user: userId });
|
||||
expect(balance).toBeDefined();
|
||||
expect(balance?.tokenCredits).toBe(1000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateUser', () => {
|
||||
test('should update user fields', async () => {
|
||||
const user = await User.create({
|
||||
name: 'Original Name',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const updated = await methods.updateUser(user._id?.toString() ?? '', {
|
||||
name: 'Updated Name',
|
||||
});
|
||||
|
||||
expect(updated).toBeDefined();
|
||||
expect(updated?.name).toBe('Updated Name');
|
||||
expect(updated?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should remove expiresAt field on update', async () => {
|
||||
const user = await User.create({
|
||||
name: 'TTL User',
|
||||
email: 'ttl@example.com',
|
||||
provider: 'local',
|
||||
expiresAt: new Date(Date.now() + 604800 * 1000),
|
||||
});
|
||||
|
||||
const updated = await methods.updateUser(user._id?.toString() || '', {
|
||||
name: 'No longer TTL',
|
||||
});
|
||||
|
||||
expect(updated).toBeDefined();
|
||||
expect(updated?.expiresAt).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return null for non-existent user', async () => {
|
||||
const fakeId = new mongoose.Types.ObjectId();
|
||||
const result = await methods.updateUser(fakeId.toString(), { name: 'Test' });
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getUserById', () => {
|
||||
test('should get user by ID', async () => {
|
||||
const user = await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const found = await methods.getUserById(user._id?.toString() || '');
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.name).toBe('Test User');
|
||||
});
|
||||
|
||||
test('should apply field selection', async () => {
|
||||
const user = await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
username: 'testuser',
|
||||
});
|
||||
|
||||
const found = await methods.getUserById(user._id?.toString() || '', 'name email');
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.name).toBe('Test User');
|
||||
expect(found?.email).toBe('test@example.com');
|
||||
expect(found?.username).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should return null for non-existent ID', async () => {
|
||||
const fakeId = new mongoose.Types.ObjectId();
|
||||
const found = await methods.getUserById(fakeId.toString());
|
||||
|
||||
expect(found).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteUserById', () => {
|
||||
test('should delete user by ID', async () => {
|
||||
const user = await User.create({
|
||||
name: 'To Delete',
|
||||
email: 'delete@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const result = await methods.deleteUserById(user._id?.toString() || '');
|
||||
|
||||
expect(result.deletedCount).toBe(1);
|
||||
expect(result.message).toBe('User was deleted successfully.');
|
||||
|
||||
const found = await User.findById(user._id);
|
||||
expect(found).toBeNull();
|
||||
});
|
||||
|
||||
test('should return zero count for non-existent user', async () => {
|
||||
const fakeId = new mongoose.Types.ObjectId();
|
||||
const result = await methods.deleteUserById(fakeId.toString());
|
||||
|
||||
expect(result.deletedCount).toBe(0);
|
||||
expect(result.message).toBe('No user found with that ID.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('countUsers', () => {
|
||||
test('should count all users', async () => {
|
||||
await User.create([
|
||||
{ name: 'User 1', email: 'user1@example.com', provider: 'local' },
|
||||
{ name: 'User 2', email: 'user2@example.com', provider: 'local' },
|
||||
{ name: 'User 3', email: 'user3@example.com', provider: 'openid' },
|
||||
]);
|
||||
|
||||
const count = await methods.countUsers();
|
||||
|
||||
expect(count).toBe(3);
|
||||
});
|
||||
|
||||
test('should count users with filter', async () => {
|
||||
await User.create([
|
||||
{ name: 'User 1', email: 'user1@example.com', provider: 'local' },
|
||||
{ name: 'User 2', email: 'user2@example.com', provider: 'local' },
|
||||
{ name: 'User 3', email: 'user3@example.com', provider: 'openid' },
|
||||
]);
|
||||
|
||||
const count = await methods.countUsers({ provider: 'local' });
|
||||
|
||||
expect(count).toBe(2);
|
||||
});
|
||||
|
||||
test('should return zero for empty collection', async () => {
|
||||
const count = await methods.countUsers();
|
||||
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchUsers', () => {
|
||||
beforeEach(async () => {
|
||||
await User.create([
|
||||
{ name: 'John Doe', email: 'john@example.com', username: 'johnd', provider: 'local' },
|
||||
{ name: 'Jane Smith', email: 'jane@example.com', username: 'janes', provider: 'local' },
|
||||
{
|
||||
name: 'Bob Johnson',
|
||||
email: 'bob@example.com',
|
||||
username: 'bobbyj',
|
||||
provider: 'local',
|
||||
},
|
||||
{
|
||||
name: 'Alice Wonder',
|
||||
email: 'alice@test.com',
|
||||
username: 'alice',
|
||||
provider: 'openid',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('should search by name', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: 'John' });
|
||||
|
||||
expect(results).toHaveLength(2); // John Doe and Bob Johnson
|
||||
});
|
||||
|
||||
test('should search by email', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: 'example.com' });
|
||||
|
||||
expect(results).toHaveLength(3);
|
||||
});
|
||||
|
||||
test('should search by username', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: 'alice' });
|
||||
|
||||
expect(results).toHaveLength(1);
|
||||
expect((results[0] as unknown as t.IUser)?.username).toBe('alice');
|
||||
});
|
||||
|
||||
test('should be case-insensitive', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: 'JOHN' });
|
||||
|
||||
expect(results.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test('should respect limit', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: 'example', limit: 2 });
|
||||
|
||||
expect(results).toHaveLength(2);
|
||||
});
|
||||
|
||||
test('should return empty array for empty search pattern', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: '' });
|
||||
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
test('should return empty array for whitespace-only pattern', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: ' ' });
|
||||
|
||||
expect(results).toEqual([]);
|
||||
});
|
||||
|
||||
test('should apply field selection', async () => {
|
||||
const results = await methods.searchUsers({
|
||||
searchPattern: 'john',
|
||||
fieldsToSelect: 'name email',
|
||||
});
|
||||
|
||||
expect(results.length).toBeGreaterThan(0);
|
||||
expect(results[0]).toHaveProperty('name');
|
||||
expect(results[0]).toHaveProperty('email');
|
||||
expect(results[0]).not.toHaveProperty('username');
|
||||
});
|
||||
|
||||
test('should sort by relevance (exact match first)', async () => {
|
||||
const results = await methods.searchUsers({ searchPattern: 'alice' });
|
||||
|
||||
/** 'alice' username should score highest due to exact match */
|
||||
expect((results[0] as unknown as t.IUser).username).toBe('alice');
|
||||
});
|
||||
});
|
||||
|
||||
describe('toggleUserMemories', () => {
|
||||
test('should enable memories for user', async () => {
|
||||
const user = await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const updated = await methods.toggleUserMemories(user._id?.toString() || '', true);
|
||||
|
||||
expect(updated).toBeDefined();
|
||||
expect(updated?.personalization?.memories).toBe(true);
|
||||
});
|
||||
|
||||
test('should disable memories for user', async () => {
|
||||
const user = await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
personalization: { memories: true },
|
||||
});
|
||||
|
||||
const updated = await methods.toggleUserMemories(user._id?.toString() || '', false);
|
||||
|
||||
expect(updated).toBeDefined();
|
||||
expect(updated?.personalization?.memories).toBe(false);
|
||||
});
|
||||
|
||||
test('should update personalization.memories field', async () => {
|
||||
const user = await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
/** Toggle memories to true */
|
||||
const updated = await methods.toggleUserMemories(user._id?.toString() || '', true);
|
||||
|
||||
expect(updated?.personalization).toBeDefined();
|
||||
expect(updated?.personalization?.memories).toBe(true);
|
||||
|
||||
/** Toggle back to false */
|
||||
const updatedAgain = await methods.toggleUserMemories(user._id?.toString() || '', false);
|
||||
expect(updatedAgain?.personalization?.memories).toBe(false);
|
||||
});
|
||||
|
||||
test('should return null for non-existent user', async () => {
|
||||
const fakeId = new mongoose.Types.ObjectId();
|
||||
const result = await methods.toggleUserMemories(fakeId.toString(), true);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Email Normalization Edge Cases', () => {
|
||||
test('should handle email with multiple spaces', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({ email: ' test@example.com ' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should handle mixed case with international characters', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'user@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const found = await methods.findUser({ email: 'USER@EXAMPLE.COM' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
});
|
||||
|
||||
test('should handle email normalization in complex $or queries', async () => {
|
||||
const user1 = await User.create({
|
||||
name: 'User One',
|
||||
email: 'user1@example.com',
|
||||
provider: 'openid',
|
||||
openidId: 'openid-1',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
name: 'User Two',
|
||||
email: 'user2@example.com',
|
||||
provider: 'openid',
|
||||
openidId: 'openid-2',
|
||||
});
|
||||
|
||||
/** Search with mixed case email in $or */
|
||||
const found = await methods.findUser({
|
||||
$or: [{ openidId: 'nonexistent' }, { email: 'USER1@EXAMPLE.COM' }],
|
||||
});
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?._id?.toString()).toBe(user1._id?.toString());
|
||||
});
|
||||
|
||||
test('should not normalize non-string email values', async () => {
|
||||
await User.create({
|
||||
name: 'Test User',
|
||||
email: 'test@example.com',
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
/** Using regex for email (should not be normalized) */
|
||||
const found = await methods.findUser({ email: /test@example\.com/i });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('test@example.com');
|
||||
});
|
||||
|
||||
test('should handle OpenID provider migration scenario', async () => {
|
||||
/** Simulate user stored with lowercase email */
|
||||
await User.create({
|
||||
name: 'John Doe',
|
||||
email: 'john.doe@company.com',
|
||||
provider: 'openid',
|
||||
openidId: 'old-provider-id',
|
||||
});
|
||||
|
||||
/**
|
||||
* New OpenID provider returns email with different casing
|
||||
* This simulates the exact bug reported in the GitHub issue
|
||||
*/
|
||||
const emailFromNewProvider = 'John.Doe@Company.COM';
|
||||
|
||||
const found = await methods.findUser({ email: emailFromNewProvider });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.email).toBe('john.doe@company.com');
|
||||
expect(found?.name).toBe('John Doe');
|
||||
});
|
||||
|
||||
test('should handle SAML provider email normalization', async () => {
|
||||
await User.create({
|
||||
name: 'SAML User',
|
||||
email: 'saml.user@enterprise.com',
|
||||
provider: 'saml',
|
||||
samlId: 'saml-123',
|
||||
});
|
||||
|
||||
/** SAML providers sometimes return emails in different formats */
|
||||
const found = await methods.findUser({ email: ' SAML.USER@ENTERPRISE.COM ' });
|
||||
|
||||
expect(found).toBeDefined();
|
||||
expect(found?.provider).toBe('saml');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,15 +4,37 @@ import { signPayload } from '~/crypto';
|
||||
|
||||
/** Factory function that takes mongoose instance and returns the methods */
|
||||
export function createUserMethods(mongoose: typeof import('mongoose')) {
|
||||
/**
|
||||
* Normalizes email fields in search criteria to lowercase and trimmed.
|
||||
* Handles both direct email fields and $or arrays containing email conditions.
|
||||
*/
|
||||
function normalizeEmailInCriteria<T extends FilterQuery<IUser>>(criteria: T): T {
|
||||
const normalized = { ...criteria };
|
||||
if (typeof normalized.email === 'string') {
|
||||
normalized.email = normalized.email.trim().toLowerCase();
|
||||
}
|
||||
if (Array.isArray(normalized.$or)) {
|
||||
normalized.$or = normalized.$or.map((condition) => {
|
||||
if (typeof condition.email === 'string') {
|
||||
return { ...condition, email: condition.email.trim().toLowerCase() };
|
||||
}
|
||||
return condition;
|
||||
});
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for a single user based on partial data and return matching user document as plain object.
|
||||
* Email fields in searchCriteria are automatically normalized to lowercase for case-insensitive matching.
|
||||
*/
|
||||
async function findUser(
|
||||
searchCriteria: FilterQuery<IUser>,
|
||||
fieldsToSelect?: string | string[] | null,
|
||||
): Promise<IUser | null> {
|
||||
const User = mongoose.models.User;
|
||||
const query = User.findOne(searchCriteria);
|
||||
const normalizedCriteria = normalizeEmailInCriteria(searchCriteria);
|
||||
const query = User.findOne(normalizedCriteria);
|
||||
if (fieldsToSelect) {
|
||||
query.select(fieldsToSelect);
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ export interface TokenUpdateData {
|
||||
identifier?: string;
|
||||
token?: string;
|
||||
expiresAt?: Date;
|
||||
expiresIn?: number;
|
||||
metadata?: Map<string, unknown>;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user